60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
import numpy as np
|
|
|
|
|
|
############## used for visulize eliminated tokens #################
|
|
def get_keep_indices(decisions):
|
|
keep_indices = []
|
|
for i in range(3):
|
|
if i == 0:
|
|
keep_indices.append(decisions[i])
|
|
else:
|
|
keep_indices.append(keep_indices[-1][decisions[i]])
|
|
return keep_indices
|
|
|
|
|
|
def gen_masked_tokens(tokens, indices, alpha=0.2):
|
|
# indices = [i for i in range(196) if i not in indices]
|
|
indices = indices[0].astype(int)
|
|
tokens = tokens.copy()
|
|
tokens[indices] = alpha * tokens[indices] + (1 - alpha) * 255
|
|
return tokens
|
|
|
|
|
|
def recover_image(tokens, H, W, Hp, Wp, patch_size):
|
|
# image: (C, 196, 16, 16)
|
|
image = tokens.reshape(Hp, Wp, patch_size, patch_size, 3).swapaxes(1, 2).reshape(H, W, 3)
|
|
return image
|
|
|
|
|
|
def pad_img(img):
|
|
height, width, channels = img.shape
|
|
im_bg = np.ones((height, width + 8, channels)) * 255
|
|
im_bg[0:height, 0:width, :] = img
|
|
return im_bg
|
|
|
|
|
|
def gen_visualization(image, mask_indices, patch_size=16):
|
|
# image [224, 224, 3]
|
|
# mask_indices, list of masked token indices
|
|
|
|
# mask mask_indices need to cat
|
|
# mask_indices = mask_indices[::-1]
|
|
num_stages = len(mask_indices)
|
|
for i in range(1, num_stages):
|
|
mask_indices[i] = np.concatenate([mask_indices[i-1], mask_indices[i]], axis=1)
|
|
|
|
# keep_indices = get_keep_indices(decisions)
|
|
image = np.asarray(image)
|
|
H, W, C = image.shape
|
|
Hp, Wp = H // patch_size, W // patch_size
|
|
image_tokens = image.reshape(Hp, patch_size, Wp, patch_size, 3).swapaxes(1, 2).reshape(Hp * Wp, patch_size, patch_size, 3)
|
|
|
|
stages = [
|
|
recover_image(gen_masked_tokens(image_tokens, mask_indices[i]), H, W, Hp, Wp, patch_size)
|
|
for i in range(num_stages)
|
|
]
|
|
imgs = [image] + stages
|
|
imgs = [pad_img(img) for img in imgs]
|
|
viz = np.concatenate(imgs, axis=1)
|
|
return viz
|