Pytorch-Lightning基本方法介绍

 

 

LIGHTNINGMODULE

LightningModule将PyTorch代码整理成5个部分:

  • Computations (init).
  • Train loop (training_step)
  • Validation loop (validation_step)
  • Test loop (test_step)
  • Optimizers (configure_optimizers)

Minimal Example

所需要的方法:

 1 import pytorch_lightning as pl
 2 class LitModel(pl.LightningModule):
 3 
 4      def __init__(self):
 5          super().__init__()
 6          self.l1 = torch.nn.Linear(28 * 28, 10)
 7 
 8      def forward(self, x):
 9          return torch.relu(self.l1(x.view(x.size(0), -1)))
10 
11      def training_step(self, batch, batch_idx):
12          x, y = batch
13          y_hat = self(x)
14          loss = F.cross_entropy(y_hat, y)
15          return loss
16 
17      def configure_optimizers(self):
18          return torch.optim.Adam(self.parameters(), lr=0.02)

使用下面的代码进行训练:

1 train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
2 trainer = pl.Trainer()
3 model = LitModel()
4 
5 trainer.fit(model, train_loader)

一些基本方法

Training

Training loop

使用training_step方法来增加training loop

 1 class LitClassifier(pl.LightningModule):
 2 
 3      def __init__(self, model):
 4          super().__init__()
 5          self.model = model
 6 
 7      def training_step(self, batch, batch_idx):
 8          x, y = batch
 9          y_hat = self.model(x)
10          loss = F.cross_entropy(y_hat, y)
11          return loss

如果需要在epoch-level进行度量,并进行记录,可以使用*.log*方法

1 def training_step(self, batch, batch_idx):
2     x, y = batch
3     y_hat = self.model(x)
4     loss = F.cross_entropy(y_hat, y)
5 
6     # logs metrics for each training_step,
7     # and the average across the epoch, to the progress bar and logger
8     self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
9     return loss

如果需要对每个training_step的输出做一些操作,可以通过改写training_epoch_end来实现

 1 def training_step(self, batch, batch_idx):
 2     x, y = batch
 3     y_hat = self.model(x)
 4     loss = F.cross_entropy(y_hat, y)
 5     preds = ...
 6     return {'loss': loss, 'other_stuff': preds}
 7 
 8 def training_epoch_end(self, training_step_outputs):
 9    for pred in training_step_outputs:
10        # do something

如果需要对每个batch分配到不同GPU上进行训练,可以采用training_step_end方法来实现

 1 def training_step(self, batch, batch_idx):
 2     x, y = batch
 3     y_hat = self.model(x)
 4     loss = F.cross_entropy(y_hat, y)
 5     pred = ...
 6     return {'loss': loss, 'pred': pred}
 7 
 8 def training_step_end(self, batch_parts):
 9     gpu_0_prediction = batch_parts.pred[0]['pred']
10     gpu_1_prediction = batch_parts.pred[1]['pred']
11 
12     # do something with both outputs
13     return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2
14 
15 def training_epoch_end(self, training_step_outputs):
16    for out in training_step_outputs:
17        # do something with preds
Validation loop

增加一个validation loop,可以通过改写LightningModule中的validation_step来实现

1 class LitModel(pl.LightningModule):
2     def validation_step(self, batch, batch_idx):
3         x, y = batch
4         y_hat = self.model(x)
5         loss = F.cross_entropy(y_hat, y)
6         self.log('val_loss', loss)

对validation进行epoch-level度量,可以通过改写validation_epoch_end实现

 1 def validation_step(self, batch, batch_idx):
 2     x, y = batch
 3     y_hat = self.model(x)
 4     loss = F.cross_entropy(y_hat, y)
 5     pred =  ...
 6     return pred
 7 
 8 def validation_epoch_end(self, validation_step_outputs):
 9    for pred in validation_step_outputs:
10        # do something with a pred

如果需要validation进行数据并行计算(多GPU),可以通过validation_step_end方法实现

 1 def validation_step(self, batch, batch_idx):
 2     x, y = batch
 3     y_hat = self.model(x)
 4     loss = F.cross_entropy(y_hat, y)
 5     pred = ...
 6     return {'loss': loss, 'pred': pred}
 7 
 8 def validation_step_end(self, batch_parts):
 9     gpu_0_prediction = batch_parts.pred[0]['pred']
10     gpu_1_prediction = batch_parts.pred[1]['pred']
11 
12     # do something with both outputs
13     return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2
14 
15 def validation_epoch_end(self, validation_step_outputs):
16    for out in validation_step_outputs:
17        # do something with preds
Test loop

增加一个test loop的过程和上面增加validation loop是相同的,唯一不同的是,只有在使用*.test()*的时候,test loop才会被调用

1 model = Model()
2 trainer = Trainer()
3 trainer.fit()
4 
5 # automatically loads the best weights for you
6 trainer.test(model)

这里,有两种方式调用test():

 1 # call after training
 2 trainer = Trainer()
 3 trainer.fit(model)
 4 
 5 # automatically auto-loads the best weights
 6 trainer.test(test_dataloaders=test_dataloader)
 7 
 8 # or call with pretrained model
 9 model = MyLightningModule.load_from_checkpoint(PATH)
10 trainer = Trainer()
11 trainer.test(model, test_dataloaders=test_dataloader)    

Inference

对于研究,LightningModules像系统一样结构化

 1 import pytorch_lightning as pl
 2 import torch
 3 from torch import nn
 4 
 5 class Autoencoder(pl.LightningModule):
 6 
 7      def __init__(self, latent_dim=2):
 8         super().__init__()
 9         self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
10         self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))
11 
12      def training_step(self, batch, batch_idx):
13         x, _ = batch
14 
15         # encode
16         x = x.view(x.size(0), -1)
17         z = self.encoder(x)
18 
19         # decode
20         recons = self.decoder(z)
21 
22         # reconstruction
23         reconstruction_loss = nn.functional.mse_loss(recons, x)
24         return reconstruction_loss
25 
26      def validation_step(self, batch, batch_idx):
27         x, _ = batch
28         x = x.view(x.size(0), -1)
29         z = self.encoder(x)
30         recons = self.decoder(z)
31         reconstruction_loss = nn.functional.mse_loss(recons, x)
32         self.log('val_reconstruction', reconstruction_loss)
33 
34      def configure_optimizers(self):
35         return torch.optim.Adam(self.parameters(), lr=0.0002)

可以用如下方式训练

1 autoencoder = Autoencoder()
2 trainer = pl.Trainer(gpus=1)
3 trainer.fit(autoencoder, train_dataloader, val_dataloader)

lightning inference部分的方法:

  • training_step
  • validation_step
  • test_step
  • configure_optimizers

注意到在这个例子中,train loop和val loop完全相同,我们可以重复使用这部分代码

 1 class Autoencoder(pl.LightningModule):
 2 
 3      def __init__(self, latent_dim=2):
 4         super().__init__()
 5         self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
 6         self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))
 7 
 8      def training_step(self, batch, batch_idx):
 9         loss = self.shared_step(batch)
10 
11         return loss
12 
13      def validation_step(self, batch, batch_idx):
14         loss = self.shared_step(batch)
15         self.log('val_loss', loss)
16 
17      def shared_step(self, batch):
18         x, _ = batch
19 
20         # encode
21         x = x.view(x.size(0), -1)
22         z = self.encoder(x)
23 
24         # decode
25         recons = self.decoder(z)
26 
27         # loss
28         return nn.functional.mse_loss(recons, x)
29 
30      def configure_optimizers(self):
31         return torch.optim.Adam(self.parameters(), lr=0.0002)

注:我们创建了所有loop都可以使用的一个新方法shared_step,这个方法的名字可以任意取

Inference in research

如果需要进行系统推断,可以将forward方法加入到LightningModule中

1 class Autoencoder(pl.LightningModule):
2     def forward(self, x):
3         return self.decoder(x)

在复杂系统中增加forward的优势,使得可以进行包含inference procedure等

1 class Seq2Seq(pl.LightningModule):
2 
3     def forward(self, x):
4         embeddings = self(x)
5         hidden_states = self.encoder(embeddings)
6         for h in hidden_states:
7             # decode
8             ...
9         return decoded
Inference in production

在LightningModule中迭代不同的模型

 1 import pytorch_lightning as pl
 2 from pytorch_lightning.metrics import functional as FM
 3 
 4 class ClassificationTask(pl.LightningModule):
 5 
 6      def __init__(self, model):
 7          super().__init__()
 8          self.model = model
 9 
10      def training_step(self, batch, batch_idx):
11          x, y = batch
12          y_hat = self.model(x)
13          loss = F.cross_entropy(y_hat, y)
14          return loss
15 
16      def validation_step(self, batch, batch_idx):
17         x, y = batch
18         y_hat = self.model(x)
19         loss = F.cross_entropy(y_hat, y)
20         acc = FM.accuracy(y_hat, y)
21 
22         # loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on'
23         metrics = {'val_acc': acc, 'val_loss': loss}
24         self.log_dict(metrics)
25         return metrics
26 
27      def test_step(self, batch, batch_idx):
28         metrics = self.validation_step(batch, batch_idx)
29         metrics = {'test_acc': metrics['val_acc'], 'test_loss': metrics['val_loss']}
30         self.log_dict(metrics)
31 
32      def configure_optimizers(self):
33          return torch.optim.Adam(self.model.parameters(), lr=0.02)

然后将任意适合该task的模型传进去

1 for model in [resnet50(), vgg16(), BidirectionalRNN()]:
2     task = ClassificationTask(model)
3 
4     trainer = Trainer(gpus=2)
5     trainer.fit(task, train_dataloader, val_dataloader)

tasks可以任意复杂,比如,可以实现GAN训练,self-supervised,甚至RL

1 class GANTask(pl.LightningModule):
2 
3      def __init__(self, generator, discriminator):
4          super().__init__()
5          self.generator = generator
6          self.discriminator = discriminator
7      ...

del)

1 trainer = Trainer(gpus=2)
2 trainer.fit(task, train_dataloader, val_dataloader)
 1 tasks可以任意复杂,比如,可以实现GAN训练,self-supervised,甚至RL
 2 
 3 ```python
 4 class GANTask(pl.LightningModule):
 5 
 6      def __init__(self, generator, discriminator):
 7          super().__init__()
 8          self.generator = generator
 9          self.discriminator = discriminator
10      ...

 

posted @ 2021-12-25 18:36  咖啡陪你  阅读(472)  评论(0)    收藏  举报