EasyMocap/myeasymocap/stages/basestage.py

149 lines
5.4 KiB
Python
Raw Normal View History

2023-06-19 16:39:27 +08:00
from typing import Any
from easymocap.config import Config, load_object
from easymocap.mytools.debug_utils import mywarn, log
import numpy as np
import time
from tabulate import tabulate
class Timer:
def __init__(self, record, verbose) -> None:
self.keys = list(record.keys())
self.header = self.keys
self.verbose = verbose
def update(self, timer):
if not self.verbose:
return
contents = []
for key in self.keys:
if key not in timer:
contents.append('skip')
else:
contents.append('{:.3f}s'.format(timer[key]))
print(tabulate(headers=self.header, tabular_data=[contents], tablefmt='fancy_grid'))
class MultiStage:
def load_final(self):
at_finals = {}
for key, val in self._at_final.items():
if val['module'] == 'skip':
mywarn('Stage {} is not used'.format(key))
continue
log('[{}] loading {}'.format(self.__class__.__name__, key))
model = load_object(val['module'], val['args'])
model.output = self.output
at_finals[key] = model
self.model_finals = at_finals
def __init__(self, output, at_step, at_final) -> None:
log('[{}] writing the results to {}'.format(self.__class__.__name__, output))
at_steps = {}
for key, val in at_step.items():
if val['module'] == 'skip':
mywarn('Stage {} is not used'.format(key))
continue
log('[{}] loading module {}'.format(self.__class__.__name__, key))
model = load_object(val['module'], val['args'])
model.output = output
at_steps[key] = model
self.output = output
self.model_steps = at_steps
self._at_step = at_step
self._at_final = at_final
self.timer = Timer(at_steps, verbose=False)
def at_step(self, data, index):
ret = {}
if 'meta' in data:
ret['meta'] = data['meta']
timer = {}
for key, model in self.model_steps.items():
for k in self._at_step[key].get('key_keep', []):
ret[k] = data[k]
if self._at_step[key].get('skip', False):
continue
inputs = {}
for k in self._at_step[key].get('key_from_data', []):
inputs[k] = data[k]
for k in self._at_step[key].get('key_from_previous', []):
inputs[k] = ret[k]
start = time.time()
try:
output = model(**inputs)
except:
print('[{}] Error in {}'.format('Stages', key))
raise Exception
timer[key] = time.time() - start
if output is not None:
ret.update(output)
self.timer.update(timer)
return ret
@staticmethod
def merge_data(infos_all):
info0 = infos_all[0]
data = {}
for key, val in info0.items():
data[key] = [info[key] for info in infos_all]
if isinstance(val, np.ndarray):
try:
data[key] = np.stack(data[key])
except ValueError:
print('[{}] Skip merge {}'.format('Stages', key))
pass
elif isinstance(val, dict):
data[key] = MultiStage.merge_data(data[key])
return data
def at_final(self, infos_all):
self.load_final()
data = self.merge_data(infos_all)
log('Keep keys: {}'.format(list(data.keys())))
ret = {}
for key, model in self.model_finals.items():
for iter_ in range(self._at_final[key].get('repeat', 1)):
inputs = {}
for k in self._at_final[key].get('key_from_data', []):
inputs[k] = data[k]
for k in self._at_final[key].get('key_from_previous', []):
inputs[k] = ret[k]
try:
output = model(**inputs)
except:
print('[{}] Error in {}'.format('Stages', key))
raise Exception
if output is not None:
ret.update(output)
return ret
class StageForFittingEach:
def __init__(self, stages, keys_keep) -> None:
stages_ = {}
for key, val in stages.items():
if val['module'] == 'skip':
mywarn('Stage {} is not used'.format(key))
continue
model = load_object(val['module'], val['args'])
stages_[key] = model
self.stages = stages_
self.stages_args = stages
self.keys_keep = keys_keep
def __call__(self, results, **ret):
for pid, result in results.items():
ret0 = {}
ret0.update(ret)
for key, stage in self.stages.items():
for iter_ in range(self.stages_args[key].get('repeat', 1)):
inputs = {}
for k in self.stages_args[key].get('key_from_data', []):
inputs[k] = result[k]
for k in self.stages_args[key].get('key_from_previous', []):
inputs[k] = ret0[k]
output = stage(**inputs)
if output is not None:
ret0.update(output)
for key in self.keys_keep:
result[key] = ret0[key]
return {'results': results}