diff --git a/Pose2Sim/poseEstimation.py b/Pose2Sim/poseEstimation.py index e59b057..2d83964 100644 --- a/Pose2Sim/poseEstimation.py +++ b/Pose2Sim/poseEstimation.py @@ -99,6 +99,34 @@ def save_to_openpose(json_file_path, keypoints, scores): json.dump(json_output, json_file) +def sort_people_rtmlib(pose_tracker, keypoints, scores): + ''' + Associate persons across frames (RTMLib method) + + INPUTS: + - pose_tracker: PoseTracker. The initialized RTMLib pose tracker object + - keypoints: array of shape K, L, M with K the number of detected persons, + L the number of detected keypoints, M their 2D coordinates + - scores: array of shape K, L with K the number of detected persons, + L the confidence of detected keypoints + + OUTPUT: + - sorted_keypoints: array with reordered persons + - sorted_scores: array with reordered scores + ''' + + try: + desired_size = max(pose_tracker.track_ids_last_frame)+1 + sorted_keypoints = np.full((desired_size, keypoints.shape[1], 2), np.nan) + sorted_keypoints[pose_tracker.track_ids_last_frame] = keypoints[:len(pose_tracker.track_ids_last_frame), :, :] + sorted_scores = np.full((desired_size, scores.shape[1]), np.nan) + sorted_scores[pose_tracker.track_ids_last_frame] = scores[:len(pose_tracker.track_ids_last_frame), :] + except: + sorted_keypoints, sorted_scores = keypoints, scores + + return sorted_keypoints, sorted_scores + + def process_video(video_path, pose_tracker, tracking, output_format, save_video, save_images, display_detection, frame_range): ''' Estimate pose from a video file @@ -160,14 +188,7 @@ def process_video(video_path, pose_tracker, tracking, output_format, save_video, # Reorder keypoints, scores if tracking: - max_id = max(pose_tracker.track_ids_last_frame) - num_frames, num_points, num_coordinates = keypoints.shape - keypoints_filled = np.zeros((max_id+1, num_points, num_coordinates)) - scores_filled = np.zeros((max_id+1, num_points)) - keypoints_filled[pose_tracker.track_ids_last_frame] = keypoints - scores_filled[pose_tracker.track_ids_last_frame] = scores - keypoints = keypoints_filled - scores = scores_filled + keypoints, scores = sort_people_rtmlib(pose_tracker, keypoints, scores) # Save to json if 'openpose' in output_format: