init commit of samurai
This commit is contained in:
44
lib/train/actors/base_actor.py
Normal file
44
lib/train/actors/base_actor.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from lib.utils import TensorDict
|
||||
|
||||
|
||||
class BaseActor:
|
||||
""" Base class for actor. The actor class handles the passing of the data through the network
|
||||
and calculation the loss"""
|
||||
def __init__(self, net, objective):
|
||||
"""
|
||||
args:
|
||||
net - The network to train
|
||||
objective - The loss function
|
||||
"""
|
||||
self.net = net
|
||||
self.objective = objective
|
||||
|
||||
def __call__(self, data: TensorDict):
|
||||
""" Called in each training iteration. Should pass in input data through the network, calculate the loss, and
|
||||
return the training stats for the input data
|
||||
args:
|
||||
data - A TensorDict containing all the necessary data blocks.
|
||||
|
||||
returns:
|
||||
loss - loss for the input data
|
||||
stats - a dict containing detailed losses
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to(self, device):
|
||||
""" Move the network to device
|
||||
args:
|
||||
device - device to use. 'cpu' or 'cuda'
|
||||
"""
|
||||
self.net.to(device)
|
||||
|
||||
def train(self, mode=True):
|
||||
""" Set whether the network is in train mode.
|
||||
args:
|
||||
mode (True) - Bool specifying whether in training mode.
|
||||
"""
|
||||
self.net.train(mode)
|
||||
|
||||
def eval(self):
|
||||
""" Set network to eval mode"""
|
||||
self.train(False)
|
Reference in New Issue
Block a user