removed tracking option in pose estimation + fixed tests so that synchonization is not done in multiperson Demo + fixed multitab plots crashing

This commit is contained in:
davidpagnon 2024-09-17 23:35:40 +02:00
parent 6fd237ecc9
commit 07afe0b0fb
11 changed files with 26 additions and 45 deletions

View File

@ -152,7 +152,7 @@ make_c3d = true # save triangulated data in c3d format in addition to trc
[filtering]
type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
display_figures = true # true or false (lowercase)
make_c3d = true # also save triangulated data in c3d format
[filtering.butterworth]

View File

@ -152,7 +152,7 @@
# [filtering]
# type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
# display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
# display_figures = true # true or false (lowercase)
# make_c3d = true # also save triangulated data in c3d format
# [filtering.butterworth]

View File

@ -152,7 +152,7 @@ keypoints_to_consider = 'all' # 'all' if all points should be considered, for ex
# [filtering]
# type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
# display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
# display_figures = true # true or false (lowercase)
# make_c3d = false # also save triangulated data in c3d format
# [filtering.butterworth]

View File

@ -152,7 +152,7 @@ make_c3d = false # save triangulated data in c3d format in addition to trc
[filtering]
type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
display_figures = true # true or false (lowercase)
make_c3d = false # also save triangulated data in c3d format
[filtering.butterworth]

View File

@ -152,7 +152,7 @@ make_c3d = true # save triangulated data in c3d format in addition to trc
[filtering]
type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
display_figures = true # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
display_figures = true # true or false (lowercase)
make_c3d = true # also save triangulated data in c3d format
[filtering.butterworth]

View File

@ -129,17 +129,19 @@ class TestWorkflow(unittest.TestCase):
config_dict.get("synchronization").update({"display_sync_plots":False})
config_dict['filtering']['display_figures'] = False
# Step by step
Pose2Sim.calibration(config_dict)
Pose2Sim.poseEstimation(config_dict)
Pose2Sim.synchronization(config_dict)
# Pose2Sim.synchronization(config_dict) # No synchronization for multi-person for now
Pose2Sim.personAssociation(config_dict)
Pose2Sim.triangulation(config_dict)
Pose2Sim.filtering(config_dict)
Pose2Sim.markerAugmentation(config_dict)
# Pose2Sim.kinematics(config_dict)
# Run all
config_dict.get("pose").update({"overwrite_pose":False})
Pose2Sim.runAll(config_dict)
Pose2Sim.runAll(config_dict, do_synchronization=False)
####################
@ -149,7 +151,7 @@ class TestWorkflow(unittest.TestCase):
project_dir = '../Demo_Batch'
os.chdir(project_dir)
Pose2Sim.runAll()
Pose2Sim.runAll(do_synchronization=False)
if __name__ == '__main__':

View File

@ -8,6 +8,10 @@
##################################################
Determine gait on and off from a TRC file of point coordinates.
N.B.: Could implement the methods listed there in the future.
Please feel free to make a pull-request or keep me informed if you do so!
Three available methods, each of them with their own pros and cons:
- "forward_coordinates":

View File

@ -22,9 +22,9 @@ import sys
import matplotlib as mpl
mpl.use('qt5agg')
mpl.rc('figure', max_open_warning=0)
from PyQt5.QtWidgets import QMainWindow, QApplication, QWidget, QTabWidget, QVBoxLayout
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
from PyQt5.QtWidgets import QMainWindow, QApplication, QWidget, QTabWidget, QVBoxLayout
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="c3d")
@ -482,7 +482,8 @@ class plotWindow():
def __init__(self, parent=None):
self.app = QApplication(sys.argv)
self.MainWindow = QMainWindow()
if not self.app:
self.app = QApplication(sys.argv)
self.MainWindow.__init__()
self.MainWindow.setWindowTitle("Multitabs figure")
self.canvases = []

View File

@ -319,7 +319,7 @@ def median_filter_1d(config_dict, frame_rate, col):
return col_filtered
def display_figures_fun(Q_unfilt, Q_filt, time_col, keypoints_names):
def display_figures_fun(Q_unfilt, Q_filt, time_col, keypoints_names, person_id=0):
'''
Displays filtered and unfiltered data for comparison
@ -334,6 +334,7 @@ def display_figures_fun(Q_unfilt, Q_filt, time_col, keypoints_names):
'''
pw = plotWindow()
pw.MainWindow.setWindowTitle('Person '+ str(person_id) + ' coordinates')
for id, keypoint in enumerate(keypoints_names):
f = plt.figure()
@ -472,7 +473,7 @@ def filter_all(config_dict):
trc_f_out = [f'{os.path.basename(t).split(".")[0]}_filt_{filter_type}.trc' for t in trc_path_in]
trc_path_out = [os.path.join(pose3d_dir, t) for t in trc_f_out]
for t_in, t_out in zip(trc_path_in, trc_path_out):
for person_id, t_in, t_out in enumerate(zip(trc_path_in, trc_path_out)):
# Read trc header
with open(t_in, 'r') as trc_file:
header = [next(trc_file) for line in range(5)]
@ -489,7 +490,7 @@ def filter_all(config_dict):
if display_figures:
# Retrieve keypoints
keypoints_names = pd.read_csv(t_in, sep="\t", skiprows=3, nrows=0).columns[2::3].to_numpy()
display_figures_fun(Q_coord, Q_filt, time_col, keypoints_names)
display_figures_fun(Q_coord, Q_filt, time_col, keypoints_names, person_id)
# Reconstruct trc file with filtered coordinates
with open(t_out, 'w') as trc_o:

View File

@ -127,14 +127,13 @@ def sort_people_rtmlib(pose_tracker, keypoints, scores):
return sorted_keypoints, sorted_scores
def process_video(video_path, pose_tracker, tracking, output_format, save_video, save_images, display_detection, frame_range):
def process_video(video_path, pose_tracker, output_format, save_video, save_images, display_detection, frame_range):
'''
Estimate pose from a video file
INPUTS:
- video_path: str. Path to the input video file
- pose_tracker: PoseTracker. Initialized pose tracker object from RTMLib
- tracking: bool. Whether to give consistent person ID across frames
- output_format: str. Output format for the pose estimation results ('openpose', 'mmpose', 'deeplabcut')
- save_video: bool. Whether to save the output video
- save_images: bool. Whether to save the output images
@ -186,10 +185,6 @@ def process_video(video_path, pose_tracker, tracking, output_format, save_video,
# Perform pose estimation on the frame
keypoints, scores = pose_tracker(frame)
# Reorder keypoints, scores
if tracking:
keypoints, scores = sort_people_rtmlib(pose_tracker, keypoints, scores)
# Save to json
if 'openpose' in output_format:
json_file_path = os.path.join(json_output_dir, f'{video_name_wo_ext}_{frame_idx:06d}.json')
@ -225,7 +220,7 @@ def process_video(video_path, pose_tracker, tracking, output_format, save_video,
cv2.destroyAllWindows()
def process_images(image_folder_path, vid_img_extension, pose_tracker, tracking, output_format, fps, save_video, save_images, display_detection, frame_range):
def process_images(image_folder_path, vid_img_extension, pose_tracker, output_format, fps, save_video, save_images, display_detection, frame_range):
'''
Estimate pose estimation from a folder of images
@ -233,7 +228,6 @@ def process_images(image_folder_path, vid_img_extension, pose_tracker, tracking,
- image_folder_path: str. Path to the input image folder
- vid_img_extension: str. Extension of the image files
- pose_tracker: PoseTracker. Initialized pose tracker object from RTMLib
- tracking: bool. Whether to give consistent person ID across frames
- output_format: str. Output format for the pose estimation results ('openpose', 'mmpose', 'deeplabcut')
- save_video: bool. Whether to save the output video
- save_images: bool. Whether to save the output images
@ -275,17 +269,6 @@ def process_images(image_folder_path, vid_img_extension, pose_tracker, tracking,
# Perform pose estimation on the image
keypoints, scores = pose_tracker(frame)
# Reorder keypoints, scores
if tracking:
max_id = max(pose_tracker.track_ids_last_frame)
num_frames, num_points, num_coordinates = keypoints.shape
keypoints_filled = np.zeros((max_id+1, num_points, num_coordinates))
scores_filled = np.zeros((max_id+1, num_points))
keypoints_filled[pose_tracker.track_ids_last_frame] = keypoints
scores_filled[pose_tracker.track_ids_last_frame] = scores
keypoints = keypoints_filled
scores = scores_filled
# Extract frame number from the filename
if 'openpose' in output_format:
@ -361,9 +344,7 @@ def rtm_estimator(config_dict):
save_images = True if 'to_images' in config_dict['pose']['save_video'] else False
display_detection = config_dict['pose']['display_detection']
overwrite_pose = config_dict['pose']['overwrite_pose']
det_frequency = config_dict['pose']['det_frequency']
tracking = config_dict['pose']['tracking']
# Determine frame rate
video_files = glob.glob(os.path.join(video_dir, '*'+vid_img_extension))
@ -407,9 +388,6 @@ def rtm_estimator(config_dict):
logging.info(f'Inference run on every single frame.')
else:
raise ValueError(f"Invalid det_frequency: {det_frequency}. Must be an integer greater or equal to 1.")
if tracking:
logging.info(f'Pose estimation will attempt to give consistent person IDs across frames.\n')
# Select the appropriate model based on the model_type
if pose_model.upper() == 'HALPE_26':
@ -433,7 +411,7 @@ def rtm_estimator(config_dict):
mode=mode,
backend=backend,
device=device,
tracking=tracking,
tracking=False,
to_openpose=False)
@ -454,7 +432,7 @@ def rtm_estimator(config_dict):
logging.info(f'Found video files with extension {vid_img_extension}.')
for video_path in video_files:
pose_tracker.reset()
process_video(video_path, pose_tracker, tracking, output_format, save_video, save_images, display_detection, frame_range)
process_video(video_path, pose_tracker, output_format, save_video, save_images, display_detection, frame_range)
else:
# Process image folders
@ -463,4 +441,4 @@ def rtm_estimator(config_dict):
for image_folder in image_folders:
pose_tracker.reset()
image_folder_path = os.path.join(video_dir, image_folder)
process_images(image_folder_path, vid_img_extension, pose_tracker, tracking, output_format, frame_rate, save_video, save_images, display_detection, frame_range)
process_images(image_folder_path, vid_img_extension, pose_tracker, output_format, frame_rate, save_video, save_images, display_detection, frame_range)

View File

@ -177,11 +177,6 @@ def time_lagged_cross_corr(camx, camy, lag_range, show=True, ref_cam_id=0, cam_i
if isinstance(lag_range, int):
lag_range = [-lag_range, lag_range]
import hashlib
print(repr(list(camx)), repr(list(camy)))
hashlib.md5(pd.util.hash_pandas_object(camx).values).hexdigest()
hashlib.md5(pd.util.hash_pandas_object(camy).values).hexdigest()
pearson_r = [camx.corr(camy.shift(lag)) for lag in range(lag_range[0], lag_range[1])]
offset = int(np.floor(len(pearson_r)/2)-np.argmax(pearson_r))
if not np.isnan(pearson_r).all():