Files
Grounded-SAM-2/lib/test/tracker/vis_utils.py

60 lines
1.9 KiB
Python
Raw Normal View History

2024-11-19 22:12:54 -08:00
import numpy as np
############## used for visulize eliminated tokens #################
def get_keep_indices(decisions):
keep_indices = []
for i in range(3):
if i == 0:
keep_indices.append(decisions[i])
else:
keep_indices.append(keep_indices[-1][decisions[i]])
return keep_indices
def gen_masked_tokens(tokens, indices, alpha=0.2):
# indices = [i for i in range(196) if i not in indices]
indices = indices[0].astype(int)
tokens = tokens.copy()
tokens[indices] = alpha * tokens[indices] + (1 - alpha) * 255
return tokens
def recover_image(tokens, H, W, Hp, Wp, patch_size):
# image: (C, 196, 16, 16)
image = tokens.reshape(Hp, Wp, patch_size, patch_size, 3).swapaxes(1, 2).reshape(H, W, 3)
return image
def pad_img(img):
height, width, channels = img.shape
im_bg = np.ones((height, width + 8, channels)) * 255
im_bg[0:height, 0:width, :] = img
return im_bg
def gen_visualization(image, mask_indices, patch_size=16):
# image [224, 224, 3]
# mask_indices, list of masked token indices
# mask mask_indices need to cat
# mask_indices = mask_indices[::-1]
num_stages = len(mask_indices)
for i in range(1, num_stages):
mask_indices[i] = np.concatenate([mask_indices[i-1], mask_indices[i]], axis=1)
# keep_indices = get_keep_indices(decisions)
image = np.asarray(image)
H, W, C = image.shape
Hp, Wp = H // patch_size, W // patch_size
image_tokens = image.reshape(Hp, patch_size, Wp, patch_size, 3).swapaxes(1, 2).reshape(Hp * Wp, patch_size, patch_size, 3)
stages = [
recover_image(gen_masked_tokens(image_tokens, mask_indices[i]), H, W, Hp, Wp, patch_size)
for i in range(num_stages)
]
imgs = [image] + stages
imgs = [pad_img(img) for img in imgs]
viz = np.concatenate(imgs, axis=1)
return viz