feat : Update code, new args
This commit is contained in:
@@ -16,7 +16,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
from grounding_dino.groundingdino.util.misc import NestedTensor
|
||||
|
||||
@@ -113,7 +113,7 @@ class WindowAttention(nn.Module):
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
|
@@ -8,7 +8,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from timm.models.layers import DropPath
|
||||
from timm.layers import DropPath
|
||||
|
||||
|
||||
class FeatureResizer(nn.Module):
|
||||
|
@@ -470,6 +470,7 @@ class TransformerEncoder(nn.Module):
|
||||
ref_y, ref_x = torch.meshgrid(
|
||||
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
||||
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
|
||||
indexing="ij"
|
||||
)
|
||||
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
|
||||
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
|
||||
@@ -859,7 +860,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_ffn(self, tgt):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout4(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
|
@@ -79,6 +79,7 @@ def gen_encoder_output_proposals(
|
||||
grid_y, grid_x = torch.meshgrid(
|
||||
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
|
||||
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
|
||||
indexing="ij"
|
||||
)
|
||||
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
|
||||
|
||||
|
@@ -118,7 +118,7 @@ def masks_to_boxes(masks):
|
||||
|
||||
y = torch.arange(0, h, dtype=torch.float)
|
||||
x = torch.arange(0, w, dtype=torch.float)
|
||||
y, x = torch.meshgrid(y, x)
|
||||
y, x = torch.meshgrid(y, x, indexing="ij")
|
||||
|
||||
x_mask = masks * x.unsqueeze(0)
|
||||
x_max = x_mask.flatten(1).max(-1)[0]
|
||||
|
@@ -63,6 +63,7 @@ def predict(
|
||||
|
||||
model = model.to(device)
|
||||
image = image.to(device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(image[None], captions=[caption])
|
||||
@@ -76,10 +77,10 @@ def predict(
|
||||
|
||||
tokenizer = model.tokenizer
|
||||
tokenized = tokenizer(caption)
|
||||
|
||||
|
||||
if remove_combined:
|
||||
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
|
||||
|
||||
|
||||
phrases = []
|
||||
for logit in logits:
|
||||
max_idx = logit.argmax()
|
||||
|
Reference in New Issue
Block a user