fixed RTMLib tracking
This commit is contained in:
parent
56d2565f37
commit
7248c8a582
@ -99,6 +99,34 @@ def save_to_openpose(json_file_path, keypoints, scores):
|
||||
json.dump(json_output, json_file)
|
||||
|
||||
|
||||
def sort_people_rtmlib(pose_tracker, keypoints, scores):
|
||||
'''
|
||||
Associate persons across frames (RTMLib method)
|
||||
|
||||
INPUTS:
|
||||
- pose_tracker: PoseTracker. The initialized RTMLib pose tracker object
|
||||
- keypoints: array of shape K, L, M with K the number of detected persons,
|
||||
L the number of detected keypoints, M their 2D coordinates
|
||||
- scores: array of shape K, L with K the number of detected persons,
|
||||
L the confidence of detected keypoints
|
||||
|
||||
OUTPUT:
|
||||
- sorted_keypoints: array with reordered persons
|
||||
- sorted_scores: array with reordered scores
|
||||
'''
|
||||
|
||||
try:
|
||||
desired_size = max(pose_tracker.track_ids_last_frame)+1
|
||||
sorted_keypoints = np.full((desired_size, keypoints.shape[1], 2), np.nan)
|
||||
sorted_keypoints[pose_tracker.track_ids_last_frame] = keypoints[:len(pose_tracker.track_ids_last_frame), :, :]
|
||||
sorted_scores = np.full((desired_size, scores.shape[1]), np.nan)
|
||||
sorted_scores[pose_tracker.track_ids_last_frame] = scores[:len(pose_tracker.track_ids_last_frame), :]
|
||||
except:
|
||||
sorted_keypoints, sorted_scores = keypoints, scores
|
||||
|
||||
return sorted_keypoints, sorted_scores
|
||||
|
||||
|
||||
def process_video(video_path, pose_tracker, tracking, output_format, save_video, save_images, display_detection, frame_range):
|
||||
'''
|
||||
Estimate pose from a video file
|
||||
@ -160,14 +188,7 @@ def process_video(video_path, pose_tracker, tracking, output_format, save_video,
|
||||
|
||||
# Reorder keypoints, scores
|
||||
if tracking:
|
||||
max_id = max(pose_tracker.track_ids_last_frame)
|
||||
num_frames, num_points, num_coordinates = keypoints.shape
|
||||
keypoints_filled = np.zeros((max_id+1, num_points, num_coordinates))
|
||||
scores_filled = np.zeros((max_id+1, num_points))
|
||||
keypoints_filled[pose_tracker.track_ids_last_frame] = keypoints
|
||||
scores_filled[pose_tracker.track_ids_last_frame] = scores
|
||||
keypoints = keypoints_filled
|
||||
scores = scores_filled
|
||||
keypoints, scores = sort_people_rtmlib(pose_tracker, keypoints, scores)
|
||||
|
||||
# Save to json
|
||||
if 'openpose' in output_format:
|
||||
|
Loading…
Reference in New Issue
Block a user