fixed RTMLib tracking

This commit is contained in:
davidpagnon 2024-09-17 00:47:04 +02:00
parent 56d2565f37
commit 7248c8a582

View File

@ -99,6 +99,34 @@ def save_to_openpose(json_file_path, keypoints, scores):
json.dump(json_output, json_file) json.dump(json_output, json_file)
def sort_people_rtmlib(pose_tracker, keypoints, scores):
'''
Associate persons across frames (RTMLib method)
INPUTS:
- pose_tracker: PoseTracker. The initialized RTMLib pose tracker object
- keypoints: array of shape K, L, M with K the number of detected persons,
L the number of detected keypoints, M their 2D coordinates
- scores: array of shape K, L with K the number of detected persons,
L the confidence of detected keypoints
OUTPUT:
- sorted_keypoints: array with reordered persons
- sorted_scores: array with reordered scores
'''
try:
desired_size = max(pose_tracker.track_ids_last_frame)+1
sorted_keypoints = np.full((desired_size, keypoints.shape[1], 2), np.nan)
sorted_keypoints[pose_tracker.track_ids_last_frame] = keypoints[:len(pose_tracker.track_ids_last_frame), :, :]
sorted_scores = np.full((desired_size, scores.shape[1]), np.nan)
sorted_scores[pose_tracker.track_ids_last_frame] = scores[:len(pose_tracker.track_ids_last_frame), :]
except:
sorted_keypoints, sorted_scores = keypoints, scores
return sorted_keypoints, sorted_scores
def process_video(video_path, pose_tracker, tracking, output_format, save_video, save_images, display_detection, frame_range): def process_video(video_path, pose_tracker, tracking, output_format, save_video, save_images, display_detection, frame_range):
''' '''
Estimate pose from a video file Estimate pose from a video file
@ -160,14 +188,7 @@ def process_video(video_path, pose_tracker, tracking, output_format, save_video,
# Reorder keypoints, scores # Reorder keypoints, scores
if tracking: if tracking:
max_id = max(pose_tracker.track_ids_last_frame) keypoints, scores = sort_people_rtmlib(pose_tracker, keypoints, scores)
num_frames, num_points, num_coordinates = keypoints.shape
keypoints_filled = np.zeros((max_id+1, num_points, num_coordinates))
scores_filled = np.zeros((max_id+1, num_points))
keypoints_filled[pose_tracker.track_ids_last_frame] = keypoints
scores_filled[pose_tracker.track_ids_last_frame] = scores
keypoints = keypoints_filled
scores = scores_filled
# Save to json # Save to json
if 'openpose' in output_format: if 'openpose' in output_format: