16 lines
467 B
Python
16 lines
467 B
Python
![]() |
import torch.nn as nn
|
||
|
# Here we use DistributedDataParallel(DDP) rather than DataParallel(DP) for multiple GPUs training
|
||
|
|
||
|
|
||
|
def is_multi_gpu(net):
|
||
|
return isinstance(net, (MultiGPU, nn.parallel.distributed.DistributedDataParallel))
|
||
|
|
||
|
|
||
|
class MultiGPU(nn.parallel.distributed.DistributedDataParallel):
|
||
|
def __getattr__(self, item):
|
||
|
try:
|
||
|
return super().__getattr__(item)
|
||
|
except:
|
||
|
pass
|
||
|
return getattr(self.module, item)
|