raise err if augment on model other than BODY_25/B
This commit is contained in:
parent
c29abcc9d9
commit
07ec3e46c1
@ -68,6 +68,7 @@ def augmentTRC(config_dict):
|
|||||||
session_dir = os.path.realpath(os.path.join(project_dir, '..', '..'))
|
session_dir = os.path.realpath(os.path.join(project_dir, '..', '..'))
|
||||||
pathInputTRCFile = os.path.realpath(os.path.join(project_dir, 'pose-3d'))
|
pathInputTRCFile = os.path.realpath(os.path.join(project_dir, 'pose-3d'))
|
||||||
pathOutputTRCFile = os.path.realpath(os.path.join(project_dir, 'pose-3d'))
|
pathOutputTRCFile = os.path.realpath(os.path.join(project_dir, 'pose-3d'))
|
||||||
|
pose_model = config_dict.get('pose').get('pose_model')
|
||||||
subject_height = config_dict.get('markerAugmentation').get('participant_height')
|
subject_height = config_dict.get('markerAugmentation').get('participant_height')
|
||||||
if subject_height is None or subject_height == 0:
|
if subject_height is None or subject_height == 0:
|
||||||
raise ValueError("Subject height is not set or invalid in the config file.")
|
raise ValueError("Subject height is not set or invalid in the config file.")
|
||||||
@ -77,6 +78,9 @@ def augmentTRC(config_dict):
|
|||||||
augmenter_model = 'v0.3'
|
augmenter_model = 'v0.3'
|
||||||
offset = True
|
offset = True
|
||||||
|
|
||||||
|
if pose_model not in ['BODY_25', 'BODY_25B']:
|
||||||
|
raise ValueError('Marker augmentation is only supported with OpenPose BODY_25 and BODY_25B models.')
|
||||||
|
|
||||||
# Apply all trc files
|
# Apply all trc files
|
||||||
trc_files = [f for f in glob.glob(os.path.join(pathInputTRCFile, '*.trc')) if '_LSTM' not in f]
|
trc_files = [f for f in glob.glob(os.path.join(pathInputTRCFile, '*.trc')) if '_LSTM' not in f]
|
||||||
for pathInputTRCFile in trc_files:
|
for pathInputTRCFile in trc_files:
|
||||||
|
Loading…
Reference in New Issue
Block a user