feat : Update code, new args

This commit is contained in:
kiennt
2025-08-14 09:26:37 +00:00
parent 2111d9c52c
commit 34b17b0280
7 changed files with 13 additions and 9 deletions

View File

@@ -118,7 +118,7 @@ def masks_to_boxes(masks):
y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x)
y, x = torch.meshgrid(y, x, indexing="ij")
x_mask = masks * x.unsqueeze(0)
x_max = x_mask.flatten(1).max(-1)[0]

View File

@@ -63,6 +63,7 @@ def predict(
model = model.to(device)
image = image.to(device)
model.eval()
with torch.no_grad():
outputs = model(image[None], captions=[caption])
@@ -76,10 +77,10 @@ def predict(
tokenizer = model.tokenizer
tokenized = tokenizer(caption)
if remove_combined:
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
phrases = []
for logit in logits:
max_idx = logit.argmax()