fix: for CUDA version >= 12.6

This commit is contained in:
kiennt
2025-08-16 09:57:17 +00:00
parent e1420f9335
commit 546f444c1c
10 changed files with 1453 additions and 23 deletions

View File

@@ -16,7 +16,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.layers import DropPath, to_2tuple, trunc_normal_
from groundingdino.util.misc import NestedTensor
@@ -445,7 +445,7 @@ class BasicLayer(nn.Module):
for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
x = checkpoint.checkpoint(blk, x, attn_mask, use_reentrant=True)
else:
x = blk(x, attn_mask)
if self.downsample is not None: