Files
Grounded-SAM-2/lib/train/data/wandb_logger.py
2024-11-19 22:12:54 -08:00

34 lines
1.1 KiB
Python

from collections import OrderedDict
try:
import wandb
except ImportError:
raise ImportError(
'Please run "pip install wandb" to install wandb')
class WandbWriter:
def __init__(self, exp_name, cfg, output_dir, cur_step=0, step_interval=0):
self.wandb = wandb
self.step = cur_step
self.interval = step_interval
wandb.init(project="tracking", name=exp_name, config=cfg, dir=output_dir)
def write_log(self, stats: OrderedDict, epoch=-1):
self.step += 1
for loader_name, loader_stats in stats.items():
if loader_stats is None:
continue
log_dict = {}
for var_name, val in loader_stats.items():
if hasattr(val, 'avg'):
log_dict.update({loader_name + '/' + var_name: val.avg})
else:
log_dict.update({loader_name + '/' + var_name: val.val})
if epoch >= 0:
log_dict.update({loader_name + '/epoch': epoch})
self.wandb.log(log_dict, step=self.step*self.interval)