149 lines
5.4 KiB
Python
149 lines
5.4 KiB
Python
|
from typing import Any
|
||
|
from easymocap.config import Config, load_object
|
||
|
from easymocap.mytools.debug_utils import mywarn, log
|
||
|
import numpy as np
|
||
|
import time
|
||
|
from tabulate import tabulate
|
||
|
|
||
|
class Timer:
|
||
|
def __init__(self, record, verbose) -> None:
|
||
|
self.keys = list(record.keys())
|
||
|
self.header = self.keys
|
||
|
self.verbose = verbose
|
||
|
|
||
|
def update(self, timer):
|
||
|
if not self.verbose:
|
||
|
return
|
||
|
contents = []
|
||
|
for key in self.keys:
|
||
|
if key not in timer:
|
||
|
contents.append('skip')
|
||
|
else:
|
||
|
contents.append('{:.3f}s'.format(timer[key]))
|
||
|
print(tabulate(headers=self.header, tabular_data=[contents], tablefmt='fancy_grid'))
|
||
|
|
||
|
class MultiStage:
|
||
|
def load_final(self):
|
||
|
at_finals = {}
|
||
|
for key, val in self._at_final.items():
|
||
|
if val['module'] == 'skip':
|
||
|
mywarn('Stage {} is not used'.format(key))
|
||
|
continue
|
||
|
log('[{}] loading {}'.format(self.__class__.__name__, key))
|
||
|
model = load_object(val['module'], val['args'])
|
||
|
model.output = self.output
|
||
|
at_finals[key] = model
|
||
|
self.model_finals = at_finals
|
||
|
|
||
|
def __init__(self, output, at_step, at_final) -> None:
|
||
|
log('[{}] writing the results to {}'.format(self.__class__.__name__, output))
|
||
|
at_steps = {}
|
||
|
for key, val in at_step.items():
|
||
|
if val['module'] == 'skip':
|
||
|
mywarn('Stage {} is not used'.format(key))
|
||
|
continue
|
||
|
log('[{}] loading module {}'.format(self.__class__.__name__, key))
|
||
|
model = load_object(val['module'], val['args'])
|
||
|
model.output = output
|
||
|
at_steps[key] = model
|
||
|
self.output = output
|
||
|
self.model_steps = at_steps
|
||
|
self._at_step = at_step
|
||
|
self._at_final = at_final
|
||
|
self.timer = Timer(at_steps, verbose=False)
|
||
|
|
||
|
def at_step(self, data, index):
|
||
|
ret = {}
|
||
|
if 'meta' in data:
|
||
|
ret['meta'] = data['meta']
|
||
|
timer = {}
|
||
|
for key, model in self.model_steps.items():
|
||
|
for k in self._at_step[key].get('key_keep', []):
|
||
|
ret[k] = data[k]
|
||
|
if self._at_step[key].get('skip', False):
|
||
|
continue
|
||
|
inputs = {}
|
||
|
for k in self._at_step[key].get('key_from_data', []):
|
||
|
inputs[k] = data[k]
|
||
|
for k in self._at_step[key].get('key_from_previous', []):
|
||
|
inputs[k] = ret[k]
|
||
|
start = time.time()
|
||
|
try:
|
||
|
output = model(**inputs)
|
||
|
except:
|
||
|
print('[{}] Error in {}'.format('Stages', key))
|
||
|
raise Exception
|
||
|
timer[key] = time.time() - start
|
||
|
if output is not None:
|
||
|
ret.update(output)
|
||
|
|
||
|
self.timer.update(timer)
|
||
|
return ret
|
||
|
|
||
|
@staticmethod
|
||
|
def merge_data(infos_all):
|
||
|
info0 = infos_all[0]
|
||
|
data = {}
|
||
|
for key, val in info0.items():
|
||
|
data[key] = [info[key] for info in infos_all]
|
||
|
if isinstance(val, np.ndarray):
|
||
|
try:
|
||
|
data[key] = np.stack(data[key])
|
||
|
except ValueError:
|
||
|
print('[{}] Skip merge {}'.format('Stages', key))
|
||
|
pass
|
||
|
elif isinstance(val, dict):
|
||
|
data[key] = MultiStage.merge_data(data[key])
|
||
|
return data
|
||
|
|
||
|
def at_final(self, infos_all):
|
||
|
self.load_final()
|
||
|
data = self.merge_data(infos_all)
|
||
|
log('Keep keys: {}'.format(list(data.keys())))
|
||
|
ret = {}
|
||
|
for key, model in self.model_finals.items():
|
||
|
for iter_ in range(self._at_final[key].get('repeat', 1)):
|
||
|
inputs = {}
|
||
|
for k in self._at_final[key].get('key_from_data', []):
|
||
|
inputs[k] = data[k]
|
||
|
for k in self._at_final[key].get('key_from_previous', []):
|
||
|
inputs[k] = ret[k]
|
||
|
try:
|
||
|
output = model(**inputs)
|
||
|
except:
|
||
|
print('[{}] Error in {}'.format('Stages', key))
|
||
|
raise Exception
|
||
|
if output is not None:
|
||
|
ret.update(output)
|
||
|
return ret
|
||
|
|
||
|
class StageForFittingEach:
|
||
|
def __init__(self, stages, keys_keep) -> None:
|
||
|
stages_ = {}
|
||
|
for key, val in stages.items():
|
||
|
if val['module'] == 'skip':
|
||
|
mywarn('Stage {} is not used'.format(key))
|
||
|
continue
|
||
|
model = load_object(val['module'], val['args'])
|
||
|
stages_[key] = model
|
||
|
self.stages = stages_
|
||
|
self.stages_args = stages
|
||
|
self.keys_keep = keys_keep
|
||
|
|
||
|
def __call__(self, results, **ret):
|
||
|
for pid, result in results.items():
|
||
|
ret0 = {}
|
||
|
ret0.update(ret)
|
||
|
for key, stage in self.stages.items():
|
||
|
for iter_ in range(self.stages_args[key].get('repeat', 1)):
|
||
|
inputs = {}
|
||
|
for k in self.stages_args[key].get('key_from_data', []):
|
||
|
inputs[k] = result[k]
|
||
|
for k in self.stages_args[key].get('key_from_previous', []):
|
||
|
inputs[k] = ret0[k]
|
||
|
output = stage(**inputs)
|
||
|
if output is not None:
|
||
|
ret0.update(output)
|
||
|
for key in self.keys_keep:
|
||
|
result[key] = ret0[key]
|
||
|
return {'results': results}
|