diff --git a/library/pymatch/CMakeLists.txt b/library/pymatch/CMakeLists.txt new file mode 100644 index 0000000..276f056 --- /dev/null +++ b/library/pymatch/CMakeLists.txt @@ -0,0 +1,49 @@ +cmake_minimum_required(VERSION 3.12.0) +project(MATCHLR VERSION 0.1.0) + +set(CMAKE_CXX_STANDARD 11) + +set(CMAKE_CXX_FLAGS "-pthread -O3 -fPIC") + +add_definitions(-DPROJECT_SOURCE_DIR="${PROJECT_SOURCE_DIR}") + +set(Eigen3_DIR "${PROJECT_SOURCE_DIR}/../../3rdparty/eigen-3.3.7/share/eigen3/cmake") +set(PYBIND11_DIR "${PROJECT_SOURCE_DIR}/../../3rdparty/pybind11") + +find_package(Eigen3 3.3.7 REQUIRED) +message(STATUS "Eigen3 path is ${EIGEN3_INCLUDE_DIR}") + +set(USE_OPENCV 0) +if(USE_OPENCV) + # OpenCV + find_package(OpenCV REQUIRED) + add_definitions(-D_USE_OPENCV_) +endif(USE_OPENCV) + +IF (WIN32) + MESSAGE(STATUS "I don't test on Windows") +ELSEIF (APPLE) + MESSAGE(STATUS "Not use openmp") +ELSEIF (UNIX) + find_package(OpenMP REQUIRED) + set(OTHER_LIBS ${OTHER_LIBS} OpenMP::OpenMP_CXX) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") +ENDIF () + +set(PUBLIC_INCLUDE ${PROJECT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/include + ${EIGEN3_INCLUDE_DIR} +) + +set(OTHER_SRCS "") + + +if(USE_OPENCV) + set(PUBLIC_INCLUDE ${PUBLIC_INCLUDE} ${OpenCV_INCLUDE_DIRS}) + set(OTHER_LIBS ${OTHER_LIBS} ${OpenCV_LIBRARIES}) +endif(USE_OPENCV) + +# add_subdirectory(test) +add_subdirectory(${PYBIND11_DIR} "PYBIND11_out") +add_subdirectory(python) +# add_subdirectory(src) diff --git a/library/pymatch/include/Timer.hpp b/library/pymatch/include/Timer.hpp new file mode 100644 index 0000000..139622c --- /dev/null +++ b/library/pymatch/include/Timer.hpp @@ -0,0 +1,63 @@ +#pragma once +#include +#include + +class Timer +{ + public: + std::string name_; + Timer(std::string name):name_(name){}; + Timer(){}; + void start(); + void tic(); + void toc(); + void toc(std::string things); + void end(); + double now(); + + + private: + std::chrono::steady_clock::time_point t_s_; //start time ponit + std::chrono::steady_clock::time_point t_tic_; //tic time ponit + std::chrono::steady_clock::time_point t_toc_; //toc time ponit + std::chrono::steady_clock::time_point t_e_; //stop time point +}; + + +void Timer::tic() +{ + t_tic_ = std::chrono::steady_clock::now(); +} + +void Timer::toc() +{ + t_toc_ = std::chrono::steady_clock::now(); + auto tt = std::chrono::duration_cast>(t_toc_ - t_tic_).count(); + std::cout << "Time spend: " << tt << " seconds" << std::endl; +} + +void Timer::toc(std::string things) +{ + t_toc_ = std::chrono::steady_clock::now(); + auto tt = std::chrono::duration_cast>(t_toc_ - t_tic_).count(); + std::cout << "Time spend: " << tt << " seconds when doing "<>(t_e_ - t_s_).count(); + std::cout << "< "<name_<<" > Time total spend: " << tt << " seconds" << std::endl; +} + +double Timer::now() +{ + t_e_ = std::chrono::steady_clock::now(); + auto tt = std::chrono::duration_cast>(t_e_ - t_s_).count(); + return tt; +} diff --git a/library/pymatch/include/base.h b/library/pymatch/include/base.h new file mode 100644 index 0000000..18fb398 --- /dev/null +++ b/library/pymatch/include/base.h @@ -0,0 +1,41 @@ +/*** + * @Date: 2020-09-19 16:10:21 + * @Author: Qing Shuai + * @LastEditors: Qing Shuai + * @LastEditTime: 2020-09-24 21:09:58 + * @FilePath: /MatchLR/include/match/base.h + */ +#pragma once +#include +#include "Eigen/Dense" +#include +#include + +namespace match +{ + typedef float Type; + typedef Eigen::Matrix Mat; + typedef Eigen::Array Array; + template + using Vec=std::vector; + typedef std::vector List; + typedef std::vector ListList; + struct MatchInfo + { + int maxIter = 100; + float alpha = 200; + float beta = 0.1; + float tol = 1e-3; + float w_sparse = 0.1; + float w_rank = 50; + }; + typedef std::unordered_map Control; + + void print(Vec& lists, std::string name){ + std::cout << name << ": ["; + for(auto i:lists){ + std::cout << i << ", "; + } + std::cout << "]" << std::endl; + } +} // namespace match diff --git a/library/pymatch/include/matchSVT.hpp b/library/pymatch/include/matchSVT.hpp new file mode 100644 index 0000000..22704d5 --- /dev/null +++ b/library/pymatch/include/matchSVT.hpp @@ -0,0 +1,246 @@ +/*** + * @Date: 2020-09-12 19:01:56 + * @Author: Qing Shuai + * @LastEditors: Qing Shuai + * @LastEditTime: 2022-07-29 22:38:40 + * @FilePath: /EasyMocapPublic/library/pymatch/include/matchSVT.hpp + */ +#pragma once +#include "base.h" +#include "projfunc.hpp" +#include "Timer.hpp" +#include "visualize.hpp" + +namespace match +{ + struct Config + { + bool debug; + int max_iter; + float tol; + float w_rank; + float w_sparse; + Config(Control& control){ + debug = (control["debug"] > 0.); + max_iter = int(control["maxIter"]); + tol = control["tol"]; + w_rank = control["w_rank"]; + w_sparse = control["w_sparse"]; + } + }; + + // dimGroups: [0, nF1, nF1 + nF2, ...] + ListList getBlocksFromDimGroups(const List& dimGroups){ + ListList block; + for(int i=0;i 0.); + int max_iter = int(control["maxIter"]); + float tol = control["tol"]; + + int N = M_aff.rows(); + auto dual_blocks = getBlocksFromDimGroups(dimGroups); + // 对角线约束 + for(int i=0;i 0.7 && block[1] > 1 && block[3] > 1){ + // 如果大于0.9,说明区分度不够高啊,认为观测是虚假的 + M_obs.block(block[0], block[2], block[1], block[3]).setZero(); + } + } + // set the diag of M_aff to zeros + M_aff.diagonal().setConstant(0); + Mat X = M_aff; + Mat Y = Mat::Zero(N, N); + Mat Q = M_aff; + Mat W = (control["w_sparse"] - M_aff.array()).matrix(); + float mu = 64; + Timer timer; + timer.tic(); + for (int iter = 0; iter < max_iter; iter++) + { + Mat X0 = X; + // update Q with SVT + Q = 1.0 / mu * Y + X; + Eigen::BDCSVD UDV(Q.bdcSvd(Eigen::ComputeThinU | Eigen::ComputeThinV)); + Array Ds(Dsoft(UDV.singularValues(), control["w_rank"] / mu)); + Mat Qnew(UDV.matrixU() * Ds.matrix().asDiagonal() * UDV.matrixV().adjoint()); + Q = Qnew; + // update X + X = Q - (M_obs.array() * W.array() + Y.array()).matrix() / mu; + X = (X.array() * M_constr.array()).matrix(); + // set the diagonal + X.diagonal().setOnes(); + // 注意这个min,max + X = X.cwiseMin(1.f).cwiseMax(0.f); + #pragma omp parallel for + for(int i=0;i + alpha||x||_* + beta||x||_1, st. X \in C + // + alpha/2||A||^2 + alpha/2||B||^2 + // st AB^T = Z, Z\in \Omega + const Config cfg(control); + Mat W = (M_aff + M_aff.transpose())/2; + // set the diag of W to zeros + for(int i=0;i maxRank){ + maxRank = dimGroups[i+1]-dimGroups[i]; + } + } + std::cout << "[matchALS] set the max rank = " << maxRank << std::endl; + Mat eyeRank = Mat::Identity(maxRank, maxRank); + // initial value + Mat A = Mat::Random(n, maxRank); + Mat B; + for(int iter=0;iter u, sv; + for (int i = 0; i < y.rows(); i++) + { + u.push_back(y(i, 0)); + } + // 排序 + std::sort(u.begin(), u.end(), std::greater()); + float usum = 0; + for (int i = 0; i < u.size(); i++) + { + usum += u[i]; + sv.push_back(usum); + } + int rho = 0; + for (int i = 0; i < u.size(); i++) + { + if (u[i] > (sv[i] - 1) / (i + 1)) + { + rho = i; + } + } + float theta = std::max(0.f, (sv[rho] - 1) / (rho + 1)); + x = (y.array() - theta).matrix(); + x = x.cwiseMax(0.f); + } + return x; + } + + Mat proj2pavC(Mat y) + { + // y: N, 1 + // y[y<0] = 0 + int n = y.rows(); + y = y.cwiseMax(0.f); + Mat x = y; + if (y.sum() < 1) + { + x = y; + } + else + { + std::vector u; + for (int i = 0; i < y.rows(); i++) + { + u.push_back(y(i, 0)); + } + // 排序 + std::sort(u.begin(), u.end(), std::greater()); + float tmpsum = 0; + bool bget = false; + float tmax; + for (int ii = 0; ii < n - 1; ii++) + { + tmpsum += u[ii]; + tmax = (tmpsum - 1) / (ii + 1); + if (tmax >= u[ii + 1]) + { + bget = true; + break; + } + } + if (!bget) + { + tmax = (tmpsum + u[n - 1] - 1) / n; + } + x = (y.array() - tmax).matrix(); + x = x.cwiseMax(0.f); + } + return x; + } + + int proj201(Mat &z) + { + z = z.cwiseMin(1.f).cwiseMax(0.f); + return 0; + } + + Mat proj2kav_(Mat x0, Mat A, Mat b) + { + // to solve: + // min 1/2||x - x_0||_F^2 + ||z||_1 + // s.t. Ax = b, x-z=0, x>=0, x<=1 + // convert to L(x, y) = 1/2||x - x_0||_F^2 + y^T(Ax - b) + // x = (I + \rho A^T @A)^-1 @ (x_0 - A^T@y + \rho A^T@b) + // y = y + \rho *(Ax - b) + int n = x0.rows(); + Mat I(n, n); + I.setIdentity(); + Mat X = x0; + Mat y = b; + float rho = 2; + y.setZero(); + float tol = 1e-4; + Mat Y, B, Z, c; + for (int iter = 0; iter < 100; iter++) + { + Mat X0 = X; + // (x - x_0) + A^Ty + \rho A^T(Ax + By -c) + X = (I + rho * A.transpose() * A).ldlt().solve(x0 - A.transpose() * y + rho * A.transpose() * b); + y = y + rho * (A * X - b); + + Y = Y + rho * (A * X + B * Z - c); + float pRes = (A * X + B * Z - c).norm() / n; + float dRes = rho * (X - X0).norm() / n; + // std::cout << " Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), rho = " << rho << std::endl; + + if (pRes < tol && dRes < tol) + break; + if (pRes > 10 * dRes) + { + rho *= 2; + } + else if (dRes > 10 * pRes) + { + rho /= 2; + } + } + return X; + } + + Mat softthres(Mat b, float thres) + { + // TODO:vector + for (int i = 0; i < b.rows(); i++) + { + if (b(i, 0) < -thres) + { + b(i, 0) += thres; + } + else if (b(i, 0) > thres) + { + b(i, 0) -= thres; + } + else + { + b(i, 0) = 0; + } + } + return b; + } + + Array Dsoft(const Array &d, float penalty) + { + // inverts the singular values + // takes advantage of the fact that singular values are never negative + Array di(d.rows(), d.cols()); + int maxRank = 0; + for (int j = 0; j < d.size(); ++j) + { + double penalized = d(j, 0) - penalty; + if (penalized < 0) + { + di(j, 0) = 0; + } + else + { + di(j, 0) = penalized; + maxRank++; + } + } + // std::cout << "max rank: " << maxRank << std::endl; + return di; + } + + Mat _proj2kav(Mat x0, Mat A, Mat b, Mat weight) + { + // to solve: + // min 1/2||x - x_0||_F^2 + 1/2\lambda||Ax - b||_F^2 + \alpha||z||_1 + // s.t. x=z, z \in {z| z>=0, z<=1| + // convert to L(x, y) = 1/2||x - x_0||_F^2 + 1/2\lambda||Ax - b||_F^2 + \alpha||z||_1 + // + + \rho/2||x - z||_F^2 + // update x: + // x = (1/rho + I + lambda/rhoA^TA)^-1 @ (1/rho x_0 + lambda/rho A^Tb + y) + // update z: + // z = softthres(x + 1/rho y, lambda/rho) + // update y: + // y = y + \rho *(x - z) + int n = x0.rows(); + Mat I(n, n); + I.setIdentity(); + Mat X = x0; + Mat Y = X; + Mat Z = Y; + Y.setZero(); + float rho = 64; + // weight + float w_init = 1; + float w_paf = 1e-1; + float w_Ax = 100; + float w_l1 = 1e-1; + float tol = 1e-4; + std::cout << "x0: " << x0 << std::endl; + std::cout << "paf: " << weight << std::endl; + for (int iter = 0; iter < 100; iter++) + { + Mat X0 = X; + // update X + X = ((rho + w_init) * I + w_Ax * A.transpose() * A).ldlt().solve(x0 + w_Ax * A.transpose() * b + rho * Z - Y + w_paf * weight); + // update Z + Z = softthres(X + 1 / rho * Y, w_l1 / rho); + // projection Z + Z = Z.cwiseMin(1.f).cwiseMax(0.f); + // update Y + Y = Y + rho * (X - Z); + // convergence + float pRes = (X - Z).norm() / n; + float dRes = rho * (X - X0).norm() / n; + std::cout << " proj2kav Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), rho = " << rho << std::endl; + std::cout << " init= " << w_init * 0.5 * (X - x0).norm() / n + << ", equ= " << 0.5 * w_Ax * (A * X - b).norm() / n + << ", paf=" << -w_paf * weight.transpose() * X + << ", l1= " << 1.0 * (X.array() > 0).count() / n << std::endl; + + if (pRes < tol && dRes < tol) + break; + if (pRes > 10 * dRes) + { + rho *= 2; + } + else if (dRes > 10 * pRes) + { + rho /= 2; + } + } + return X; + } + + Mat proj2kav(Mat x0, Mat A, Mat b, Mat paf) + { + // reduce this problem + + std::vector indices; + // here we directly set the non-zero entries + for (int j = 0; j < A.cols(); j++) + { + if (A(A.rows() - 1, j) != 0) + { + indices.push_back(j); + } + } + // just use the last row + int n = indices.size(); + Mat Areduce(1, n); + Areduce.setOnes(); + Mat x0reduce(n, 1), pafreduce(n, 1); + for (int i = 0; i < n; i++) + { + x0reduce(i, 0) = x0(indices[i], 0); + pafreduce(i, 0) = paf(indices[i], 0); + } + Mat breduce(1, 1); + breduce(0, 0) = b(b.rows() - 1, 0); + + Mat xreduce = _proj2kav(x0reduce, Areduce, breduce, pafreduce); + Mat X(x0.rows(), 1); + X.setOnes(); + for (int i = 0; i < n; i++) + { + X(indices[i], 0) = xreduce(i, 0); + } + return X; + } + +} // namespace match diff --git a/library/pymatch/include/visualize.hpp b/library/pymatch/include/visualize.hpp new file mode 100644 index 0000000..ff3704a --- /dev/null +++ b/library/pymatch/include/visualize.hpp @@ -0,0 +1,34 @@ +/*** + * @Date: 2020-09-12 19:37:01 + * @Author: Qing Shuai + * @LastEditors: Qing Shuai + * @LastEditTime: 2020-09-12 19:37:46 + * @FilePath: /MatchLR/include/visualize.hpp + */ +#pragma once +#ifdef _USE_OPENCV_ +#include "opencv2/opencv.hpp" +#include "opencv2/core/eigen.hpp" +#endif + +namespace match +{ + + +#ifdef _USE_OPENCV_ +cv::Mat eigen2mat(Mat Z){ + cv::Mat showi, showd, showrgb; + auto Zmin = Z.minCoeff(); + auto Zmax = Z.maxCoeff(); + Z = ((Z.array() - Zmin)/(Zmax - Zmin)).matrix(); + cv::eigen2cv(Z, showd); + showd.convertTo(showi, CV_8UC1, 255); + cv::applyColorMap(showi, showrgb, cv::COLORMAP_JET); + while(showrgb.rows < 600){ + cv::resize(showrgb, showrgb, cv::Size(), 2, 2, cv::INTER_NEAREST); + } + return showrgb; +} +#endif + +} // namespace match diff --git a/library/pymatch/pymatchlr/__init__.py b/library/pymatch/pymatchlr/__init__.py new file mode 100644 index 0000000..eaf8dc4 --- /dev/null +++ b/library/pymatch/pymatchlr/__init__.py @@ -0,0 +1,8 @@ +''' + * @ Date: 2020-09-12 17:58:02 + * @ Author: Qing Shuai + * @ LastEditors: Qing Shuai + * @ LastEditTime: 2020-09-15 14:38:40 + * @ FilePath: /MatchLR/pymatchlr/__init__.py +''' +from .pymatchlr import * diff --git a/library/pymatch/python/CMakeLists.txt b/library/pymatch/python/CMakeLists.txt new file mode 100644 index 0000000..7e54860 --- /dev/null +++ b/library/pymatch/python/CMakeLists.txt @@ -0,0 +1,10 @@ +set(PYMODULES "pymatchlr") + +foreach(module ${PYMODULES}) + pybind11_add_module("${module}" "${module}.cpp") + target_link_libraries("${module}" + PRIVATE pybind11::module + PUBLIC ${OTHER_LIBS}) + target_include_directories("${module}" + PUBLIC ${PUBLIC_INCLUDE}) +endforeach(module ${PYMODULES}) diff --git a/library/pymatch/python/__init__.py b/library/pymatch/python/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/library/pymatch/python/pymatchlr.cpp b/library/pymatch/python/pymatchlr.cpp new file mode 100644 index 0000000..b2ef6f1 --- /dev/null +++ b/library/pymatch/python/pymatchlr.cpp @@ -0,0 +1,33 @@ +/*** + * @Date: 2020-09-18 14:05:37 + * @Author: Qing Shuai + * @LastEditors: Qing Shuai + * @LastEditTime: 2021-07-24 14:50:42 + * @FilePath: /EasyMocap/library/pymatch/python/pymatchlr.cpp + */ +/* + * @Date: 2020-06-29 10:51:28 + * @LastEditors: Qing Shuai + * @LastEditTime: 2020-07-12 17:11:43 + * @Author: Qing Shuai + * @Mail: s_q@zju.edu.cn + */ +#include +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/numpy.h" +#include "pybind11/eigen.h" +#include "matchSVT.hpp" + +#define myprint(x) std::cout << #x << ": " << std::endl << x.transpose() << std::endl; +#define printshape(x) std::cout << #x << ": (" << x.rows() << ", " << x.cols() << ")" << std::endl; + +namespace py = pybind11; + +PYBIND11_MODULE(pymatchlr, m) { + m.def("matchSVT", &match::matchSVT, "SVT for matching", + py::arg("affinity"), py::arg("dimGroups"), py::arg("constraint"), py::arg("observe"), py::arg("debug")); + m.def("matchALS", &match::matchALS, "ALS for matching", + py::arg("affinity"), py::arg("dimGroups"), py::arg("constraint"), py::arg("observe"), py::arg("debug")); + m.attr("__version__") = "0.1.0"; +} \ No newline at end of file diff --git a/library/pymatch/setup.py b/library/pymatch/setup.py new file mode 100644 index 0000000..476d1fc --- /dev/null +++ b/library/pymatch/setup.py @@ -0,0 +1,81 @@ +''' + * @ Date: 2020-09-12 17:35:46 + * @ Author: Qing Shuai + * @ LastEditors: Qing Shuai + * @ LastEditTime: 2020-09-12 17:57:36 + * @ FilePath: /MatchLR/setup.py +''' +import os +import re +import sys +import platform +import subprocess + +from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext +from distutils.version import LooseVersion + + +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=''): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +class CMakeBuild(build_ext): + def run(self): + try: + out = subprocess.check_output(['cmake', '--version']) + except OSError: + raise RuntimeError("CMake must be installed to build the following extensions: " + + ", ".join(e.name for e in self.extensions)) + + if platform.system() == "Windows": + cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1)) + if cmake_version < '3.1.0': + raise RuntimeError("CMake >= 3.1.0 is required on Windows") + + for ext in self.extensions: + self.build_extension(ext) + + def build_extension(self, ext): + extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + # required for auto-detection of auxiliary "native" libs + if not extdir.endswith(os.path.sep): + extdir += os.path.sep + + cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, + '-DPYTHON_EXECUTABLE=' + sys.executable] + + cfg = 'Debug' if self.debug else 'Release' + build_args = ['--config', cfg] + + if platform.system() == "Windows": + cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)] + if sys.maxsize > 2**32: + cmake_args += ['-A', 'x64'] + build_args += ['--', '/m'] + else: + cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg] + build_args += ['--', '-j2'] + + env = os.environ.copy() + env['CXXFLAGS'] = '{} -DVERSION_INFO=\\"{}\\"'.format(env.get('CXXFLAGS', ''), + self.distribution.get_version()) + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env) + subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp) + +setup( + name='pymatchlr', + version='0.0.2', + author='Qing Shuai', + author_email='s_q@zju.edu.cn', + description='A project for low rank matching algorithm', + long_description='', + ext_modules=[CMakeExtension('pymatchlr.pymatchlr')], + packages=['pymatchlr'], + cmdclass=dict(build_ext=CMakeBuild), + zip_safe=False, +)