pose2sim/Pose2Sim/filter_3d.py
2022-09-30 17:12:41 +02:00

330 lines
11 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
###########################################################################
## FILTER 3D COORDINATES ##
###########################################################################
Filter trc 3D coordinates.
Available filters: Butterworth, Butterworth on speed, Gaussian, LOESS, Median
Set your parameters in Config.toml
INPUTS:
- a trc file
- filtering parameters in Config.toml
OUTPUT:
- a filtered trc file
'''
## INIT
import os
import fnmatch
import pandas as pd
import matplotlib.pyplot as plt
import logging
from scipy import signal
from scipy.ndimage import gaussian_filter1d
from statsmodels.nonparametric.smoothers_lowess import lowess
from Pose2Sim.common import plotWindow
## AUTHORSHIP INFORMATION
__author__ = "David Pagnon"
__copyright__ = "Copyright 2021, Pose2Sim"
__credits__ = ["David Pagnon"]
__license__ = "BSD 3-Clause License"
__version__ = "0.1"
__maintainer__ = "David Pagnon"
__email__ = "contact@david-pagnon.com"
__status__ = "Development"
## FUNCTIONS
def butterworth_filter_1d(config, col):
'''
1D Zero-phase Butterworth filter (dual pass)
INPUT:
- col: Pandas dataframe column
- frame rate, order, cut-off frequency, type (from Config.toml)
OUTPUT
- col_filtered: Filtered pandas dataframe column
'''
butterworth_filter_type = config.get('3d-filtering').get('butterworth').get('type')
butterworth_filter_order = int(config.get('3d-filtering').get('butterworth').get('order'))
butterworth_filter_cutoff = int(config.get('3d-filtering').get('butterworth').get('cut_off_frequency'))
frame_rate = config.get('project').get('frame_rate')
b, a = signal.butter(butterworth_filter_order/2, butterworth_filter_cutoff/(frame_rate/2), butterworth_filter_type, analog = False)
col_filtered = signal.filtfilt(b, a, col)
return col_filtered
def butterworth_on_speed_filter_1d(config, col):
'''
1D zero-phase Butterworth filter (dual pass) on derivative
INPUT:
- col: Pandas dataframe column
- frame rate, order, cut-off frequency, type (from Config.toml)
OUTPUT
- col_filtered: Filtered pandas dataframe column
'''
butter_speed_filter_type = config.get('3d-filtering').get('butterworth_on_speed').get('type')
butter_speed_filter_order = int(config.get('3d-filtering').get('butterworth_on_speed').get('order'))
butter_speed_filter_cutoff = int(config.get('3d-filtering').get('butterworth_on_speed').get('cut_off_frequency'))
frame_rate = config.get('project').get('frame_rate')
b, a = signal.butter(butter_speed_filter_order/2, butter_speed_filter_cutoff/(frame_rate/2), butter_speed_filter_type, analog = False)
col_diff = col.diff() # derivative
col_diff = col_diff.fillna(col_diff.iloc[1]/2) # set first value correctly instead of nan
col_diff_filt = signal.filtfilt(b, a, col_diff) # filter derivative
col_filtered = col_diff_filt.cumsum() + col.iloc[0] # integrate filtered derivative
return col_filtered
def gaussian_filter_1d(config, col):
'''
1D Gaussian filter
INPUT:
- col: Pandas dataframe column
- gaussian_filter_sigma_kernel: kernel size from Config.toml
OUTPUT
- col_filtered: Filtered pandas dataframe column
'''
gaussian_filter_sigma_kernel = int(config.get('3d-filtering').get('gaussian').get('sigma_kernel'))
col_filtered = gaussian_filter1d(col, gaussian_filter_sigma_kernel)
return col_filtered
def loess_filter_1d(config, col):
'''
1D LOWESS filter (Locally Weighted Scatterplot Smoothing)
INPUT:
- col: Pandas dataframe column
- loess_filter_nb_values: window used for smoothing from Config.toml
frac = loess_filter_nb_values * frames_number
OUTPUT
- col_filtered: Filtered pandas dataframe column
'''
loess_filter_nb_values = config.get('3d-filtering').get('LOESS').get('nb_values_used')
col_filtered = lowess(col, col.index, is_sorted=True, frac=loess_filter_nb_values/len(col), it=0)[:,1]
return col_filtered
def median_filter_1d(config, col):
'''
1D median filter
INPUT:
- col: Pandas dataframe column
- median_filter_kernel_size: kernel size from Config.toml
OUTPUT
- col_filtered: Filtered pandas dataframe column
'''
median_filter_kernel_size = config.get('3d-filtering').get('median').get('kernel_size')
col_filtered = signal.medfilt(col, kernel_size=median_filter_kernel_size)
return col_filtered
def display_figures_fun(Q_unfilt, Q_filt, time_col, keypoints_names):
'''
Displays filtered and unfiltered data for comparison
INPUTS:
- Q_unfilt: pandas dataframe of unfiltered 3D coordinates
- Q_filt: pandas dataframe of filtered 3D coordinates
- time_col: pandas column
- keypoints_names: list of strings
OUTPUT:
- matplotlib window with tabbed figures for each keypoint
'''
pw = plotWindow()
for id, keypoint in enumerate(keypoints_names):
f = plt.figure()
axX = plt.subplot(311)
plt.plot(time_col.to_numpy(), Q_unfilt.iloc[:,id*3].to_numpy(), label='unfiltered')
plt.plot(time_col.to_numpy(), Q_filt.iloc[:,id*3].to_numpy(), label='filtered')
plt.setp(axX.get_xticklabels(), visible=False)
axX.set_ylabel(keypoint+' X')
plt.legend()
axY = plt.subplot(312)
plt.plot(time_col.to_numpy(), Q_unfilt.iloc[:,id*3+1].to_numpy(), label='unfiltered')
plt.plot(time_col.to_numpy(), Q_filt.iloc[:,id*3+1].to_numpy(), label='filtered')
plt.setp(axY.get_xticklabels(), visible=False)
axY.set_ylabel(keypoint+' Y')
plt.legend()
axZ = plt.subplot(313)
plt.plot(time_col.to_numpy(), Q_unfilt.iloc[:,id*3+2].to_numpy(), label='unfiltered')
plt.plot(time_col.to_numpy(), Q_filt.iloc[:,id*3+2].to_numpy(), label='filtered')
axZ.set_ylabel(keypoint+' Z')
axZ.set_xlabel('Time')
plt.legend()
pw.addPlot(keypoint, f)
pw.show()
def filter1d(col, config, filter_type):
'''
Choose filter type and filter column
INPUT:
- col: Pandas dataframe column
- filter_type: filter type from Config.toml
OUTPUT
- col_filtered: Filtered pandas dataframe column
'''
# Choose filter
filter_mapping = {
'butterworth': butterworth_filter_1d,
'butterworth_on_speed': butterworth_on_speed_filter_1d,
'gaussian': gaussian_filter_1d,
'LOESS': loess_filter_1d,
'median': median_filter_1d
}
filter_fun = filter_mapping[filter_type]
# Filter column
col_filtered = filter_fun(config, col)
return col_filtered
def recap_filter3d(config, trc_path):
'''
Print a log message giving filtering parameters. Also stored in User/logs.txt.
OUTPUT:
- Message in console
'''
# Read Config
filter_type = config.get('3d-filtering').get('type')
butterworth_filter_type = config.get('3d-filtering').get('butterworth').get('type')
butterworth_filter_order = int(config.get('3d-filtering').get('butterworth').get('order'))
butterworth_filter_cutoff = int(config.get('3d-filtering').get('butterworth').get('cut_off_frequency'))
butter_speed_filter_type = config.get('3d-filtering').get('butterworth_on_speed').get('type')
butter_speed_filter_order = int(config.get('3d-filtering').get('butterworth_on_speed').get('order'))
butter_speed_filter_cutoff = int(config.get('3d-filtering').get('butterworth_on_speed').get('cut_off_frequency'))
gaussian_filter_sigma_kernel = int(config.get('3d-filtering').get('gaussian').get('sigma_kernel'))
loess_filter_nb_values = config.get('3d-filtering').get('LOESS').get('nb_values_used')
median_filter_kernel_size = config.get('3d-filtering').get('median').get('kernel_size')
# Recap
filter_mapping_recap = {
'butterworth': f'--> Filter type: Butterworth {butterworth_filter_type}-pass. Order {butterworth_filter_order}, Cut-off frequency {butterworth_filter_cutoff} Hz.',
'butterworth_on_speed': f'--> Filter type: Butterworth on speed {butter_speed_filter_type}-pass. Order {butter_speed_filter_order}, Cut-off frequency {butter_speed_filter_cutoff} Hz.',
'gaussian': f'--> Filter type: Gaussian. Standard deviation kernel: {gaussian_filter_sigma_kernel}',
'LOESS': f'--> Filter type: LOESS. Number of values used: {loess_filter_nb_values}',
'median': f'--> Filter type: Median. Kernel size: {median_filter_kernel_size}'
}
logging.info(filter_mapping_recap[filter_type])
logging.info(f'Filtered 3D coordinates are stored at {trc_path}.')
def filter_all(config):
'''
Filter the 3D coordinates of the trc file.
Displays filtered coordinates for checking.
INPUTS:
- a trc file
- filtration parameters from Config.toml
OUTPUT:
- a filtered trc file
'''
# Read config
project_dir = config.get('project').get('project_dir')
if project_dir == '': project_dir = os.getcwd()
try:
pose_folder_name = config.get('project').get('poseTracked_folder_name')
except:
pose_folder_name = config.get('project').get('pose_folder_name')
pose_dir = os.path.join(project_dir, pose_folder_name)
json_folder_extension = config.get('project').get('pose_json_folder_extension')
frame_range = config.get('project').get('frame_range')
seq_name = os.path.basename(project_dir)
pose3d_folder_name = config.get('project').get('pose3d_folder_name')
pose3d_dir = os.path.join(project_dir, pose3d_folder_name)
display_figures = config.get('3d-filtering').get('display_figures')
filter_type = config.get('3d-filtering').get('type')
# Frames range
pose_listdirs_names = next(os.walk(pose_dir))[1]
json_dirs_names = [k for k in pose_listdirs_names if json_folder_extension in k]
json_files_names = [fnmatch.filter(os.listdir(os.path.join(pose_dir, js_dir)), '*.json') for js_dir in json_dirs_names]
f_range = [[0,min([len(j) for j in json_files_names])] if frame_range==[] else frame_range][0]
# Trc paths
trc_f_in = f'{seq_name}_{f_range[0]}-{f_range[1]}.trc'
trc_f_out = f'{seq_name}_filt_{f_range[0]}-{f_range[1]}.trc'
trc_path_in = os.path.join(pose3d_dir, trc_f_in)
trc_path_out = os.path.join(pose3d_dir, trc_f_out)
# Read trc header
with open(trc_path_in, 'r') as trc_file:
header = [next(trc_file) for line in range(5)]
# Read trc coordinates values
trc_df = pd.read_csv(trc_path_in, sep="\t", skiprows=4)
frames_col, time_col = trc_df.iloc[:,0], trc_df.iloc[:,1]
Q_coord = trc_df.drop(trc_df.columns[[0, 1]], axis=1)
# Filter coordinates
Q_filt = Q_coord.apply(filter1d, axis=0, args = [config, filter_type])
# Display figures
if display_figures=='True':
# Retrieve keypoints
keypoints_names = pd.read_csv(trc_path_in, sep="\t", skiprows=3, nrows=0).columns[2::3].to_numpy()
display_figures_fun(Q_coord, Q_filt, time_col, keypoints_names)
# Reconstruct trc file with filtered coordinates
with open(trc_path_out, 'w') as trc_o:
[trc_o.write(line) for line in header]
Q_filt.insert(0, 'Frame#', frames_col)
Q_filt.insert(1, 'Time', time_col)
Q_filt.to_csv(trc_o, sep='\t', index=False, header=None, line_terminator='\n')
# Recap
recap_filter3d(config, trc_path_out)