feat : Update code, new args

This commit is contained in:
kiennt
2025-08-14 09:26:37 +00:00
parent 2111d9c52c
commit 34b17b0280
7 changed files with 13 additions and 9 deletions

View File

@@ -63,6 +63,7 @@ def predict(
model = model.to(device)
image = image.to(device)
model.eval()
with torch.no_grad():
outputs = model(image[None], captions=[caption])
@@ -76,10 +77,10 @@ def predict(
tokenizer = model.tokenizer
tokenized = tokenizer(caption)
if remove_combined:
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
phrases = []
for logit in logits:
max_idx = logit.argmax()