import pytorch as torch from torch.utils.data.distributed import DistributedSample torch.distributed.init_process_group(backend='nccl') n_gpu = torch.cuda.device_count() model = torch.nn.DataParallel(model) data = TensorDataset(data) data = DistributedSample(data) loss = model(data) loss = loss.mean()
|