removed tracking option in pose estimation + fixed tests so that synchonization is not done in multiperson Demo + fixed multitab plots crashing
This commit is contained in:
parent
6fd237ecc9
commit
07afe0b0fb
@ -152,7 +152,7 @@ make_c3d = true # save triangulated data in c3d format in addition to trc
|
|||||||
|
|
||||||
[filtering]
|
[filtering]
|
||||||
type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
||||||
display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
|
display_figures = true # true or false (lowercase)
|
||||||
make_c3d = true # also save triangulated data in c3d format
|
make_c3d = true # also save triangulated data in c3d format
|
||||||
|
|
||||||
[filtering.butterworth]
|
[filtering.butterworth]
|
||||||
|
@ -152,7 +152,7 @@
|
|||||||
|
|
||||||
# [filtering]
|
# [filtering]
|
||||||
# type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
# type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
||||||
# display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
|
# display_figures = true # true or false (lowercase)
|
||||||
# make_c3d = true # also save triangulated data in c3d format
|
# make_c3d = true # also save triangulated data in c3d format
|
||||||
|
|
||||||
# [filtering.butterworth]
|
# [filtering.butterworth]
|
||||||
|
@ -152,7 +152,7 @@ keypoints_to_consider = 'all' # 'all' if all points should be considered, for ex
|
|||||||
|
|
||||||
# [filtering]
|
# [filtering]
|
||||||
# type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
# type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
||||||
# display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
|
# display_figures = true # true or false (lowercase)
|
||||||
# make_c3d = false # also save triangulated data in c3d format
|
# make_c3d = false # also save triangulated data in c3d format
|
||||||
|
|
||||||
# [filtering.butterworth]
|
# [filtering.butterworth]
|
||||||
|
@ -152,7 +152,7 @@ make_c3d = false # save triangulated data in c3d format in addition to trc
|
|||||||
|
|
||||||
[filtering]
|
[filtering]
|
||||||
type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
||||||
display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
|
display_figures = true # true or false (lowercase)
|
||||||
make_c3d = false # also save triangulated data in c3d format
|
make_c3d = false # also save triangulated data in c3d format
|
||||||
|
|
||||||
[filtering.butterworth]
|
[filtering.butterworth]
|
||||||
|
@ -152,7 +152,7 @@ make_c3d = true # save triangulated data in c3d format in addition to trc
|
|||||||
|
|
||||||
[filtering]
|
[filtering]
|
||||||
type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
||||||
display_figures = true # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
|
display_figures = true # true or false (lowercase)
|
||||||
make_c3d = true # also save triangulated data in c3d format
|
make_c3d = true # also save triangulated data in c3d format
|
||||||
|
|
||||||
[filtering.butterworth]
|
[filtering.butterworth]
|
||||||
|
@ -129,17 +129,19 @@ class TestWorkflow(unittest.TestCase):
|
|||||||
config_dict.get("synchronization").update({"display_sync_plots":False})
|
config_dict.get("synchronization").update({"display_sync_plots":False})
|
||||||
config_dict['filtering']['display_figures'] = False
|
config_dict['filtering']['display_figures'] = False
|
||||||
|
|
||||||
|
# Step by step
|
||||||
Pose2Sim.calibration(config_dict)
|
Pose2Sim.calibration(config_dict)
|
||||||
Pose2Sim.poseEstimation(config_dict)
|
Pose2Sim.poseEstimation(config_dict)
|
||||||
Pose2Sim.synchronization(config_dict)
|
# Pose2Sim.synchronization(config_dict) # No synchronization for multi-person for now
|
||||||
Pose2Sim.personAssociation(config_dict)
|
Pose2Sim.personAssociation(config_dict)
|
||||||
Pose2Sim.triangulation(config_dict)
|
Pose2Sim.triangulation(config_dict)
|
||||||
Pose2Sim.filtering(config_dict)
|
Pose2Sim.filtering(config_dict)
|
||||||
Pose2Sim.markerAugmentation(config_dict)
|
Pose2Sim.markerAugmentation(config_dict)
|
||||||
# Pose2Sim.kinematics(config_dict)
|
# Pose2Sim.kinematics(config_dict)
|
||||||
|
|
||||||
|
# Run all
|
||||||
config_dict.get("pose").update({"overwrite_pose":False})
|
config_dict.get("pose").update({"overwrite_pose":False})
|
||||||
Pose2Sim.runAll(config_dict)
|
Pose2Sim.runAll(config_dict, do_synchronization=False)
|
||||||
|
|
||||||
|
|
||||||
####################
|
####################
|
||||||
@ -149,7 +151,7 @@ class TestWorkflow(unittest.TestCase):
|
|||||||
project_dir = '../Demo_Batch'
|
project_dir = '../Demo_Batch'
|
||||||
os.chdir(project_dir)
|
os.chdir(project_dir)
|
||||||
|
|
||||||
Pose2Sim.runAll()
|
Pose2Sim.runAll(do_synchronization=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -8,6 +8,10 @@
|
|||||||
##################################################
|
##################################################
|
||||||
|
|
||||||
Determine gait on and off from a TRC file of point coordinates.
|
Determine gait on and off from a TRC file of point coordinates.
|
||||||
|
|
||||||
|
N.B.: Could implement the methods listed there in the future.
|
||||||
|
Please feel free to make a pull-request or keep me informed if you do so!
|
||||||
|
|
||||||
Three available methods, each of them with their own pros and cons:
|
Three available methods, each of them with their own pros and cons:
|
||||||
|
|
||||||
- "forward_coordinates":
|
- "forward_coordinates":
|
||||||
|
@ -22,9 +22,9 @@ import sys
|
|||||||
import matplotlib as mpl
|
import matplotlib as mpl
|
||||||
mpl.use('qt5agg')
|
mpl.use('qt5agg')
|
||||||
mpl.rc('figure', max_open_warning=0)
|
mpl.rc('figure', max_open_warning=0)
|
||||||
|
from PyQt5.QtWidgets import QMainWindow, QApplication, QWidget, QTabWidget, QVBoxLayout
|
||||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
||||||
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
|
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
|
||||||
from PyQt5.QtWidgets import QMainWindow, QApplication, QWidget, QTabWidget, QVBoxLayout
|
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="c3d")
|
warnings.filterwarnings("ignore", category=UserWarning, module="c3d")
|
||||||
|
|
||||||
@ -482,7 +482,8 @@ class plotWindow():
|
|||||||
|
|
||||||
def __init__(self, parent=None):
|
def __init__(self, parent=None):
|
||||||
self.app = QApplication(sys.argv)
|
self.app = QApplication(sys.argv)
|
||||||
self.MainWindow = QMainWindow()
|
if not self.app:
|
||||||
|
self.app = QApplication(sys.argv)
|
||||||
self.MainWindow.__init__()
|
self.MainWindow.__init__()
|
||||||
self.MainWindow.setWindowTitle("Multitabs figure")
|
self.MainWindow.setWindowTitle("Multitabs figure")
|
||||||
self.canvases = []
|
self.canvases = []
|
||||||
|
@ -319,7 +319,7 @@ def median_filter_1d(config_dict, frame_rate, col):
|
|||||||
return col_filtered
|
return col_filtered
|
||||||
|
|
||||||
|
|
||||||
def display_figures_fun(Q_unfilt, Q_filt, time_col, keypoints_names):
|
def display_figures_fun(Q_unfilt, Q_filt, time_col, keypoints_names, person_id=0):
|
||||||
'''
|
'''
|
||||||
Displays filtered and unfiltered data for comparison
|
Displays filtered and unfiltered data for comparison
|
||||||
|
|
||||||
@ -334,6 +334,7 @@ def display_figures_fun(Q_unfilt, Q_filt, time_col, keypoints_names):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
pw = plotWindow()
|
pw = plotWindow()
|
||||||
|
pw.MainWindow.setWindowTitle('Person '+ str(person_id) + ' coordinates')
|
||||||
for id, keypoint in enumerate(keypoints_names):
|
for id, keypoint in enumerate(keypoints_names):
|
||||||
f = plt.figure()
|
f = plt.figure()
|
||||||
|
|
||||||
@ -472,7 +473,7 @@ def filter_all(config_dict):
|
|||||||
trc_f_out = [f'{os.path.basename(t).split(".")[0]}_filt_{filter_type}.trc' for t in trc_path_in]
|
trc_f_out = [f'{os.path.basename(t).split(".")[0]}_filt_{filter_type}.trc' for t in trc_path_in]
|
||||||
trc_path_out = [os.path.join(pose3d_dir, t) for t in trc_f_out]
|
trc_path_out = [os.path.join(pose3d_dir, t) for t in trc_f_out]
|
||||||
|
|
||||||
for t_in, t_out in zip(trc_path_in, trc_path_out):
|
for person_id, t_in, t_out in enumerate(zip(trc_path_in, trc_path_out)):
|
||||||
# Read trc header
|
# Read trc header
|
||||||
with open(t_in, 'r') as trc_file:
|
with open(t_in, 'r') as trc_file:
|
||||||
header = [next(trc_file) for line in range(5)]
|
header = [next(trc_file) for line in range(5)]
|
||||||
@ -489,7 +490,7 @@ def filter_all(config_dict):
|
|||||||
if display_figures:
|
if display_figures:
|
||||||
# Retrieve keypoints
|
# Retrieve keypoints
|
||||||
keypoints_names = pd.read_csv(t_in, sep="\t", skiprows=3, nrows=0).columns[2::3].to_numpy()
|
keypoints_names = pd.read_csv(t_in, sep="\t", skiprows=3, nrows=0).columns[2::3].to_numpy()
|
||||||
display_figures_fun(Q_coord, Q_filt, time_col, keypoints_names)
|
display_figures_fun(Q_coord, Q_filt, time_col, keypoints_names, person_id)
|
||||||
|
|
||||||
# Reconstruct trc file with filtered coordinates
|
# Reconstruct trc file with filtered coordinates
|
||||||
with open(t_out, 'w') as trc_o:
|
with open(t_out, 'w') as trc_o:
|
||||||
|
@ -127,14 +127,13 @@ def sort_people_rtmlib(pose_tracker, keypoints, scores):
|
|||||||
return sorted_keypoints, sorted_scores
|
return sorted_keypoints, sorted_scores
|
||||||
|
|
||||||
|
|
||||||
def process_video(video_path, pose_tracker, tracking, output_format, save_video, save_images, display_detection, frame_range):
|
def process_video(video_path, pose_tracker, output_format, save_video, save_images, display_detection, frame_range):
|
||||||
'''
|
'''
|
||||||
Estimate pose from a video file
|
Estimate pose from a video file
|
||||||
|
|
||||||
INPUTS:
|
INPUTS:
|
||||||
- video_path: str. Path to the input video file
|
- video_path: str. Path to the input video file
|
||||||
- pose_tracker: PoseTracker. Initialized pose tracker object from RTMLib
|
- pose_tracker: PoseTracker. Initialized pose tracker object from RTMLib
|
||||||
- tracking: bool. Whether to give consistent person ID across frames
|
|
||||||
- output_format: str. Output format for the pose estimation results ('openpose', 'mmpose', 'deeplabcut')
|
- output_format: str. Output format for the pose estimation results ('openpose', 'mmpose', 'deeplabcut')
|
||||||
- save_video: bool. Whether to save the output video
|
- save_video: bool. Whether to save the output video
|
||||||
- save_images: bool. Whether to save the output images
|
- save_images: bool. Whether to save the output images
|
||||||
@ -186,10 +185,6 @@ def process_video(video_path, pose_tracker, tracking, output_format, save_video,
|
|||||||
# Perform pose estimation on the frame
|
# Perform pose estimation on the frame
|
||||||
keypoints, scores = pose_tracker(frame)
|
keypoints, scores = pose_tracker(frame)
|
||||||
|
|
||||||
# Reorder keypoints, scores
|
|
||||||
if tracking:
|
|
||||||
keypoints, scores = sort_people_rtmlib(pose_tracker, keypoints, scores)
|
|
||||||
|
|
||||||
# Save to json
|
# Save to json
|
||||||
if 'openpose' in output_format:
|
if 'openpose' in output_format:
|
||||||
json_file_path = os.path.join(json_output_dir, f'{video_name_wo_ext}_{frame_idx:06d}.json')
|
json_file_path = os.path.join(json_output_dir, f'{video_name_wo_ext}_{frame_idx:06d}.json')
|
||||||
@ -225,7 +220,7 @@ def process_video(video_path, pose_tracker, tracking, output_format, save_video,
|
|||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
def process_images(image_folder_path, vid_img_extension, pose_tracker, tracking, output_format, fps, save_video, save_images, display_detection, frame_range):
|
def process_images(image_folder_path, vid_img_extension, pose_tracker, output_format, fps, save_video, save_images, display_detection, frame_range):
|
||||||
'''
|
'''
|
||||||
Estimate pose estimation from a folder of images
|
Estimate pose estimation from a folder of images
|
||||||
|
|
||||||
@ -233,7 +228,6 @@ def process_images(image_folder_path, vid_img_extension, pose_tracker, tracking,
|
|||||||
- image_folder_path: str. Path to the input image folder
|
- image_folder_path: str. Path to the input image folder
|
||||||
- vid_img_extension: str. Extension of the image files
|
- vid_img_extension: str. Extension of the image files
|
||||||
- pose_tracker: PoseTracker. Initialized pose tracker object from RTMLib
|
- pose_tracker: PoseTracker. Initialized pose tracker object from RTMLib
|
||||||
- tracking: bool. Whether to give consistent person ID across frames
|
|
||||||
- output_format: str. Output format for the pose estimation results ('openpose', 'mmpose', 'deeplabcut')
|
- output_format: str. Output format for the pose estimation results ('openpose', 'mmpose', 'deeplabcut')
|
||||||
- save_video: bool. Whether to save the output video
|
- save_video: bool. Whether to save the output video
|
||||||
- save_images: bool. Whether to save the output images
|
- save_images: bool. Whether to save the output images
|
||||||
@ -276,17 +270,6 @@ def process_images(image_folder_path, vid_img_extension, pose_tracker, tracking,
|
|||||||
# Perform pose estimation on the image
|
# Perform pose estimation on the image
|
||||||
keypoints, scores = pose_tracker(frame)
|
keypoints, scores = pose_tracker(frame)
|
||||||
|
|
||||||
# Reorder keypoints, scores
|
|
||||||
if tracking:
|
|
||||||
max_id = max(pose_tracker.track_ids_last_frame)
|
|
||||||
num_frames, num_points, num_coordinates = keypoints.shape
|
|
||||||
keypoints_filled = np.zeros((max_id+1, num_points, num_coordinates))
|
|
||||||
scores_filled = np.zeros((max_id+1, num_points))
|
|
||||||
keypoints_filled[pose_tracker.track_ids_last_frame] = keypoints
|
|
||||||
scores_filled[pose_tracker.track_ids_last_frame] = scores
|
|
||||||
keypoints = keypoints_filled
|
|
||||||
scores = scores_filled
|
|
||||||
|
|
||||||
# Extract frame number from the filename
|
# Extract frame number from the filename
|
||||||
if 'openpose' in output_format:
|
if 'openpose' in output_format:
|
||||||
json_file_path = os.path.join(json_output_dir, f"{os.path.splitext(os.path.basename(image_file))[0]}_{frame_idx:06d}.json")
|
json_file_path = os.path.join(json_output_dir, f"{os.path.splitext(os.path.basename(image_file))[0]}_{frame_idx:06d}.json")
|
||||||
@ -361,9 +344,7 @@ def rtm_estimator(config_dict):
|
|||||||
save_images = True if 'to_images' in config_dict['pose']['save_video'] else False
|
save_images = True if 'to_images' in config_dict['pose']['save_video'] else False
|
||||||
display_detection = config_dict['pose']['display_detection']
|
display_detection = config_dict['pose']['display_detection']
|
||||||
overwrite_pose = config_dict['pose']['overwrite_pose']
|
overwrite_pose = config_dict['pose']['overwrite_pose']
|
||||||
|
|
||||||
det_frequency = config_dict['pose']['det_frequency']
|
det_frequency = config_dict['pose']['det_frequency']
|
||||||
tracking = config_dict['pose']['tracking']
|
|
||||||
|
|
||||||
# Determine frame rate
|
# Determine frame rate
|
||||||
video_files = glob.glob(os.path.join(video_dir, '*'+vid_img_extension))
|
video_files = glob.glob(os.path.join(video_dir, '*'+vid_img_extension))
|
||||||
@ -408,9 +389,6 @@ def rtm_estimator(config_dict):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid det_frequency: {det_frequency}. Must be an integer greater or equal to 1.")
|
raise ValueError(f"Invalid det_frequency: {det_frequency}. Must be an integer greater or equal to 1.")
|
||||||
|
|
||||||
if tracking:
|
|
||||||
logging.info(f'Pose estimation will attempt to give consistent person IDs across frames.\n')
|
|
||||||
|
|
||||||
# Select the appropriate model based on the model_type
|
# Select the appropriate model based on the model_type
|
||||||
if pose_model.upper() == 'HALPE_26':
|
if pose_model.upper() == 'HALPE_26':
|
||||||
ModelClass = BodyWithFeet
|
ModelClass = BodyWithFeet
|
||||||
@ -433,7 +411,7 @@ def rtm_estimator(config_dict):
|
|||||||
mode=mode,
|
mode=mode,
|
||||||
backend=backend,
|
backend=backend,
|
||||||
device=device,
|
device=device,
|
||||||
tracking=tracking,
|
tracking=False,
|
||||||
to_openpose=False)
|
to_openpose=False)
|
||||||
|
|
||||||
|
|
||||||
@ -454,7 +432,7 @@ def rtm_estimator(config_dict):
|
|||||||
logging.info(f'Found video files with extension {vid_img_extension}.')
|
logging.info(f'Found video files with extension {vid_img_extension}.')
|
||||||
for video_path in video_files:
|
for video_path in video_files:
|
||||||
pose_tracker.reset()
|
pose_tracker.reset()
|
||||||
process_video(video_path, pose_tracker, tracking, output_format, save_video, save_images, display_detection, frame_range)
|
process_video(video_path, pose_tracker, output_format, save_video, save_images, display_detection, frame_range)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Process image folders
|
# Process image folders
|
||||||
@ -463,4 +441,4 @@ def rtm_estimator(config_dict):
|
|||||||
for image_folder in image_folders:
|
for image_folder in image_folders:
|
||||||
pose_tracker.reset()
|
pose_tracker.reset()
|
||||||
image_folder_path = os.path.join(video_dir, image_folder)
|
image_folder_path = os.path.join(video_dir, image_folder)
|
||||||
process_images(image_folder_path, vid_img_extension, pose_tracker, tracking, output_format, frame_rate, save_video, save_images, display_detection, frame_range)
|
process_images(image_folder_path, vid_img_extension, pose_tracker, output_format, frame_rate, save_video, save_images, display_detection, frame_range)
|
||||||
|
@ -177,11 +177,6 @@ def time_lagged_cross_corr(camx, camy, lag_range, show=True, ref_cam_id=0, cam_i
|
|||||||
if isinstance(lag_range, int):
|
if isinstance(lag_range, int):
|
||||||
lag_range = [-lag_range, lag_range]
|
lag_range = [-lag_range, lag_range]
|
||||||
|
|
||||||
import hashlib
|
|
||||||
print(repr(list(camx)), repr(list(camy)))
|
|
||||||
hashlib.md5(pd.util.hash_pandas_object(camx).values).hexdigest()
|
|
||||||
hashlib.md5(pd.util.hash_pandas_object(camy).values).hexdigest()
|
|
||||||
|
|
||||||
pearson_r = [camx.corr(camy.shift(lag)) for lag in range(lag_range[0], lag_range[1])]
|
pearson_r = [camx.corr(camy.shift(lag)) for lag in range(lag_range[0], lag_range[1])]
|
||||||
offset = int(np.floor(len(pearson_r)/2)-np.argmax(pearson_r))
|
offset = int(np.floor(len(pearson_r)/2)-np.argmax(pearson_r))
|
||||||
if not np.isnan(pearson_r).all():
|
if not np.isnan(pearson_r).all():
|
||||||
|
Loading…
Reference in New Issue
Block a user