feat : Update code, new args

This commit is contained in:
kiennt
2025-08-14 09:26:37 +00:00
parent 2111d9c52c
commit 34b17b0280
7 changed files with 13 additions and 9 deletions

View File

@@ -16,7 +16,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.layers import DropPath, to_2tuple, trunc_normal_
from grounding_dino.groundingdino.util.misc import NestedTensor from grounding_dino.groundingdino.util.misc import NestedTensor
@@ -113,7 +113,7 @@ class WindowAttention(nn.Module):
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0]) coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1]) coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2

View File

@@ -8,7 +8,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.models.layers import DropPath from timm.layers import DropPath
class FeatureResizer(nn.Module): class FeatureResizer(nn.Module):

View File

@@ -470,6 +470,7 @@ class TransformerEncoder(nn.Module):
ref_y, ref_x = torch.meshgrid( ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
indexing="ij"
) )
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
@@ -859,7 +860,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
return tensor if pos is None else tensor + pos return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt): def forward_ffn(self, tgt):
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2) tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt) tgt = self.norm3(tgt)

View File

@@ -79,6 +79,7 @@ def gen_encoder_output_proposals(
grid_y, grid_x = torch.meshgrid( grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device), torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
indexing="ij"
) )
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2 grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2

View File

@@ -118,7 +118,7 @@ def masks_to_boxes(masks):
y = torch.arange(0, h, dtype=torch.float) y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float) x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x) y, x = torch.meshgrid(y, x, indexing="ij")
x_mask = masks * x.unsqueeze(0) x_mask = masks * x.unsqueeze(0)
x_max = x_mask.flatten(1).max(-1)[0] x_max = x_mask.flatten(1).max(-1)[0]

View File

@@ -63,6 +63,7 @@ def predict(
model = model.to(device) model = model.to(device)
image = image.to(device) image = image.to(device)
model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(image[None], captions=[caption]) outputs = model(image[None], captions=[caption])

View File

@@ -623,7 +623,7 @@ class Trainer:
# compute output # compute output
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast( with torch.amp.autocast("cuda",
enabled=(self.optim_conf.amp.enabled if self.optim_conf else False), enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
dtype=( dtype=(
get_amp_type(self.optim_conf.amp.amp_dtype) get_amp_type(self.optim_conf.amp.amp_dtype)
@@ -858,7 +858,8 @@ class Trainer:
# grads will also update a model even if the step doesn't produce # grads will also update a model even if the step doesn't produce
# gradients # gradients
self.optim.zero_grad(set_to_none=True) self.optim.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast( with torch.amp.autocast(
"cuda",
enabled=self.optim_conf.amp.enabled, enabled=self.optim_conf.amp.enabled,
dtype=get_amp_type(self.optim_conf.amp.amp_dtype), dtype=get_amp_type(self.optim_conf.amp.amp_dtype),
): ):