51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
![]() |
import torch
|
||
|
from bytecode import Bytecode, Instr
|
||
|
|
||
|
|
||
|
class get_local(object):
|
||
|
cache = {}
|
||
|
is_activate = False
|
||
|
|
||
|
def __init__(self, varname):
|
||
|
self.varname = varname
|
||
|
|
||
|
def __call__(self, func):
|
||
|
if not type(self).is_activate:
|
||
|
return func
|
||
|
|
||
|
type(self).cache[func.__qualname__] = []
|
||
|
c = Bytecode.from_code(func.__code__)
|
||
|
extra_code = [
|
||
|
Instr('STORE_FAST', '_res'),
|
||
|
Instr('LOAD_FAST', self.varname),
|
||
|
Instr('STORE_FAST', '_value'),
|
||
|
Instr('LOAD_FAST', '_res'),
|
||
|
Instr('LOAD_FAST', '_value'),
|
||
|
Instr('BUILD_TUPLE', 2),
|
||
|
Instr('STORE_FAST', '_result_tuple'),
|
||
|
Instr('LOAD_FAST', '_result_tuple'),
|
||
|
]
|
||
|
c[-1:-1] = extra_code
|
||
|
func.__code__ = c.to_code()
|
||
|
|
||
|
def wrapper(*args, **kwargs):
|
||
|
res, values = func(*args, **kwargs)
|
||
|
if isinstance(values, torch.Tensor):
|
||
|
type(self).cache[func.__qualname__].append(values.detach().cpu().numpy())
|
||
|
elif isinstance(values, list): # list of Tensor
|
||
|
type(self).cache[func.__qualname__].append([value.detach().cpu().numpy() for value in values])
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
return res
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
@classmethod
|
||
|
def clear(cls):
|
||
|
for key in cls.cache.keys():
|
||
|
cls.cache[key] = []
|
||
|
|
||
|
@classmethod
|
||
|
def activate(cls):
|
||
|
cls.is_activate = True
|