Pytorch-Lightning基本方法介绍

LIGHTNINGMODULE
LightningModule将PyTorch代码整理成5个部分:
- Computations (init).
- Train loop (training_step)
- Validation loop (validation_step)
- Test loop (test_step)
- Optimizers (configure_optimizers)
Minimal Example
所需要的方法:
1 import pytorch_lightning as pl 2 class LitModel(pl.LightningModule): 3 4 def __init__(self): 5 super().__init__() 6 self.l1 = torch.nn.Linear(28 * 28, 10) 7 8 def forward(self, x): 9 return torch.relu(self.l1(x.view(x.size(0), -1))) 10 11 def training_step(self, batch, batch_idx): 12 x, y = batch 13 y_hat = self(x) 14 loss = F.cross_entropy(y_hat, y) 15 return loss 16 17 def configure_optimizers(self): 18 return torch.optim.Adam(self.parameters(), lr=0.02)
使用下面的代码进行训练:
1 train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())) 2 trainer = pl.Trainer() 3 model = LitModel() 4 5 trainer.fit(model, train_loader)
一些基本方法
Training
Training loop
使用training_step方法来增加training loop
1 class LitClassifier(pl.LightningModule): 2 3 def __init__(self, model): 4 super().__init__() 5 self.model = model 6 7 def training_step(self, batch, batch_idx): 8 x, y = batch 9 y_hat = self.model(x) 10 loss = F.cross_entropy(y_hat, y) 11 return loss
如果需要在epoch-level进行度量,并进行记录,可以使用*.log*方法
1 def training_step(self, batch, batch_idx): 2 x, y = batch 3 y_hat = self.model(x) 4 loss = F.cross_entropy(y_hat, y) 5 6 # logs metrics for each training_step, 7 # and the average across the epoch, to the progress bar and logger 8 self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 9 return loss
如果需要对每个training_step的输出做一些操作,可以通过改写training_epoch_end来实现
1 def training_step(self, batch, batch_idx): 2 x, y = batch 3 y_hat = self.model(x) 4 loss = F.cross_entropy(y_hat, y) 5 preds = ... 6 return {'loss': loss, 'other_stuff': preds} 7 8 def training_epoch_end(self, training_step_outputs): 9 for pred in training_step_outputs: 10 # do something
如果需要对每个batch分配到不同GPU上进行训练,可以采用training_step_end方法来实现
1 def training_step(self, batch, batch_idx): 2 x, y = batch 3 y_hat = self.model(x) 4 loss = F.cross_entropy(y_hat, y) 5 pred = ... 6 return {'loss': loss, 'pred': pred} 7 8 def training_step_end(self, batch_parts): 9 gpu_0_prediction = batch_parts.pred[0]['pred'] 10 gpu_1_prediction = batch_parts.pred[1]['pred'] 11 12 # do something with both outputs 13 return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2 14 15 def training_epoch_end(self, training_step_outputs): 16 for out in training_step_outputs: 17 # do something with preds
Validation loop
增加一个validation loop,可以通过改写LightningModule中的validation_step来实现
1 class LitModel(pl.LightningModule): 2 def validation_step(self, batch, batch_idx): 3 x, y = batch 4 y_hat = self.model(x) 5 loss = F.cross_entropy(y_hat, y) 6 self.log('val_loss', loss)
对validation进行epoch-level度量,可以通过改写validation_epoch_end实现
1 def validation_step(self, batch, batch_idx): 2 x, y = batch 3 y_hat = self.model(x) 4 loss = F.cross_entropy(y_hat, y) 5 pred = ... 6 return pred 7 8 def validation_epoch_end(self, validation_step_outputs): 9 for pred in validation_step_outputs: 10 # do something with a pred
如果需要validation进行数据并行计算(多GPU),可以通过validation_step_end方法实现
1 def validation_step(self, batch, batch_idx): 2 x, y = batch 3 y_hat = self.model(x) 4 loss = F.cross_entropy(y_hat, y) 5 pred = ... 6 return {'loss': loss, 'pred': pred} 7 8 def validation_step_end(self, batch_parts): 9 gpu_0_prediction = batch_parts.pred[0]['pred'] 10 gpu_1_prediction = batch_parts.pred[1]['pred'] 11 12 # do something with both outputs 13 return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2 14 15 def validation_epoch_end(self, validation_step_outputs): 16 for out in validation_step_outputs: 17 # do something with preds
Test loop
增加一个test loop的过程和上面增加validation loop是相同的,唯一不同的是,只有在使用*.test()*的时候,test loop才会被调用
1 model = Model() 2 trainer = Trainer() 3 trainer.fit() 4 5 # automatically loads the best weights for you 6 trainer.test(model)
这里,有两种方式调用test():
1 # call after training 2 trainer = Trainer() 3 trainer.fit(model) 4 5 # automatically auto-loads the best weights 6 trainer.test(test_dataloaders=test_dataloader) 7 8 # or call with pretrained model 9 model = MyLightningModule.load_from_checkpoint(PATH) 10 trainer = Trainer() 11 trainer.test(model, test_dataloaders=test_dataloader)
Inference
对于研究,LightningModules像系统一样结构化
1 import pytorch_lightning as pl 2 import torch 3 from torch import nn 4 5 class Autoencoder(pl.LightningModule): 6 7 def __init__(self, latent_dim=2): 8 super().__init__() 9 self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim)) 10 self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28)) 11 12 def training_step(self, batch, batch_idx): 13 x, _ = batch 14 15 # encode 16 x = x.view(x.size(0), -1) 17 z = self.encoder(x) 18 19 # decode 20 recons = self.decoder(z) 21 22 # reconstruction 23 reconstruction_loss = nn.functional.mse_loss(recons, x) 24 return reconstruction_loss 25 26 def validation_step(self, batch, batch_idx): 27 x, _ = batch 28 x = x.view(x.size(0), -1) 29 z = self.encoder(x) 30 recons = self.decoder(z) 31 reconstruction_loss = nn.functional.mse_loss(recons, x) 32 self.log('val_reconstruction', reconstruction_loss) 33 34 def configure_optimizers(self): 35 return torch.optim.Adam(self.parameters(), lr=0.0002)
可以用如下方式训练
1 autoencoder = Autoencoder() 2 trainer = pl.Trainer(gpus=1) 3 trainer.fit(autoencoder, train_dataloader, val_dataloader)
lightning inference部分的方法:
- training_step
- validation_step
- test_step
- configure_optimizers
注意到在这个例子中,train loop和val loop完全相同,我们可以重复使用这部分代码
1 class Autoencoder(pl.LightningModule): 2 3 def __init__(self, latent_dim=2): 4 super().__init__() 5 self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim)) 6 self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28)) 7 8 def training_step(self, batch, batch_idx): 9 loss = self.shared_step(batch) 10 11 return loss 12 13 def validation_step(self, batch, batch_idx): 14 loss = self.shared_step(batch) 15 self.log('val_loss', loss) 16 17 def shared_step(self, batch): 18 x, _ = batch 19 20 # encode 21 x = x.view(x.size(0), -1) 22 z = self.encoder(x) 23 24 # decode 25 recons = self.decoder(z) 26 27 # loss 28 return nn.functional.mse_loss(recons, x) 29 30 def configure_optimizers(self): 31 return torch.optim.Adam(self.parameters(), lr=0.0002)
注:我们创建了所有loop都可以使用的一个新方法shared_step,这个方法的名字可以任意取
Inference in research
如果需要进行系统推断,可以将forward方法加入到LightningModule中
1 class Autoencoder(pl.LightningModule): 2 def forward(self, x): 3 return self.decoder(x)
在复杂系统中增加forward的优势,使得可以进行包含inference procedure等
1 class Seq2Seq(pl.LightningModule): 2 3 def forward(self, x): 4 embeddings = self(x) 5 hidden_states = self.encoder(embeddings) 6 for h in hidden_states: 7 # decode 8 ... 9 return decoded
Inference in production
在LightningModule中迭代不同的模型
1 import pytorch_lightning as pl 2 from pytorch_lightning.metrics import functional as FM 3 4 class ClassificationTask(pl.LightningModule): 5 6 def __init__(self, model): 7 super().__init__() 8 self.model = model 9 10 def training_step(self, batch, batch_idx): 11 x, y = batch 12 y_hat = self.model(x) 13 loss = F.cross_entropy(y_hat, y) 14 return loss 15 16 def validation_step(self, batch, batch_idx): 17 x, y = batch 18 y_hat = self.model(x) 19 loss = F.cross_entropy(y_hat, y) 20 acc = FM.accuracy(y_hat, y) 21 22 # loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on' 23 metrics = {'val_acc': acc, 'val_loss': loss} 24 self.log_dict(metrics) 25 return metrics 26 27 def test_step(self, batch, batch_idx): 28 metrics = self.validation_step(batch, batch_idx) 29 metrics = {'test_acc': metrics['val_acc'], 'test_loss': metrics['val_loss']} 30 self.log_dict(metrics) 31 32 def configure_optimizers(self): 33 return torch.optim.Adam(self.model.parameters(), lr=0.02)
然后将任意适合该task的模型传进去
1 for model in [resnet50(), vgg16(), BidirectionalRNN()]: 2 task = ClassificationTask(model) 3 4 trainer = Trainer(gpus=2) 5 trainer.fit(task, train_dataloader, val_dataloader)
tasks可以任意复杂,比如,可以实现GAN训练,self-supervised,甚至RL
1 class GANTask(pl.LightningModule): 2 3 def __init__(self, generator, discriminator): 4 super().__init__() 5 self.generator = generator 6 self.discriminator = discriminator 7 ...
del)
1 trainer = Trainer(gpus=2) 2 trainer.fit(task, train_dataloader, val_dataloader)
1 tasks可以任意复杂,比如,可以实现GAN训练,self-supervised,甚至RL 2 3 ```python 4 class GANTask(pl.LightningModule): 5 6 def __init__(self, generator, discriminator): 7 super().__init__() 8 self.generator = generator 9 self.discriminator = discriminator 10 ...

浙公网安备 33010602011771号