removed tracking option in pose estimation + fixed tests so that synchonization is not done in multiperson Demo + fixed multitab plots crashing
This commit is contained in:
parent
6fd237ecc9
commit
07afe0b0fb
@ -152,7 +152,7 @@ make_c3d = true # save triangulated data in c3d format in addition to trc
|
||||
|
||||
[filtering]
|
||||
type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
||||
display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
|
||||
display_figures = true # true or false (lowercase)
|
||||
make_c3d = true # also save triangulated data in c3d format
|
||||
|
||||
[filtering.butterworth]
|
||||
|
@ -152,7 +152,7 @@
|
||||
|
||||
# [filtering]
|
||||
# type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
||||
# display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
|
||||
# display_figures = true # true or false (lowercase)
|
||||
# make_c3d = true # also save triangulated data in c3d format
|
||||
|
||||
# [filtering.butterworth]
|
||||
|
@ -152,7 +152,7 @@ keypoints_to_consider = 'all' # 'all' if all points should be considered, for ex
|
||||
|
||||
# [filtering]
|
||||
# type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
||||
# display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
|
||||
# display_figures = true # true or false (lowercase)
|
||||
# make_c3d = false # also save triangulated data in c3d format
|
||||
|
||||
# [filtering.butterworth]
|
||||
|
@ -152,7 +152,7 @@ make_c3d = false # save triangulated data in c3d format in addition to trc
|
||||
|
||||
[filtering]
|
||||
type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
||||
display_figures = false # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
|
||||
display_figures = true # true or false (lowercase)
|
||||
make_c3d = false # also save triangulated data in c3d format
|
||||
|
||||
[filtering.butterworth]
|
||||
|
@ -152,7 +152,7 @@ make_c3d = true # save triangulated data in c3d format in addition to trc
|
||||
|
||||
[filtering]
|
||||
type = 'butterworth' # butterworth, kalman, gaussian, LOESS, median, butterworth_on_speed
|
||||
display_figures = true # true or false (lowercase) # fails when run multiple times https://github.com/superjax/plotWindow/issues/7
|
||||
display_figures = true # true or false (lowercase)
|
||||
make_c3d = true # also save triangulated data in c3d format
|
||||
|
||||
[filtering.butterworth]
|
||||
|
@ -129,17 +129,19 @@ class TestWorkflow(unittest.TestCase):
|
||||
config_dict.get("synchronization").update({"display_sync_plots":False})
|
||||
config_dict['filtering']['display_figures'] = False
|
||||
|
||||
# Step by step
|
||||
Pose2Sim.calibration(config_dict)
|
||||
Pose2Sim.poseEstimation(config_dict)
|
||||
Pose2Sim.synchronization(config_dict)
|
||||
# Pose2Sim.synchronization(config_dict) # No synchronization for multi-person for now
|
||||
Pose2Sim.personAssociation(config_dict)
|
||||
Pose2Sim.triangulation(config_dict)
|
||||
Pose2Sim.filtering(config_dict)
|
||||
Pose2Sim.markerAugmentation(config_dict)
|
||||
# Pose2Sim.kinematics(config_dict)
|
||||
|
||||
# Run all
|
||||
config_dict.get("pose").update({"overwrite_pose":False})
|
||||
Pose2Sim.runAll(config_dict)
|
||||
Pose2Sim.runAll(config_dict, do_synchronization=False)
|
||||
|
||||
|
||||
####################
|
||||
@ -149,7 +151,7 @@ class TestWorkflow(unittest.TestCase):
|
||||
project_dir = '../Demo_Batch'
|
||||
os.chdir(project_dir)
|
||||
|
||||
Pose2Sim.runAll()
|
||||
Pose2Sim.runAll(do_synchronization=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -8,6 +8,10 @@
|
||||
##################################################
|
||||
|
||||
Determine gait on and off from a TRC file of point coordinates.
|
||||
|
||||
N.B.: Could implement the methods listed there in the future.
|
||||
Please feel free to make a pull-request or keep me informed if you do so!
|
||||
|
||||
Three available methods, each of them with their own pros and cons:
|
||||
|
||||
- "forward_coordinates":
|
||||
|
@ -22,9 +22,9 @@ import sys
|
||||
import matplotlib as mpl
|
||||
mpl.use('qt5agg')
|
||||
mpl.rc('figure', max_open_warning=0)
|
||||
from PyQt5.QtWidgets import QMainWindow, QApplication, QWidget, QTabWidget, QVBoxLayout
|
||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
||||
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
|
||||
from PyQt5.QtWidgets import QMainWindow, QApplication, QWidget, QTabWidget, QVBoxLayout
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="c3d")
|
||||
|
||||
@ -482,7 +482,8 @@ class plotWindow():
|
||||
|
||||
def __init__(self, parent=None):
|
||||
self.app = QApplication(sys.argv)
|
||||
self.MainWindow = QMainWindow()
|
||||
if not self.app:
|
||||
self.app = QApplication(sys.argv)
|
||||
self.MainWindow.__init__()
|
||||
self.MainWindow.setWindowTitle("Multitabs figure")
|
||||
self.canvases = []
|
||||
|
@ -319,7 +319,7 @@ def median_filter_1d(config_dict, frame_rate, col):
|
||||
return col_filtered
|
||||
|
||||
|
||||
def display_figures_fun(Q_unfilt, Q_filt, time_col, keypoints_names):
|
||||
def display_figures_fun(Q_unfilt, Q_filt, time_col, keypoints_names, person_id=0):
|
||||
'''
|
||||
Displays filtered and unfiltered data for comparison
|
||||
|
||||
@ -334,6 +334,7 @@ def display_figures_fun(Q_unfilt, Q_filt, time_col, keypoints_names):
|
||||
'''
|
||||
|
||||
pw = plotWindow()
|
||||
pw.MainWindow.setWindowTitle('Person '+ str(person_id) + ' coordinates')
|
||||
for id, keypoint in enumerate(keypoints_names):
|
||||
f = plt.figure()
|
||||
|
||||
@ -472,7 +473,7 @@ def filter_all(config_dict):
|
||||
trc_f_out = [f'{os.path.basename(t).split(".")[0]}_filt_{filter_type}.trc' for t in trc_path_in]
|
||||
trc_path_out = [os.path.join(pose3d_dir, t) for t in trc_f_out]
|
||||
|
||||
for t_in, t_out in zip(trc_path_in, trc_path_out):
|
||||
for person_id, t_in, t_out in enumerate(zip(trc_path_in, trc_path_out)):
|
||||
# Read trc header
|
||||
with open(t_in, 'r') as trc_file:
|
||||
header = [next(trc_file) for line in range(5)]
|
||||
@ -489,7 +490,7 @@ def filter_all(config_dict):
|
||||
if display_figures:
|
||||
# Retrieve keypoints
|
||||
keypoints_names = pd.read_csv(t_in, sep="\t", skiprows=3, nrows=0).columns[2::3].to_numpy()
|
||||
display_figures_fun(Q_coord, Q_filt, time_col, keypoints_names)
|
||||
display_figures_fun(Q_coord, Q_filt, time_col, keypoints_names, person_id)
|
||||
|
||||
# Reconstruct trc file with filtered coordinates
|
||||
with open(t_out, 'w') as trc_o:
|
||||
|
@ -127,14 +127,13 @@ def sort_people_rtmlib(pose_tracker, keypoints, scores):
|
||||
return sorted_keypoints, sorted_scores
|
||||
|
||||
|
||||
def process_video(video_path, pose_tracker, tracking, output_format, save_video, save_images, display_detection, frame_range):
|
||||
def process_video(video_path, pose_tracker, output_format, save_video, save_images, display_detection, frame_range):
|
||||
'''
|
||||
Estimate pose from a video file
|
||||
|
||||
INPUTS:
|
||||
- video_path: str. Path to the input video file
|
||||
- pose_tracker: PoseTracker. Initialized pose tracker object from RTMLib
|
||||
- tracking: bool. Whether to give consistent person ID across frames
|
||||
- output_format: str. Output format for the pose estimation results ('openpose', 'mmpose', 'deeplabcut')
|
||||
- save_video: bool. Whether to save the output video
|
||||
- save_images: bool. Whether to save the output images
|
||||
@ -186,10 +185,6 @@ def process_video(video_path, pose_tracker, tracking, output_format, save_video,
|
||||
# Perform pose estimation on the frame
|
||||
keypoints, scores = pose_tracker(frame)
|
||||
|
||||
# Reorder keypoints, scores
|
||||
if tracking:
|
||||
keypoints, scores = sort_people_rtmlib(pose_tracker, keypoints, scores)
|
||||
|
||||
# Save to json
|
||||
if 'openpose' in output_format:
|
||||
json_file_path = os.path.join(json_output_dir, f'{video_name_wo_ext}_{frame_idx:06d}.json')
|
||||
@ -225,7 +220,7 @@ def process_video(video_path, pose_tracker, tracking, output_format, save_video,
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def process_images(image_folder_path, vid_img_extension, pose_tracker, tracking, output_format, fps, save_video, save_images, display_detection, frame_range):
|
||||
def process_images(image_folder_path, vid_img_extension, pose_tracker, output_format, fps, save_video, save_images, display_detection, frame_range):
|
||||
'''
|
||||
Estimate pose estimation from a folder of images
|
||||
|
||||
@ -233,7 +228,6 @@ def process_images(image_folder_path, vid_img_extension, pose_tracker, tracking,
|
||||
- image_folder_path: str. Path to the input image folder
|
||||
- vid_img_extension: str. Extension of the image files
|
||||
- pose_tracker: PoseTracker. Initialized pose tracker object from RTMLib
|
||||
- tracking: bool. Whether to give consistent person ID across frames
|
||||
- output_format: str. Output format for the pose estimation results ('openpose', 'mmpose', 'deeplabcut')
|
||||
- save_video: bool. Whether to save the output video
|
||||
- save_images: bool. Whether to save the output images
|
||||
@ -275,17 +269,6 @@ def process_images(image_folder_path, vid_img_extension, pose_tracker, tracking,
|
||||
|
||||
# Perform pose estimation on the image
|
||||
keypoints, scores = pose_tracker(frame)
|
||||
|
||||
# Reorder keypoints, scores
|
||||
if tracking:
|
||||
max_id = max(pose_tracker.track_ids_last_frame)
|
||||
num_frames, num_points, num_coordinates = keypoints.shape
|
||||
keypoints_filled = np.zeros((max_id+1, num_points, num_coordinates))
|
||||
scores_filled = np.zeros((max_id+1, num_points))
|
||||
keypoints_filled[pose_tracker.track_ids_last_frame] = keypoints
|
||||
scores_filled[pose_tracker.track_ids_last_frame] = scores
|
||||
keypoints = keypoints_filled
|
||||
scores = scores_filled
|
||||
|
||||
# Extract frame number from the filename
|
||||
if 'openpose' in output_format:
|
||||
@ -361,9 +344,7 @@ def rtm_estimator(config_dict):
|
||||
save_images = True if 'to_images' in config_dict['pose']['save_video'] else False
|
||||
display_detection = config_dict['pose']['display_detection']
|
||||
overwrite_pose = config_dict['pose']['overwrite_pose']
|
||||
|
||||
det_frequency = config_dict['pose']['det_frequency']
|
||||
tracking = config_dict['pose']['tracking']
|
||||
|
||||
# Determine frame rate
|
||||
video_files = glob.glob(os.path.join(video_dir, '*'+vid_img_extension))
|
||||
@ -407,9 +388,6 @@ def rtm_estimator(config_dict):
|
||||
logging.info(f'Inference run on every single frame.')
|
||||
else:
|
||||
raise ValueError(f"Invalid det_frequency: {det_frequency}. Must be an integer greater or equal to 1.")
|
||||
|
||||
if tracking:
|
||||
logging.info(f'Pose estimation will attempt to give consistent person IDs across frames.\n')
|
||||
|
||||
# Select the appropriate model based on the model_type
|
||||
if pose_model.upper() == 'HALPE_26':
|
||||
@ -433,7 +411,7 @@ def rtm_estimator(config_dict):
|
||||
mode=mode,
|
||||
backend=backend,
|
||||
device=device,
|
||||
tracking=tracking,
|
||||
tracking=False,
|
||||
to_openpose=False)
|
||||
|
||||
|
||||
@ -454,7 +432,7 @@ def rtm_estimator(config_dict):
|
||||
logging.info(f'Found video files with extension {vid_img_extension}.')
|
||||
for video_path in video_files:
|
||||
pose_tracker.reset()
|
||||
process_video(video_path, pose_tracker, tracking, output_format, save_video, save_images, display_detection, frame_range)
|
||||
process_video(video_path, pose_tracker, output_format, save_video, save_images, display_detection, frame_range)
|
||||
|
||||
else:
|
||||
# Process image folders
|
||||
@ -463,4 +441,4 @@ def rtm_estimator(config_dict):
|
||||
for image_folder in image_folders:
|
||||
pose_tracker.reset()
|
||||
image_folder_path = os.path.join(video_dir, image_folder)
|
||||
process_images(image_folder_path, vid_img_extension, pose_tracker, tracking, output_format, frame_rate, save_video, save_images, display_detection, frame_range)
|
||||
process_images(image_folder_path, vid_img_extension, pose_tracker, output_format, frame_rate, save_video, save_images, display_detection, frame_range)
|
||||
|
@ -177,11 +177,6 @@ def time_lagged_cross_corr(camx, camy, lag_range, show=True, ref_cam_id=0, cam_i
|
||||
if isinstance(lag_range, int):
|
||||
lag_range = [-lag_range, lag_range]
|
||||
|
||||
import hashlib
|
||||
print(repr(list(camx)), repr(list(camy)))
|
||||
hashlib.md5(pd.util.hash_pandas_object(camx).values).hexdigest()
|
||||
hashlib.md5(pd.util.hash_pandas_object(camy).values).hexdigest()
|
||||
|
||||
pearson_r = [camx.corr(camy.shift(lag)) for lag in range(lag_range[0], lag_range[1])]
|
||||
offset = int(np.floor(len(pearson_r)/2)-np.argmax(pearson_r))
|
||||
if not np.isnan(pearson_r).all():
|
||||
|
Loading…
Reference in New Issue
Block a user