fix: for CUDA version >= 12.6
This commit is contained in:
@@ -554,6 +554,7 @@ class TransformerEncoder(nn.Module):
|
||||
memory_text,
|
||||
key_padding_mask,
|
||||
text_attention_mask,
|
||||
use_reentrant=True,
|
||||
)
|
||||
else:
|
||||
output, memory_text = self.fusion_layers[layer_id](
|
||||
@@ -581,6 +582,7 @@ class TransformerEncoder(nn.Module):
|
||||
spatial_shapes,
|
||||
level_start_index,
|
||||
key_padding_mask,
|
||||
use_reentrant=True,
|
||||
)
|
||||
else:
|
||||
output = layer(
|
||||
@@ -859,7 +861,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_ffn(self, tgt):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout4(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
|
Reference in New Issue
Block a user