feat : Update code, new args
This commit is contained in:
@@ -623,7 +623,7 @@ class Trainer:
|
||||
|
||||
# compute output
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(
|
||||
with torch.amp.autocast("cuda",
|
||||
enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
|
||||
dtype=(
|
||||
get_amp_type(self.optim_conf.amp.amp_dtype)
|
||||
@@ -858,7 +858,8 @@ class Trainer:
|
||||
# grads will also update a model even if the step doesn't produce
|
||||
# gradients
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
with torch.cuda.amp.autocast(
|
||||
with torch.amp.autocast(
|
||||
"cuda",
|
||||
enabled=self.optim_conf.amp.enabled,
|
||||
dtype=get_amp_type(self.optim_conf.amp.amp_dtype),
|
||||
):
|
||||
|
Reference in New Issue
Block a user