Fix demos for CPU inference (#104)
This commit is contained in:
@@ -44,7 +44,7 @@ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
# use bfloat16
|
||||
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
||||
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
|
||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
Reference in New Issue
Block a user