init commit of samurai
This commit is contained in:
50
lib/utils/variable_hook.py
Normal file
50
lib/utils/variable_hook.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import torch
|
||||
from bytecode import Bytecode, Instr
|
||||
|
||||
|
||||
class get_local(object):
|
||||
cache = {}
|
||||
is_activate = False
|
||||
|
||||
def __init__(self, varname):
|
||||
self.varname = varname
|
||||
|
||||
def __call__(self, func):
|
||||
if not type(self).is_activate:
|
||||
return func
|
||||
|
||||
type(self).cache[func.__qualname__] = []
|
||||
c = Bytecode.from_code(func.__code__)
|
||||
extra_code = [
|
||||
Instr('STORE_FAST', '_res'),
|
||||
Instr('LOAD_FAST', self.varname),
|
||||
Instr('STORE_FAST', '_value'),
|
||||
Instr('LOAD_FAST', '_res'),
|
||||
Instr('LOAD_FAST', '_value'),
|
||||
Instr('BUILD_TUPLE', 2),
|
||||
Instr('STORE_FAST', '_result_tuple'),
|
||||
Instr('LOAD_FAST', '_result_tuple'),
|
||||
]
|
||||
c[-1:-1] = extra_code
|
||||
func.__code__ = c.to_code()
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
res, values = func(*args, **kwargs)
|
||||
if isinstance(values, torch.Tensor):
|
||||
type(self).cache[func.__qualname__].append(values.detach().cpu().numpy())
|
||||
elif isinstance(values, list): # list of Tensor
|
||||
type(self).cache[func.__qualname__].append([value.detach().cpu().numpy() for value in values])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return res
|
||||
|
||||
return wrapper
|
||||
|
||||
@classmethod
|
||||
def clear(cls):
|
||||
for key in cls.cache.keys():
|
||||
cls.cache[key] = []
|
||||
|
||||
@classmethod
|
||||
def activate(cls):
|
||||
cls.is_activate = True
|
Reference in New Issue
Block a user