raise err if augment on model other than BODY_25/B

This commit is contained in:
davidpagnon 2024-01-19 22:18:16 +01:00
parent c29abcc9d9
commit 07ec3e46c1

View File

@ -68,6 +68,7 @@ def augmentTRC(config_dict):
session_dir = os.path.realpath(os.path.join(project_dir, '..', '..')) session_dir = os.path.realpath(os.path.join(project_dir, '..', '..'))
pathInputTRCFile = os.path.realpath(os.path.join(project_dir, 'pose-3d')) pathInputTRCFile = os.path.realpath(os.path.join(project_dir, 'pose-3d'))
pathOutputTRCFile = os.path.realpath(os.path.join(project_dir, 'pose-3d')) pathOutputTRCFile = os.path.realpath(os.path.join(project_dir, 'pose-3d'))
pose_model = config_dict.get('pose').get('pose_model')
subject_height = config_dict.get('markerAugmentation').get('participant_height') subject_height = config_dict.get('markerAugmentation').get('participant_height')
if subject_height is None or subject_height == 0: if subject_height is None or subject_height == 0:
raise ValueError("Subject height is not set or invalid in the config file.") raise ValueError("Subject height is not set or invalid in the config file.")
@ -77,6 +78,9 @@ def augmentTRC(config_dict):
augmenter_model = 'v0.3' augmenter_model = 'v0.3'
offset = True offset = True
if pose_model not in ['BODY_25', 'BODY_25B']:
raise ValueError('Marker augmentation is only supported with OpenPose BODY_25 and BODY_25B models.')
# Apply all trc files # Apply all trc files
trc_files = [f for f in glob.glob(os.path.join(pathInputTRCFile, '*.trc')) if '_LSTM' not in f] trc_files = [f for f in glob.glob(os.path.join(pathInputTRCFile, '*.trc')) if '_LSTM' not in f]
for pathInputTRCFile in trc_files: for pathInputTRCFile in trc_files: