16 lines
467 B
Python
16 lines
467 B
Python
import torch.nn as nn
|
|
# Here we use DistributedDataParallel(DDP) rather than DataParallel(DP) for multiple GPUs training
|
|
|
|
|
|
def is_multi_gpu(net):
|
|
return isinstance(net, (MultiGPU, nn.parallel.distributed.DistributedDataParallel))
|
|
|
|
|
|
class MultiGPU(nn.parallel.distributed.DistributedDataParallel):
|
|
def __getattr__(self, item):
|
|
try:
|
|
return super().__getattr__(item)
|
|
except:
|
|
pass
|
|
return getattr(self.module, item)
|