lightning框架
基本知识
max_epoch:指训练过程中完整遍历整个训练数据集的总次数。例如,max_epoch=10表示模型会把训练集中的所有样本完整过一遍,重复 10 次。
batch_size:指每次模型更新参数时输入的样本数量。例如,batch_size=32表示每次从数据集中抽取 32 个样本,计算损失并更新一次模型参数。具体在模型里的tran_step或者val_step、test_step中
2. 核心关系:共同决定训练总迭代次数
假设训练数据集总样本数为N,则:
每个 epoch 的迭代次数(steps) = 总样本数 /batch_size(向上取整)。
例如:若N=1000,batch_size=32,则每个 epoch 需要 1000÷32≈31.25 → 32 次迭代(最后一个 batch 可能不足 32 个样本)
设定epoch=1
def test_step(self, batch, batch_idx):
data, curves, plain_target = batch
predict_curve = self(data)
mse = nn.functional.mse_loss(predict_curve, curves)
self.log("test_mse",mse,on_epoch=True,prog_bar=True)
test_step 中计算的 mse 仅针对当前这个批次batch,是批次内样本的平均 MSE。
on_epoch=True 会让 Lightning 在整个测试 epoch 结束后,对所有批次的 mse 结果进行聚合(默认是取平均)
TGZ

浙公网安备 33010602011771号