init commit of samurai
This commit is contained in:
15
lib/train/admin/multigpu.py
Normal file
15
lib/train/admin/multigpu.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch.nn as nn
|
||||
# Here we use DistributedDataParallel(DDP) rather than DataParallel(DP) for multiple GPUs training
|
||||
|
||||
|
||||
def is_multi_gpu(net):
|
||||
return isinstance(net, (MultiGPU, nn.parallel.distributed.DistributedDataParallel))
|
||||
|
||||
|
||||
class MultiGPU(nn.parallel.distributed.DistributedDataParallel):
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return super().__getattr__(item)
|
||||
except:
|
||||
pass
|
||||
return getattr(self.module, item)
|
Reference in New Issue
Block a user