feat : Update code, new args
This commit is contained in:
@@ -118,7 +118,7 @@ def masks_to_boxes(masks):
|
||||
|
||||
y = torch.arange(0, h, dtype=torch.float)
|
||||
x = torch.arange(0, w, dtype=torch.float)
|
||||
y, x = torch.meshgrid(y, x)
|
||||
y, x = torch.meshgrid(y, x, indexing="ij")
|
||||
|
||||
x_mask = masks * x.unsqueeze(0)
|
||||
x_max = x_mask.flatten(1).max(-1)[0]
|
||||
|
@@ -63,6 +63,7 @@ def predict(
|
||||
|
||||
model = model.to(device)
|
||||
image = image.to(device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(image[None], captions=[caption])
|
||||
@@ -76,10 +77,10 @@ def predict(
|
||||
|
||||
tokenizer = model.tokenizer
|
||||
tokenized = tokenizer(caption)
|
||||
|
||||
|
||||
if remove_combined:
|
||||
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
|
||||
|
||||
|
||||
phrases = []
|
||||
for logit in logits:
|
||||
max_idx = logit.argmax()
|
||||
|
Reference in New Issue
Block a user