feat : Update code, new args
This commit is contained in:
@@ -63,6 +63,7 @@ def predict(
|
||||
|
||||
model = model.to(device)
|
||||
image = image.to(device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(image[None], captions=[caption])
|
||||
@@ -76,10 +77,10 @@ def predict(
|
||||
|
||||
tokenizer = model.tokenizer
|
||||
tokenized = tokenizer(caption)
|
||||
|
||||
|
||||
if remove_combined:
|
||||
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
|
||||
|
||||
|
||||
phrases = []
|
||||
for logit in logits:
|
||||
max_idx = logit.argmax()
|
||||
|
Reference in New Issue
Block a user