init commit of samurai
This commit is contained in:
29
lib/utils/merge.py
Normal file
29
lib/utils/merge.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
|
||||
|
||||
def merge_template_search(inp_list, return_search=False, return_template=False):
|
||||
"""NOTICE: search region related features must be in the last place"""
|
||||
seq_dict = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0),
|
||||
"mask": torch.cat([x["mask"] for x in inp_list], dim=1),
|
||||
"pos": torch.cat([x["pos"] for x in inp_list], dim=0)}
|
||||
if return_search:
|
||||
x = inp_list[-1]
|
||||
seq_dict.update({"feat_x": x["feat"], "mask_x": x["mask"], "pos_x": x["pos"]})
|
||||
if return_template:
|
||||
z = inp_list[0]
|
||||
seq_dict.update({"feat_z": z["feat"], "mask_z": z["mask"], "pos_z": z["pos"]})
|
||||
return seq_dict
|
||||
|
||||
|
||||
def get_qkv(inp_list):
|
||||
"""The 1st element of the inp_list is about the template,
|
||||
the 2nd (the last) element is about the search region"""
|
||||
dict_x = inp_list[-1]
|
||||
dict_c = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0),
|
||||
"mask": torch.cat([x["mask"] for x in inp_list], dim=1),
|
||||
"pos": torch.cat([x["pos"] for x in inp_list], dim=0)} # concatenated dict
|
||||
q = dict_x["feat"] + dict_x["pos"]
|
||||
k = dict_c["feat"] + dict_c["pos"]
|
||||
v = dict_c["feat"]
|
||||
key_padding_mask = dict_c["mask"]
|
||||
return q, k, v, key_padding_mask
|
Reference in New Issue
Block a user