feat : Update code, new args

This commit is contained in:
kiennt
2025-08-14 09:26:37 +00:00
parent 2111d9c52c
commit 34b17b0280
7 changed files with 13 additions and 9 deletions

View File

@@ -623,7 +623,7 @@ class Trainer:
# compute output
with torch.no_grad():
with torch.cuda.amp.autocast(
with torch.amp.autocast("cuda",
enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
dtype=(
get_amp_type(self.optim_conf.amp.amp_dtype)
@@ -858,7 +858,8 @@ class Trainer:
# grads will also update a model even if the step doesn't produce
# gradients
self.optim.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast(
with torch.amp.autocast(
"cuda",
enabled=self.optim_conf.amp.enabled,
dtype=get_amp_type(self.optim_conf.amp.amp_dtype),
):