init commit of samurai
This commit is contained in:
1
lib/utils/__init__.py
Normal file
1
lib/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .tensor import TensorDict, TensorList
|
106
lib/utils/box_ops.py
Normal file
106
lib/utils/box_ops.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import torch
|
||||
from torchvision.ops.boxes import box_area
|
||||
import numpy as np
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
||||
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xywh_to_xyxy(x):
|
||||
x1, y1, w, h = x.unbind(-1)
|
||||
b = [x1, y1, x1 + w, y1 + h]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_xywh(x):
|
||||
x1, y1, x2, y2 = x.unbind(-1)
|
||||
b = [x1, y1, x2 - x1, y2 - y1]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_cxcywh(x):
|
||||
x0, y0, x1, y1 = x.unbind(-1)
|
||||
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
||||
(x1 - x0), (y1 - y0)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# modified from torchvision to also return the union
|
||||
'''Note that this function only supports shape (N,4)'''
|
||||
|
||||
|
||||
def box_iou(boxes1, boxes2):
|
||||
"""
|
||||
|
||||
:param boxes1: (N, 4) (x1,y1,x2,y2)
|
||||
:param boxes2: (N, 4) (x1,y1,x2,y2)
|
||||
:return:
|
||||
"""
|
||||
area1 = box_area(boxes1) # (N,)
|
||||
area2 = box_area(boxes2) # (N,)
|
||||
|
||||
lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (N,2)
|
||||
rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (N,2)
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # (N,2)
|
||||
inter = wh[:, 0] * wh[:, 1] # (N,)
|
||||
|
||||
union = area1 + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
'''Note that this implementation is different from DETR's'''
|
||||
|
||||
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/
|
||||
|
||||
The boxes should be in [x0, y0, x1, y1] format
|
||||
|
||||
boxes1: (N, 4)
|
||||
boxes2: (N, 4)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
# try:
|
||||
#assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
||||
# assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
||||
iou, union = box_iou(boxes1, boxes2) # (N,)
|
||||
|
||||
lt = torch.min(boxes1[:, :2], boxes2[:, :2])
|
||||
rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # (N,2)
|
||||
area = wh[:, 0] * wh[:, 1] # (N,)
|
||||
|
||||
return iou - (area - union) / area, iou
|
||||
|
||||
|
||||
def giou_loss(boxes1, boxes2):
|
||||
"""
|
||||
|
||||
:param boxes1: (N, 4) (x1,y1,x2,y2)
|
||||
:param boxes2: (N, 4) (x1,y1,x2,y2)
|
||||
:return:
|
||||
"""
|
||||
giou, iou = generalized_box_iou(boxes1, boxes2)
|
||||
return (1 - giou).mean(), iou
|
||||
|
||||
|
||||
def clip_box(box: list, H, W, margin=0):
|
||||
x1, y1, w, h = box
|
||||
x2, y2 = x1 + w, y1 + h
|
||||
x1 = min(max(0, x1), W-margin)
|
||||
x2 = min(max(margin, x2), W)
|
||||
y1 = min(max(0, y1), H-margin)
|
||||
y2 = min(max(margin, y2), H)
|
||||
w = max(margin, x2-x1)
|
||||
h = max(margin, y2-y1)
|
||||
return [x1, y1, w, h]
|
80
lib/utils/ce_utils.py
Normal file
80
lib/utils/ce_utils.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def generate_bbox_mask(bbox_mask, bbox):
|
||||
b, h, w = bbox_mask.shape
|
||||
for i in range(b):
|
||||
bbox_i = bbox[i].cpu().tolist()
|
||||
bbox_mask[i, int(bbox_i[1]):int(bbox_i[1] + bbox_i[3] - 1), int(bbox_i[0]):int(bbox_i[0] + bbox_i[2] - 1)] = 1
|
||||
return bbox_mask
|
||||
|
||||
|
||||
def generate_mask_cond(cfg, bs, device, gt_bbox):
|
||||
template_size = cfg.DATA.TEMPLATE.SIZE
|
||||
stride = cfg.MODEL.BACKBONE.STRIDE
|
||||
template_feat_size = template_size // stride
|
||||
|
||||
if cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'ALL':
|
||||
box_mask_z = None
|
||||
elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_POINT':
|
||||
if template_feat_size == 8:
|
||||
index = slice(3, 4)
|
||||
elif template_feat_size == 12:
|
||||
index = slice(5, 6)
|
||||
elif template_feat_size == 7:
|
||||
index = slice(3, 4)
|
||||
elif template_feat_size == 14:
|
||||
index = slice(6, 7)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device)
|
||||
box_mask_z[:, index, index] = 1
|
||||
box_mask_z = box_mask_z.flatten(1).to(torch.bool)
|
||||
elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_REC':
|
||||
# use fixed 4x4 region, 3:5 for 8x8
|
||||
# use fixed 4x4 region 5:6 for 12x12
|
||||
if template_feat_size == 8:
|
||||
index = slice(3, 5)
|
||||
elif template_feat_size == 12:
|
||||
index = slice(5, 7)
|
||||
elif template_feat_size == 7:
|
||||
index = slice(3, 4)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device)
|
||||
box_mask_z[:, index, index] = 1
|
||||
box_mask_z = box_mask_z.flatten(1).to(torch.bool)
|
||||
|
||||
elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'GT_BOX':
|
||||
box_mask_z = torch.zeros([bs, template_size, template_size], device=device)
|
||||
# box_mask_z_ori = data['template_seg'][0].view(-1, 1, *data['template_seg'].shape[2:]) # (batch, 1, 128, 128)
|
||||
box_mask_z = generate_bbox_mask(box_mask_z, gt_bbox * template_size).unsqueeze(1).to(
|
||||
torch.float) # (batch, 1, 128, 128)
|
||||
# box_mask_z_vis = box_mask_z.cpu().numpy()
|
||||
box_mask_z = F.interpolate(box_mask_z, scale_factor=1. / cfg.MODEL.BACKBONE.STRIDE, mode='bilinear',
|
||||
align_corners=False)
|
||||
box_mask_z = box_mask_z.flatten(1).to(torch.bool)
|
||||
# box_mask_z_vis = box_mask_z[:, 0, ...].cpu().numpy()
|
||||
# gaussian_maps_vis = generate_heatmap(data['template_anno'], self.cfg.DATA.TEMPLATE.SIZE, self.cfg.MODEL.STRIDE)[0].cpu().numpy()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return box_mask_z
|
||||
|
||||
|
||||
def adjust_keep_rate(epoch, warmup_epochs, total_epochs, ITERS_PER_EPOCH, base_keep_rate=0.5, max_keep_rate=1, iters=-1):
|
||||
if epoch < warmup_epochs:
|
||||
return 1
|
||||
if epoch >= total_epochs:
|
||||
return base_keep_rate
|
||||
if iters == -1:
|
||||
iters = epoch * ITERS_PER_EPOCH
|
||||
total_iters = ITERS_PER_EPOCH * (total_epochs - warmup_epochs)
|
||||
iters = iters - ITERS_PER_EPOCH * warmup_epochs
|
||||
keep_rate = base_keep_rate + (max_keep_rate - base_keep_rate) \
|
||||
* (math.cos(iters / total_iters * math.pi) + 1) * 0.5
|
||||
|
||||
return keep_rate
|
63
lib/utils/focal_loss.py
Normal file
63
lib/utils/focal_loss.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FocalLoss(nn.Module, ABC):
|
||||
def __init__(self, alpha=2, beta=4):
|
||||
super(FocalLoss, self).__init__()
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
|
||||
def forward(self, prediction, target):
|
||||
positive_index = target.eq(1).float()
|
||||
negative_index = target.lt(1).float()
|
||||
|
||||
negative_weights = torch.pow(1 - target, self.beta)
|
||||
# clamp min value is set to 1e-12 to maintain the numerical stability
|
||||
prediction = torch.clamp(prediction, 1e-12)
|
||||
|
||||
positive_loss = torch.log(prediction) * torch.pow(1 - prediction, self.alpha) * positive_index
|
||||
negative_loss = torch.log(1 - prediction) * torch.pow(prediction,
|
||||
self.alpha) * negative_weights * negative_index
|
||||
|
||||
num_positive = positive_index.float().sum()
|
||||
positive_loss = positive_loss.sum()
|
||||
negative_loss = negative_loss.sum()
|
||||
|
||||
if num_positive == 0:
|
||||
loss = -negative_loss
|
||||
else:
|
||||
loss = -(positive_loss + negative_loss) / num_positive
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class LBHinge(nn.Module):
|
||||
"""Loss that uses a 'hinge' on the lower bound.
|
||||
This means that for samples with a label value smaller than the threshold, the loss is zero if the prediction is
|
||||
also smaller than that threshold.
|
||||
args:
|
||||
error_matric: What base loss to use (MSE by default).
|
||||
threshold: Threshold to use for the hinge.
|
||||
clip: Clip the loss if it is above this value.
|
||||
"""
|
||||
def __init__(self, error_metric=nn.MSELoss(), threshold=None, clip=None):
|
||||
super().__init__()
|
||||
self.error_metric = error_metric
|
||||
self.threshold = threshold if threshold is not None else -100
|
||||
self.clip = clip
|
||||
|
||||
def forward(self, prediction, label, target_bb=None):
|
||||
negative_mask = (label < self.threshold).float()
|
||||
positive_mask = (1.0 - negative_mask)
|
||||
|
||||
prediction = negative_mask * F.relu(prediction) + positive_mask * prediction
|
||||
|
||||
loss = self.error_metric(prediction, positive_mask * label)
|
||||
|
||||
if self.clip is not None:
|
||||
loss = torch.min(loss, torch.tensor([self.clip], device=loss.device))
|
||||
return loss
|
150
lib/utils/heapmap_utils.py
Normal file
150
lib/utils/heapmap_utils.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def generate_heatmap(bboxes, patch_size=320, stride=16):
|
||||
"""
|
||||
Generate ground truth heatmap same as CenterNet
|
||||
Args:
|
||||
bboxes (torch.Tensor): shape of [num_search, bs, 4]
|
||||
|
||||
Returns:
|
||||
gaussian_maps: list of generated heatmap
|
||||
|
||||
"""
|
||||
gaussian_maps = []
|
||||
heatmap_size = patch_size // stride
|
||||
for single_patch_bboxes in bboxes:
|
||||
bs = single_patch_bboxes.shape[0]
|
||||
gt_scoremap = torch.zeros(bs, heatmap_size, heatmap_size)
|
||||
classes = torch.arange(bs).to(torch.long)
|
||||
bbox = single_patch_bboxes * heatmap_size
|
||||
wh = bbox[:, 2:]
|
||||
centers_int = (bbox[:, :2] + wh / 2).round()
|
||||
CenterNetHeatMap.generate_score_map(gt_scoremap, classes, wh, centers_int, 0.7)
|
||||
gaussian_maps.append(gt_scoremap.to(bbox.device))
|
||||
return gaussian_maps
|
||||
|
||||
|
||||
class CenterNetHeatMap(object):
|
||||
@staticmethod
|
||||
def generate_score_map(fmap, gt_class, gt_wh, centers_int, min_overlap):
|
||||
radius = CenterNetHeatMap.get_gaussian_radius(gt_wh, min_overlap)
|
||||
radius = torch.clamp_min(radius, 0)
|
||||
radius = radius.type(torch.int).cpu().numpy()
|
||||
for i in range(gt_class.shape[0]):
|
||||
channel_index = gt_class[i]
|
||||
CenterNetHeatMap.draw_gaussian(fmap[channel_index], centers_int[i], radius[i])
|
||||
|
||||
@staticmethod
|
||||
def get_gaussian_radius(box_size, min_overlap):
|
||||
"""
|
||||
copyed from CornerNet
|
||||
box_size (w, h), it could be a torch.Tensor, numpy.ndarray, list or tuple
|
||||
notice: we are using a bug-version, please refer to fix bug version in CornerNet
|
||||
"""
|
||||
# box_tensor = torch.Tensor(box_size)
|
||||
box_tensor = box_size
|
||||
width, height = box_tensor[..., 0], box_tensor[..., 1]
|
||||
|
||||
a1 = 1
|
||||
b1 = height + width
|
||||
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
|
||||
sq1 = torch.sqrt(b1 ** 2 - 4 * a1 * c1)
|
||||
r1 = (b1 + sq1) / 2
|
||||
|
||||
a2 = 4
|
||||
b2 = 2 * (height + width)
|
||||
c2 = (1 - min_overlap) * width * height
|
||||
sq2 = torch.sqrt(b2 ** 2 - 4 * a2 * c2)
|
||||
r2 = (b2 + sq2) / 2
|
||||
|
||||
a3 = 4 * min_overlap
|
||||
b3 = -2 * min_overlap * (height + width)
|
||||
c3 = (min_overlap - 1) * width * height
|
||||
sq3 = torch.sqrt(b3 ** 2 - 4 * a3 * c3)
|
||||
r3 = (b3 + sq3) / 2
|
||||
|
||||
return torch.min(r1, torch.min(r2, r3))
|
||||
|
||||
@staticmethod
|
||||
def gaussian2D(radius, sigma=1):
|
||||
# m, n = [(s - 1.) / 2. for s in shape]
|
||||
m, n = radius
|
||||
y, x = np.ogrid[-m: m + 1, -n: n + 1]
|
||||
|
||||
gauss = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
|
||||
gauss[gauss < np.finfo(gauss.dtype).eps * gauss.max()] = 0
|
||||
return gauss
|
||||
|
||||
@staticmethod
|
||||
def draw_gaussian(fmap, center, radius, k=1):
|
||||
diameter = 2 * radius + 1
|
||||
gaussian = CenterNetHeatMap.gaussian2D((radius, radius), sigma=diameter / 6)
|
||||
gaussian = torch.Tensor(gaussian)
|
||||
x, y = int(center[0]), int(center[1])
|
||||
height, width = fmap.shape[:2]
|
||||
|
||||
left, right = min(x, radius), min(width - x, radius + 1)
|
||||
top, bottom = min(y, radius), min(height - y, radius + 1)
|
||||
|
||||
masked_fmap = fmap[y - top: y + bottom, x - left: x + right]
|
||||
masked_gaussian = gaussian[radius - top: radius + bottom, radius - left: radius + right]
|
||||
if min(masked_gaussian.shape) > 0 and min(masked_fmap.shape) > 0:
|
||||
masked_fmap = torch.max(masked_fmap, masked_gaussian * k)
|
||||
fmap[y - top: y + bottom, x - left: x + right] = masked_fmap
|
||||
# return fmap
|
||||
|
||||
|
||||
def compute_grids(features, strides):
|
||||
"""
|
||||
grids regret to the input image size
|
||||
"""
|
||||
grids = []
|
||||
for level, feature in enumerate(features):
|
||||
h, w = feature.size()[-2:]
|
||||
shifts_x = torch.arange(
|
||||
0, w * strides[level],
|
||||
step=strides[level],
|
||||
dtype=torch.float32, device=feature.device)
|
||||
shifts_y = torch.arange(
|
||||
0, h * strides[level],
|
||||
step=strides[level],
|
||||
dtype=torch.float32, device=feature.device)
|
||||
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
|
||||
shift_x = shift_x.reshape(-1)
|
||||
shift_y = shift_y.reshape(-1)
|
||||
grids_per_level = torch.stack((shift_x, shift_y), dim=1) + \
|
||||
strides[level] // 2
|
||||
grids.append(grids_per_level)
|
||||
return grids
|
||||
|
||||
|
||||
def get_center3x3(locations, centers, strides, range=3):
|
||||
'''
|
||||
Inputs:
|
||||
locations: M x 2
|
||||
centers: N x 2
|
||||
strides: M
|
||||
'''
|
||||
range = (range - 1) / 2
|
||||
M, N = locations.shape[0], centers.shape[0]
|
||||
locations_expanded = locations.view(M, 1, 2).expand(M, N, 2) # M x N x 2
|
||||
centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2
|
||||
strides_expanded = strides.view(M, 1, 1).expand(M, N, 2) # M x N
|
||||
centers_discret = ((centers_expanded / strides_expanded).int() * strides_expanded).float() + \
|
||||
strides_expanded / 2 # M x N x 2
|
||||
dist_x = (locations_expanded[:, :, 0] - centers_discret[:, :, 0]).abs()
|
||||
dist_y = (locations_expanded[:, :, 1] - centers_discret[:, :, 1]).abs()
|
||||
return (dist_x <= strides_expanded[:, :, 0] * range) & \
|
||||
(dist_y <= strides_expanded[:, :, 0] * range)
|
||||
|
||||
|
||||
def get_pred(score_map_ctr, size_map, offset_map, feat_size):
|
||||
max_score, idx = torch.max(score_map_ctr.flatten(1), dim=1, keepdim=True)
|
||||
|
||||
idx = idx.unsqueeze(1).expand(idx.shape[0], 2, 1)
|
||||
size = size_map.flatten(2).gather(dim=2, index=idx).squeeze(-1)
|
||||
offset = offset_map.flatten(2).gather(dim=2, index=idx).squeeze(-1)
|
||||
|
||||
return size * feat_size, offset
|
55
lib/utils/lmdb_utils.py
Normal file
55
lib/utils/lmdb_utils.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import lmdb
|
||||
import numpy as np
|
||||
import cv2
|
||||
import json
|
||||
|
||||
LMDB_ENVS = dict()
|
||||
LMDB_HANDLES = dict()
|
||||
LMDB_FILELISTS = dict()
|
||||
|
||||
|
||||
def get_lmdb_handle(name):
|
||||
global LMDB_HANDLES, LMDB_FILELISTS
|
||||
item = LMDB_HANDLES.get(name, None)
|
||||
if item is None:
|
||||
env = lmdb.open(name, readonly=True, lock=False, readahead=False, meminit=False)
|
||||
LMDB_ENVS[name] = env
|
||||
item = env.begin(write=False)
|
||||
LMDB_HANDLES[name] = item
|
||||
|
||||
return item
|
||||
|
||||
|
||||
def decode_img(lmdb_fname, key_name):
|
||||
handle = get_lmdb_handle(lmdb_fname)
|
||||
binfile = handle.get(key_name.encode())
|
||||
if binfile is None:
|
||||
print("Illegal data detected. %s %s" % (lmdb_fname, key_name))
|
||||
s = np.frombuffer(binfile, np.uint8)
|
||||
x = cv2.cvtColor(cv2.imdecode(s, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
|
||||
return x
|
||||
|
||||
|
||||
def decode_str(lmdb_fname, key_name):
|
||||
handle = get_lmdb_handle(lmdb_fname)
|
||||
binfile = handle.get(key_name.encode())
|
||||
string = binfile.decode()
|
||||
return string
|
||||
|
||||
|
||||
def decode_json(lmdb_fname, key_name):
|
||||
return json.loads(decode_str(lmdb_fname, key_name))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
lmdb_fname = "/data/sda/v-yanbi/iccv21/LittleBoy_clean/data/got10k_lmdb"
|
||||
'''Decode image'''
|
||||
# key_name = "test/GOT-10k_Test_000001/00000001.jpg"
|
||||
# img = decode_img(lmdb_fname, key_name)
|
||||
# cv2.imwrite("001.jpg", img)
|
||||
'''Decode str'''
|
||||
# key_name = "test/list.txt"
|
||||
# key_name = "train/GOT-10k_Train_000001/groundtruth.txt"
|
||||
key_name = "train/GOT-10k_Train_000001/absence.label"
|
||||
str_ = decode_str(lmdb_fname, key_name)
|
||||
print(str_)
|
29
lib/utils/merge.py
Normal file
29
lib/utils/merge.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
|
||||
|
||||
def merge_template_search(inp_list, return_search=False, return_template=False):
|
||||
"""NOTICE: search region related features must be in the last place"""
|
||||
seq_dict = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0),
|
||||
"mask": torch.cat([x["mask"] for x in inp_list], dim=1),
|
||||
"pos": torch.cat([x["pos"] for x in inp_list], dim=0)}
|
||||
if return_search:
|
||||
x = inp_list[-1]
|
||||
seq_dict.update({"feat_x": x["feat"], "mask_x": x["mask"], "pos_x": x["pos"]})
|
||||
if return_template:
|
||||
z = inp_list[0]
|
||||
seq_dict.update({"feat_z": z["feat"], "mask_z": z["mask"], "pos_z": z["pos"]})
|
||||
return seq_dict
|
||||
|
||||
|
||||
def get_qkv(inp_list):
|
||||
"""The 1st element of the inp_list is about the template,
|
||||
the 2nd (the last) element is about the search region"""
|
||||
dict_x = inp_list[-1]
|
||||
dict_c = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0),
|
||||
"mask": torch.cat([x["mask"] for x in inp_list], dim=1),
|
||||
"pos": torch.cat([x["pos"] for x in inp_list], dim=0)} # concatenated dict
|
||||
q = dict_x["feat"] + dict_x["pos"]
|
||||
k = dict_c["feat"] + dict_c["pos"]
|
||||
v = dict_c["feat"]
|
||||
key_padding_mask = dict_c["mask"]
|
||||
return q, k, v, key_padding_mask
|
468
lib/utils/misc.py
Normal file
468
lib/utils/misc.py
Normal file
@@ -0,0 +1,468 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
import pickle
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
||||
import torchvision
|
||||
vers = torchvision.__version__.split('.')
|
||||
if int(vers[0]) <= 0 and int(vers[1]) < 7:
|
||||
from torchvision.ops import _new_empty_tensor
|
||||
from torchvision.ops.misc import _output_size
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value)
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that all processes
|
||||
have the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
if average:
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
return reduced_dict
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, str(meter))
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ''
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
data_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}',
|
||||
'max mem: {memory:.0f}'
|
||||
])
|
||||
else:
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}'
|
||||
])
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB))
|
||||
else:
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time)))
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('{} Total time: {} ({:.4f} s / it)'.format(
|
||||
header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
||||
sha = 'N/A'
|
||||
diff = "clean"
|
||||
branch = 'N/A'
|
||||
try:
|
||||
sha = _run(['git', 'rev-parse', 'HEAD'])
|
||||
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
||||
diff = _run(['git', 'diff-index', 'HEAD'])
|
||||
diff = "has uncommited changes" if diff else "clean"
|
||||
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
batch = list(zip(*batch))
|
||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||
return tuple(batch)
|
||||
|
||||
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0] # get the first one
|
||||
for sublist in the_list[1:]: # [h,w,3]
|
||||
for index, item in enumerate(sublist): # index: 0,1,2
|
||||
maxes[index] = max(maxes[index], item) # compare current max with the other elements in the whole
|
||||
return maxes
|
||||
|
||||
|
||||
class NestedTensor(object):
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
# type: (Device) -> NestedTensor # noqa
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
assert mask is not None
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
# TODO make this more general
|
||||
if tensor_list[0].ndim == 3:
|
||||
if torchvision._is_tracing():
|
||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||||
|
||||
# TODO make it support different-sized images
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # [[3,h1,w1], [3,h2,w2], [3,h3,w3], ...]
|
||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||||
batch_shape = [len(tensor_list)] + max_size # ()
|
||||
b, c, h, w = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) # copy valid regions of the images to the largest padded base.
|
||||
m[: img.shape[1], :img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError('not supported')
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||||
@torch.jit.unused
|
||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
||||
max_size = []
|
||||
for i in range(tensor_list[0].dim()):
|
||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
||||
max_size.append(max_size_i)
|
||||
max_size = tuple(max_size)
|
||||
|
||||
# work around for
|
||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
# m[: img.shape[1], :img.shape[2]] = False
|
||||
# which is not yet supported in onnx
|
||||
padded_imgs = []
|
||||
padded_masks = []
|
||||
for img in tensor_list:
|
||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||
padded_imgs.append(padded_img)
|
||||
|
||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
||||
padded_masks.append(padded_mask.to(torch.bool))
|
||||
|
||||
tensor = torch.stack(padded_imgs)
|
||||
mask = torch.stack(padded_masks)
|
||||
|
||||
return NestedTensor(tensor, mask=mask)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop('force', False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print('Not using distributed mode')
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}'.format(
|
||||
args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
if target.numel() == 0:
|
||||
return [torch.zeros([], device=output.device)]
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
||||
"""
|
||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
||||
This will eventually be supported natively by PyTorch, and this
|
||||
class can go away.
|
||||
"""
|
||||
if float(torchvision.__version__[:3]) < 0.7:
|
||||
if input.numel() > 0:
|
||||
return torch.nn.functional.interpolate(
|
||||
input, size, scale_factor, mode, align_corners
|
||||
)
|
||||
|
||||
output_shape = _output_size(2, input, size, scale_factor)
|
||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
||||
return _new_empty_tensor(input, output_shape)
|
||||
else:
|
||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
244
lib/utils/tensor.py
Normal file
244
lib/utils/tensor.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import functools
|
||||
import torch
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class TensorDict(OrderedDict):
|
||||
"""Container mainly used for dicts of torch tensors. Extends OrderedDict with pytorch functionality."""
|
||||
|
||||
def concat(self, other):
|
||||
"""Concatenates two dicts without copying internal data."""
|
||||
return TensorDict(self, **other)
|
||||
|
||||
def copy(self):
|
||||
return TensorDict(super(TensorDict, self).copy())
|
||||
|
||||
def __deepcopy__(self, memodict={}):
|
||||
return TensorDict(copy.deepcopy(list(self), memodict))
|
||||
|
||||
def __getattr__(self, name):
|
||||
if not hasattr(torch.Tensor, name):
|
||||
raise AttributeError('\'TensorDict\' object has not attribute \'{}\''.format(name))
|
||||
|
||||
def apply_attr(*args, **kwargs):
|
||||
return TensorDict({n: getattr(e, name)(*args, **kwargs) if hasattr(e, name) else e for n, e in self.items()})
|
||||
return apply_attr
|
||||
|
||||
def attribute(self, attr: str, *args):
|
||||
return TensorDict({n: getattr(e, attr, *args) for n, e in self.items()})
|
||||
|
||||
def apply(self, fn, *args, **kwargs):
|
||||
return TensorDict({n: fn(e, *args, **kwargs) for n, e in self.items()})
|
||||
|
||||
@staticmethod
|
||||
def _iterable(a):
|
||||
return isinstance(a, (TensorDict, list))
|
||||
|
||||
|
||||
class TensorList(list):
|
||||
"""Container mainly used for lists of torch tensors. Extends lists with pytorch functionality."""
|
||||
|
||||
def __init__(self, list_of_tensors = None):
|
||||
if list_of_tensors is None:
|
||||
list_of_tensors = list()
|
||||
super(TensorList, self).__init__(list_of_tensors)
|
||||
|
||||
def __deepcopy__(self, memodict={}):
|
||||
return TensorList(copy.deepcopy(list(self), memodict))
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, int):
|
||||
return super(TensorList, self).__getitem__(item)
|
||||
elif isinstance(item, (tuple, list)):
|
||||
return TensorList([super(TensorList, self).__getitem__(i) for i in item])
|
||||
else:
|
||||
return TensorList(super(TensorList, self).__getitem__(item))
|
||||
|
||||
def __add__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
return TensorList([e1 + e2 for e1, e2 in zip(self, other)])
|
||||
return TensorList([e + other for e in self])
|
||||
|
||||
def __radd__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
return TensorList([e2 + e1 for e1, e2 in zip(self, other)])
|
||||
return TensorList([other + e for e in self])
|
||||
|
||||
def __iadd__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
for i, e2 in enumerate(other):
|
||||
self[i] += e2
|
||||
else:
|
||||
for i in range(len(self)):
|
||||
self[i] += other
|
||||
return self
|
||||
|
||||
def __sub__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
return TensorList([e1 - e2 for e1, e2 in zip(self, other)])
|
||||
return TensorList([e - other for e in self])
|
||||
|
||||
def __rsub__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
return TensorList([e2 - e1 for e1, e2 in zip(self, other)])
|
||||
return TensorList([other - e for e in self])
|
||||
|
||||
def __isub__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
for i, e2 in enumerate(other):
|
||||
self[i] -= e2
|
||||
else:
|
||||
for i in range(len(self)):
|
||||
self[i] -= other
|
||||
return self
|
||||
|
||||
def __mul__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
return TensorList([e1 * e2 for e1, e2 in zip(self, other)])
|
||||
return TensorList([e * other for e in self])
|
||||
|
||||
def __rmul__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
return TensorList([e2 * e1 for e1, e2 in zip(self, other)])
|
||||
return TensorList([other * e for e in self])
|
||||
|
||||
def __imul__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
for i, e2 in enumerate(other):
|
||||
self[i] *= e2
|
||||
else:
|
||||
for i in range(len(self)):
|
||||
self[i] *= other
|
||||
return self
|
||||
|
||||
def __truediv__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
return TensorList([e1 / e2 for e1, e2 in zip(self, other)])
|
||||
return TensorList([e / other for e in self])
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
return TensorList([e2 / e1 for e1, e2 in zip(self, other)])
|
||||
return TensorList([other / e for e in self])
|
||||
|
||||
def __itruediv__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
for i, e2 in enumerate(other):
|
||||
self[i] /= e2
|
||||
else:
|
||||
for i in range(len(self)):
|
||||
self[i] /= other
|
||||
return self
|
||||
|
||||
def __matmul__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
return TensorList([e1 @ e2 for e1, e2 in zip(self, other)])
|
||||
return TensorList([e @ other for e in self])
|
||||
|
||||
def __rmatmul__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
return TensorList([e2 @ e1 for e1, e2 in zip(self, other)])
|
||||
return TensorList([other @ e for e in self])
|
||||
|
||||
def __imatmul__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
for i, e2 in enumerate(other):
|
||||
self[i] @= e2
|
||||
else:
|
||||
for i in range(len(self)):
|
||||
self[i] @= other
|
||||
return self
|
||||
|
||||
def __mod__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
return TensorList([e1 % e2 for e1, e2 in zip(self, other)])
|
||||
return TensorList([e % other for e in self])
|
||||
|
||||
def __rmod__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
return TensorList([e2 % e1 for e1, e2 in zip(self, other)])
|
||||
return TensorList([other % e for e in self])
|
||||
|
||||
def __pos__(self):
|
||||
return TensorList([+e for e in self])
|
||||
|
||||
def __neg__(self):
|
||||
return TensorList([-e for e in self])
|
||||
|
||||
def __le__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
return TensorList([e1 <= e2 for e1, e2 in zip(self, other)])
|
||||
return TensorList([e <= other for e in self])
|
||||
|
||||
def __ge__(self, other):
|
||||
if TensorList._iterable(other):
|
||||
return TensorList([e1 >= e2 for e1, e2 in zip(self, other)])
|
||||
return TensorList([e >= other for e in self])
|
||||
|
||||
def concat(self, other):
|
||||
return TensorList(super(TensorList, self).__add__(other))
|
||||
|
||||
def copy(self):
|
||||
return TensorList(super(TensorList, self).copy())
|
||||
|
||||
def unroll(self):
|
||||
if not any(isinstance(t, TensorList) for t in self):
|
||||
return self
|
||||
|
||||
new_list = TensorList()
|
||||
for t in self:
|
||||
if isinstance(t, TensorList):
|
||||
new_list.extend(t.unroll())
|
||||
else:
|
||||
new_list.append(t)
|
||||
return new_list
|
||||
|
||||
def list(self):
|
||||
return list(self)
|
||||
|
||||
def attribute(self, attr: str, *args):
|
||||
return TensorList([getattr(e, attr, *args) for e in self])
|
||||
|
||||
def apply(self, fn):
|
||||
return TensorList([fn(e) for e in self])
|
||||
|
||||
def __getattr__(self, name):
|
||||
if not hasattr(torch.Tensor, name):
|
||||
raise AttributeError('\'TensorList\' object has not attribute \'{}\''.format(name))
|
||||
|
||||
def apply_attr(*args, **kwargs):
|
||||
return TensorList([getattr(e, name)(*args, **kwargs) for e in self])
|
||||
|
||||
return apply_attr
|
||||
|
||||
@staticmethod
|
||||
def _iterable(a):
|
||||
return isinstance(a, (TensorList, list))
|
||||
|
||||
|
||||
def tensor_operation(op):
|
||||
def islist(a):
|
||||
return isinstance(a, TensorList)
|
||||
|
||||
@functools.wraps(op)
|
||||
def oplist(*args, **kwargs):
|
||||
if len(args) == 0:
|
||||
raise ValueError('Must be at least one argument without keyword (i.e. operand).')
|
||||
|
||||
if len(args) == 1:
|
||||
if islist(args[0]):
|
||||
return TensorList([op(a, **kwargs) for a in args[0]])
|
||||
else:
|
||||
# Multiple operands, assume max two
|
||||
if islist(args[0]) and islist(args[1]):
|
||||
return TensorList([op(a, b, *args[2:], **kwargs) for a, b in zip(*args[:2])])
|
||||
if islist(args[0]):
|
||||
return TensorList([op(a, *args[1:], **kwargs) for a in args[0]])
|
||||
if islist(args[1]):
|
||||
return TensorList([op(args[0], b, *args[2:], **kwargs) for b in args[1]])
|
||||
|
||||
# None of the operands are lists
|
||||
return op(*args, **kwargs)
|
||||
|
||||
return oplist
|
50
lib/utils/variable_hook.py
Normal file
50
lib/utils/variable_hook.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import torch
|
||||
from bytecode import Bytecode, Instr
|
||||
|
||||
|
||||
class get_local(object):
|
||||
cache = {}
|
||||
is_activate = False
|
||||
|
||||
def __init__(self, varname):
|
||||
self.varname = varname
|
||||
|
||||
def __call__(self, func):
|
||||
if not type(self).is_activate:
|
||||
return func
|
||||
|
||||
type(self).cache[func.__qualname__] = []
|
||||
c = Bytecode.from_code(func.__code__)
|
||||
extra_code = [
|
||||
Instr('STORE_FAST', '_res'),
|
||||
Instr('LOAD_FAST', self.varname),
|
||||
Instr('STORE_FAST', '_value'),
|
||||
Instr('LOAD_FAST', '_res'),
|
||||
Instr('LOAD_FAST', '_value'),
|
||||
Instr('BUILD_TUPLE', 2),
|
||||
Instr('STORE_FAST', '_result_tuple'),
|
||||
Instr('LOAD_FAST', '_result_tuple'),
|
||||
]
|
||||
c[-1:-1] = extra_code
|
||||
func.__code__ = c.to_code()
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
res, values = func(*args, **kwargs)
|
||||
if isinstance(values, torch.Tensor):
|
||||
type(self).cache[func.__qualname__].append(values.detach().cpu().numpy())
|
||||
elif isinstance(values, list): # list of Tensor
|
||||
type(self).cache[func.__qualname__].append([value.detach().cpu().numpy() for value in values])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return res
|
||||
|
||||
return wrapper
|
||||
|
||||
@classmethod
|
||||
def clear(cls):
|
||||
for key in cls.cache.keys():
|
||||
cls.cache[key] = []
|
||||
|
||||
@classmethod
|
||||
def activate(cls):
|
||||
cls.is_activate = True
|
Reference in New Issue
Block a user