深度学习(单机多gpu训练)

如果一个机器上有多个gpu,可以使用多gpu训练。

一般数据量和模型比较大的时候训练速度会有明显的提升,模型和数据比较小的时候反而可能因为数据通信原因导致性能下降。

下面是一个简单的例子:

import time
import torch
import torchvision.models
from torchvision.transforms import transforms
from torch import nn, optim
from torchvision.datasets import CIFAR10

if __name__ == "__main__":

    device = torch.device("cuda")
    
    dataTransforms = transforms.Compose([
            transforms.ToTensor()
            , transforms.RandomCrop(32, padding=4)  
            , transforms.RandomHorizontalFlip(p=0.5) 
            , transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  
        ])

    trainset = CIFAR10(root='./data', train=True, download=True, transform=dataTransforms)
    trainLoader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)
 
    model = torchvision.models.resnet18(pretrained=False)
    model.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=True)  
    model.maxpool = nn.MaxPool2d(1, 1, 0) 
    model.fc = nn.Linear(model.fc.in_features, 10)
 
    model.to(device)

    # 将模型包装成 DataParallel
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    cross = nn.CrossEntropyLoss()
    cross.to(device)

    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

    start = time.time()
    for epoch in range(10):
   
        model.train()  

        correctSum = 0.0
        lossSum = 0.0
        dataLen = 0

        for inputs, labels in trainLoader:
            inputs = inputs.to(device)
            labels = labels.to(device)
 
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = cross(outputs, labels)
 
            _, preds = torch.max(outputs, 1)  
 
            loss.backward() 
            optimizer.step()  
 
            correct = (preds == labels).sum() 
            correctSum +=correct
            lossSum += loss.item()
            dataLen +=inputs.size(0)
        
        print(lossSum/dataLen, correctSum/dataLen)

    timeElapsed = time.time() - start
    print('耗时 {:.0f}m {:.0f}s'.format(timeElapsed // 60, timeElapsed % 60))
posted @ 2024-03-31 17:53  Dsp Tian  阅读(11)  评论(0编辑  收藏  举报