add match func

This commit is contained in:
shuaiqing 2023-07-10 22:10:55 +08:00
parent 4919cdc417
commit a66a138314
11 changed files with 900 additions and 0 deletions

View File

@ -0,0 +1,49 @@
cmake_minimum_required(VERSION 3.12.0)
project(MATCHLR VERSION 0.1.0)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_FLAGS "-pthread -O3 -fPIC")
add_definitions(-DPROJECT_SOURCE_DIR="${PROJECT_SOURCE_DIR}")
set(Eigen3_DIR "${PROJECT_SOURCE_DIR}/../../3rdparty/eigen-3.3.7/share/eigen3/cmake")
set(PYBIND11_DIR "${PROJECT_SOURCE_DIR}/../../3rdparty/pybind11")
find_package(Eigen3 3.3.7 REQUIRED)
message(STATUS "Eigen3 path is ${EIGEN3_INCLUDE_DIR}")
set(USE_OPENCV 0)
if(USE_OPENCV)
# OpenCV
find_package(OpenCV REQUIRED)
add_definitions(-D_USE_OPENCV_)
endif(USE_OPENCV)
IF (WIN32)
MESSAGE(STATUS "I don't test on Windows")
ELSEIF (APPLE)
MESSAGE(STATUS "Not use openmp")
ELSEIF (UNIX)
find_package(OpenMP REQUIRED)
set(OTHER_LIBS ${OTHER_LIBS} OpenMP::OpenMP_CXX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
ENDIF ()
set(PUBLIC_INCLUDE ${PROJECT_SOURCE_DIR}
${PROJECT_SOURCE_DIR}/include
${EIGEN3_INCLUDE_DIR}
)
set(OTHER_SRCS "")
if(USE_OPENCV)
set(PUBLIC_INCLUDE ${PUBLIC_INCLUDE} ${OpenCV_INCLUDE_DIRS})
set(OTHER_LIBS ${OTHER_LIBS} ${OpenCV_LIBRARIES})
endif(USE_OPENCV)
# add_subdirectory(test)
add_subdirectory(${PYBIND11_DIR} "PYBIND11_out")
add_subdirectory(python)
# add_subdirectory(src)

View File

@ -0,0 +1,63 @@
#pragma once
#include <iostream>
#include <chrono>
class Timer
{
public:
std::string name_;
Timer(std::string name):name_(name){};
Timer(){};
void start();
void tic();
void toc();
void toc(std::string things);
void end();
double now();
private:
std::chrono::steady_clock::time_point t_s_; //start time ponit
std::chrono::steady_clock::time_point t_tic_; //tic time ponit
std::chrono::steady_clock::time_point t_toc_; //toc time ponit
std::chrono::steady_clock::time_point t_e_; //stop time point
};
void Timer::tic()
{
t_tic_ = std::chrono::steady_clock::now();
}
void Timer::toc()
{
t_toc_ = std::chrono::steady_clock::now();
auto tt = std::chrono::duration_cast<std::chrono::duration<double>>(t_toc_ - t_tic_).count();
std::cout << "Time spend: " << tt << " seconds" << std::endl;
}
void Timer::toc(std::string things)
{
t_toc_ = std::chrono::steady_clock::now();
auto tt = std::chrono::duration_cast<std::chrono::duration<double>>(t_toc_ - t_tic_).count();
std::cout << "Time spend: " << tt << " seconds when doing "<<things << std::endl;
}
void Timer::start()
{
t_s_ = std::chrono::steady_clock::now();
}
void Timer::end()
{
t_e_ = std::chrono::steady_clock::now();
auto tt = std::chrono::duration_cast<std::chrono::duration<double>>(t_e_ - t_s_).count();
std::cout << "< "<<this->name_<<" > Time total spend: " << tt << " seconds" << std::endl;
}
double Timer::now()
{
t_e_ = std::chrono::steady_clock::now();
auto tt = std::chrono::duration_cast<std::chrono::duration<double>>(t_e_ - t_s_).count();
return tt;
}

View File

@ -0,0 +1,41 @@
/***
* @Date: 2020-09-19 16:10:21
* @Author: Qing Shuai
* @LastEditors: Qing Shuai
* @LastEditTime: 2020-09-24 21:09:58
* @FilePath: /MatchLR/include/match/base.h
*/
#pragma once
#include <vector>
#include "Eigen/Dense"
#include <unordered_map>
#include <string>
namespace match
{
typedef float Type;
typedef Eigen::Matrix<Type, -1, -1> Mat;
typedef Eigen::Array<Type, -1, -1> Array;
template <typename T>
using Vec=std::vector<T>;
typedef std::vector<int> List;
typedef std::vector<List> ListList;
struct MatchInfo
{
int maxIter = 100;
float alpha = 200;
float beta = 0.1;
float tol = 1e-3;
float w_sparse = 0.1;
float w_rank = 50;
};
typedef std::unordered_map<std::string, float> Control;
void print(Vec<int>& lists, std::string name){
std::cout << name << ": [";
for(auto i:lists){
std::cout << i << ", ";
}
std::cout << "]" << std::endl;
}
} // namespace match

View File

@ -0,0 +1,246 @@
/***
* @Date: 2020-09-12 19:01:56
* @Author: Qing Shuai
* @LastEditors: Qing Shuai
* @LastEditTime: 2022-07-29 22:38:40
* @FilePath: /EasyMocapPublic/library/pymatch/include/matchSVT.hpp
*/
#pragma once
#include "base.h"
#include "projfunc.hpp"
#include "Timer.hpp"
#include "visualize.hpp"
namespace match
{
struct Config
{
bool debug;
int max_iter;
float tol;
float w_rank;
float w_sparse;
Config(Control& control){
debug = (control["debug"] > 0.);
max_iter = int(control["maxIter"]);
tol = control["tol"];
w_rank = control["w_rank"];
w_sparse = control["w_sparse"];
}
};
// dimGroups: [0, nF1, nF1 + nF2, ...]
ListList getBlocksFromDimGroups(const List& dimGroups){
ListList block;
for(int i=0;i<dimGroups.size() - 1;i++){
// 这个视角没有找到人的情况
if(dimGroups[i] == dimGroups[i+1])continue;
for(int j=0;j<dimGroups.size()-1;j++){
if(i==j)continue;
if(dimGroups[j] == dimGroups[j+1])continue;
block.push_back({dimGroups[i], dimGroups[i+1] - dimGroups[i],
dimGroups[j], dimGroups[j+1] - dimGroups[j]});
}
}
return block;
}
// matchSVT with constraint and observation
// M_aff: (N, N): affinity matrix
// M_constr: =0, when (i, j) cannot be the same person
// if not consider this, set to 1(N, N)
// M_obs: =0, when (i, j) cannot be observed
// if not consider this, set to 1(N, N)
Mat matchSVT(Mat M_aff, List dimGroups, Mat M_constr, Mat M_obs, Control control)
{
bool debug = (control["debug"] > 0.);
int max_iter = int(control["maxIter"]);
float tol = control["tol"];
int N = M_aff.rows();
auto dual_blocks = getBlocksFromDimGroups(dimGroups);
// 对角线约束
for(int i=0;i<dimGroups.size() - 1;i++){
M_constr.block(dimGroups[i], dimGroups[i], dimGroups[i+1] - dimGroups[i], dimGroups[i+1]-dimGroups[i]).setZero();
}
M_constr.diagonal().setOnes();
// 将affinity乘一下constraint保证满足约束
M_aff = (M_aff.array() * M_constr.array()).matrix();
// check一下所有区块如果最大值和最小值差异过小的直接认为是错误观测
for (auto block : dual_blocks)
{
Mat mat = M_aff.block(block[0], block[2], block[1], block[3]);
if(debug){
std::cout << "(" << block[0] << ", " << block[2] << "), ";
std::cout << "min: " << mat.minCoeff() << ", max: " << mat.maxCoeff() << std::endl;
}
if(mat.minCoeff() > 0.7 && block[1] > 1 && block[3] > 1){
// 如果大于0.9,说明区分度不够高啊,认为观测是虚假的
M_obs.block(block[0], block[2], block[1], block[3]).setZero();
}
}
// set the diag of M_aff to zeros
M_aff.diagonal().setConstant(0);
Mat X = M_aff;
Mat Y = Mat::Zero(N, N);
Mat Q = M_aff;
Mat W = (control["w_sparse"] - M_aff.array()).matrix();
float mu = 64;
Timer timer;
timer.tic();
for (int iter = 0; iter < max_iter; iter++)
{
Mat X0 = X;
// update Q with SVT
Q = 1.0 / mu * Y + X;
Eigen::BDCSVD<Mat> UDV(Q.bdcSvd(Eigen::ComputeThinU | Eigen::ComputeThinV));
Array Ds(Dsoft(UDV.singularValues(), control["w_rank"] / mu));
Mat Qnew(UDV.matrixU() * Ds.matrix().asDiagonal() * UDV.matrixV().adjoint());
Q = Qnew;
// update X
X = Q - (M_obs.array() * W.array() + Y.array()).matrix() / mu;
X = (X.array() * M_constr.array()).matrix();
// set the diagonal
X.diagonal().setOnes();
// 注意这个min,max
X = X.cwiseMin(1.f).cwiseMax(0.f);
#pragma omp parallel for
for(int i=0;i<dual_blocks.size();i++)
{
auto& block = dual_blocks[i];
X.block(block[0], block[2], block[1], block[3]) = myproj2dpam(X.block(block[0], block[2], block[1], block[3]), 1e-2);
}
X = (X + X.transpose().eval()) / 2;
Y = Y + mu * (X - Q);
float pRes = (X - Q).norm() / N;
float dRes = mu * (X - X0).norm() / N;
if(debug){
#ifdef _USE_OPENCV_
cv::imshow("Q", eigen2mat(Q));
cv::imshow("X", eigen2mat(X));
cv::waitKey(100);
#endif
std::cout << "Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
}
if (pRes < tol && dRes < tol)
{
std::cout << "End " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
break;
}
if (pRes > 10 * dRes)
{
mu *= 2;
}
else if (dRes > 10 * pRes)
{
mu /= 2;
}
}
if(debug){
#ifdef _USE_OPENCV_
timer.toc("solving svt");
cv::imshow("X", eigen2mat(X));
cv::imshow("Q", eigen2mat(Q));
cv::waitKey(0);
#endif
}
return X;
}
Mat matchALS(Mat M_aff, List dimGroups, Mat M_constr, Mat M_obs, Control control)
{
// This function is to solve
// min <W, X> + alpha||x||_* + beta||x||_1, st. X \in C
// <beta - W, AB^T> + alpha/2||A||^2 + alpha/2||B||^2
// st AB^T = Z, Z\in \Omega
const Config cfg(control);
Mat W = (M_aff + M_aff.transpose())/2;
// set the diag of W to zeros
for(int i=0;i<W.rows();i++){
W(i, i) = 0;
}
Mat X = W;
Mat Z = W;
Mat Y = W;
Y.setZero();
int mu = 64;
int n = X.rows();
int maxRank = 0;
for(size_t i=0;i<dimGroups.size() - 1;i++){
if(dimGroups[i+1]-dimGroups[i] > maxRank){
maxRank = dimGroups[i+1]-dimGroups[i];
}
}
std::cout << "[matchALS] set the max rank = " << maxRank << std::endl;
Mat eyeRank = Mat::Identity(maxRank, maxRank);
// initial value
Mat A = Mat::Random(n, maxRank);
Mat B;
for(int iter=0;iter<cfg.max_iter;iter++){
Mat X0 = X;
X = Z - (((Y - W).array() + cfg.w_sparse)/mu).matrix();
B = ((A.transpose() * A + cfg.w_rank/mu * eyeRank).ldlt().solve(A.transpose() * X)).transpose();
A = ((B.transpose() * B + cfg.w_rank/mu * eyeRank).ldlt().solve(B.transpose() * X.transpose())).transpose();
X = A * B.transpose();
Z = X + Y/mu;
for(int i=0;i<dimGroups.size() - 1;i++){
int start = dimGroups[i];
int end = dimGroups[i+1];
Z.block(start, start, end-start, end-start).setIdentity();
}
// 注意这个min,max
Z = Z.cwiseMin(1.f).cwiseMax(0.f);
Y = Y + mu*(X - Z);
float pRes = (X - Z).norm()/n;
float dRes = mu*(X - X0).norm()/n;
if(cfg.debug){
#ifdef _USE_OPENCV_
cv::imshow("Z", eigen2mat(Z));
cv::imshow("X", eigen2mat(X));
cv::waitKey(10);
#endif
std::cout << "Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
}
if (pRes < cfg.tol && dRes < cfg.tol)
{
std::cout << "End " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
break;
}
if(pRes > 10*dRes){
mu *= 2;
}else if(dRes > 10*pRes){
mu /= 2;
}
}
X = (X + X.transpose()) / 2;
return X;
}
Vec<int> getViewsFromDim(List& dimGroups){
Vec<int> lists(dimGroups.back(), -1);
for(int i=0;i<dimGroups.size() - 1;i++){
for(int c=dimGroups[i];c<dimGroups[i+1];c++){
lists[c] = i;
}
}
return lists;
}
Vec<int> getDimsFromViews(List& views){
Vec<int> dims = {0};
int startview = 0;
for(int i=0;i<views.size();i++){
if(views[i] != startview){
dims.push_back(i);
startview = views[i];
}
}
dims.push_back(views.size());
return dims;
}
} // namespace match

View File

@ -0,0 +1,335 @@
#pragma once
#include <vector>
#include <iostream>
#include "base.h"
namespace match
{
Mat proj2pav(Mat y);
Mat projR(Mat X);
Mat projC(Mat X);
Mat myproj2dpam(Mat Y, float tol = 1e-4, bool debug = false)
{
Mat X0 = Y;
Mat X = Y;
Mat I2 = X;
I2.setZero();
Mat X1, I1, X2;
for (int iter = 0; iter < 10; iter++)
{
X1 = projR(X0 + I2);
I1 = X1 - (X0 + I2);
X2 = projC(X0 + I1);
I2 = X2 - (X0 + I1);
float chg = (X2 - X).array().abs().sum() / (X.rows() * X.cols());
X = X2;
if (chg < tol)
{
return X;
}
}
return X;
}
Mat projR(Mat X)
{
int n = X.cols();
// std::cout << "before projR: " << X << std::endl;
for (int i = 0; i < X.rows(); i++)
{
Mat x = proj2pav(X.block(i, 0, 1, n).transpose());
X.block(i, 0, 1, n) = x.transpose();
}
// std::cout << "after projR: " << X << std::endl;
return X;
}
Mat projC(Mat X)
{
int n = X.rows();
// std::cout << "before projC: " << X << std::endl;
for (int j = 0; j < X.cols(); j++)
{
Mat x = proj2pav(X.block(0, j, n, 1));
X.block(0, j, n, 1) = x;
}
// std::cout << "after projC: " << X << std::endl;
return X;
}
Mat proj2pav(Mat y)
{
y = y.cwiseMax(0.f);
Mat x = y;
x.setZero();
if (y.sum() < 1)
{
x = y;
}
else
{
std::vector<float> u, sv;
for (int i = 0; i < y.rows(); i++)
{
u.push_back(y(i, 0));
}
// 排序
std::sort(u.begin(), u.end(), std::greater<float>());
float usum = 0;
for (int i = 0; i < u.size(); i++)
{
usum += u[i];
sv.push_back(usum);
}
int rho = 0;
for (int i = 0; i < u.size(); i++)
{
if (u[i] > (sv[i] - 1) / (i + 1))
{
rho = i;
}
}
float theta = std::max(0.f, (sv[rho] - 1) / (rho + 1));
x = (y.array() - theta).matrix();
x = x.cwiseMax(0.f);
}
return x;
}
Mat proj2pavC(Mat y)
{
// y: N, 1
// y[y<0] = 0
int n = y.rows();
y = y.cwiseMax(0.f);
Mat x = y;
if (y.sum() < 1)
{
x = y;
}
else
{
std::vector<float> u;
for (int i = 0; i < y.rows(); i++)
{
u.push_back(y(i, 0));
}
// 排序
std::sort(u.begin(), u.end(), std::greater<float>());
float tmpsum = 0;
bool bget = false;
float tmax;
for (int ii = 0; ii < n - 1; ii++)
{
tmpsum += u[ii];
tmax = (tmpsum - 1) / (ii + 1);
if (tmax >= u[ii + 1])
{
bget = true;
break;
}
}
if (!bget)
{
tmax = (tmpsum + u[n - 1] - 1) / n;
}
x = (y.array() - tmax).matrix();
x = x.cwiseMax(0.f);
}
return x;
}
int proj201(Mat &z)
{
z = z.cwiseMin(1.f).cwiseMax(0.f);
return 0;
}
Mat proj2kav_(Mat x0, Mat A, Mat b)
{
// to solve:
// min 1/2||x - x_0||_F^2 + ||z||_1
// s.t. Ax = b, x-z=0, x>=0, x<=1
// convert to L(x, y) = 1/2||x - x_0||_F^2 + y^T(Ax - b)
// x = (I + \rho A^T @A)^-1 @ (x_0 - A^T@y + \rho A^T@b)
// y = y + \rho *(Ax - b)
int n = x0.rows();
Mat I(n, n);
I.setIdentity();
Mat X = x0;
Mat y = b;
float rho = 2;
y.setZero();
float tol = 1e-4;
Mat Y, B, Z, c;
for (int iter = 0; iter < 100; iter++)
{
Mat X0 = X;
// (x - x_0) + A^Ty + \rho A^T(Ax + By -c)
X = (I + rho * A.transpose() * A).ldlt().solve(x0 - A.transpose() * y + rho * A.transpose() * b);
y = y + rho * (A * X - b);
Y = Y + rho * (A * X + B * Z - c);
float pRes = (A * X + B * Z - c).norm() / n;
float dRes = rho * (X - X0).norm() / n;
// std::cout << " Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), rho = " << rho << std::endl;
if (pRes < tol && dRes < tol)
break;
if (pRes > 10 * dRes)
{
rho *= 2;
}
else if (dRes > 10 * pRes)
{
rho /= 2;
}
}
return X;
}
Mat softthres(Mat b, float thres)
{
// TODO:vector
for (int i = 0; i < b.rows(); i++)
{
if (b(i, 0) < -thres)
{
b(i, 0) += thres;
}
else if (b(i, 0) > thres)
{
b(i, 0) -= thres;
}
else
{
b(i, 0) = 0;
}
}
return b;
}
Array Dsoft(const Array &d, float penalty)
{
// inverts the singular values
// takes advantage of the fact that singular values are never negative
Array di(d.rows(), d.cols());
int maxRank = 0;
for (int j = 0; j < d.size(); ++j)
{
double penalized = d(j, 0) - penalty;
if (penalized < 0)
{
di(j, 0) = 0;
}
else
{
di(j, 0) = penalized;
maxRank++;
}
}
// std::cout << "max rank: " << maxRank << std::endl;
return di;
}
Mat _proj2kav(Mat x0, Mat A, Mat b, Mat weight)
{
// to solve:
// min 1/2||x - x_0||_F^2 + 1/2\lambda||Ax - b||_F^2 + \alpha||z||_1
// s.t. x=z, z \in {z| z>=0, z<=1|
// convert to L(x, y) = 1/2||x - x_0||_F^2 + 1/2\lambda||Ax - b||_F^2 + \alpha||z||_1
// + <y, x - z> + \rho/2||x - z||_F^2
// update x:
// x = (1/rho + I + lambda/rhoA^TA)^-1 @ (1/rho x_0 + lambda/rho A^Tb + y)
// update z:
// z = softthres(x + 1/rho y, lambda/rho)
// update y:
// y = y + \rho *(x - z)
int n = x0.rows();
Mat I(n, n);
I.setIdentity();
Mat X = x0;
Mat Y = X;
Mat Z = Y;
Y.setZero();
float rho = 64;
// weight
float w_init = 1;
float w_paf = 1e-1;
float w_Ax = 100;
float w_l1 = 1e-1;
float tol = 1e-4;
std::cout << "x0: " << x0 << std::endl;
std::cout << "paf: " << weight << std::endl;
for (int iter = 0; iter < 100; iter++)
{
Mat X0 = X;
// update X
X = ((rho + w_init) * I + w_Ax * A.transpose() * A).ldlt().solve(x0 + w_Ax * A.transpose() * b + rho * Z - Y + w_paf * weight);
// update Z
Z = softthres(X + 1 / rho * Y, w_l1 / rho);
// projection Z
Z = Z.cwiseMin(1.f).cwiseMax(0.f);
// update Y
Y = Y + rho * (X - Z);
// convergence
float pRes = (X - Z).norm() / n;
float dRes = rho * (X - X0).norm() / n;
std::cout << " proj2kav Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), rho = " << rho << std::endl;
std::cout << " init= " << w_init * 0.5 * (X - x0).norm() / n
<< ", equ= " << 0.5 * w_Ax * (A * X - b).norm() / n
<< ", paf=" << -w_paf * weight.transpose() * X
<< ", l1= " << 1.0 * (X.array() > 0).count() / n << std::endl;
if (pRes < tol && dRes < tol)
break;
if (pRes > 10 * dRes)
{
rho *= 2;
}
else if (dRes > 10 * pRes)
{
rho /= 2;
}
}
return X;
}
Mat proj2kav(Mat x0, Mat A, Mat b, Mat paf)
{
// reduce this problem
std::vector<int> indices;
// here we directly set the non-zero entries
for (int j = 0; j < A.cols(); j++)
{
if (A(A.rows() - 1, j) != 0)
{
indices.push_back(j);
}
}
// just use the last row
int n = indices.size();
Mat Areduce(1, n);
Areduce.setOnes();
Mat x0reduce(n, 1), pafreduce(n, 1);
for (int i = 0; i < n; i++)
{
x0reduce(i, 0) = x0(indices[i], 0);
pafreduce(i, 0) = paf(indices[i], 0);
}
Mat breduce(1, 1);
breduce(0, 0) = b(b.rows() - 1, 0);
Mat xreduce = _proj2kav(x0reduce, Areduce, breduce, pafreduce);
Mat X(x0.rows(), 1);
X.setOnes();
for (int i = 0; i < n; i++)
{
X(indices[i], 0) = xreduce(i, 0);
}
return X;
}
} // namespace match

View File

@ -0,0 +1,34 @@
/***
* @Date: 2020-09-12 19:37:01
* @Author: Qing Shuai
* @LastEditors: Qing Shuai
* @LastEditTime: 2020-09-12 19:37:46
* @FilePath: /MatchLR/include/visualize.hpp
*/
#pragma once
#ifdef _USE_OPENCV_
#include "opencv2/opencv.hpp"
#include "opencv2/core/eigen.hpp"
#endif
namespace match
{
#ifdef _USE_OPENCV_
cv::Mat eigen2mat(Mat Z){
cv::Mat showi, showd, showrgb;
auto Zmin = Z.minCoeff();
auto Zmax = Z.maxCoeff();
Z = ((Z.array() - Zmin)/(Zmax - Zmin)).matrix();
cv::eigen2cv(Z, showd);
showd.convertTo(showi, CV_8UC1, 255);
cv::applyColorMap(showi, showrgb, cv::COLORMAP_JET);
while(showrgb.rows < 600){
cv::resize(showrgb, showrgb, cv::Size(), 2, 2, cv::INTER_NEAREST);
}
return showrgb;
}
#endif
} // namespace match

View File

@ -0,0 +1,8 @@
'''
* @ Date: 2020-09-12 17:58:02
* @ Author: Qing Shuai
* @ LastEditors: Qing Shuai
* @ LastEditTime: 2020-09-15 14:38:40
* @ FilePath: /MatchLR/pymatchlr/__init__.py
'''
from .pymatchlr import *

View File

@ -0,0 +1,10 @@
set(PYMODULES "pymatchlr")
foreach(module ${PYMODULES})
pybind11_add_module("${module}" "${module}.cpp")
target_link_libraries("${module}"
PRIVATE pybind11::module
PUBLIC ${OTHER_LIBS})
target_include_directories("${module}"
PUBLIC ${PUBLIC_INCLUDE})
endforeach(module ${PYMODULES})

View File

View File

@ -0,0 +1,33 @@
/***
* @Date: 2020-09-18 14:05:37
* @Author: Qing Shuai
* @LastEditors: Qing Shuai
* @LastEditTime: 2021-07-24 14:50:42
* @FilePath: /EasyMocap/library/pymatch/python/pymatchlr.cpp
*/
/*
* @Date: 2020-06-29 10:51:28
* @LastEditors: Qing Shuai
* @LastEditTime: 2020-07-12 17:11:43
* @Author: Qing Shuai
* @Mail: s_q@zju.edu.cn
*/
#include <iostream>
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/numpy.h"
#include "pybind11/eigen.h"
#include "matchSVT.hpp"
#define myprint(x) std::cout << #x << ": " << std::endl << x.transpose() << std::endl;
#define printshape(x) std::cout << #x << ": (" << x.rows() << ", " << x.cols() << ")" << std::endl;
namespace py = pybind11;
PYBIND11_MODULE(pymatchlr, m) {
m.def("matchSVT", &match::matchSVT, "SVT for matching",
py::arg("affinity"), py::arg("dimGroups"), py::arg("constraint"), py::arg("observe"), py::arg("debug"));
m.def("matchALS", &match::matchALS, "ALS for matching",
py::arg("affinity"), py::arg("dimGroups"), py::arg("constraint"), py::arg("observe"), py::arg("debug"));
m.attr("__version__") = "0.1.0";
}

81
library/pymatch/setup.py Normal file
View File

@ -0,0 +1,81 @@
'''
* @ Date: 2020-09-12 17:35:46
* @ Author: Qing Shuai
* @ LastEditors: Qing Shuai
* @ LastEditTime: 2020-09-12 17:57:36
* @ FilePath: /MatchLR/setup.py
'''
import os
import re
import sys
import platform
import subprocess
from setuptools import setup, Extension
from setuptools.command.build_ext import build_ext
from distutils.version import LooseVersion
class CMakeExtension(Extension):
def __init__(self, name, sourcedir=''):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)
class CMakeBuild(build_ext):
def run(self):
try:
out = subprocess.check_output(['cmake', '--version'])
except OSError:
raise RuntimeError("CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions))
if platform.system() == "Windows":
cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1))
if cmake_version < '3.1.0':
raise RuntimeError("CMake >= 3.1.0 is required on Windows")
for ext in self.extensions:
self.build_extension(ext)
def build_extension(self, ext):
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
# required for auto-detection of auxiliary "native" libs
if not extdir.endswith(os.path.sep):
extdir += os.path.sep
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
'-DPYTHON_EXECUTABLE=' + sys.executable]
cfg = 'Debug' if self.debug else 'Release'
build_args = ['--config', cfg]
if platform.system() == "Windows":
cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)]
if sys.maxsize > 2**32:
cmake_args += ['-A', 'x64']
build_args += ['--', '/m']
else:
cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg]
build_args += ['--', '-j2']
env = os.environ.copy()
env['CXXFLAGS'] = '{} -DVERSION_INFO=\\"{}\\"'.format(env.get('CXXFLAGS', ''),
self.distribution.get_version())
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
setup(
name='pymatchlr',
version='0.0.2',
author='Qing Shuai',
author_email='s_q@zju.edu.cn',
description='A project for low rank matching algorithm',
long_description='',
ext_modules=[CMakeExtension('pymatchlr.pymatchlr')],
packages=['pymatchlr'],
cmdclass=dict(build_ext=CMakeBuild),
zip_safe=False,
)