133 lines
4.2 KiB
Python
133 lines
4.2 KiB
Python
|
'''
|
||
|
@ Date: 2021-09-05 20:11:27
|
||
|
@ Author: Qing Shuai
|
||
|
@ LastEditors: Qing Shuai
|
||
|
@ LastEditTime: 2021-09-05 20:11:27
|
||
|
@ FilePath: /EasyMocap/easymocap/neuralbody/trainer/recorder.py
|
||
|
'''
|
||
|
from collections import deque, defaultdict
|
||
|
import torch
|
||
|
from tensorboardX import SummaryWriter
|
||
|
import os
|
||
|
|
||
|
from termcolor import colored
|
||
|
|
||
|
|
||
|
class SmoothedValue(object):
|
||
|
"""Track a series of values and provide access to smoothed values over a
|
||
|
window or the global series average.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, window_size=20):
|
||
|
self.deque = deque(maxlen=window_size)
|
||
|
self.total = 0.0
|
||
|
self.count = 0
|
||
|
|
||
|
def update(self, value):
|
||
|
self.deque.append(value)
|
||
|
self.count += 1
|
||
|
self.total += value
|
||
|
|
||
|
@property
|
||
|
def median(self):
|
||
|
d = torch.tensor(list(self.deque))
|
||
|
return d.median().item()
|
||
|
|
||
|
@property
|
||
|
def avg(self):
|
||
|
d = torch.tensor(list(self.deque))
|
||
|
return d.mean().item()
|
||
|
|
||
|
@property
|
||
|
def global_avg(self):
|
||
|
return self.total / self.count
|
||
|
|
||
|
class Recorder(object):
|
||
|
def __init__(self, local_rank=0, resume=False, log_dir="", task=""):
|
||
|
self.local_rank = local_rank
|
||
|
if local_rank > 0:
|
||
|
return
|
||
|
if not resume:
|
||
|
print(colored('[{}] remove contents of directory {}'.format(local_rank, log_dir), 'red'))
|
||
|
os.system('rm -r %s/*' % log_dir)
|
||
|
# self.writer = SummaryWriter(log_dir=log_dir, flush_secs=5)
|
||
|
self.writer = SummaryWriter(log_dir=log_dir)
|
||
|
self.log_dir = log_dir
|
||
|
# scalars
|
||
|
self.epoch = 0
|
||
|
self.step = 0
|
||
|
self.loss_stats = defaultdict(SmoothedValue)
|
||
|
self.batch_time = SmoothedValue()
|
||
|
self.data_time = SmoothedValue()
|
||
|
|
||
|
# images
|
||
|
self.image_stats = defaultdict(object)
|
||
|
# if 'process_' + cfg.task in globals():
|
||
|
# self.processor = globals()['process_' + cfg.task]
|
||
|
# else:
|
||
|
# self.processor = None
|
||
|
self.processor = None
|
||
|
# self.cfg = cfg
|
||
|
|
||
|
def update_loss_stats(self, loss_dict):
|
||
|
if self.local_rank > 0:
|
||
|
return
|
||
|
for k, v in loss_dict.items():
|
||
|
self.loss_stats[k].update(v.detach().cpu())
|
||
|
|
||
|
def update_image_stats(self, image_stats):
|
||
|
if self.local_rank > 0:
|
||
|
return
|
||
|
# if self.processor is None:
|
||
|
# return
|
||
|
# image_stats = self.processor(image_stats)
|
||
|
for k, v in image_stats.items():
|
||
|
self.image_stats[k] = v #.detach().cpu()
|
||
|
|
||
|
def record(self, prefix, step=-1, loss_stats=None, image_stats=None):
|
||
|
if self.local_rank > 0:
|
||
|
return
|
||
|
|
||
|
pattern = prefix + '/{}'
|
||
|
step = step if step >= 0 else self.step
|
||
|
loss_stats = loss_stats if loss_stats else self.loss_stats
|
||
|
image_stats = image_stats if image_stats else self.image_stats
|
||
|
|
||
|
for k, v in loss_stats.items():
|
||
|
if isinstance(v, SmoothedValue):
|
||
|
self.writer.add_scalar(pattern.format(k), v.median, step)
|
||
|
else:
|
||
|
self.writer.add_scalar(pattern.format(k), v, step)
|
||
|
for k, v in image_stats.items():
|
||
|
if len(v.shape) == 2:
|
||
|
self.writer.add_image(pattern.format(k), v, step, dataformats='HW')
|
||
|
else:
|
||
|
self.writer.add_image(pattern.format(k), v, step)
|
||
|
|
||
|
def state_dict(self):
|
||
|
if self.local_rank > 0:
|
||
|
return
|
||
|
scalar_dict = {}
|
||
|
scalar_dict['step'] = self.step
|
||
|
return scalar_dict
|
||
|
|
||
|
def load_state_dict(self, scalar_dict):
|
||
|
if self.local_rank > 0:
|
||
|
return
|
||
|
self.step = scalar_dict['step']
|
||
|
|
||
|
def __str__(self):
|
||
|
if self.local_rank > 0:
|
||
|
return
|
||
|
loss_state = []
|
||
|
for k, v in self.loss_stats.items():
|
||
|
loss_state.append('{}: {:.4f}'.format(k, v.avg))
|
||
|
loss_state = ' '.join(loss_state)
|
||
|
|
||
|
recording_state = ' '.join(['epoch: {}', 'step: {}', '{}', 'data: {:.4f}', 'batch: {:.4f}'])
|
||
|
return recording_state.format(self.epoch, self.step, loss_state, self.data_time.avg, self.batch_time.avg)
|
||
|
|
||
|
def write_cfg(self, cfg):
|
||
|
if self.local_rank > 0:
|
||
|
return
|
||
|
print(cfg, file=open(os.path.join(self.log_dir, 'exp.yml'), 'w'))
|