PyTorch并行训练

下面代码实现torch多卡同时训练

1
2
3
4
5
6
7
8
9
import pytorch as torch
from torch.utils.data.distributed import DistributedSample
torch.distributed.init_process_group(backend='nccl')#初始化并行训练
n_gpu = torch.cuda.device_count()#统计gpu数量
model = torch.nn.DataParallel(model)#多卡部署模型
data = TensorDataset(data)
data = DistributedSample(data)#分布训练
loss = model(data)
loss = loss.mean()#多卡取平均loss

PyTorch并行训练
http://example.com/2023/03/03/PyTorch并行训练/
作者
ZHUHAI
发布于
2023年3月3日
许可协议