From 07afe0b0fb81f0d543ad6e3ce9714797cfb182c6 Mon Sep 17 00:00:00 2001 From: davidpagnon Date: Tue, 17 Sep 2024 23:35:40 +0200 Subject: [PATCH] removed tracking option in pose estimation + fixed tests so that synchonization is not done in multiperson Demo + fixed multitab plots crashing --- Pose2Sim/Demo_Batch/Config.toml | 2 +- Pose2Sim/Demo_Batch/Trial_1/Config.toml | 2 +- Pose2Sim/Demo_Batch/Trial_2/Config.toml | 2 +- Pose2Sim/Demo_MultiPerson/Config.toml | 2 +- Pose2Sim/Demo_SinglePerson/Config.toml | 2 +- Pose2Sim/Utilities/tests.py | 8 ++++--- Pose2Sim/Utilities/trc_gaitevents.py | 4 ++++ Pose2Sim/common.py | 5 ++-- Pose2Sim/filtering.py | 7 +++--- Pose2Sim/poseEstimation.py | 32 ++++--------------------- Pose2Sim/synchronization.py | 5 ---- 11 files changed, 26 insertions(+), 45 deletions(-) diff --git a/Pose2Sim/Demo_Batch/Config.toml b/Pose2Sim/Demo_Batch/Config.toml index 2447a05..4bd0847 100644 --- a/Pose2Sim/Demo_Batch/Config.toml +++ b/Pose2Sim/Demo_Batch/Config.toml @@ -152,7 +152,7 @@ make_c3d = true # save triangulated data in c3d format in addition to trc [filtering] type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed -display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7 +display_figures = true # true or false (lowercase) make_c3d = true # also save triangulated data in c3d format [filtering.butterworth] diff --git a/Pose2Sim/Demo_Batch/Trial_1/Config.toml b/Pose2Sim/Demo_Batch/Trial_1/Config.toml index 7a200f6..0511ce0 100644 --- a/Pose2Sim/Demo_Batch/Trial_1/Config.toml +++ b/Pose2Sim/Demo_Batch/Trial_1/Config.toml @@ -152,7 +152,7 @@ # [filtering] # type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed -# display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7 +# display_figures = true # true or false (lowercase) # make_c3d = true # also save triangulated data in c3d format # [filtering.butterworth] diff --git a/Pose2Sim/Demo_Batch/Trial_2/Config.toml b/Pose2Sim/Demo_Batch/Trial_2/Config.toml index 84b6c82..708943a 100644 --- a/Pose2Sim/Demo_Batch/Trial_2/Config.toml +++ b/Pose2Sim/Demo_Batch/Trial_2/Config.toml @@ -152,7 +152,7 @@ keypoints_to_consider = 'all' # 'all' if all points should be considered, for ex # [filtering] # type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed -# display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7 +# display_figures = true # true or false (lowercase) # make_c3d = false # also save triangulated data in c3d format # [filtering.butterworth] diff --git a/Pose2Sim/Demo_MultiPerson/Config.toml b/Pose2Sim/Demo_MultiPerson/Config.toml index 5c8273d..d63a32a 100644 --- a/Pose2Sim/Demo_MultiPerson/Config.toml +++ b/Pose2Sim/Demo_MultiPerson/Config.toml @@ -152,7 +152,7 @@ make_c3d = false # save triangulated data in c3d format in addition to trc [filtering] type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed -display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7 +display_figures = true # true or false (lowercase) make_c3d = false # also save triangulated data in c3d format [filtering.butterworth] diff --git a/Pose2Sim/Demo_SinglePerson/Config.toml b/Pose2Sim/Demo_SinglePerson/Config.toml index dcd5e53..b0a37b1 100644 --- a/Pose2Sim/Demo_SinglePerson/Config.toml +++ b/Pose2Sim/Demo_SinglePerson/Config.toml @@ -152,7 +152,7 @@ make_c3d = true # save triangulated data in c3d format in addition to trc [filtering] type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed -display_figures = true # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7 +display_figures = true # true or false (lowercase) make_c3d = true # also save triangulated data in c3d format [filtering.butterworth] diff --git a/Pose2Sim/Utilities/tests.py b/Pose2Sim/Utilities/tests.py index 76a8f47..7931195 100644 --- a/Pose2Sim/Utilities/tests.py +++ b/Pose2Sim/Utilities/tests.py @@ -129,17 +129,19 @@ class TestWorkflow(unittest.TestCase): config_dict.get("synchronization").update({"display_sync_plots":False}) config_dict['filtering']['display_figures'] = False + # Step by step Pose2Sim.calibration(config_dict) Pose2Sim.poseEstimation(config_dict) - Pose2Sim.synchronization(config_dict) + # Pose2Sim.synchronization(config_dict) # No synchronization for multi-person for now Pose2Sim.personAssociation(config_dict) Pose2Sim.triangulation(config_dict) Pose2Sim.filtering(config_dict) Pose2Sim.markerAugmentation(config_dict) # Pose2Sim.kinematics(config_dict) + # Run all config_dict.get("pose").update({"overwrite_pose":False}) - Pose2Sim.runAll(config_dict) + Pose2Sim.runAll(config_dict, do_synchronization=False) #################### @@ -149,7 +151,7 @@ class TestWorkflow(unittest.TestCase): project_dir = '../Demo_Batch' os.chdir(project_dir) - Pose2Sim.runAll() + Pose2Sim.runAll(do_synchronization=False) if __name__ == '__main__': diff --git a/Pose2Sim/Utilities/trc_gaitevents.py b/Pose2Sim/Utilities/trc_gaitevents.py index ff94b32..b8f0cfd 100644 --- a/Pose2Sim/Utilities/trc_gaitevents.py +++ b/Pose2Sim/Utilities/trc_gaitevents.py @@ -8,6 +8,10 @@ ################################################## Determine gait on and off from a TRC file of point coordinates. + + N.B.: Could implement the methods listed there in the future. + Please feel free to make a pull-request or keep me informed if you do so! + Three available methods, each of them with their own pros and cons: - "forward_coordinates": diff --git a/Pose2Sim/common.py b/Pose2Sim/common.py index 35a72e5..7fd5c4e 100644 --- a/Pose2Sim/common.py +++ b/Pose2Sim/common.py @@ -22,9 +22,9 @@ import sys import matplotlib as mpl mpl.use('qt5agg') mpl.rc('figure', max_open_warning=0) +from PyQt5.QtWidgets import QMainWindow, QApplication, QWidget, QTabWidget, QVBoxLayout from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar -from PyQt5.QtWidgets import QMainWindow, QApplication, QWidget, QTabWidget, QVBoxLayout import warnings warnings.filterwarnings("ignore", category=UserWarning, module="c3d") @@ -482,7 +482,8 @@ class plotWindow(): def __init__(self, parent=None): self.app = QApplication(sys.argv) - self.MainWindow = QMainWindow() + if not self.app: + self.app = QApplication(sys.argv) self.MainWindow.__init__() self.MainWindow.setWindowTitle("Multitabs figure") self.canvases = [] diff --git a/Pose2Sim/filtering.py b/Pose2Sim/filtering.py index 6215413..17833ee 100644 --- a/Pose2Sim/filtering.py +++ b/Pose2Sim/filtering.py @@ -319,7 +319,7 @@ def median_filter_1d(config_dict, frame_rate, col): return col_filtered -def display_figures_fun(Q_unfilt, Q_filt, time_col, keypoints_names): +def display_figures_fun(Q_unfilt, Q_filt, time_col, keypoints_names, person_id=0): ''' Displays filtered and unfiltered data for comparison @@ -334,6 +334,7 @@ def display_figures_fun(Q_unfilt, Q_filt, time_col, keypoints_names): ''' pw = plotWindow() + pw.MainWindow.setWindowTitle('Person '+ str(person_id) + ' coordinates') for id, keypoint in enumerate(keypoints_names): f = plt.figure() @@ -472,7 +473,7 @@ def filter_all(config_dict): trc_f_out = [f'{os.path.basename(t).split(".")[0]}_filt_{filter_type}.trc' for t in trc_path_in] trc_path_out = [os.path.join(pose3d_dir, t) for t in trc_f_out] - for t_in, t_out in zip(trc_path_in, trc_path_out): + for person_id, t_in, t_out in enumerate(zip(trc_path_in, trc_path_out)): # Read trc header with open(t_in, 'r') as trc_file: header = [next(trc_file) for line in range(5)] @@ -489,7 +490,7 @@ def filter_all(config_dict): if display_figures: # Retrieve keypoints keypoints_names = pd.read_csv(t_in, sep="\t", skiprows=3, nrows=0).columns[2::3].to_numpy() - display_figures_fun(Q_coord, Q_filt, time_col, keypoints_names) + display_figures_fun(Q_coord, Q_filt, time_col, keypoints_names, person_id) # Reconstruct trc file with filtered coordinates with open(t_out, 'w') as trc_o: diff --git a/Pose2Sim/poseEstimation.py b/Pose2Sim/poseEstimation.py index 2d83964..a689567 100644 --- a/Pose2Sim/poseEstimation.py +++ b/Pose2Sim/poseEstimation.py @@ -127,14 +127,13 @@ def sort_people_rtmlib(pose_tracker, keypoints, scores): return sorted_keypoints, sorted_scores -def process_video(video_path, pose_tracker, tracking, output_format, save_video, save_images, display_detection, frame_range): +def process_video(video_path, pose_tracker, output_format, save_video, save_images, display_detection, frame_range): ''' Estimate pose from a video file INPUTS: - video_path: str. Path to the input video file - pose_tracker: PoseTracker. Initialized pose tracker object from RTMLib - - tracking: bool. Whether to give consistent person ID across frames - output_format: str. Output format for the pose estimation results ('openpose', 'mmpose', 'deeplabcut') - save_video: bool. Whether to save the output video - save_images: bool. Whether to save the output images @@ -186,10 +185,6 @@ def process_video(video_path, pose_tracker, tracking, output_format, save_video, # Perform pose estimation on the frame keypoints, scores = pose_tracker(frame) - # Reorder keypoints, scores - if tracking: - keypoints, scores = sort_people_rtmlib(pose_tracker, keypoints, scores) - # Save to json if 'openpose' in output_format: json_file_path = os.path.join(json_output_dir, f'{video_name_wo_ext}_{frame_idx:06d}.json') @@ -225,7 +220,7 @@ def process_video(video_path, pose_tracker, tracking, output_format, save_video, cv2.destroyAllWindows() -def process_images(image_folder_path, vid_img_extension, pose_tracker, tracking, output_format, fps, save_video, save_images, display_detection, frame_range): +def process_images(image_folder_path, vid_img_extension, pose_tracker, output_format, fps, save_video, save_images, display_detection, frame_range): ''' Estimate pose estimation from a folder of images @@ -233,7 +228,6 @@ def process_images(image_folder_path, vid_img_extension, pose_tracker, tracking, - image_folder_path: str. Path to the input image folder - vid_img_extension: str. Extension of the image files - pose_tracker: PoseTracker. Initialized pose tracker object from RTMLib - - tracking: bool. Whether to give consistent person ID across frames - output_format: str. Output format for the pose estimation results ('openpose', 'mmpose', 'deeplabcut') - save_video: bool. Whether to save the output video - save_images: bool. Whether to save the output images @@ -275,17 +269,6 @@ def process_images(image_folder_path, vid_img_extension, pose_tracker, tracking, # Perform pose estimation on the image keypoints, scores = pose_tracker(frame) - - # Reorder keypoints, scores - if tracking: - max_id = max(pose_tracker.track_ids_last_frame) - num_frames, num_points, num_coordinates = keypoints.shape - keypoints_filled = np.zeros((max_id+1, num_points, num_coordinates)) - scores_filled = np.zeros((max_id+1, num_points)) - keypoints_filled[pose_tracker.track_ids_last_frame] = keypoints - scores_filled[pose_tracker.track_ids_last_frame] = scores - keypoints = keypoints_filled - scores = scores_filled # Extract frame number from the filename if 'openpose' in output_format: @@ -361,9 +344,7 @@ def rtm_estimator(config_dict): save_images = True if 'to_images' in config_dict['pose']['save_video'] else False display_detection = config_dict['pose']['display_detection'] overwrite_pose = config_dict['pose']['overwrite_pose'] - det_frequency = config_dict['pose']['det_frequency'] - tracking = config_dict['pose']['tracking'] # Determine frame rate video_files = glob.glob(os.path.join(video_dir, '*'+vid_img_extension)) @@ -407,9 +388,6 @@ def rtm_estimator(config_dict): logging.info(f'Inference run on every single frame.') else: raise ValueError(f"Invalid det_frequency: {det_frequency}. Must be an integer greater or equal to 1.") - - if tracking: - logging.info(f'Pose estimation will attempt to give consistent person IDs across frames.\n') # Select the appropriate model based on the model_type if pose_model.upper() == 'HALPE_26': @@ -433,7 +411,7 @@ def rtm_estimator(config_dict): mode=mode, backend=backend, device=device, - tracking=tracking, + tracking=False, to_openpose=False) @@ -454,7 +432,7 @@ def rtm_estimator(config_dict): logging.info(f'Found video files with extension {vid_img_extension}.') for video_path in video_files: pose_tracker.reset() - process_video(video_path, pose_tracker, tracking, output_format, save_video, save_images, display_detection, frame_range) + process_video(video_path, pose_tracker, output_format, save_video, save_images, display_detection, frame_range) else: # Process image folders @@ -463,4 +441,4 @@ def rtm_estimator(config_dict): for image_folder in image_folders: pose_tracker.reset() image_folder_path = os.path.join(video_dir, image_folder) - process_images(image_folder_path, vid_img_extension, pose_tracker, tracking, output_format, frame_rate, save_video, save_images, display_detection, frame_range) + process_images(image_folder_path, vid_img_extension, pose_tracker, output_format, frame_rate, save_video, save_images, display_detection, frame_range) diff --git a/Pose2Sim/synchronization.py b/Pose2Sim/synchronization.py index 2fbdf22..77e5399 100644 --- a/Pose2Sim/synchronization.py +++ b/Pose2Sim/synchronization.py @@ -177,11 +177,6 @@ def time_lagged_cross_corr(camx, camy, lag_range, show=True, ref_cam_id=0, cam_i if isinstance(lag_range, int): lag_range = [-lag_range, lag_range] - import hashlib - print(repr(list(camx)), repr(list(camy))) - hashlib.md5(pd.util.hash_pandas_object(camx).values).hexdigest() - hashlib.md5(pd.util.hash_pandas_object(camy).values).hexdigest() - pearson_r = [camx.corr(camy.shift(lag)) for lag in range(lag_range[0], lag_range[1])] offset = int(np.floor(len(pearson_r)/2)-np.argmax(pearson_r)) if not np.isnan(pearson_r).all():