From 4346e0e001e0ac980aa0eef3d168c021d734d058 Mon Sep 17 00:00:00 2001 From: HunMinKim <144449115+rlagnsals@users.noreply.github.com> Date: Tue, 9 Jan 2024 20:30:06 +0900 Subject: [PATCH] Update Pose2Sim.py Update augmenter --- Pose2Sim/Pose2Sim.py | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/Pose2Sim/Pose2Sim.py b/Pose2Sim/Pose2Sim.py index 438677e..0e7e374 100644 --- a/Pose2Sim/Pose2Sim.py +++ b/Pose2Sim/Pose2Sim.py @@ -469,4 +469,42 @@ def opensimProcessing(config=None): # end = time.time() # logging.info(f'Model scaling took {end-start:.2f} s.') - \ No newline at end of file + +def augmenter(config=None): + ''' + Augmentation process for marker data. + + config can be a dictionary, + or the directory path of a trial, participant, or session, + or the function can be called without an argument, in which case the config directory is the current one. + ''' + + from Pose2Sim.augmenter import augmentTRC + + level, config_dicts = read_config_files(config) + + if type(config) == dict: + config_dict = config_dicts[0] + if config_dict.get('project').get('project_dir') is None: + raise ValueError('Please specify the project directory in config_dict:\n \ + config_dict.get("project").update({"project_dir":""})') + + session_dir = os.path.realpath(os.path.join(config_dicts[0].get('project').get('project_dir'), '..', '..')) + setup_logging(session_dir) + + for config_dict in config_dicts: + start = time.time() + project_dir = os.path.realpath(config_dict.get('project').get('project_dir')) + seq_name = os.path.basename(project_dir) + frame_range = config_dict.get('project').get('frame_range') + frames = ["all frames" if frame_range == [] else f"frames {frame_range[0]} to {frame_range[1]}"][0] + + logging.info("\n\n---------------------------------------------------------------------") + logging.info(f"Augmentation process for {seq_name}, for {frames}.") + logging.info("---------------------------------------------------------------------") + logging.info(f"\nProject directory: {project_dir}") + + augmentTRC(config_dict) + + end = time.time() + logging.info(f'Augmentation took {end - start:.2f} s.')