option to skip pose estimation if already done

This commit is contained in:
davidpagnon 2024-07-17 10:57:12 +02:00
parent d35b64efcb
commit e5dd81d94d
3 changed files with 31 additions and 16 deletions

View File

@ -46,6 +46,7 @@ det_frequency = 1 # Run person detection only every N frames, and inbetween trac
# Equal to or greater than 1, can be as high as you want in simple uncrowded cases. Much faster, but might be less accurate.
tracking = true # Gives consistent person ID across frames. Slightly slower but might facilitate synchronization if other people are in the background
display_detection = true
overwrite_pose = false # set to false if you don't want to recalculate pose estimation when it has already been done
save_video = 'to_video' # 'to_video' or 'to_images', 'none', or ['to_video', 'to_images']
output_format = 'openpose' # 'openpose', 'mmpose', 'deeplabcut', 'none' or a list of them # /!\ only 'openpose' is supported for now

View File

@ -46,6 +46,7 @@ det_frequency = 1 # Run person detection only every N frames, and inbetween trac
# Equal to or greater than 1, can be as high as you want in simple uncrowded cases. Much faster, but might be less accurate.
tracking = true # Gives consistent person ID across frames. Slightly slower but might facilitate synchronization if other people are in the background
display_detection = true
overwrite_pose = false # set to false if you don't want to recalculate pose estimation when it has already been done
save_video = 'to_video' # 'to_video' or 'to_images', 'none', or ['to_video', 'to_images']
output_format = 'openpose' # 'openpose', 'mmpose', 'deeplabcut', 'none' or a list of them # /!\ only 'openpose' is supported for now

View File

@ -331,6 +331,7 @@ def rtm_estimator(config_dict):
session_dir = session_dir if 'Config.toml' in os.listdir(session_dir) else os.getcwd()
frame_range = config_dict.get('project').get('frame_range')
video_dir = os.path.join(project_dir, 'videos')
pose_dir = os.path.join(project_dir, 'pose')
pose_model = config_dict['pose']['pose_model']
mode = config_dict['pose']['mode'] # lightweight, balanced, performance
@ -340,6 +341,7 @@ def rtm_estimator(config_dict):
save_video = True if 'to_video' in config_dict['pose']['save_video'] else False
save_images = True if 'to_images' in config_dict['pose']['save_video'] else False
display_detection = config_dict['pose']['display_detection']
overwrite_pose = config_dict['pose']['overwrite_pose']
det_frequency = config_dict['pose']['det_frequency']
tracking = config_dict['pose']['tracking']
@ -405,20 +407,31 @@ def rtm_estimator(config_dict):
tracking=tracking,
to_openpose=False)
logging.info('\nEstimating pose...')
video_files = glob.glob(os.path.join(video_dir, '*'+vid_img_extension))
if not len(video_files) == 0:
# Process video files
logging.info(f'Found video files with extension {vid_img_extension}.')
for video_path in video_files:
pose_tracker.reset()
process_video(video_path, pose_tracker, tracking, output_format, save_video, save_images, display_detection, frame_range)
else:
# Process image folders
logging.info(f'Found image folders with extension {vid_img_extension}.')
image_folders = [f for f in os.listdir(video_dir) if os.path.isdir(os.path.join(video_dir, f))]
for image_folder in image_folders:
pose_tracker.reset()
image_folder_path = os.path.join(video_dir, image_folder)
process_images(image_folder_path, vid_img_extension, pose_tracker, tracking, output_format, frame_rate, save_video, save_images, display_detection, frame_range)
logging.info('\nEstimating pose...')
try:
pose_listdirs_names = next(os.walk(pose_dir))[1]
os.listdir(os.path.join(pose_dir, pose_listdirs_names[0]))[0]
if not overwrite_pose:
logging.info('Skipping pose estimation as it has already been done. Set overwrite_pose to true in Config.toml if you want to run it again.')
else:
logging.info('Overwriting previous pose estimation. Set overwrite_pose to false in Config.toml if you want to keep the previous results.')
raise
except:
video_files = glob.glob(os.path.join(video_dir, '*'+vid_img_extension))
if not len(video_files) == 0:
# Process video files
logging.info(f'Found video files with extension {vid_img_extension}.')
for video_path in video_files:
pose_tracker.reset()
process_video(video_path, pose_tracker, tracking, output_format, save_video, save_images, display_detection, frame_range)
else:
# Process image folders
logging.info(f'Found image folders with extension {vid_img_extension}.')
image_folders = [f for f in os.listdir(video_dir) if os.path.isdir(os.path.join(video_dir, f))]
for image_folder in image_folders:
pose_tracker.reset()
image_folder_path = os.path.join(video_dir, image_folder)
process_images(image_folder_path, vid_img_extension, pose_tracker, tracking, output_format, frame_rate, save_video, save_images, display_detection, frame_range)