import functools import torch import copy from collections import OrderedDict class TensorDict(OrderedDict): """Container mainly used for dicts of torch tensors. Extends OrderedDict with pytorch functionality.""" def concat(self, other): """Concatenates two dicts without copying internal data.""" return TensorDict(self, **other) def copy(self): return TensorDict(super(TensorDict, self).copy()) def __deepcopy__(self, memodict={}): return TensorDict(copy.deepcopy(list(self), memodict)) def __getattr__(self, name): if not hasattr(torch.Tensor, name): raise AttributeError('\'TensorDict\' object has not attribute \'{}\''.format(name)) def apply_attr(*args, **kwargs): return TensorDict({n: getattr(e, name)(*args, **kwargs) if hasattr(e, name) else e for n, e in self.items()}) return apply_attr def attribute(self, attr: str, *args): return TensorDict({n: getattr(e, attr, *args) for n, e in self.items()}) def apply(self, fn, *args, **kwargs): return TensorDict({n: fn(e, *args, **kwargs) for n, e in self.items()}) @staticmethod def _iterable(a): return isinstance(a, (TensorDict, list)) class TensorList(list): """Container mainly used for lists of torch tensors. Extends lists with pytorch functionality.""" def __init__(self, list_of_tensors = None): if list_of_tensors is None: list_of_tensors = list() super(TensorList, self).__init__(list_of_tensors) def __deepcopy__(self, memodict={}): return TensorList(copy.deepcopy(list(self), memodict)) def __getitem__(self, item): if isinstance(item, int): return super(TensorList, self).__getitem__(item) elif isinstance(item, (tuple, list)): return TensorList([super(TensorList, self).__getitem__(i) for i in item]) else: return TensorList(super(TensorList, self).__getitem__(item)) def __add__(self, other): if TensorList._iterable(other): return TensorList([e1 + e2 for e1, e2 in zip(self, other)]) return TensorList([e + other for e in self]) def __radd__(self, other): if TensorList._iterable(other): return TensorList([e2 + e1 for e1, e2 in zip(self, other)]) return TensorList([other + e for e in self]) def __iadd__(self, other): if TensorList._iterable(other): for i, e2 in enumerate(other): self[i] += e2 else: for i in range(len(self)): self[i] += other return self def __sub__(self, other): if TensorList._iterable(other): return TensorList([e1 - e2 for e1, e2 in zip(self, other)]) return TensorList([e - other for e in self]) def __rsub__(self, other): if TensorList._iterable(other): return TensorList([e2 - e1 for e1, e2 in zip(self, other)]) return TensorList([other - e for e in self]) def __isub__(self, other): if TensorList._iterable(other): for i, e2 in enumerate(other): self[i] -= e2 else: for i in range(len(self)): self[i] -= other return self def __mul__(self, other): if TensorList._iterable(other): return TensorList([e1 * e2 for e1, e2 in zip(self, other)]) return TensorList([e * other for e in self]) def __rmul__(self, other): if TensorList._iterable(other): return TensorList([e2 * e1 for e1, e2 in zip(self, other)]) return TensorList([other * e for e in self]) def __imul__(self, other): if TensorList._iterable(other): for i, e2 in enumerate(other): self[i] *= e2 else: for i in range(len(self)): self[i] *= other return self def __truediv__(self, other): if TensorList._iterable(other): return TensorList([e1 / e2 for e1, e2 in zip(self, other)]) return TensorList([e / other for e in self]) def __rtruediv__(self, other): if TensorList._iterable(other): return TensorList([e2 / e1 for e1, e2 in zip(self, other)]) return TensorList([other / e for e in self]) def __itruediv__(self, other): if TensorList._iterable(other): for i, e2 in enumerate(other): self[i] /= e2 else: for i in range(len(self)): self[i] /= other return self def __matmul__(self, other): if TensorList._iterable(other): return TensorList([e1 @ e2 for e1, e2 in zip(self, other)]) return TensorList([e @ other for e in self]) def __rmatmul__(self, other): if TensorList._iterable(other): return TensorList([e2 @ e1 for e1, e2 in zip(self, other)]) return TensorList([other @ e for e in self]) def __imatmul__(self, other): if TensorList._iterable(other): for i, e2 in enumerate(other): self[i] @= e2 else: for i in range(len(self)): self[i] @= other return self def __mod__(self, other): if TensorList._iterable(other): return TensorList([e1 % e2 for e1, e2 in zip(self, other)]) return TensorList([e % other for e in self]) def __rmod__(self, other): if TensorList._iterable(other): return TensorList([e2 % e1 for e1, e2 in zip(self, other)]) return TensorList([other % e for e in self]) def __pos__(self): return TensorList([+e for e in self]) def __neg__(self): return TensorList([-e for e in self]) def __le__(self, other): if TensorList._iterable(other): return TensorList([e1 <= e2 for e1, e2 in zip(self, other)]) return TensorList([e <= other for e in self]) def __ge__(self, other): if TensorList._iterable(other): return TensorList([e1 >= e2 for e1, e2 in zip(self, other)]) return TensorList([e >= other for e in self]) def concat(self, other): return TensorList(super(TensorList, self).__add__(other)) def copy(self): return TensorList(super(TensorList, self).copy()) def unroll(self): if not any(isinstance(t, TensorList) for t in self): return self new_list = TensorList() for t in self: if isinstance(t, TensorList): new_list.extend(t.unroll()) else: new_list.append(t) return new_list def list(self): return list(self) def attribute(self, attr: str, *args): return TensorList([getattr(e, attr, *args) for e in self]) def apply(self, fn): return TensorList([fn(e) for e in self]) def __getattr__(self, name): if not hasattr(torch.Tensor, name): raise AttributeError('\'TensorList\' object has not attribute \'{}\''.format(name)) def apply_attr(*args, **kwargs): return TensorList([getattr(e, name)(*args, **kwargs) for e in self]) return apply_attr @staticmethod def _iterable(a): return isinstance(a, (TensorList, list)) def tensor_operation(op): def islist(a): return isinstance(a, TensorList) @functools.wraps(op) def oplist(*args, **kwargs): if len(args) == 0: raise ValueError('Must be at least one argument without keyword (i.e. operand).') if len(args) == 1: if islist(args[0]): return TensorList([op(a, **kwargs) for a in args[0]]) else: # Multiple operands, assume max two if islist(args[0]) and islist(args[1]): return TensorList([op(a, b, *args[2:], **kwargs) for a, b in zip(*args[:2])]) if islist(args[0]): return TensorList([op(a, *args[1:], **kwargs) for a in args[0]]) if islist(args[1]): return TensorList([op(args[0], b, *args[2:], **kwargs) for b in args[1]]) # None of the operands are lists return op(*args, **kwargs) return oplist