折腾笔记[51]-基于ViT的mnist手写数字识别训练及推理
摘要
本文基于 Vision Transformer (ViT) 架构,在 MNIST 数据集上完成图像分类任务的端到端实践。通过 Patch Embedding、可学习位置编码、CLS Token、Multi-Head Self-Attention 与 MLP 分类头的模块化构建,配合 AdamW + CosineAnnealingLR 训练策略,在 PyTorch框架的CPU/GPU 环境下完成 20 轮训练。本方案的核心价值在于验证 ViT 从零构建、训练、推理的完整工程链路。
声明
本文人类为第一作者, 龙虾为通讯作者.本文有AI生成内容.
简介
ViT简介
[https://zhuanlan.zhihu.com/p/445122996]
[https://github.com/google-research/vision_transformer]
[https://arxiv.org/abs/2010.11929]
[https://arxiv.org/abs/2105.01601]
[https://arxiv.org/abs/2106.10270]
[https://arxiv.org/abs/2106.01548]
[https://arxiv.org/abs/2111.07991]
[https://arxiv.org/abs/2203.08065]
一定要数据集很大,预训练效果才能提升明显,不用预训练模型的权重那就是随机初始化权重,那些先验知识啥都没有的
模型概述:我们将图像分割为固定大小的图像块(patches),对每个图像块进行线性嵌入,添加位置嵌入(position embeddings),然后将得到的向量序列输入到标准的 Transformer 编码器中。为了进行分类,我们采用标准做法——在序列中添加一个额外的可学习的"分类标记"(classification token)。
Overview of the model: we split an image into fixed-size patches, linearly embed each of them, add position embeddings, and feed the resulting sequence of vectors to a standard Transformer encoder. In order to perform classification, we use the standard approach of adding an extra learnable "classification token" to the sequence.
AdamW简介
[https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py]
[https://github.com/pytorch/pytorch/blob/main/torch/optim/adam.py]
AdamW 是一种优化算法,是 Adam 优化器的改进版本,专门用于解决权重衰减(weight decay)与梯度更新耦合的问题。
传统 Adam 在应用 L2 正则化(权重衰减)时,会将衰减项与梯度一起纳入自适应估计(一阶矩和二阶矩)中,导致权重衰减的效果被自适应学习率缩放所扭曲。AdamW 将权重衰减解耦到参数更新步骤之外,使其效果与 SGD 配合权重衰减时一致。
# mypy: allow-untyped-defs
from typing import cast
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_fused_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_get_value,
_maximize_doc,
_params_doc,
_stack_if_compiling,
_to_scalar,
_use_grad_for_differentiable,
_view_as_real,
DeviceDict,
DeviceDtypeDict,
Optimizer,
ParamsT,
)
__all__ = ["Adam", "adam"]
class Adam(Optimizer):
def __init__(
self,
params: ParamsT,
lr: float | Tensor = 1e-3,
betas: tuple[float | Tensor, float | Tensor] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
amsgrad: bool = False,
*,
foreach: bool | None = None,
maximize: bool = False,
capturable: bool = False,
differentiable: bool = False,
fused: bool | None = None,
decoupled_weight_decay: bool = False,
) -> None:
if isinstance(lr, Tensor):
if foreach and not capturable:
raise ValueError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
if lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not (
(isinstance(betas[0], float) and isinstance(betas[1], float))
or (isinstance(betas[0], Tensor) and isinstance(betas[1], Tensor))
):
raise ValueError("betas must be either both floats or both Tensors")
if isinstance(betas[0], Tensor):
if not capturable and foreach:
raise ValueError(
"betas[0] as a Tensor is not supported for capturable=False and foreach=True"
)
if betas[0].numel() != 1:
raise ValueError("Tensor betas[0] must be 1-element")
if isinstance(betas[1], Tensor):
if not capturable and foreach:
raise ValueError(
"betas[1] as a Tensor is not supported for capturable=False and foreach=True"
)
if betas[1].numel() != 1:
raise ValueError("Tensor betas[1] must be 1-element")
betas = tuple(map(_to_scalar, betas))
defaults = {
"lr": lr,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
"amsgrad": amsgrad,
"maximize": maximize,
"foreach": foreach,
"capturable": capturable,
"differentiable": differentiable,
"fused": fused,
"decoupled_weight_decay": decoupled_weight_decay,
}
super().__init__(params, defaults)
if fused:
if differentiable:
raise RuntimeError("`fused` does not support `differentiable`")
self._step_supports_amp_scaling = True
# TODO(crcrpar): [low prec params & their higher prec copy]
# Support AMP with FP16/BF16 model params which would need
# higher prec copy of params to do update math in higher prec to
# alleviate the loss of information.
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("amsgrad", False)
group.setdefault("maximize", False)
group.setdefault("foreach", None)
group.setdefault("capturable", False)
group.setdefault("differentiable", False)
group.setdefault("decoupled_weight_decay", False)
fused = group.setdefault("fused", None)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = (
torch.tensor(
step_val,
dtype=_get_scalar_dtype(is_fused=fused),
device=p.device,
)
if group["capturable"] or group["fused"]
else torch.tensor(step_val, dtype=_get_scalar_dtype())
)
def _init_group(
self,
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
):
has_complex = False
for p in group["params"]:
if p.grad is not None:
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
grads.append(p.grad)
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
if group["fused"]:
_device_dtype_check_for_fused(p)
# note(crcrpar): [special device hosting for step]
# Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = (
torch.zeros(
(),
dtype=_get_scalar_dtype(is_fused=group["fused"]),
device=p.device,
)
if group["capturable"] or group["fused"]
else torch.tensor(0.0, dtype=_get_scalar_dtype())
)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if group["amsgrad"]:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if group["amsgrad"]:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
if group["differentiable"] and state["step"].requires_grad:
raise RuntimeError(
"`requires_grad` is not supported for `step` in differentiable mode"
)
# Foreach without capturable does not support a tensor lr
if (
group["foreach"]
and torch.is_tensor(group["lr"])
and not group["capturable"]
):
raise RuntimeError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
state_steps.append(state["step"])
return has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._accelerator_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: list[Tensor] = []
grads: list[Tensor] = []
exp_avgs: list[Tensor] = []
exp_avg_sqs: list[Tensor] = []
max_exp_avg_sqs: list[Tensor] = []
state_steps: list[Tensor] = []
beta1, beta2 = group["betas"]
has_complex = self._init_group(
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
)
adam(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=group["amsgrad"],
has_complex=has_complex,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=group["maximize"],
foreach=group["foreach"],
capturable=group["capturable"],
differentiable=group["differentiable"],
fused=group["fused"],
grad_scale=getattr(self, "grad_scale", None),
found_inf=getattr(self, "found_inf", None),
decoupled_weight_decay=group["decoupled_weight_decay"],
)
return loss
Adam.__doc__ = (
r"""Implements Adam algorithm.
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
&\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},
\:\textit{maximize}, \: \epsilon \text{ (epsilon)} \\
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
v_0\leftarrow 0 \text{ (second moment)},\: v_0^{max}\leftarrow 0 \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
&\hspace{5mm}\textbf{if} \: amsgrad \\
&\hspace{10mm} v_t^{max} \leftarrow \mathrm{max}(v_{t-1}^{max},v_t) \\
&\hspace{10mm}\widehat{v_t} \leftarrow v_t^{max}/\big(1-\beta_2^t \big) \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
"""
+ rf"""
Args:
{_params_doc}
lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
is not yet supported for all our implementations. Please use a float
LR if you are not also specifying fused=True or capturable=True.
betas (tuple[float | Tensor, float | Tensor], optional):
coefficients used for computing running averages of gradient and
its square. If a tensor is provided, must be 1-element. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
decoupled_weight_decay (bool, optional): if True, this optimizer is
equivalent to AdamW and the algorithm will not accumulate weight
decay in the momentum nor variance. (default: False)
amsgrad (bool, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
{_foreach_doc}
{_maximize_doc}
{_capturable_doc}
{_differentiable_doc}
{_fused_doc}
.. Note::
A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
)
def _single_tensor_adam(
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
max_exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
grad_scale: Tensor | None,
found_inf: Tensor | None,
*,
amsgrad: bool,
has_complex: bool,
beta1: float | Tensor,
beta2: float | Tensor,
lr: float | Tensor,
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool,
differentiable: bool,
decoupled_weight_decay: bool,
) -> None:
if grad_scale is not None or found_inf is not None:
raise AssertionError("Expected grad_scale and found_inf to be None")
if torch.jit.is_scripting():
# this assert is due to JIT being dumb and not realizing that the ops below
# have overloads to handle both float and Tensor lrs, so we just assert it's
# a float since most people using JIT are using floats
if not isinstance(lr, float):
raise AssertionError(f"Expected lr to be a float, but got {type(lr)}")
if not isinstance(beta1, float):
raise AssertionError(f"Expected beta1 to be a float, but got {type(beta1)}")
if not isinstance(beta2, float):
raise AssertionError(f"Expected beta2 to be a float, but got {type(beta2)}")
else:
lr = _to_scalar(lr)
beta1 = _to_scalar(beta1)
beta2 = _to_scalar(beta2)
# We only shuffle around the beta when it is a Tensor, otherwise, we prefer
# treating it as a scalar.
# Note: ensure type declaration is under conditional check for isinstance
# or else torchscript will get cranky about the DeviceDict type.
if isinstance(beta1, Tensor):
beta1_dict: DeviceDtypeDict | None = {(beta1.device, beta1.dtype): beta1}
else:
beta1_dict = None
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
if not (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
):
raise AssertionError(
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
# update step
step_t += 1
if weight_decay != 0:
if decoupled_weight_decay:
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
else:
# Nested if is necessary to bypass jitscript rules
if differentiable and isinstance(weight_decay, Tensor):
if weight_decay.requires_grad:
grad = grad.addcmul_(param.clone(), weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
if torch.is_complex(param):
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
exp_avg_sq = torch.view_as_real(exp_avg_sq)
if amsgrad:
max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i])
param = torch.view_as_real(param)
device = param.device
if beta1_dict is not None:
dtype = param.dtype # type: ignore[union-attr]
# cast to workaround https://github.com/pytorch/pytorch/issues/140601
key = (device, dtype)
if key not in beta1_dict:
beta1_dict[key] = beta1.to( # type: ignore[union-attr]
device=device, dtype=dtype, non_blocking=True
)
device_beta1: float | Tensor = beta1_dict[key]
else:
device_beta1 = beta1
# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - device_beta1)
# Nested if is necessary to bypass jitscript rules
if differentiable and isinstance(beta2, Tensor):
if beta2.requires_grad:
# Using lerp to only use 2 operations bc addcmul's value cannot be a tensor
# Showing equivalence of differentiable path and nondifferentiable path
# expavg * b2 + grad^2 * (1-b2)
# add expavg * (1-b2) - expavg * (1-b2) = 0
# expavg * b2 + expavg * (1-b2) - expavg * (1-b2) + grad^2 * (1-b2)
# expavg - expavg * (1-b2) + grad^2 * (1-b2)
# expavg + (grad^2 - expavg) * (1-b2)
# expavg.lerp(grad^2, 1-beta2)
exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2)
else:
exp_avg_sq.mul_(beta2).addcmul_(
grad, grad, value=cast(float, 1 - beta2)
)
else:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # type: ignore[arg-type]
if capturable or differentiable:
step = step_t
# Nested if is necessary to bypass jitscript rules
if differentiable and isinstance(beta1, Tensor):
if beta1.requires_grad:
bias_correction1 = 1 - beta1 ** step.clone()
else:
bias_correction1 = 1 - beta1**step
else:
bias_correction1 = 1 - beta1**step
# Nested if is necessary to bypass jitscript rules
if differentiable and isinstance(beta2, Tensor):
if beta2.requires_grad:
bias_correction2 = 1 - beta2 ** step.clone()
else:
bias_correction2 = 1 - beta2**step
else:
bias_correction2 = 1 - beta2**step
step_size = lr / bias_correction1
step_size_neg = step_size.neg()
bias_correction2_sqrt = bias_correction2.sqrt()
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
if differentiable:
max_exp_avg_sq = max_exp_avg_sqs[i].clone()
else:
max_exp_avg_sq = max_exp_avg_sqs[i]
max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq))
# Uses the max. for normalizing running avg. of gradient
# Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
# (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
denom = (
max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
).add_(eps / step_size_neg)
else:
denom = (
exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
).add_(eps / step_size_neg)
if differentiable:
param.addcdiv_(exp_avg.clone(), denom)
else:
param.addcdiv_(exp_avg, denom)
else:
step = _get_value(step_t)
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
step_size = lr / bias_correction1
bias_correction2_sqrt = bias_correction2**0.5
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
else:
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
param.addcdiv_(exp_avg, denom, value=-step_size) # type: ignore[arg-type]
# Lastly, switch back to complex view
if amsgrad and torch.is_complex(params[i]):
max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i])
def _multi_tensor_adam(
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
max_exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
grad_scale: Tensor | None,
found_inf: Tensor | None,
*,
amsgrad: bool,
has_complex: bool,
beta1: float | Tensor,
beta2: float | Tensor,
lr: float | Tensor,
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool,
differentiable: bool,
decoupled_weight_decay: bool,
) -> None:
if len(params) == 0:
return
if isinstance(lr, Tensor):
if not capturable:
raise RuntimeError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
if lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if isinstance(beta1, Tensor):
if not capturable:
raise ValueError(
"beta1 as a Tensor is not supported for capturable=False and foreach=True"
)
if beta1.numel() != 1:
raise ValueError("Tensor beta1 must be 1-element")
if isinstance(beta2, Tensor):
if not capturable:
raise ValueError(
"beta2 as a Tensor is not supported for capturable=False and foreach=True"
)
if beta2.numel() != 1:
raise ValueError("Tensor beta2 must be 1-element")
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch.compiler.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
if not all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps, strict=True)
):
raise AssertionError(
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
if grad_scale is not None or found_inf is not None:
raise AssertionError("Expected grad_scale and found_inf to be None")
if differentiable:
raise AssertionError("_foreach ops don't support autograd")
lr = _to_scalar(lr)
beta1 = _to_scalar(beta1)
beta2 = _to_scalar(beta2)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
)
# We only shuffle around the beta when it is a Tensor and on CUDA, otherwise, we prefer
# treating it as a scalar.
beta1_dict: DeviceDict | None = ( # type: ignore[attr-defined]
{beta1.device: beta1}
if isinstance(beta1, Tensor) and str(beta1.device) != "cpu"
else None
)
for (
device_params_,
device_grads_,
device_exp_avgs_,
device_exp_avg_sqs_,
device_max_exp_avg_sqs_,
device_state_steps_,
), _ in grouped_tensors.values():
device_params = cast(list[Tensor], device_params_)
device_grads = cast(list[Tensor], device_grads_)
device_exp_avgs = cast(list[Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_)
device_state_steps = cast(list[Tensor], device_state_steps_)
device = device_params[0].device
if beta1_dict is not None and device not in beta1_dict:
beta1_dict[device] = beta1.to(device=device, non_blocking=True) # type: ignore[union-attr, attr-defined]
device_beta1 = beta1_dict[device] if beta1_dict else beta1
# Handle complex parameters
if has_complex:
if amsgrad:
device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
_view_as_real(
device_params,
device_grads,
device_exp_avgs,
device_exp_avg_sqs,
device_max_exp_avg_sqs,
)
else:
_view_as_real(
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
)
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(device_state_steps, 1)
if weight_decay != 0:
if decoupled_weight_decay:
# Perform stepweight decay
torch._foreach_mul_(device_params, 1 - lr * weight_decay)
else:
# Reuse the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
# Decay the first and second moment running average coefficient
# Use device beta1 if beta1 is a tensor to ensure all
# tensors are on the same device
torch._foreach_lerp_(
device_exp_avgs, device_grads, cast(float, 1 - device_beta1)
)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
# Due to the strictness of the _foreach_addcmul API, we can't have a single
# tensor scalar as the scalar arg (only python number is supported there)
# as a result, separate out the value mul
# Filed https://github.com/pytorch/pytorch/issues/139795
if isinstance(beta2, torch.Tensor):
scaled_device_grads = torch._foreach_mul(device_grads, 1 - beta2) # type: ignore[assignment]
value = 1.0
else:
scaled_device_grads = device_grads # type: ignore[assignment]
value = 1 - beta2
torch._foreach_addcmul_(
device_exp_avg_sqs, scaled_device_grads, device_grads, value
)
# Delete the local intermediate(s) since they won't be used anymore to save on peak memory
del device_grads
del scaled_device_grads
bias_correction1: tuple[Tensor, ...] | list[Tensor]
bias_correction2: tuple[Tensor, ...] | list[Tensor]
bias_correction2_sqrt: tuple[Tensor, ...] | list[Tensor]
if capturable:
bias_correction1 = torch._foreach_pow(beta1, device_state_steps) # type: ignore[arg-type]
bias_correction2 = torch._foreach_pow(beta2, device_state_steps) # type: ignore[arg-type]
# foreach_sub doesn't allow a scalar as the first arg
torch._foreach_sub_(bias_correction1, 1)
torch._foreach_sub_(bias_correction2, 1)
# we do not negate bias_correction1 as it'll need to be negated later anyway
torch._foreach_neg_(bias_correction2)
# foreach_div doesn't allow a scalar as the first arg
torch._foreach_div_(bias_correction1, lr)
torch._foreach_reciprocal_(bias_correction1)
torch._foreach_sqrt_(bias_correction2)
# Re-assign for clarity as we maintain minimal intermediates: we'll have
# step_size = - lr / (1 - beta1 ^ t) where t = num_steps
# bias_correction2_sqrt = sqrt(1 - beta2 ^ t)
step_size = bias_correction1
bias_correction2_sqrt = bias_correction2
if amsgrad:
device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
# Maintains the maximum of all 2nd moment running avg. till now
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment]
# Set intermediate to the max. for normalizing running avg. of gradient when amsgrad
exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
else:
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
torch._foreach_add_(exp_avg_sq_sqrt, eps)
torch._foreach_div_(exp_avg_sq_sqrt, step_size)
# at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr
torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt)
else:
bias_correction1 = [
1 - beta1 ** _get_value(step) for step in device_state_steps
]
bias_correction2 = [
1 - beta2 ** _get_value(step) for step in device_state_steps
]
step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2] # type: ignore[arg-type]
if amsgrad:
device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
# Maintains the maximum of all 2nd moment running avg. till now
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
# Use the max. for normalizing running avg. of gradient
exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
else:
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
torch._foreach_add_(exp_avg_sq_sqrt, eps)
torch._foreach_addcdiv_(
device_params,
device_exp_avgs,
exp_avg_sq_sqrt,
step_size, # type: ignore[arg-type]
)
def _fused_adam(
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
max_exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
grad_scale: Tensor | None,
found_inf: Tensor | None,
*,
amsgrad: bool,
has_complex: bool, # Needed for consistency.
beta1: float | Tensor,
beta2: float | Tensor,
lr: float | Tensor,
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool, # Needed for consistency.
differentiable: bool,
decoupled_weight_decay: bool,
) -> None:
if not params:
return
if differentiable:
raise RuntimeError("Adam with fused=True does not support differentiable=True")
beta1 = _to_scalar(beta1)
beta2 = _to_scalar(beta2)
grad_scale_dict: DeviceDict = (
{grad_scale.device: grad_scale} if grad_scale is not None else {}
)
found_inf_dict: DeviceDict = (
{found_inf.device: found_inf} if found_inf is not None else {}
)
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
# treating it as a scalar.
lr_dict: DeviceDict | None = (
{lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
)
for (device, _), (
(
device_params_,
device_grads_,
device_exp_avgs_,
device_exp_avg_sqs_,
device_max_exp_avg_sqs,
device_state_steps_,
),
_,
) in grouped_tensors.items():
device_params = cast(list[Tensor], device_params_)
device_grads = cast(list[Tensor], device_grads_)
device_exp_avgs = cast(list[Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_)
device_state_steps = cast(list[Tensor], device_state_steps_)
device_grad_scale, device_found_inf = None, None
if grad_scale is not None:
device_grad_scale = grad_scale_dict.setdefault(
device, grad_scale.to(device, non_blocking=True)
)
if found_inf is not None:
device_found_inf = found_inf_dict.setdefault(
device, found_inf.to(device, non_blocking=True)
)
if lr_dict is not None and device not in lr_dict:
lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr]
lr = lr_dict[device]
torch._foreach_add_(device_state_steps, 1)
func = torch._fused_adam_ if not decoupled_weight_decay else torch._fused_adamw_
# pyrefly: ignore [no-matching-overload]
func(
device_params,
device_grads,
device_exp_avgs,
device_exp_avg_sqs,
device_max_exp_avg_sqs, # type: ignore[arg-type]
device_state_steps,
amsgrad=amsgrad,
lr=lr, # type: ignore[arg-type]
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
grad_scale=device_grad_scale,
found_inf=device_found_inf,
)
if device_found_inf is not None:
torch._foreach_sub_(
device_state_steps, [device_found_inf] * len(device_state_steps)
)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam)
def adam(
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
max_exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: bool | None = None,
capturable: bool = False,
differentiable: bool = False,
fused: bool | None = None,
grad_scale: Tensor | None = None,
found_inf: Tensor | None = None,
has_complex: bool = False,
decoupled_weight_decay: bool = False,
*,
amsgrad: bool,
beta1: float | Tensor,
beta2: float | Tensor,
lr: float | Tensor,
weight_decay: float,
eps: float,
maximize: bool,
) -> None:
r"""Functional API that performs Adam algorithm computation.
See :class:`~torch.optim.Adam` for details.
"""
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if fused is None and foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
if foreach and isinstance(lr, Tensor) and not capturable:
foreach = False
if fused is None:
fused = False
if foreach is None:
foreach = False
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch.compiler.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if fused and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with fused optimizers")
if fused and not torch.jit.is_scripting():
func = _fused_adam
elif foreach and not torch.jit.is_scripting():
func = _multi_tensor_adam
else:
func = _single_tensor_adam
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
has_complex=has_complex,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
grad_scale=grad_scale,
found_inf=found_inf,
decoupled_weight_decay=decoupled_weight_decay,
)
# mypy: allow-untyped-defs
from torch import Tensor
from .adam import Adam, adam
from .optimizer import (
_capturable_doc,
_differentiable_doc,
_foreach_doc,
_fused_doc,
_maximize_doc,
_params_doc,
ParamsT,
)
__all__ = ["AdamW", "adamw"]
class AdamW(Adam):
def __init__(
self,
params: ParamsT,
lr: float | Tensor = 1e-3,
betas: tuple[float | Tensor, float | Tensor] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
amsgrad: bool = False,
*,
maximize: bool = False,
foreach: bool | None = None,
capturable: bool = False,
differentiable: bool = False,
fused: bool | None = None,
) -> None:
super().__init__(
params,
lr,
betas,
eps,
weight_decay,
amsgrad,
foreach=foreach,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
fused=fused,
decoupled_weight_decay=True,
)
# Preserve decoupled_weight_decay from AdamW for backwards compatibility. The following
# guarantees that decoupled_weight_decay will always be True for loading any state into
# AdamW
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group["decoupled_weight_decay"] = True
AdamW.__doc__ = (
r"""Implements AdamW algorithm, where weight decay does not accumulate in the momentum nor variance.
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
\text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
\: \epsilon \text{ (epsilon)} \\
&\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
\: \textit{maximize} \\
&\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
\text{ ( second moment)}, \: v_0^{max}\leftarrow 0 \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
&\hspace{5mm}\textbf{if} \: amsgrad \\
&\hspace{10mm} v_t^{max} \leftarrow \mathrm{max}(v_{t-1}^{max},v_t) \\
&\hspace{10mm}\widehat{v_t} \leftarrow v_t^{max}/\big(1-\beta_2^t \big) \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
&\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_.
"""
+ rf"""
Args:
{_params_doc}
lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
is not yet supported for all our implementations. Please use a float
LR if you are not also specifying fused=True or capturable=True.
betas (tuple[float | Tensor, float | Tensor], optional):
coefficients used for computing running averages of gradient and
its square. If a tensor is provided, must be 1-element. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
amsgrad (bool, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
{_maximize_doc}
{_foreach_doc}
{_capturable_doc}
{_differentiable_doc}
{_fused_doc}
.. Note::
A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`.
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
)
# @_disable_dynamo_if_unsupported logic occurs in the decorator that's applied to F.adam
def adamw(
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
max_exp_avg_sqs: list[Tensor],
state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: bool | None = None,
capturable: bool = False,
differentiable: bool = False,
fused: bool | None = None,
grad_scale: Tensor | None = None,
found_inf: Tensor | None = None,
has_complex: bool = False,
*,
amsgrad: bool,
beta1: float | Tensor,
beta2: float | Tensor,
lr: float | Tensor,
weight_decay: float,
eps: float,
maximize: bool,
) -> None:
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""
adam(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
foreach=foreach,
capturable=capturable,
differentiable=differentiable,
fused=fused,
grad_scale=grad_scale,
found_inf=found_inf,
has_complex=has_complex,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
decoupled_weight_decay=True,
)
CosineAnnealingLR简介
[https://github.com/pytorch/pytorch/blob/v2.11.0/torch/optim/lr_scheduler.py]
CosineAnnealingLR 是一种学习率调度策略,其学习率按照余弦函数曲线从初始值逐渐衰减到最小值,模拟模拟退火过程中的温度下降。
不同于阶梯式衰减(StepLR)或指数衰减(ExponentialLR),CosineAnnealingLR 让学习率平滑地周期性变化,在训练初期保持较大学习率以快速收敛,后期逐渐减小以精细优化损失曲面。
# mypy: allow-untyped-defs
r"""Learning Rate Scheduler."""
from __future__ import annotations
import math
import types
import warnings
from bisect import bisect_right
from collections import Counter
from functools import partial, wraps
from typing import Any, cast, Literal, SupportsFloat, TYPE_CHECKING, TypedDict
from typing_extensions import override, Self
from weakref import ref
from torch import inf, Tensor
from .optimizer import _to_scalar, Optimizer
if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Sequence
__all__ = [
"LambdaLR",
"MultiplicativeLR",
"StepLR",
"MultiStepLR",
"ConstantLR",
"LinearLR",
"ExponentialLR",
"SequentialLR",
"CosineAnnealingLR",
"ChainedScheduler",
"ReduceLROnPlateau",
"CyclicLR",
"CosineAnnealingWarmRestarts",
"OneCycleLR",
"PolynomialLR",
"LRScheduler",
]
EPOCH_DEPRECATION_WARNING = (
"The epoch parameter in `scheduler.step()` was not necessary and is being "
"deprecated where possible. Please use `scheduler.step()` to step the "
"scheduler. During the deprecation, if epoch is different from None, the "
"closed form is used instead of the new chainable form, where available. "
"Please open an issue if you are unable to replicate your use case: "
"https://github.com/pytorch/pytorch/issues/new/choose."
)
def _format_param(name: str, optimizer: Optimizer, param):
"""Return correctly formatted lr/momentum for each param group."""
def _copy(_param):
return _param.clone() if isinstance(_param, Tensor) else _param
if isinstance(param, (list, tuple)):
if len(param) != len(optimizer.param_groups):
raise ValueError(
f"{name} must have the same length as optimizer.param_groups. "
f"{name} has {len(param)} values, param_groups has {len(optimizer.param_groups)}."
)
else:
param = [param] * len(optimizer.param_groups)
return list(map(_copy, param))
def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
"""Create a list containing group[key] for each optimizer param_group.
Prevents aliasing when group[key] could be a Tensor.
Raises a KeyError when group[key] does not exist.
"""
return [
group[key].clone() if isinstance(group[key], Tensor) else group[key]
for group in optimizer.param_groups
]
def _update_param_group_val(
param_group: dict[str, Any], key: str, val: float | Tensor
) -> None:
"""Set param_group[key] to val without aliasing or assignment when they're
both tensors. Raises a KeyError if param_group[key] does not exist.
"""
if isinstance(param_group[key], Tensor):
param_group[key].fill_(_to_scalar(val))
else:
param_group[key] = val
class LRScheduler:
r"""Base class for all learning rate schedulers.
Subclasses implement :meth:`get_lr` and optionally override :meth:`step` to
define scheduling behavior.
Args:
optimizer (Optimizer): The optimizer this scheduler will adjust the
learning rates of.
last_epoch (int): Index of the last epoch seen by the scheduler. Use
``-1`` (default) to initialize the scheduler. Only use a non-default
value when restoring this scheduler from a saved checkpoint.
.. warning::
Initializing a scheduler overwrites its optimizer's
``param_group["lr"]``\s. When restoring a checkpoint, initialize the
scheduler **before** calling your optimizer's
:meth:`~torch.optim.Optimizer.load_state_dict` to avoid overwriting the
loaded learning rates.
"""
_get_lr_called_within_step: bool = False
_is_initial: bool = False
def __init__(
self,
optimizer: Optimizer,
last_epoch: int = -1,
) -> None: # noqa: D107
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
self.optimizer = optimizer
# Initialize epoch and base learning rates
if last_epoch == -1:
for group in optimizer.param_groups:
initial_lr = group["lr"]
if isinstance(initial_lr, Tensor):
initial_lr = initial_lr.clone()
group.setdefault("initial_lr", initial_lr)
else:
for i, group in enumerate(optimizer.param_groups):
if "initial_lr" not in group:
raise KeyError(
f"param 'initial_lr' is not specified in param_groups[{i}] when resuming scheduler with last_epoch >= 0.\n"
"This typically happens when:\n"
"1. You're trying to resume training from a checkpoint but haven't properly loaded the optimizer state\n"
"2. You're using last_epoch >= 0 for a fresh training run (not recommended)"
)
self.base_lrs: list[float | Tensor] = _param_groups_val_list(
optimizer, "initial_lr"
)
self.last_epoch = last_epoch
# Following https://github.com/pytorch/pytorch/issues/20124
# We would like to ensure that `lr_scheduler.step()` is called after
# `optimizer.step()`
def patch_track_step_called(opt: Optimizer):
if hasattr(opt.step, "_wrapped_by_lr_sched"):
# we've already patched
return opt.step
def wrap_step(step_fn):
opt_ref = ref(self.optimizer)
func = step_fn.__func__
@wraps(func)
def wrapper(*args, **kwargs):
opt = opt_ref()
opt._opt_called = True # type: ignore[union-attr]
return func.__get__(opt, opt.__class__)(*args, **kwargs)
wrapper._wrapped_by_lr_sched = True # type: ignore[attr-defined]
return wrapper
opt.step = wrap_step(opt.step) # type: ignore[method-assign]
patch_track_step_called(self.optimizer)
self._initial_step()
def _initial_step(self) -> None:
"""Initialize step counts and perform a step."""
self._step_count = 0
with _initial_mode(self):
self.step()
def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer.
"""
return {
key: value for key, value in self.__dict__.items() if key != "optimizer"
}
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load the scheduler's state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_last_lr(self) -> list[float | Tensor]:
r"""Get the most recent learning rates computed by this scheduler.
Returns:
list[float | Tensor]: A :class:`list` of learning rates with entries
for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`, with the same types as
their ``group["lr"]``\s.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
# We always update self._last_lr with _param_groups_val_list, so it's a
# .clone() of the group["lr"]s. If we didn't do this, the user could
# corrupt their learning rates by modifying the outputs in place.
return self._last_lr
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
raise NotImplementedError
def step(self, epoch: int | None = None) -> None:
"""Step the scheduler.
Args:
epoch (int, optional):
.. deprecated:: 1.4
If provided, sets :attr:`last_epoch` to ``epoch`` and uses
:meth:`_get_closed_form_lr` if it is available. This is not
universally supported. Use :meth:`step` without arguments
instead.
.. note::
Call this method after calling the optimizer's
:meth:`~torch.optim.Optimizer.step`.
"""
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_wrapped_by_lr_sched"):
warnings.warn(
"Seems like `optimizer.step()` has been overridden after learning rate scheduler "
"initialization. Please, make sure to call `optimizer.step()` before "
"`lr_scheduler.step()`. See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
UserWarning,
stacklevel=2,
)
# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif not getattr(self.optimizer, "_opt_called", False):
warnings.warn(
"Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
"will result in PyTorch skipping the first value of the learning rate schedule. "
"See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
UserWarning,
stacklevel=2,
)
self._step_count += 1
if epoch is not None:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning, stacklevel=2)
self._update_lr(epoch)
def _update_lr(self, epoch: int | None = None) -> None:
with _enable_get_lr_call(self):
if epoch is None:
self.last_epoch += 1
values = self.get_lr()
else:
self.last_epoch = epoch
if hasattr(self, "_get_closed_form_lr"):
values = cast(list[float | Tensor], self._get_closed_form_lr())
else:
values = self.get_lr()
for param_group, lr in zip(self.optimizer.param_groups, values, strict=True):
_update_param_group_val(param_group, "lr", lr)
self._last_lr: list[float | Tensor] = _param_groups_val_list(
self.optimizer, "lr"
)
def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler) -> None:
if not lr_scheduler._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.",
UserWarning,
stacklevel=2,
)
# Including _LRScheduler for backwards compatibility
# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
class _LRScheduler(LRScheduler):
pass
class _enable_get_lr_call:
def __init__(self, o: LRScheduler) -> None:
self.o = o
def __enter__(self) -> Self:
self.o._get_lr_called_within_step = True
return self
def __exit__(self, type, value, traceback) -> None:
self.o._get_lr_called_within_step = False
class _initial_mode:
def __init__(self, o: LRScheduler) -> None:
self.o = o
def __enter__(self):
self.o._is_initial = True
def __exit__(self, type, value, traceback):
self.o._is_initial = False
class LambdaLR(LRScheduler):
"""Sets the initial learning rate.
The learning rate of each parameter group is set to the initial lr
times a given function. When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
lr_lambda (function or list): A function which computes a multiplicative
factor given an integer parameter epoch, or a list of such
functions, one for each group in optimizer.param_groups.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer has two groups.
>>> num_epochs = 100
>>> lambda1 = lambda epoch: epoch // 30
>>> lambda2 = lambda epoch: 0.95**epoch
>>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
>>> for epoch in range(num_epochs):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
>>>
>>> # Alternatively, you can use a single lambda function for all groups.
>>> scheduler = LambdaLR(opt, lr_lambda=lambda epoch: epoch // 30)
>>> for epoch in range(num_epochs):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/LambdaLR.png
"""
def __init__(
self,
optimizer: Optimizer,
lr_lambda: Callable[[int], float] | list[Callable[[int], float]],
last_epoch: int = -1,
) -> None: # noqa: D107
self.optimizer = optimizer
self.lr_lambdas: list[Callable[[int], float]]
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
else:
if len(lr_lambda) != len(optimizer.param_groups):
raise ValueError(
f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}"
)
self.lr_lambdas = list(lr_lambda)
super().__init__(optimizer, last_epoch)
@override
def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in ``self.__dict__`` which is not the optimizer.
The learning rate lambda functions will only be saved if they are callable objects
and not if they are functions or lambdas.
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
"""
state_dict = {
key: value
for key, value in self.__dict__.items()
if key not in ("optimizer", "lr_lambdas")
}
state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas)
for idx, fn in enumerate(self.lr_lambdas):
if not isinstance(fn, types.FunctionType):
# pyrefly: ignore [unsupported-operation]
state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
return state_dict
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load the scheduler's state.
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
lr_lambdas = state_dict.pop("lr_lambdas")
self.__dict__.update(state_dict)
# Restore state_dict keys in order to prevent side effects
# https://github.com/pytorch/pytorch/issues/32756
state_dict["lr_lambdas"] = lr_lambdas
for idx, fn in enumerate(lr_lambdas):
if fn is not None:
self.lr_lambdas[idx].__dict__.update(fn)
@override
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the :attr:`base_lrs` by the outputs of the :attr:`lr_lambdas` at
:attr:`last_epoch`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
return [
base_lr * lmbda(self.last_epoch)
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs, strict=True)
]
class MultiplicativeLR(LRScheduler):
"""Multiply the learning rate of each parameter group by the factor given in the specified function.
When last_epoch=-1, set initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
lr_lambda (function or list): A function which computes a multiplicative
factor given an integer parameter epoch, or a list of such
functions, one for each group in optimizer.param_groups.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> # xdoctest: +SKIP
>>> lmbda = lambda epoch: 0.95
>>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/MultiplicativeLR.png
"""
def __init__(
self,
optimizer: Optimizer,
lr_lambda: Callable[[int], float] | list[Callable[[int], float]],
last_epoch: int = -1,
) -> None: # noqa: D107
self.optimizer = optimizer
self.lr_lambdas: list[Callable[[int], float]]
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
else:
if len(lr_lambda) != len(optimizer.param_groups):
raise ValueError(
f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}"
)
self.lr_lambdas = list(lr_lambda)
for lr_lambda in self.lr_lambdas:
if not callable(lr_lambda):
raise TypeError(
f"lr_lambda should be a function, but got {type(lr_lambda).__name__}"
)
super().__init__(optimizer, last_epoch)
@override
def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer.
The learning rate lambda functions will only be saved if they are callable objects
and not if they are functions or lambdas.
"""
state_dict = {
key: value
for key, value in self.__dict__.items()
if key not in ("optimizer", "lr_lambdas")
}
state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas)
for idx, fn in enumerate(self.lr_lambdas):
if not isinstance(fn, types.FunctionType):
# pyrefly: ignore [unsupported-operation]
state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
return state_dict
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load the scheduler's state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
lr_lambdas = state_dict.pop("lr_lambdas")
self.__dict__.update(state_dict)
# Restore state_dict keys in order to prevent side effects
# https://github.com/pytorch/pytorch/issues/32756
state_dict["lr_lambdas"] = lr_lambdas
for idx, fn in enumerate(lr_lambdas):
if fn is not None:
self.lr_lambdas[idx].__dict__.update(fn)
@override
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the current ``group["lr"]``\s in each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` by the outputs of the
:attr:`lr_lambdas` at :attr:`last_epoch`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
if not self._is_initial:
return [
group["lr"] * lmbda(self.last_epoch)
for lmbda, group in zip(
self.lr_lambdas, self.optimizer.param_groups, strict=True
)
]
else:
return _param_groups_val_list(self.optimizer, "lr")
class StepLR(LRScheduler):
"""Decays the learning rate of each parameter group by gamma every step_size epochs.
Notice that such decay can happen simultaneously with other changes to the learning rate
from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
step_size (int): Period of learning rate decay.
gamma (float): Multiplicative factor of learning rate decay.
Default: 0.1.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 60
>>> # lr = 0.0005 if 60 <= epoch < 90
>>> # ...
>>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/StepLR.png
"""
def __init__(
self,
optimizer: Optimizer,
step_size: int,
gamma: float = 0.1,
last_epoch: int = -1,
) -> None: # noqa: D107
self.step_size = step_size
self.gamma = gamma
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
If the current epoch is a non-zero multiple of :attr:`step_size`, we
scale the current ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
return _param_groups_val_list(self.optimizer, "lr")
return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
def _get_closed_form_lr(self) -> list[float | Tensor]:
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [
base_lr * self.gamma ** (self.last_epoch // self.step_size)
for base_lr in self.base_lrs
]
class MultiStepLR(LRScheduler):
"""Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones.
Notice that such decay can happen simultaneously with other changes to the learning rate
from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
milestones (list): List of epoch indices. Must be increasing.
gamma (float): Multiplicative factor of learning rate decay.
Default: 0.1.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch < 30
>>> # lr = 0.005 if 30 <= epoch < 80
>>> # lr = 0.0005 if epoch >= 80
>>> scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/MultiStepLR.png
"""
def __init__(
self,
optimizer: Optimizer,
milestones: Iterable[int],
gamma: float = 0.1,
last_epoch: int = -1,
) -> None: # noqa: D107
self.milestones = Counter(milestones)
self.gamma = gamma
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
If the current epoch is in :attr:`milestones`, decays the
``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
.. note::
If the current epoch appears in :attr:`milestones` ``n`` times, we
scale by :attr:`gamma` to the power of ``n``
"""
_warn_get_lr_called_within_step(self)
if self.last_epoch not in self.milestones:
return _param_groups_val_list(self.optimizer, "lr")
return [
group["lr"] * self.gamma ** self.milestones[self.last_epoch]
for group in self.optimizer.param_groups
]
def _get_closed_form_lr(self):
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
milestones = sorted(self.milestones.elements())
return [
base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
for base_lr in self.base_lrs
]
class ConstantLR(LRScheduler):
"""Multiply the learning rate of each parameter group by a small constant factor.
The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters.
Notice that such multiplication of the small constant factor can
happen simultaneously with other changes to the learning rate from outside this scheduler.
When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
total_iters (int): The number of steps that the scheduler multiplies the learning rate by the factor.
Default: 5.
last_epoch (int): The index of the last epoch. Default: -1.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.025 if epoch == 0
>>> # lr = 0.025 if epoch == 1
>>> # lr = 0.025 if epoch == 2
>>> # lr = 0.025 if epoch == 3
>>> # ...
>>> # lr = 0.05 if epoch >= 40
>>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=40)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/ConstantLR.png
"""
def __init__(
self,
optimizer: Optimizer,
factor: float = 1.0 / 3,
total_iters: int = 5,
last_epoch: int = -1,
) -> None: # noqa: D107
if factor > 1.0 or factor < 0:
raise ValueError(
"Constant multiplicative factor expected to be between 0 and 1."
)
self.factor = factor
self.total_iters = total_iters
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
When :attr:`last_epoch` is 0, this method scales the ``group["lr"]``\s
in each of the optimizer's :attr:`~torch.optim.Optimizer.param_groups`
by :attr:`factor`. Once :attr:`total_iters` is reached, it undoes this,
scaling by ``1 / factor``.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
if self.last_epoch == 0:
return [group["lr"] * self.factor for group in self.optimizer.param_groups]
if self.last_epoch != self.total_iters:
return _param_groups_val_list(self.optimizer, "lr")
return [
group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups
]
def _get_closed_form_lr(self):
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [
base_lr
* (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
for base_lr in self.base_lrs
]
class LinearLR(LRScheduler):
"""Decays the learning rate of each parameter group by linearly changing small multiplicative factor.
The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters.
Notice that such decay can happen simultaneously with other changes to the learning rate
from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
start_factor (float): The number we multiply learning rate in the first epoch.
The multiplication factor changes towards end_factor in the following epochs.
Default: 1./3.
end_factor (float): The number we multiply learning rate at the end of linear changing
process. Default: 1.0.
total_iters (int): The number of iterations that multiplicative factor reaches to 1.
Default: 5.
last_epoch (int): The index of the last epoch. Default: -1.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.003687 if epoch == 0
>>> # lr = 0.004875 if epoch == 1
>>> # lr = 0.006062 if epoch == 2
>>> # lr = 0.00725 if epoch == 3
>>> # ...
>>> # lr = 0.05 if epoch >= 40
>>> scheduler = LinearLR(optimizer, start_factor=0.05, total_iters=40)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/LinearLR.png
"""
def __init__(
self,
optimizer: Optimizer,
start_factor: float = 1.0 / 3,
end_factor: float = 1.0,
total_iters: int = 5,
last_epoch: int = -1,
) -> None: # noqa: D107
if start_factor > 1.0 or start_factor <= 0:
raise ValueError(
"Starting multiplicative factor expected to be greater than 0 and less or equal to 1."
)
if end_factor > 1.0 or end_factor < 0:
raise ValueError(
"Ending multiplicative factor expected to be between 0 and 1."
)
self.start_factor = start_factor
self.end_factor = end_factor
self.total_iters = total_iters
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` such that successive steps
interpolate linearly from :attr:`start_factor` up to :attr:`end_factor`
across :attr:`total_iters` steps.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
if self.last_epoch == 0:
return [
group["lr"] * self.start_factor for group in self.optimizer.param_groups
]
if self._is_initial or self.last_epoch > self.total_iters:
return _param_groups_val_list(self.optimizer, "lr")
return [
group["lr"]
* (
1.0
+ (self.end_factor - self.start_factor)
/ (
self.total_iters * self.start_factor
+ (self.last_epoch - 1) * (self.end_factor - self.start_factor)
)
)
for group in self.optimizer.param_groups
]
def _get_closed_form_lr(self):
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [
base_lr
* (
self.start_factor
+ (self.end_factor - self.start_factor)
* min(self.total_iters, self.last_epoch)
/ self.total_iters
)
for base_lr in self.base_lrs
]
class ExponentialLR(LRScheduler):
"""Decays the learning rate of each parameter group by gamma every epoch.
When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
gamma (float): Multiplicative factor of learning rate decay.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> # xdoctest: +SKIP
>>> scheduler = ExponentialLR(optimizer, gamma=0.95)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/ExponentialLR.png
"""
def __init__(
self,
optimizer: Optimizer,
gamma: float,
last_epoch: int = -1,
) -> None: # noqa: D107
self.gamma = gamma
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Multiplies the current ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` by :attr:`gamma`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
# when loading from a checkpoint, we don't want _initial_step (called from the constructor)
# to update the lr one more step ahead of itself.
if self._is_initial:
return _param_groups_val_list(self.optimizer, "lr")
return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
class SequentialLR(LRScheduler):
"""Contains a list of schedulers expected to be called sequentially during the optimization process.
Specifically, the schedulers will be called according to the milestone points, which should provide exact
intervals by which each scheduler should be called at a given epoch.
Args:
optimizer (Optimizer): Wrapped optimizer.
schedulers (list): List of chained schedulers.
milestones (list): List of integers that reflects milestone points.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.005 if epoch == 0
>>> # lr = 0.005 if epoch == 1
>>> # lr = 0.005 if epoch == 2
>>> # ...
>>> # lr = 0.05 if epoch == 20
>>> # lr = 0.045 if epoch == 21
>>> # lr = 0.0405 if epoch == 22
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20)
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
>>> scheduler = SequentialLR(
... optimizer,
... schedulers=[scheduler1, scheduler2],
... milestones=[20],
... )
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/SequentialLR.png
"""
def __init__(
self,
optimizer: Optimizer,
schedulers: list[LRScheduler],
milestones: list[int],
last_epoch: int = -1,
) -> None: # noqa: D107
if len(schedulers) < 1:
raise ValueError(
f"{self.__class__.__name__} expects at least one scheduler, but got no scheduler."
)
for scheduler_idx, scheduler in enumerate(schedulers):
if not hasattr(scheduler, "optimizer"):
raise TypeError(
f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute."
)
if isinstance(scheduler, ReduceLROnPlateau):
raise ValueError(
f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it "
"requires additional kwargs to be specified when calling `step`, "
f"but got one at index {scheduler_idx} in the given schedulers sequence."
)
if optimizer != scheduler.optimizer:
raise ValueError(
f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but "
f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, "
f"which is different from {optimizer.__class__.__name__}."
)
if len(milestones) != len(schedulers) - 1:
raise ValueError(
"Sequential Schedulers expects number of schedulers provided to be one more "
f"than the number of milestone points, but got number of schedulers {len(schedulers)} and the "
f"number of milestones to be equal to {len(milestones)}"
)
self._schedulers = schedulers
self._milestones = milestones
self.last_epoch = last_epoch + 1
self.optimizer = optimizer
# Reset learning rates back to initial values
for group in self.optimizer.param_groups:
_update_param_group_val(group, "lr", group["initial_lr"])
# "Undo" the step performed by other schedulers
self.recursive_undo()
# Perform the initial step for only the first scheduler
self._schedulers[0]._initial_step()
self._last_lr = schedulers[0].get_last_lr()
def recursive_undo(self, sched=None) -> None:
"""
Recursively undo any step performed by the initialisation of
schedulers.
"""
scheds = self if sched is None else sched
if hasattr(scheds, "_schedulers"):
for s in scheds._schedulers:
self.recursive_undo(s)
elif hasattr(scheds, "last_epoch"):
scheds.last_epoch -= 1
def step(self) -> None: # type: ignore[override]
"""Perform a step."""
self.last_epoch += 1
idx = bisect_right(self._milestones, self.last_epoch)
scheduler = self._schedulers[idx]
if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
scheduler._update_lr(0)
else:
scheduler.step()
self._last_lr = scheduler.get_last_lr()
@override
def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer.
The wrapped scheduler states will also be saved.
"""
state_dict = {
key: value
for key, value in self.__dict__.items()
if key not in ("optimizer", "_schedulers")
}
state_dict["_schedulers"] = [None] * len(self._schedulers)
for idx, s in enumerate(self._schedulers):
# pyrefly: ignore [unsupported-operation]
state_dict["_schedulers"][idx] = s.state_dict()
return state_dict
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load the scheduler's state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
_schedulers = state_dict.pop("_schedulers")
self.__dict__.update(state_dict)
# Restore state_dict keys in order to prevent side effects
# https://github.com/pytorch/pytorch/issues/32756
state_dict["_schedulers"] = _schedulers
for idx, s in enumerate(_schedulers):
self._schedulers[idx].load_state_dict(s)
class PolynomialLR(LRScheduler):
"""Decays the learning rate of each parameter group using a polynomial function in the given total_iters.
When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5.
power (float): The power of the polynomial. Default: 1.0.
Example:
>>> # xdoctest: +SKIP("undefined vars")
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.0490 if epoch == 0
>>> # lr = 0.0481 if epoch == 1
>>> # lr = 0.0472 if epoch == 2
>>> # ...
>>> # lr = 0.0 if epoch >= 50
>>> scheduler = PolynomialLR(optimizer, total_iters=50, power=0.9)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/PolynomialLR.png
"""
def __init__(
self,
optimizer: Optimizer,
total_iters: int = 5,
power: float = 1.0,
last_epoch: int = -1,
) -> None: # noqa: D107
self.total_iters = total_iters
self.power = power
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` such that the learning rates
follow
.. math::
\texttt{base\_lr} \cdot \left(1 - \frac{\texttt{last\_epoch}}
{\texttt{total\_iters}} \right)^\texttt{power}
Returns the current learning rates unchanged after :attr:`total_iters`
is reached.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
if self._is_initial or self.last_epoch > self.total_iters:
return _param_groups_val_list(self.optimizer, "lr")
decay_factor = (
(1.0 - self.last_epoch / self.total_iters)
/ (1.0 - (self.last_epoch - 1) / self.total_iters)
) ** self.power
return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
def _get_closed_form_lr(self) -> list[float | Tensor]:
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [
(
base_lr
* (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters)
** self.power
)
for base_lr in self.base_lrs
]
class CosineAnnealingLR(LRScheduler):
r"""
Set the learning rate of each parameter group using a cosine annealing schedule.
The learning rate is updated recursively using:
.. math::
\eta_{t+1} = \eta_{\min} + (\eta_t - \eta_{\min}) \cdot
\frac{1 + \cos\left(\frac{(T_{cur}+1) \pi}{T_{max}}\right)}
{1 + \cos\left(\frac{T_{cur} \pi}{T_{max}}\right)}
This implements a recursive approximation of the closed-form schedule proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_:
.. math::
\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min}) \left(
1 + \cos\left(\frac{T_{cur} \pi}{T_{max}}\right) \right)
where:
- :math:`\eta_t` is the learning rate at step :math:`t`
- :math:`T_{cur}` is the number of epochs since the last restart
- :math:`T_{max}` is the maximum number of epochs in a cycle
Note:
Although SGDR includes periodic restarts, this implementation performs cosine annealing
**without restarts**, so :math:`T_{cur} = t` and increases monotonically with each call
to :meth:`step`.
Args:
optimizer (Optimizer): Wrapped optimizer.
T_max (int): Maximum number of iterations.
eta_min (float): Minimum learning rate. Default: 0.
last_epoch (int): The index of the last epoch. Default: -1.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
Example:
>>> # xdoctest: +SKIP
>>> num_epochs = 100
>>> scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
>>> for epoch in range(num_epochs):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/CosineAnnealingLR.png
"""
def __init__(
self,
optimizer: Optimizer,
T_max: int,
eta_min: float = 0.0,
last_epoch: int = -1,
) -> None: # noqa: D107
self.T_max = T_max
self.eta_min = eta_min
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Scales the ``group["lr"]``\s in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` such that their learning
rates approximate
.. math::
\texttt{eta\_min} + \frac{1}{2} (\texttt{base\_lr} -
\texttt{eta\_min}) \left(1 + \cos\left(\pi \cdot
\frac{\texttt{last\_epoch}}{\texttt{T\_max}}\right) \right)
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
if self._is_initial:
return _param_groups_val_list(self.optimizer, "lr")
elif self._step_count == 1 and self.last_epoch > 0:
return [
self.eta_min
+ (base_lr - self.eta_min)
* (1 + math.cos((self.last_epoch) * math.pi / self.T_max))
/ 2
for base_lr, group in zip(
self.base_lrs, self.optimizer.param_groups, strict=True
)
]
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
return [
group["lr"]
+ (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
for base_lr, group in zip(
self.base_lrs, self.optimizer.param_groups, strict=True
)
]
return [
(1 + math.cos(math.pi * self.last_epoch / self.T_max))
/ (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max))
* (group["lr"] - self.eta_min)
+ self.eta_min
for group in self.optimizer.param_groups
]
def _get_closed_form_lr(self) -> list[float | Tensor]:
r"""Compute learning rates for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` at :attr:`last_epoch` using
a closed-form formula.
Uses :attr:`base_lrs` to compute learning rates. This method is called
when an epoch is passed to :meth:`step`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
"""
return [
self.eta_min
+ (base_lr - self.eta_min)
* (1 + math.cos(math.pi * self.last_epoch / self.T_max))
/ 2
for base_lr in self.base_lrs
]
class ChainedScheduler(LRScheduler):
"""Chains a list of learning rate schedulers.
Takes in a sequence of chainable learning rate schedulers and calls their
step() functions consecutively in just one call to step().
Args:
schedulers (sequence): sequence of chained schedulers.
optimizer (Optimizer, optional): Wrapped optimizer. Default: None.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05 if epoch == 0
>>> # lr = 0.0450 if epoch == 1
>>> # lr = 0.0405 if epoch == 2
>>> # ...
>>> # lr = 0.00675 if epoch == 19
>>> # lr = 0.06078 if epoch == 20
>>> # lr = 0.05470 if epoch == 21
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20)
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
>>> scheduler = ChainedScheduler([scheduler1, scheduler2], optimizer=optimizer)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/ChainedScheduler.png
"""
def __init__(
self, schedulers: Sequence[LRScheduler], optimizer: Optimizer | None = None
) -> None: # noqa: D107
if len(schedulers) < 1:
raise ValueError(
f"{self.__class__.__name__} expects at least one scheduler to be chained, but got no scheduler."
)
optimizer = optimizer or schedulers[0].optimizer
for scheduler_idx, scheduler in enumerate(schedulers):
if not hasattr(scheduler, "optimizer"):
raise TypeError(
f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute."
)
if isinstance(scheduler, ReduceLROnPlateau):
raise ValueError(
f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it "
"requires additional kwargs to be specified when calling `step`, "
f"but got one at index {scheduler_idx} in the given schedulers sequence."
)
if optimizer != scheduler.optimizer:
raise ValueError(
f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but "
f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, "
f"which is different from {optimizer.__class__.__name__}."
)
self._schedulers = schedulers
self.optimizer = optimizer
self._last_lr = _param_groups_val_list(self._schedulers[-1].optimizer, "lr")
def step(self) -> None: # type: ignore[override]
"""Perform a step."""
for scheduler in self._schedulers:
scheduler.step()
self._last_lr = _param_groups_val_list(self._schedulers[-1].optimizer, "lr")
@override
def state_dict(self) -> dict[str, Any]:
"""Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer.
The wrapped scheduler states will also be saved.
"""
state_dict = {
key: value
for key, value in self.__dict__.items()
if key not in ("optimizer", "_schedulers")
}
state_dict["_schedulers"] = [None] * len(self._schedulers)
for idx, s in enumerate(self._schedulers):
# pyrefly: ignore [unsupported-operation]
state_dict["_schedulers"][idx] = s.state_dict()
return state_dict
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load the scheduler's state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
_schedulers = state_dict.pop("_schedulers")
self.__dict__.update(state_dict)
# Restore state_dict keys in order to prevent side effects
# https://github.com/pytorch/pytorch/issues/32756
state_dict["_schedulers"] = _schedulers
for idx, s in enumerate(_schedulers):
self._schedulers[idx].load_state_dict(s)
class ReduceLROnPlateau(LRScheduler):
"""Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning rate by a factor
of 2-10 once learning stagnates. This scheduler reads a metrics
quantity and if no improvement is seen for a 'patience' number
of epochs, the learning rate is reduced.
Args:
optimizer (Optimizer): Wrapped optimizer.
mode (str): One of `min`, `max`. In `min` mode, lr will
be reduced when the quantity monitored has stopped
decreasing; in `max` mode it will be reduced when the
quantity monitored has stopped increasing. Default: 'min'.
factor (float): Factor by which the learning rate will be
reduced. new_lr = lr * factor. Default: 0.1.
patience (int): The number of allowed epochs with no improvement after
which the learning rate will be reduced.
For example, consider the case of having no patience (`patience = 0`).
In the first epoch, a baseline is established and is always considered good as there's no previous baseline.
In the second epoch, if the performance is worse than the baseline,
we have what is considered an intolerable epoch.
Since the count of intolerable epochs (1) is greater than the patience level (0),
the learning rate is reduced at the end of this epoch.
From the third epoch onwards, the learning rate continues to be reduced at the end of each epoch
if the performance is worse than the baseline. If the performance improves or remains the same,
the learning rate is not adjusted.
Default: 10.
threshold (float): Threshold for measuring the new optimum,
to only focus on significant changes. Default: 1e-4.
threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
dynamic_threshold = best * ( 1 + threshold ) in 'max'
mode or best * ( 1 - threshold ) in `min` mode.
In `abs` mode, dynamic_threshold = best + threshold in
`max` mode or best - threshold in `min` mode. Default: 'rel'.
cooldown (int): Number of epochs to wait before resuming
normal operation after lr has been reduced. Default: 0.
min_lr (float or list): A scalar or a list of scalars. A
lower bound on the learning rate of all param groups
or each group respectively. Default: 0.
eps (float): Minimal decay applied to lr. If the difference
between new and old lr is smaller than eps, the update is
ignored. Default: 1e-8.
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = ReduceLROnPlateau(optimizer, "min")
>>> for epoch in range(10):
>>> train(...)
>>> val_loss = validate(...)
>>> # Note that step should be called after validate()
>>> scheduler.step(val_loss)
.. image:: ../scripts/lr_scheduler_images/ReduceLROnPlateau.png
"""
def __init__(
self,
optimizer: Optimizer,
mode: Literal["min", "max"] = "min",
factor: float = 0.1,
patience: int = 10,
threshold: float = 1e-4,
threshold_mode: Literal["rel", "abs"] = "rel",
cooldown: int = 0,
min_lr: list[float] | float = 0,
eps: float = 1e-8,
) -> None: # noqa: D107
if factor >= 1.0:
raise ValueError("Factor should be < 1.0.")
self.factor = factor
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
self.optimizer = optimizer
if isinstance(min_lr, (list, tuple)):
if len(min_lr) != len(optimizer.param_groups):
raise ValueError(
f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}"
)
self.default_min_lr = None
self.min_lrs = list(min_lr)
else:
# pyrefly: ignore [bad-assignment]
self.default_min_lr = min_lr
self.min_lrs = [min_lr] * len(optimizer.param_groups)
self.patience = patience
self.cooldown = cooldown
self.eps = eps
self.last_epoch = 0
self._last_lr = _param_groups_val_list(self.optimizer, "lr")
self._init_is_better(
mode=mode, threshold=threshold, threshold_mode=threshold_mode
)
self._reset()
def _reset(self) -> None:
"""Reset num_bad_epochs counter and cooldown counter."""
self.best = self.mode_worse
self.cooldown_counter = 0
self.num_bad_epochs = 0
def step(self, metrics: SupportsFloat, epoch=None) -> None: # type: ignore[override]
"""Perform a step."""
# convert `metrics` to float, in case it's a zero-dim Tensor
current = float(metrics)
if epoch is None:
epoch = self.last_epoch + 1
else:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning, stacklevel=2)
self.last_epoch = epoch
if self._is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1
if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
if self.num_bad_epochs > self.patience:
self._reduce_lr(epoch)
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
self._last_lr = _param_groups_val_list(self.optimizer, "lr")
def _reduce_lr(self, epoch) -> None:
if len(self.optimizer.param_groups) != len(self.min_lrs):
if self.default_min_lr is None:
raise RuntimeError(
"The number of param groups in the `optimizer` "
f"({len(self.optimizer.param_groups)}) differs "
f"from when `ReduceLROnPlateau` was initialized "
f"({len(self.min_lrs)}), usually due to a new "
"param group being added to the optimizer. Please "
"modify the `min_lrs` field to match the length "
"of the `optimizer` param groups."
)
else:
# pyrefly: ignore [bad-assignment]
self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups)
for i, param_group in enumerate(self.optimizer.param_groups):
old_lr = float(param_group["lr"])
new_lr = max(old_lr * self.factor, self.min_lrs[i])
if old_lr - new_lr > self.eps:
_update_param_group_val(param_group, "lr", new_lr)
@property
def in_cooldown(self): # noqa: D102
return self.cooldown_counter > 0
def _is_better(self, a, best): # noqa: D102
if self.mode == "min" and self.threshold_mode == "rel":
rel_epsilon = 1.0 - self.threshold
return a < best * rel_epsilon
elif self.mode == "min" and self.threshold_mode == "abs":
return a < best - self.threshold
elif self.mode == "max" and self.threshold_mode == "rel":
rel_epsilon = self.threshold + 1.0
return a > best * rel_epsilon
else: # mode == 'max' and epsilon_mode == 'abs':
return a > best + self.threshold
def _init_is_better(self, mode, threshold, threshold_mode) -> None:
if mode not in {"min", "max"}:
raise ValueError("mode " + mode + " is unknown!")
if threshold_mode not in {"rel", "abs"}:
raise ValueError("threshold mode " + threshold_mode + " is unknown!")
# the worse value for the chosen mode
if mode == "min":
self.mode_worse = inf
else: # mode == 'max':
self.mode_worse = -inf
self.mode = mode
self.threshold = threshold
self.threshold_mode = threshold_mode
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load the scheduler's state."""
self.__dict__.update(state_dict)
self._init_is_better(
mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode
)
class CyclicLR(LRScheduler):
r"""Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR).
The policy cycles the learning rate between two boundaries with a constant frequency,
as detailed in the paper `Cyclical Learning Rates for Training Neural Networks`_.
The distance between the two boundaries can be scaled on a per-iteration
or per-cycle basis.
Cyclical learning rate policy changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
This class has three built-in policies, as put forth in the paper:
* "triangular": A basic triangular cycle without amplitude scaling.
* "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
* "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
at each cycle iteration.
This implementation was adapted from the github repo: `bckenstler/CLR`_
Args:
optimizer (Optimizer): Wrapped optimizer.
base_lr (float or list): Initial learning rate which is the
lower boundary in the cycle for each parameter group.
max_lr (float or list): Upper learning rate boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (max_lr - base_lr).
The lr at any cycle is the sum of base_lr
and some scaling of the amplitude; therefore
max_lr may not actually be reached depending on
scaling function.
step_size_up (int): Number of training iterations in the
increasing half of a cycle. Default: 2000
step_size_down (int): Number of training iterations in the
decreasing half of a cycle. If step_size_down is None,
it is set to step_size_up. Default: None
mode (str): One of {triangular, triangular2, exp_range}.
Values correspond to policies detailed above.
If scale_fn is not None, this argument is ignored.
Default: 'triangular'
gamma (float): Constant in 'exp_range' scaling function:
gamma**(cycle iterations)
Default: 1.0
scale_fn (function): Custom scaling policy defined by a single
argument lambda function, where
0 <= scale_fn(x) <= 1 for all x >= 0.
If specified, then 'mode' is ignored.
Default: None
scale_mode (str): {'cycle', 'iterations'}.
Defines whether scale_fn is evaluated on
cycle number or cycle iterations (training
iterations since start of cycle).
Default: 'cycle'
cycle_momentum (bool): If ``True``, momentum is cycled inversely
to learning rate between 'base_momentum' and 'max_momentum'.
Default: True
base_momentum (float or list): Lower momentum boundaries in the cycle
for each parameter group. Note that momentum is cycled inversely
to learning rate; at the peak of a cycle, momentum is
'base_momentum' and learning rate is 'max_lr'.
Default: 0.8
max_momentum (float or list): Upper momentum boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (max_momentum - base_momentum).
The momentum at any cycle is the difference of max_momentum
and some scaling of the amplitude; therefore
base_momentum may not actually be reached depending on
scaling function. Note that momentum is cycled inversely
to learning rate; at the start of a cycle, momentum is 'max_momentum'
and learning rate is 'base_lr'
Default: 0.9
last_epoch (int): The index of the last batch. This parameter is used when
resuming a training job. Since `step()` should be invoked after each
batch instead of after each epoch, this number represents the total
number of *batches* computed, not the total number of epochs computed.
When last_epoch=-1, the schedule is started from the beginning.
Default: -1
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(
... optimizer,
... base_lr=0.01,
... max_lr=0.1,
... step_size_up=10,
... )
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/CyclicLR.png
.. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
.. _bckenstler/CLR: https://github.com/bckenstler/CLR
"""
def __init__(
self,
optimizer: Optimizer,
base_lr: float | list[float],
max_lr: float | list[float],
step_size_up: int = 2000,
step_size_down: int | None = None,
mode: Literal["triangular", "triangular2", "exp_range"] = "triangular",
gamma: float = 1.0,
scale_fn: Callable[[float], float] | None = None,
scale_mode: Literal["cycle", "iterations"] = "cycle",
cycle_momentum: bool = True,
base_momentum: float = 0.8,
max_momentum: float = 0.9,
last_epoch: int = -1,
) -> None: # noqa: D107
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
self.optimizer = optimizer
base_lrs = _format_param("base_lr", optimizer, base_lr)
if last_epoch == -1:
for lr, group in zip(base_lrs, optimizer.param_groups, strict=True):
_update_param_group_val(group, "lr", lr)
self.max_lrs = _format_param("max_lr", optimizer, max_lr)
# pyrefly: ignore [bad-assignment]
step_size_up = float(step_size_up)
step_size_down = (
# pyrefly: ignore [bad-assignment]
float(step_size_down) if step_size_down is not None else step_size_up
)
# pyrefly: ignore [unsupported-operation]
self.total_size = step_size_up + step_size_down
self.step_ratio = step_size_up / self.total_size
if mode not in ["triangular", "triangular2", "exp_range"] and scale_fn is None:
raise ValueError("mode is invalid and scale_fn is None")
self.mode = mode
self.gamma = gamma
self._scale_fn_ref: Callable[[float], float]
self._scale_fn_custom = scale_fn
self.scale_mode = scale_mode
self._init_scale_fn()
self.cycle_momentum = cycle_momentum
if cycle_momentum:
if (
"momentum" not in optimizer.defaults
and "betas" not in optimizer.defaults
):
raise ValueError(
"optimizer must support momentum or beta1 with `cycle_momentum` option enabled"
)
self.use_beta1 = "betas" in self.optimizer.defaults
self.base_momentums = _format_param(
"base_momentum", optimizer, base_momentum
)
self.max_momentums = _format_param("max_momentum", optimizer, max_momentum)
if last_epoch == -1:
for m_momentum, b_momentum, group in zip(
self.max_momentums,
self.base_momentums,
optimizer.param_groups,
strict=True,
):
if self.use_beta1:
group["betas"] = (m_momentum, *group["betas"][1:])
else:
group["momentum"] = m_momentum
group["max_momentum"] = m_momentum
group["base_momentum"] = b_momentum
super().__init__(optimizer, last_epoch)
self.base_lrs = base_lrs
def _init_scale_fn(self) -> None:
if self._scale_fn_custom is not None:
return
if self.mode == "triangular":
self._scale_fn_ref = self._triangular_scale_fn
self.scale_mode = "cycle"
elif self.mode == "triangular2":
self._scale_fn_ref = self._triangular2_scale_fn
self.scale_mode = "cycle"
elif self.mode == "exp_range":
self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma)
self.scale_mode = "iterations"
def scale_fn(self, x) -> float:
"""Get the scaling policy."""
if self._scale_fn_custom is not None:
return self._scale_fn_custom(x)
else:
return self._scale_fn_ref(x) # static method
@staticmethod
def _triangular_scale_fn(x: float) -> float:
return 1.0
@staticmethod
def _triangular2_scale_fn(x: float) -> float:
return 1 / (2.0 ** (x - 1))
@staticmethod
def _exp_range_scale_fn(gamma: float, x: float) -> float:
return gamma**x
@override
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Advances each ``group["lr"]`` in the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` along a cycle between the
group's ``base_lr`` and ``max_lr`` using :meth:`scale_fn`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
.. note::
This method treats :attr:`last_epoch` as the index of the previous
batch.
.. note::
When :attr:`cycle_momentum` is ``True``, this method has a side
effect of updating the optimizer's momentum.
"""
_warn_get_lr_called_within_step(self)
cycle = math.floor(1 + self.last_epoch / self.total_size)
x = 1.0 + self.last_epoch / self.total_size - cycle
if x <= self.step_ratio:
scale_factor = x / self.step_ratio
else:
scale_factor = (x - 1) / (self.step_ratio - 1)
lrs = []
for base_lr, max_lr in zip(self.base_lrs, self.max_lrs, strict=True):
base_height = (max_lr - base_lr) * scale_factor
if self.scale_mode == "cycle":
lr = base_lr + base_height * self.scale_fn(cycle)
else:
lr = base_lr + base_height * self.scale_fn(self.last_epoch)
lrs.append(lr)
if self.cycle_momentum:
momentums = []
for base_momentum, max_momentum in zip(
self.base_momentums, self.max_momentums, strict=True
):
base_height = (max_momentum - base_momentum) * scale_factor
if self.scale_mode == "cycle":
momentum = max_momentum - base_height * self.scale_fn(cycle)
else:
momentum = max_momentum - base_height * self.scale_fn(
self.last_epoch
)
momentums.append(momentum)
for param_group, momentum in zip(
self.optimizer.param_groups, momentums, strict=True
):
if self.use_beta1:
param_group["betas"] = (momentum, *param_group["betas"][1:])
else:
param_group["momentum"] = momentum
return lrs
@override
def state_dict(self) -> dict[str, Any]: # noqa: D102
"""Return the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in ``self.__dict__`` which
is not the optimizer.
The learning rate lambda functions will only be saved if they are callable objects
and not if they are functions or lambdas.
When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
"""
state = super().state_dict()
# We are dropping the `_scale_fn_ref` attribute because it is a
# `weakref.WeakMethod` and can't be pickled.
state.pop("_scale_fn_ref", None)
fn = state.pop("_scale_fn_custom")
state["_scale_fn_custom"] = None
if fn is not None and not isinstance(fn, types.FunctionType):
# The _scale_fn_custom will only be saved if it is a callable object
# and not if it is a function or lambda.
state["_scale_fn_custom"] = fn.__dict__.copy()
return state
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load the scheduler's state."""
fn = state_dict.pop("_scale_fn_custom")
super().load_state_dict(state_dict)
if fn is not None:
self._scale_fn_custom.__dict__.update(fn)
self._init_scale_fn()
class CosineAnnealingWarmRestarts(LRScheduler):
r"""Set the learning rate of each parameter group using a cosine annealing schedule.
The :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
is the number of epochs since the last restart and :math:`T_{i}` is the number
of epochs between two warm restarts in SGDR:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
It has been proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
Args:
optimizer (Optimizer): Wrapped optimizer.
T_0 (int): Number of iterations until the first restart.
T_mult (int, optional): A factor by which :math:`T_{i}` increases after a restart. Default: 1.
eta_min (float, optional): Minimum learning rate. Default: 0.
last_epoch (int, optional): The index of the last epoch. Default: -1.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
... optimizer, T_0=20
... )
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/CosineAnnealingWarmRestarts.png
"""
def __init__(
self,
optimizer: Optimizer,
T_0: int,
T_mult: int = 1,
eta_min: float = 0.0,
last_epoch: int = -1,
) -> None: # noqa: D107
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError(f"Expected positive integer T_0, but got {T_0}")
if T_mult < 1 or not isinstance(T_mult, int):
raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}")
if not isinstance(eta_min, (float, int)):
raise ValueError(
f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}"
)
self.T_0 = T_0
self.T_i = T_0
self.T_mult = T_mult
self.eta_min = eta_min
self.T_cur = last_epoch
super().__init__(optimizer, last_epoch)
@override
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Computes learning rates for the optimizer's
:attr:`~torch.optim.Optimizer.param_groups` following:
.. math::
\texttt{eta\_min} + \frac{1}{2}(\texttt{base\_lr} -
\texttt{eta\_min})\left(1 + \cos\left(\pi \cdot
\frac{\texttt{T\_cur}}{\texttt{T\_i}}\right)\right)
Where :attr:`T_cur` is the number of epochs since the last restart and
:attr:`T_i` is the number of epochs between two restarts. Both
:attr:`T_cur` and :attr:`T_i` are updated in :meth:`step`, and
:attr:`T_i` becomes :attr:`T_mult` times larger after each restart.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
"""
_warn_get_lr_called_within_step(self)
return [
self.eta_min
+ (base_lr - self.eta_min)
* (1 + math.cos(math.pi * self.T_cur / self.T_i))
/ 2
for base_lr in self.base_lrs
]
@override
def step(self, epoch=None) -> None:
"""Step could be called after every batch update.
Example:
>>> # xdoctest: +SKIP("Undefined vars")
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
>>> iters = len(dataloader)
>>> for epoch in range(20):
>>> for i, sample in enumerate(dataloader):
>>> inputs, labels = sample['inputs'], sample['labels']
>>> optimizer.zero_grad()
>>> outputs = net(inputs)
>>> loss = criterion(outputs, labels)
>>> loss.backward()
>>> optimizer.step()
>>> scheduler.step(epoch + i / iters)
This function can be called in an interleaved way.
Example:
>>> # xdoctest: +SKIP("Undefined vars")
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
>>> for epoch in range(20):
>>> scheduler.step()
>>> scheduler.step(26)
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
"""
if epoch is None and self.last_epoch < 0:
epoch = 0
if epoch is None:
epoch = self.last_epoch + 1
self.T_cur = self.T_cur + 1
if self.T_cur >= self.T_i:
self.T_cur = self.T_cur % self.T_i
self.T_i = self.T_i * self.T_mult
else:
if epoch < 0:
raise ValueError(f"Expected non-negative epoch, but got {epoch}")
if epoch >= self.T_0:
if self.T_mult == 1:
self.T_cur = epoch % self.T_0
else:
n = int(
math.log(
(epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult
)
)
self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (
self.T_mult - 1
)
self.T_i = self.T_0 * self.T_mult ** (n)
else:
self.T_i = self.T_0
self.T_cur = epoch
self.last_epoch = math.floor(epoch)
with _enable_get_lr_call(self):
for param_group, lr in zip(
self.optimizer.param_groups, self.get_lr(), strict=True
):
_update_param_group_val(param_group, "lr", lr)
self._last_lr = _param_groups_val_list(self.optimizer, "lr")
class _SchedulePhase(TypedDict):
end_step: float
start_lr: str
end_lr: str
start_momentum: str
end_momentum: str
class OneCycleLR(LRScheduler):
r"""Sets the learning rate of each parameter group according to the 1cycle learning rate policy.
The 1cycle policy anneals the learning rate from an initial learning rate to some maximum
learning rate and then from that maximum learning rate to some minimum learning rate much
lower than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
This scheduler is not chainable.
Note also that the total number of steps in the cycle can be determined in one
of two ways (listed in order of precedence):
#. A value for total_steps is explicitly provided.
#. A number of epochs (epochs) and a number of steps per epoch
(steps_per_epoch) are provided.
In this case, the number of total steps is inferred by
total_steps = epochs * steps_per_epoch
You must either provide a value for total_steps or provide a value for both
epochs and steps_per_epoch.
The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
claims that "unpublished work has shown even better results by using only two phases". To
mimic the behaviour of the original paper instead, set ``three_phase=True``.
Args:
optimizer (Optimizer): Wrapped optimizer.
max_lr (float or list): Upper learning rate boundaries in the cycle
for each parameter group.
total_steps (int): The total number of steps in the cycle. Note that
if a value is not provided here, then it must be inferred by providing
a value for epochs and steps_per_epoch.
Default: None
epochs (int): The number of epochs to train for. This is used along
with steps_per_epoch in order to infer the total number of steps in the cycle
if a value for total_steps is not provided.
Default: None
steps_per_epoch (int): The number of steps per epoch to train for. This is
used along with epochs in order to infer the total number of steps in the
cycle if a value for total_steps is not provided.
Default: None
pct_start (float): The percentage of the cycle (in number of steps) spent
increasing the learning rate.
Default: 0.3
anneal_strategy (str): {'cos', 'linear'}
Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
linear annealing.
Default: 'cos'
cycle_momentum (bool): If ``True``, momentum is cycled inversely
to learning rate between 'base_momentum' and 'max_momentum'.
Default: True
base_momentum (float or list): Lower momentum boundaries in the cycle
for each parameter group. Note that momentum is cycled inversely
to learning rate; at the peak of a cycle, momentum is
'base_momentum' and learning rate is 'max_lr'.
Default: 0.85
max_momentum (float or list): Upper momentum boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (max_momentum - base_momentum).
Note that momentum is cycled inversely
to learning rate; at the start of a cycle, momentum is 'max_momentum'
and learning rate is 'base_lr'
Default: 0.95
div_factor (float): Determines the initial learning rate via
initial_lr = max_lr/div_factor
Default: 25
final_div_factor (float): Determines the minimum learning rate via
min_lr = initial_lr/final_div_factor
Default: 1e4
three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
learning rate according to 'final_div_factor' instead of modifying the second
phase (the first two phases will be symmetrical about the step indicated by
'pct_start').
last_epoch (int): The index of the last batch. This parameter is used when
resuming a training job. Since `step()` should be invoked after each
batch instead of after each epoch, this number represents the total
number of *batches* computed, not the total number of epochs computed.
When last_epoch=-1, the schedule is started from the beginning.
Default: -1
Example:
>>> # xdoctest: +SKIP
>>> data_loader = torch.utils.data.DataLoader(...)
>>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(
... optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10
... )
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> optimizer.step()
>>> scheduler.step()
.. image:: ../scripts/lr_scheduler_images/OneCycleLR.png
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
https://arxiv.org/abs/1708.07120
"""
def __init__(
self,
optimizer: Optimizer,
max_lr: float | list[float],
total_steps: int | None = None,
epochs: int | None = None,
steps_per_epoch: int | None = None,
pct_start: float = 0.3,
anneal_strategy: Literal["cos", "linear"] = "cos",
cycle_momentum: bool = True,
base_momentum: float | list[float] = 0.85,
max_momentum: float | list[float] = 0.95,
div_factor: float = 25.0,
final_div_factor: float = 1e4,
three_phase: bool = False,
last_epoch: int = -1,
) -> None: # noqa: D107
# Validate optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
self.optimizer = optimizer
# Validate total_steps
if total_steps is not None:
if total_steps <= 0 or not isinstance(total_steps, int):
raise ValueError(
f"Expected positive integer total_steps, but got {total_steps}"
)
self.total_steps = total_steps
elif epochs is not None and steps_per_epoch is not None:
if not isinstance(epochs, int) or epochs <= 0:
raise ValueError(f"Expected positive integer epochs, but got {epochs}")
if not isinstance(steps_per_epoch, int) or steps_per_epoch <= 0:
raise ValueError(
f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}"
)
self.total_steps = epochs * steps_per_epoch
else:
raise ValueError(
"You must define either total_steps OR (epochs AND steps_per_epoch)"
)
self._schedule_phases: list[_SchedulePhase]
if three_phase:
self._schedule_phases = [
{
"end_step": float(pct_start * self.total_steps) - 1,
"start_lr": "initial_lr",
"end_lr": "max_lr",
"start_momentum": "max_momentum",
"end_momentum": "base_momentum",
},
{
"end_step": float(2 * pct_start * self.total_steps) - 2,
"start_lr": "max_lr",
"end_lr": "initial_lr",
"start_momentum": "base_momentum",
"end_momentum": "max_momentum",
},
{
"end_step": self.total_steps - 1,
"start_lr": "initial_lr",
"end_lr": "min_lr",
"start_momentum": "max_momentum",
"end_momentum": "max_momentum",
},
]
else:
self._schedule_phases = [
{
"end_step": float(pct_start * self.total_steps) - 1,
"start_lr": "initial_lr",
"end_lr": "max_lr",
"start_momentum": "max_momentum",
"end_momentum": "base_momentum",
},
{
"end_step": self.total_steps - 1,
"start_lr": "max_lr",
"end_lr": "min_lr",
"start_momentum": "base_momentum",
"end_momentum": "max_momentum",
},
]
# Validate pct_start
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
raise ValueError(
f"Expected float between 0 and 1 pct_start, but got {pct_start}"
)
# Validate anneal_strategy
if anneal_strategy not in ["cos", "linear"]:
raise ValueError(
f"anneal_strategy must be one of 'cos' or 'linear', instead got {anneal_strategy}"
)
else:
self._anneal_func_type = anneal_strategy
# Initialize learning rate variables
max_lrs = _format_param("max_lr", self.optimizer, max_lr)
if last_epoch == -1:
for idx, group in enumerate(self.optimizer.param_groups):
group["initial_lr"] = max_lrs[idx] / div_factor
group["max_lr"] = max_lrs[idx]
group["min_lr"] = group["initial_lr"] / final_div_factor
# Initialize momentum variables
self.cycle_momentum = cycle_momentum
if self.cycle_momentum:
if (
"momentum" not in self.optimizer.defaults
and "betas" not in self.optimizer.defaults
):
raise ValueError(
"optimizer must support momentum or beta1 with `cycle_momentum` option enabled"
)
self.use_beta1 = "betas" in self.optimizer.defaults
max_momentums = _format_param("max_momentum", optimizer, max_momentum)
base_momentums = _format_param("base_momentum", optimizer, base_momentum)
if last_epoch == -1:
for m_momentum, b_momentum, group in zip(
max_momentums, base_momentums, optimizer.param_groups, strict=True
):
if self.use_beta1:
group["betas"] = (m_momentum, *group["betas"][1:])
else:
group["momentum"] = m_momentum
group["max_momentum"] = m_momentum
group["base_momentum"] = b_momentum
super().__init__(optimizer, last_epoch)
def _anneal_func(self, *args, **kwargs):
if hasattr(self, "_anneal_func_type"):
if self._anneal_func_type == "cos":
return self._annealing_cos(*args, **kwargs)
elif self._anneal_func_type == "linear":
return self._annealing_linear(*args, **kwargs)
else:
raise ValueError(f"Unknown _anneal_func_type: {self._anneal_func_type}")
else:
# For BC
return self.anneal_func(*args, **kwargs) # type: ignore[attr-defined]
@staticmethod
def _annealing_cos(start, end, pct):
"""Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
@staticmethod
def _annealing_linear(start, end, pct):
"""Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
return (end - start) * pct + start
@override
def get_lr(self) -> list[float | Tensor]:
r"""Compute the next learning rate for each of the optimizer's
:attr:`~torch.optim.Optimizer.param_groups`.
Finds the appropriate :attr:`_schedule_phases` entry for the current
step and interpolates between its ``start_lr`` and ``end_lr`` using
:meth:`_anneal_func`.
Returns:
list[float | Tensor]: A :class:`list` of learning rates for each of
the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
same types as their current ``group["lr"]``\s.
.. note::
If you're trying to inspect the most recent learning rate, use
:meth:`get_last_lr()` instead.
.. note::
The returned :class:`~torch.Tensor`\s are copies, and never alias
the optimizer's ``group["lr"]``\s.
.. note::
When :attr:`cycle_momentum` is ``True``, this method has a side
effect of updating the optimizer's momentum.
"""
_warn_get_lr_called_within_step(self)
lrs = []
step_num = self.last_epoch
if step_num > self.total_steps:
raise ValueError(
f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}"
)
for group in self.optimizer.param_groups:
start_step = 0.0
for i, phase in enumerate(self._schedule_phases):
end_step = phase["end_step"]
if step_num <= end_step or i == len(self._schedule_phases) - 1:
pct = (step_num - start_step) / (end_step - start_step)
computed_lr = self._anneal_func(
group[phase["start_lr"]], group[phase["end_lr"]], pct
)
if self.cycle_momentum:
computed_momentum = self._anneal_func(
group[phase["start_momentum"]],
group[phase["end_momentum"]],
pct,
)
break
start_step = phase["end_step"]
lrs.append(computed_lr) # type: ignore[possibly-undefined]
if self.cycle_momentum:
if self.use_beta1:
group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined]
else:
group["momentum"] = computed_momentum # type: ignore[possibly-undefined]
return lrs
ViTAX-从零拆解 ViT 的 JAX/NNX 实现
[https://maurocomi.com/blog/vit.html]
Seeing the World Through Transformers
Step 1: Patchifying the Image, From Pixels to Patches
Step 2: Patch Embedding
Step 3: Positional Embeddings
Step 4: The CLS Token
Step 5: Attention and the Transformer Encoder
Step 6: Classification Head
Training the ViT
Conclusion
一、文章定位
这是一篇从零手写 Vision Transformer (ViT) 的教程,使用 JAX + NNX(Google 推荐的 JAX 新一代深度学习框架)实现,面向希望深入理解 ViT 内部机制的读者。文章强调直觉理解,每个组件都解释了"为什么"以及代码实现。
二、ViT 核心架构六步走
| 步骤 | 组件 | 作用 |
|---|---|---|
| 1 | 图像分块 (Patchify) | 用 einops 将图像从 (B,H,W,C) 重排为 (B, N, P²C),把图像切成固定大小的 patch 序列 |
| 2 | Patch Embedding | 通过可学习的 nnx.Linear 层,将每个展平的 patch 投影到高维嵌入空间 |
| 3 | 位置嵌入 (Positional Embedding) | 由于 Transformer 是置换不变的,需添加可学习的位置向量来保留空间信息 |
| 4 | CLS Token | 借鉴 BERT,在序列开头添加一个特殊的分类 token,通过自注意力聚合全局信息 |
| 5 | Transformer Encoder | 堆叠多个 Encoder Block,每个包含:LayerNorm → 多头自注意力 → 残差连接 → LayerNorm → MLP → 残差连接 |
| 6 | 分类头 (Classification Head) | 取 CLS token 的最终输出,经线性层映射到类别数,得到分类 logits |
三、关键实现细节
- 分块函数:使用
einops.rearrange优雅地实现图像到 patch 序列的转换,并用@jax.jit编译加速 - 多头自注意力:
- 每个 token 生成 Q/K/V 三个向量
- 通过缩放点积注意力计算相关性:
scores = (Q·Kᵀ) / √d_k - 多组注意力头并行工作,捕捉不同维度的特征关系
- 残差连接:解决深层网络的梯度消失问题,提供梯度回传的"捷径"
- LayerNorm:采用无参数版本(论文表明参数并非总是有益)
四、训练流程
使用 Hugging Face 的 snacks 数据集作为示例:
- 数据预处理:加载数据集、调整尺寸、转换为 JAX 数组格式
- 损失函数:交叉熵损失(Cross-Entropy)
- 训练步骤:
nnx.value_and_grad自动计算损失和梯度nnx.Optimizer(AdamW)更新参数nnx.jit编译加速nnx.scan高效遍历批次数据
- 验证:定期在验证集上评估模型性能
五、NNX vs 传统 JAX/Flax
- NNX 提供了类似 PyTorch 的面向对象编程模型,同时保留 JAX 的函数式核心优势
- 需要显式处理随机数生成器(
rngs),确保可复现性 - 状态管理更直观:
optimizer.update()会直接修改模型状态,无需像纯函数式 JAX 那样手动传递状态
六、结论与展望
文章构建了一个功能完整的基线 ViT,涵盖了从图像分块到分类的完整流程。同时指出,自 2020 年原始论文发表以来,ViT 已有诸多改进方向:
- 混合架构(Hybrid ViTs,结合 CNN 的 overlapping patches)
- 更高效的注意力机制
- 二维/相对位置编码方案
- 先进的训练策略(如数据增强、知识蒸馏等)
ViT 不仅成为图像分类的 SOTA 骨干网络,还广泛应用于目标检测、语义分割,以及生成模型(如 Diffusion Transformers / DiT)中。
工程设计
2.1 从图像到序列:Patchify
ViT 的核心思想是将图像视为"视觉句子",每个 Patch 是一个"词"。对于 28×28 的 MNIST 图像,使用 7×7 的 Patch 尺寸,可得到:
Num Patches = (28 / 7) × (28 / 7) = 4 × 4 = 16
Patch Dim = 7 × 7 × 1 = 49
通过 nn.Conv2d(kernel_size=7, stride=7) 一步完成 Patchify + Linear Projection,输出 (B, 16, 96)。
2.2 可学习位置编码
与 BERT 类似,ViT 使用可学习的位置嵌入(Learnable Positional Embedding)。序列长度为 16 patches + 1 CLS token = 17,每个位置对应一个 96-dim 向量,随机初始化并在训练中更新。
为什么不使用固定 sinusoidal?ViTAX 指出:可学习嵌入允许模型针对特定任务优化位置表示,在小型数据集上通常表现更好。
2.3 CLS Token:全局信息的"海绵"
在序列最前方添加一个可学习的 [CLS] Token。经过多层 Self-Attention 后,该 Token 通过与所有 Patch Embedding 交互,聚合全局图像信息。最终取 CLS Token 的输出送入分类头,而非对所有 Patch 做平均池化。
2.4 Multi-Head Self-Attention
每个 Encoder Block 包含:
LayerNorm -> MultiHeadAttention -> Residual
LayerNorm -> MLP (GELU) -> Residual
- Pre-LayerNorm:在子层输入前做归一化,训练更稳定。
- Residual Connection:防止梯度消失,使 16 层深层网络可训练。
- GELU 激活:比 ReLU 更平滑,在 Transformer 中成为事实标准。
2.5 分类头
self.head = nn.Linear(embed_dim, num_classes) # 96 -> 10
三、MNIST 超参数调优
参考 Keras 在 MNIST 上的 ViT 实验数据,本方案选择以下配置:
| 参数 | 取值 | 调优依据 |
|---|---|---|
| Patch Size | 7 | 28×28 图像下,7×7 在精度与速度间最佳平衡 |
| Embed Dim | 96 | Keras 实验:96-dim 可达 98.94% |
| Depth | 16 | 深层网络提升表征能力 |
| Num Heads | 4 | 96-dim 下每头 24-dim,计算高效 |
| MLP Ratio | 4.0 | 隐藏层 384,标准配置 |
| Batch Size | 128 | CPU/GPU 兼容 |
| LR | 1e-3 | AdamW 默认值 |
| Weight Decay | 1e-4 | 防止过拟合 |
数据增强策略:
RandomRotation(7°):模拟不同书写角度RandomAffine(translate=0.1):微小位移- 禁用水平翻转:避免 6/9 混淆
- Normalize
(0.1307, 0.3081):MNIST 标准统计值
工程
mnist数据集
[https://github.com/sunsided/mnist]
train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
the IDX file format is a simple format for vectors and multidimensional matrices of various numerical types. The basic format is
magic number
size in dimension 0
size in dimension 1
size in dimension 2
.....
size in dimension N
data
The magic number is an integer (MSB first). The first 2 bytes are always 0.
The third byte codes the type of the data: 0x08: unsigned byte 0x09: signed byte 0x0B: short (2 bytes) 0x0C: int (4 bytes) 0x0D: float (4 bytes) 0x0E: double (8 bytes)
The 4-th byte codes the number of dimensions of the vector/matrix: 1 for vectors, 2 for matrices....
The sizes in each dimension are 4-byte integers (MSB first, high endian, like in most non-Intel processors).
The data is stored like in a C array, i.e. the index in the last dimension changes the fastest.
pytorch代码
结构
inference.py
train.py
utils.py
vit_model.py
pyproject.toml
[project]
name = "vit-mnist"
version = "0.1.0"
description = "Vision Transformer on MNIST"
requires-python = ">=3.13"
dependencies = [
"torch>=2.9.0",
"torchvision>=0.24.0",
"numpy",
"matplotlib",
]
[project.optional-dependencies]
dev = ["pytest"]
vit_model.py
"""
ViT Model for MNIST Classification (PyTorch)
Reference:
- ViTAX: Building a Vision Transformer from Scratch (JAX/NNX)
- Keras Example: Image classification with Vision Transformer
"""
import math
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
"""Image to Patch Embedding using Conv2d."""
def __init__(self, img_size=28, patch_size=7, in_chans=1, embed_dim=96):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(
in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (B, C, H, W) -> (B, embed_dim, H//p, W//p) -> (B, N, embed_dim)
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class TransformerBlock(nn.Module):
"""Standard Transformer Encoder Block with Pre-LayerNorm."""
def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.attn = nn.MultiheadAttention(
dim, num_heads, dropout=dropout, batch_first=True
)
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
# Pre-norm + MHA + Residual
x_norm = self.norm1(x)
attn_out, _ = self.attn(x_norm, x_norm, x_norm)
x = x + attn_out
# Pre-norm + MLP + Residual
x = x + self.mlp(self.norm2(x))
return x
class ViT(nn.Module):
"""Vision Transformer for MNIST."""
def __init__(
self,
img_size=28,
patch_size=7,
in_chans=1,
num_classes=10,
embed_dim=96,
depth=16,
num_heads=4,
mlp_ratio=4.0,
dropout=0.1,
):
super().__init__()
self.num_classes = num_classes
self.embed_dim = embed_dim
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
# CLS token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# Positional embedding (learnable)
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=dropout)
# Transformer Encoder blocks
self.blocks = nn.Sequential(*[
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
self.head = nn.Linear(embed_dim, num_classes)
# Initialize weights
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.cls_token, std=0.02)
nn.init.normal_(self.pos_embed, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x) # (B, N, embed_dim)
# Concatenate CLS token
cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, embed_dim)
x = torch.cat((cls_tokens, x), dim=1) # (B, N+1, embed_dim)
# Add positional embedding
x = x + self.pos_embed
x = self.pos_drop(x)
# Transformer blocks
x = self.blocks(x)
# LayerNorm + CLS head
x = self.norm(x)
cls_output = x[:, 0] # (B, embed_dim)
logits = self.head(cls_output) # (B, num_classes)
return logits
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == "__main__":
model = ViT()
dummy = torch.randn(2, 1, 28, 28)
out = model(dummy)
print("Output shape:", out.shape)
print("Params:", count_parameters(model))
utils.py
"""
Utilities: data loading, metrics, and TODO.md auto-update.
"""
import os
import re
from datetime import datetime
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
def get_mnist_loaders(data_dir="./mnist", batch_size=128, num_workers=0):
"""Create MNIST train/val/test DataLoaders."""
train_transform = transforms.Compose([
transforms.RandomRotation(degrees=7),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
full_train = datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
test_dataset = datasets.MNIST(root=data_dir, train=False, download=True, transform=test_transform)
# Split train into train/val (90/10)
train_size = int(0.9 * len(full_train))
val_size = len(full_train) - train_size
train_dataset, val_dataset = random_split(
full_train, [train_size, val_size],
generator=torch.Generator().manual_seed(42)
)
# Override val transform to deterministic
val_dataset.dataset.transform = test_transform
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_loader, val_loader, test_loader
def accuracy(output, target, topk=(1,)):
"""Compute top-k accuracy."""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def evaluate(model, loader, criterion, device):
model.eval()
total_loss = 0.0
total_correct_1 = 0
total_correct_5 = 0
total_samples = 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item() * images.size(0)
acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
total_correct_1 += acc1.item() * images.size(0) / 100.0
total_correct_5 += acc5.item() * images.size(0) / 100.0
total_samples += images.size(0)
avg_loss = total_loss / total_samples
avg_acc1 = total_correct_1 / total_samples * 100.0
avg_acc5 = total_correct_5 / total_samples * 100.0
return avg_loss, avg_acc1, avg_acc5
def update_todo_md(epoch, train_loss, val_acc, is_best, todo_path="./TODO.md"):
"""Auto-update TODO.md after each epoch."""
if not os.path.exists(todo_path):
return
with open(todo_path, "r", encoding="utf-8") as f:
content = f.read()
# Pattern to match the epoch row
pattern = rf"(\| {epoch} \| )⬜ 未开始( \| — \| — \| — \| — \|)"
replacement = rf"\1✅ 已完成\2"
content_new = re.sub(pattern, replacement, content)
# If row was already updated or not found, try a more flexible replacement
if content_new == content:
old_row = f"| {epoch} | ⬜ 未开始 | — | — | — | — |"
new_row = f"| {epoch} | ✅ 已完成 | {train_loss:.4f} | {val_acc:.2%} | {'✅' if is_best else '—'} | {datetime.now().strftime('%Y-%m-%d %H:%M')} |"
content_new = content.replace(old_row, new_row)
if content_new != content:
with open(todo_path, "w", encoding="utf-8") as f:
f.write(content_new)
train.py
"""
Training script for ViT on MNIST.
Auto-updates TODO.md after each epoch.
"""
import os
import sys
import argparse
import time
import torch
import torch.nn as nn
from tqdm import tqdm
from vit_model import ViT, count_parameters
from utils import get_mnist_loaders, evaluate, update_todo_md
def train_one_epoch(model, loader, criterion, optimizer, device, grad_clip=1.0):
model.train()
total_loss = 0.0
total_correct = 0
total_samples = 0
pbar = tqdm(loader, desc="Training", leave=False)
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
if grad_clip > 0:
nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
optimizer.step()
total_loss += loss.item() * images.size(0)
_, pred = outputs.max(1)
total_correct += pred.eq(labels).sum().item()
total_samples += images.size(0)
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
avg_loss = total_loss / total_samples
avg_acc = total_correct / total_samples * 100.0
return avg_loss, avg_acc
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--weight_decay", type=float, default=1e-4)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--device", type=str, default="auto")
parser.add_argument("--save_dir", type=str, default="./checkpoints")
parser.add_argument("--todo_path", type=str, default="./TODO.md")
args = parser.parse_args()
if args.device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
print(f"Device: {device}")
os.makedirs(args.save_dir, exist_ok=True)
# Data
train_loader, val_loader, test_loader = get_mnist_loaders(
data_dir="./mnist", batch_size=args.batch_size
)
# Model
model = ViT(
img_size=28,
patch_size=7,
in_chans=1,
num_classes=10,
embed_dim=96,
depth=16,
num_heads=4,
mlp_ratio=4.0,
dropout=0.1,
).to(device)
print(f"Model parameters: {count_parameters(model):,}")
# Loss / Optimizer / Scheduler
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs
)
best_acc = 0.0
best_epoch = 0
print("\n========== Training Start ==========")
total_train_time = 0.0
for epoch in range(1, args.epochs + 1):
print(f"\nEpoch [{epoch}/{args.epochs}]")
epoch_start = time.time()
train_loss, train_acc = train_one_epoch(
model, train_loader, criterion, optimizer, device, args.grad_clip
)
epoch_train_time = time.time() - epoch_start
total_train_time += epoch_train_time
val_loss, val_acc, val_acc5 = evaluate(model, val_loader, criterion, device)
scheduler.step()
is_best = val_acc > best_acc
if is_best:
best_acc = val_acc
best_epoch = epoch
save_path = os.path.join(args.save_dir, "best_model.pt")
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"best_acc": best_acc,
}, save_path)
print(f" -> New best model saved (val_acc={val_acc:.2f}%)")
print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
print(f" Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Val Top-5: {val_acc5:.2f}%")
# Auto-update TODO.md
update_todo_md(epoch, train_loss, val_acc, is_best, todo_path=args.todo_path)
# Final test
print("\n========== Final Test ==========")
checkpoint = torch.load(os.path.join(args.save_dir, "best_model.pt"), map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
test_loss, test_acc, test_acc5 = evaluate(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Acc: {test_acc:.2f}%")
print(f"Test Top-5: {test_acc5:.2f}%")
print(f"Best Val Acc: {best_acc:.2f}% @ Epoch {best_epoch}")
# Print timing summary
print("\n========== Timing Summary ==========")
print(f"Device: {device}")
print(f"Total epochs: {args.epochs}")
print(f"Total training time: {total_train_time:.2f}s")
print(f"Avg time per epoch: {total_train_time/args.epochs:.2f}s")
# Save timing to file
timing_file = os.path.join(args.save_dir, f"timing_{device.type}.txt")
with open(timing_file, "w") as f:
f.write(f"device: {device}\n")
f.write(f"epochs: {args.epochs}\n")
f.write(f"batch_size: {args.batch_size}\n")
f.write(f"total_train_time: {total_train_time:.2f}\n")
f.write(f"avg_time_per_epoch: {total_train_time/args.epochs:.2f}\n")
f.write(f"best_val_acc: {best_acc:.2f}\n")
f.write(f"test_acc: {test_acc:.2f}\n")
print(f"Timing saved to {timing_file}")
if __name__ == "__main__":
main()
inference.py
"""
Inference script for ViT on MNIST.
Supports single prediction, batch visualization, and ONNX export.
"""
import os
import argparse
import random
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from vit_model import ViT
from utils import get_mnist_loaders
def load_model(checkpoint_path, device):
model = ViT(
img_size=28, patch_size=7, in_chans=1, num_classes=10,
embed_dim=96, depth=16, num_heads=4, mlp_ratio=4.0, dropout=0.0,
).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model
def predict_single(model, image_tensor, device):
image_tensor = image_tensor.unsqueeze(0).to(device)
with torch.no_grad():
logits = model(image_tensor)
probs = torch.softmax(logits, dim=1)
pred = logits.argmax(dim=1).item()
return pred, probs[0].cpu().numpy()
def visualize_batch(model, test_loader, device, num_samples=16, save_path="./outputs/predictions.png"):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
model.eval()
# Collect some samples
images_list = []
labels_list = []
for images, labels in test_loader:
images_list.append(images)
labels_list.append(labels)
if sum(len(x) for x in images_list) >= num_samples:
break
images = torch.cat(images_list)[:num_samples]
labels = torch.cat(labels_list)[:num_samples]
images = images.to(device)
with torch.no_grad():
logits = model(images)
preds = logits.argmax(dim=1).cpu().numpy()
images = images.cpu().numpy()
labels = labels.cpu().numpy()
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
axes = axes.flatten()
for i in range(num_samples):
ax = axes[i]
img = images[i][0] # (1, 28, 28) -> (28, 28)
# Denormalize
img = img * 0.3081 + 0.1307
ax.imshow(img, cmap='gray')
color = 'green' if preds[i] == labels[i] else 'red'
ax.set_title(f"Pred: {preds[i]}\nTrue: {labels[i]}", color=color, fontsize=10)
ax.axis('off')
plt.tight_layout()
plt.savefig(save_path, dpi=150)
print(f"Visualization saved to {save_path}")
def export_onnx(model, save_path="./outputs/vit_mnist.onnx"):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(
model,
dummy_input,
save_path,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
opset_version=11,
)
print(f"ONNX model exported to {save_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, default="./checkpoints/best_model.pt")
parser.add_argument("--mode", type=str, choices=["visualize", "onnx"], default="visualize")
parser.add_argument("--device", type=str, default="auto")
parser.add_argument("--output", type=str, default="./outputs/predictions.png")
args = parser.parse_args()
if args.device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
model = load_model(args.checkpoint, device)
if args.mode == "visualize":
_, _, test_loader = get_mnist_loaders(batch_size=128)
visualize_batch(model, test_loader, device, num_samples=16, save_path=args.output)
elif args.mode == "onnx":
export_onnx(model, save_path="./outputs/vit_mnist.onnx")
if __name__ == "__main__":
main()
launch.py
#!/usr/bin/env python3
"""
项目启动器 - 自动检测并使用可用的 PyTorch 环境
"""
import sys
import os
import subprocess
# 优先使用 ultralytics 环境的 Python(已知有 PyTorch 2.9.1+cu128)
ULTRALYTICS_PYTHON = "/home/qsbye/.local/share/uv/tools/ultralytics/bin/python"
PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
def check_pytorch(python_path):
"""检查指定 Python 是否有 PyTorch"""
try:
result = subprocess.run(
[python_path, "-c", "import torch; print(torch.__version__)"],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
return result.stdout.strip()
except:
pass
return None
def main():
# 检查当前 Python 是否有 PyTorch
current_torch = check_pytorch(sys.executable)
if current_torch:
print(f"✅ 当前环境已有 PyTorch {current_torch}")
print(f" Python: {sys.executable}")
else:
print("⚠️ 当前环境未安装 PyTorch")
# 检查 ultralytics 环境
ultralytics_torch = check_pytorch(ULTRALYTICS_PYTHON)
if ultralytics_torch:
print(f"🔄 切换到 ultralytics 环境: PyTorch {ultralytics_torch}")
print(f" Python: {ULTRALYTICS_PYTHON}")
# 重新执行脚本
args = [ULTRALYTICS_PYTHON] + sys.argv
os.execv(ULTRALYTICS_PYTHON, args)
else:
print("❌ 未找到可用的 PyTorch 环境")
sys.exit(1)
# 导入项目模块
import torch
print(f"\n📊 PyTorch: {torch.__version__}")
print(f"📊 CUDA: {torch.version.cuda}")
if torch.cuda.is_available():
print(f"📊 GPU: {torch.cuda.get_device_name(0)}")
# 运行训练或推理
if len(sys.argv) > 1:
script = sys.argv[1]
script_path = os.path.join(PROJECT_DIR, "src", script)
if os.path.exists(script_path):
print(f"\n🚀 运行: {script_path}")
exec(open(script_path).read(), {"__name__": "__main__", "__file__": script_path})
else:
print(f"❌ 脚本不存在: {script_path}")
else:
print("\n用法: python launch.py [train.py|inference.py|vit_model.py]")
if __name__ == "__main__":
main()
训练及推理命令
# GPU训练
/home/qsbye/.local/share/uv/tools/ultralytics/bin/python src/train.py --device cuda
# GPU 推理
python src/inference.py --checkpoint ./checkpoints_gpu/best_model.pt --device cuda
训练输出
🚀 MNIST-ViT GPU vs CPU 训练对比报告
📊 实验配置
• 模型: ViT (embed_dim=96, depth=16, heads=4)
• 参数: 1,797,130
• 数据集: MNIST (60k训练 / 10k测试)
• 批次: 128
• 轮数: 5 epochs
• PyTorch: 2.9.1+cu128
━━━━━━━━━━━━━━━━━━━━━━━
⏱️ 训练速度对比
| 指标 | GPU (A100) | CPU | 对比 |
|------|-----------|-----|------|
| 总训练时间 | 165.94s | 1107.88s | GPU 快 6.7x |
| 每轮平均 | 33.19s | 221.58s | — |
| 节省时间 | — | — | 15.7 分钟 |
━━━━━━━━━━━━━━━━━━━━━━━
🎯 准确率对比
| 指标 | GPU | CPU |
|------|-----|-----|
| Best Val Acc | 96.63% | 97.03% |
| Test Acc | 96.93% | 97.39% |
| Test Top-5 | 99.93% | 99.90% |
━━━━━━━━━━━━━━━━━━━━━━━
💡 结论
1. 🚀 GPU (A100) 训练速度是 CPU 的 6.7 倍
2. 🎯 两者准确率接近 (~97%),GPU 略低可能因数值精度差异
3. ⚡ 对于 ViT 模型,GPU 加速效果显著,强烈推荐使用 GPU 训练
4. 💰 5轮训练 GPU 仅需 ~2.8 分钟,CPU 需 ~18.5 分钟
━━━━━━━━━━━━━━━━━━━━━━━
📁 生成文件
• checkpoints_gpu/best_model.pt
• checkpoints_cpu/best_model.pt
• outputs/predictions_gpu.png
• outputs/predictions_cpu.png
推理输出
| 图片 |
|---|
![]() |
分析
随机抽取 16 张测试图像,绘制 4×4 网格对比图:
- 绿色标题:预测正确
- 红色标题:预测错误,括号内为真实标签
GPU 与 CPU 训练的准确率差异(~0.4%)属于正常波动范围,原因包括:
- 数值精度差异:GPU 使用 Tensor Core 的混合精度计算,CPU 为纯 FP32,累积误差不同
- 随机性因素:数据增强、Dropout、权重初始化均引入随机性
- 优化器状态:AdamW 的二阶矩估计在不同硬件上数值路径略有差异

本文基于 Vision Transformer (ViT) 架构,在 MNIST 数据集上完成图像分类任务的端到端实践。通过 Patch Embedding、可学习位置编码、CLS Token、Multi-Head Self-Attention 与 MLP 分类头的模块化构建,配合 AdamW + CosineAnnealingLR 训练策略,在 PyTorch框架的CPU/GPU 环境下完成 20 轮训练。本方案的核心价值在于验证 ViT 从零构建、训练、推理的完整工程链路。

浙公网安备 33010602011771号