add match func
This commit is contained in:
parent
4919cdc417
commit
a66a138314
49
library/pymatch/CMakeLists.txt
Normal file
49
library/pymatch/CMakeLists.txt
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.12.0)
|
||||||
|
project(MATCHLR VERSION 0.1.0)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 11)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_FLAGS "-pthread -O3 -fPIC")
|
||||||
|
|
||||||
|
add_definitions(-DPROJECT_SOURCE_DIR="${PROJECT_SOURCE_DIR}")
|
||||||
|
|
||||||
|
set(Eigen3_DIR "${PROJECT_SOURCE_DIR}/../../3rdparty/eigen-3.3.7/share/eigen3/cmake")
|
||||||
|
set(PYBIND11_DIR "${PROJECT_SOURCE_DIR}/../../3rdparty/pybind11")
|
||||||
|
|
||||||
|
find_package(Eigen3 3.3.7 REQUIRED)
|
||||||
|
message(STATUS "Eigen3 path is ${EIGEN3_INCLUDE_DIR}")
|
||||||
|
|
||||||
|
set(USE_OPENCV 0)
|
||||||
|
if(USE_OPENCV)
|
||||||
|
# OpenCV
|
||||||
|
find_package(OpenCV REQUIRED)
|
||||||
|
add_definitions(-D_USE_OPENCV_)
|
||||||
|
endif(USE_OPENCV)
|
||||||
|
|
||||||
|
IF (WIN32)
|
||||||
|
MESSAGE(STATUS "I don't test on Windows")
|
||||||
|
ELSEIF (APPLE)
|
||||||
|
MESSAGE(STATUS "Not use openmp")
|
||||||
|
ELSEIF (UNIX)
|
||||||
|
find_package(OpenMP REQUIRED)
|
||||||
|
set(OTHER_LIBS ${OTHER_LIBS} OpenMP::OpenMP_CXX)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
|
||||||
|
ENDIF ()
|
||||||
|
|
||||||
|
set(PUBLIC_INCLUDE ${PROJECT_SOURCE_DIR}
|
||||||
|
${PROJECT_SOURCE_DIR}/include
|
||||||
|
${EIGEN3_INCLUDE_DIR}
|
||||||
|
)
|
||||||
|
|
||||||
|
set(OTHER_SRCS "")
|
||||||
|
|
||||||
|
|
||||||
|
if(USE_OPENCV)
|
||||||
|
set(PUBLIC_INCLUDE ${PUBLIC_INCLUDE} ${OpenCV_INCLUDE_DIRS})
|
||||||
|
set(OTHER_LIBS ${OTHER_LIBS} ${OpenCV_LIBRARIES})
|
||||||
|
endif(USE_OPENCV)
|
||||||
|
|
||||||
|
# add_subdirectory(test)
|
||||||
|
add_subdirectory(${PYBIND11_DIR} "PYBIND11_out")
|
||||||
|
add_subdirectory(python)
|
||||||
|
# add_subdirectory(src)
|
63
library/pymatch/include/Timer.hpp
Normal file
63
library/pymatch/include/Timer.hpp
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <iostream>
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
class Timer
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
std::string name_;
|
||||||
|
Timer(std::string name):name_(name){};
|
||||||
|
Timer(){};
|
||||||
|
void start();
|
||||||
|
void tic();
|
||||||
|
void toc();
|
||||||
|
void toc(std::string things);
|
||||||
|
void end();
|
||||||
|
double now();
|
||||||
|
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::chrono::steady_clock::time_point t_s_; //start time ponit
|
||||||
|
std::chrono::steady_clock::time_point t_tic_; //tic time ponit
|
||||||
|
std::chrono::steady_clock::time_point t_toc_; //toc time ponit
|
||||||
|
std::chrono::steady_clock::time_point t_e_; //stop time point
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
void Timer::tic()
|
||||||
|
{
|
||||||
|
t_tic_ = std::chrono::steady_clock::now();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Timer::toc()
|
||||||
|
{
|
||||||
|
t_toc_ = std::chrono::steady_clock::now();
|
||||||
|
auto tt = std::chrono::duration_cast<std::chrono::duration<double>>(t_toc_ - t_tic_).count();
|
||||||
|
std::cout << "Time spend: " << tt << " seconds" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Timer::toc(std::string things)
|
||||||
|
{
|
||||||
|
t_toc_ = std::chrono::steady_clock::now();
|
||||||
|
auto tt = std::chrono::duration_cast<std::chrono::duration<double>>(t_toc_ - t_tic_).count();
|
||||||
|
std::cout << "Time spend: " << tt << " seconds when doing "<<things << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Timer::start()
|
||||||
|
{
|
||||||
|
t_s_ = std::chrono::steady_clock::now();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Timer::end()
|
||||||
|
{
|
||||||
|
t_e_ = std::chrono::steady_clock::now();
|
||||||
|
auto tt = std::chrono::duration_cast<std::chrono::duration<double>>(t_e_ - t_s_).count();
|
||||||
|
std::cout << "< "<<this->name_<<" > Time total spend: " << tt << " seconds" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
double Timer::now()
|
||||||
|
{
|
||||||
|
t_e_ = std::chrono::steady_clock::now();
|
||||||
|
auto tt = std::chrono::duration_cast<std::chrono::duration<double>>(t_e_ - t_s_).count();
|
||||||
|
return tt;
|
||||||
|
}
|
41
library/pymatch/include/base.h
Normal file
41
library/pymatch/include/base.h
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
/***
|
||||||
|
* @Date: 2020-09-19 16:10:21
|
||||||
|
* @Author: Qing Shuai
|
||||||
|
* @LastEditors: Qing Shuai
|
||||||
|
* @LastEditTime: 2020-09-24 21:09:58
|
||||||
|
* @FilePath: /MatchLR/include/match/base.h
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <vector>
|
||||||
|
#include "Eigen/Dense"
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace match
|
||||||
|
{
|
||||||
|
typedef float Type;
|
||||||
|
typedef Eigen::Matrix<Type, -1, -1> Mat;
|
||||||
|
typedef Eigen::Array<Type, -1, -1> Array;
|
||||||
|
template <typename T>
|
||||||
|
using Vec=std::vector<T>;
|
||||||
|
typedef std::vector<int> List;
|
||||||
|
typedef std::vector<List> ListList;
|
||||||
|
struct MatchInfo
|
||||||
|
{
|
||||||
|
int maxIter = 100;
|
||||||
|
float alpha = 200;
|
||||||
|
float beta = 0.1;
|
||||||
|
float tol = 1e-3;
|
||||||
|
float w_sparse = 0.1;
|
||||||
|
float w_rank = 50;
|
||||||
|
};
|
||||||
|
typedef std::unordered_map<std::string, float> Control;
|
||||||
|
|
||||||
|
void print(Vec<int>& lists, std::string name){
|
||||||
|
std::cout << name << ": [";
|
||||||
|
for(auto i:lists){
|
||||||
|
std::cout << i << ", ";
|
||||||
|
}
|
||||||
|
std::cout << "]" << std::endl;
|
||||||
|
}
|
||||||
|
} // namespace match
|
246
library/pymatch/include/matchSVT.hpp
Normal file
246
library/pymatch/include/matchSVT.hpp
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
/***
|
||||||
|
* @Date: 2020-09-12 19:01:56
|
||||||
|
* @Author: Qing Shuai
|
||||||
|
* @LastEditors: Qing Shuai
|
||||||
|
* @LastEditTime: 2022-07-29 22:38:40
|
||||||
|
* @FilePath: /EasyMocapPublic/library/pymatch/include/matchSVT.hpp
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include "base.h"
|
||||||
|
#include "projfunc.hpp"
|
||||||
|
#include "Timer.hpp"
|
||||||
|
#include "visualize.hpp"
|
||||||
|
|
||||||
|
namespace match
|
||||||
|
{
|
||||||
|
struct Config
|
||||||
|
{
|
||||||
|
bool debug;
|
||||||
|
int max_iter;
|
||||||
|
float tol;
|
||||||
|
float w_rank;
|
||||||
|
float w_sparse;
|
||||||
|
Config(Control& control){
|
||||||
|
debug = (control["debug"] > 0.);
|
||||||
|
max_iter = int(control["maxIter"]);
|
||||||
|
tol = control["tol"];
|
||||||
|
w_rank = control["w_rank"];
|
||||||
|
w_sparse = control["w_sparse"];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// dimGroups: [0, nF1, nF1 + nF2, ...]
|
||||||
|
ListList getBlocksFromDimGroups(const List& dimGroups){
|
||||||
|
ListList block;
|
||||||
|
for(int i=0;i<dimGroups.size() - 1;i++){
|
||||||
|
// 这个视角没有找到人的情况
|
||||||
|
if(dimGroups[i] == dimGroups[i+1])continue;
|
||||||
|
for(int j=0;j<dimGroups.size()-1;j++){
|
||||||
|
if(i==j)continue;
|
||||||
|
if(dimGroups[j] == dimGroups[j+1])continue;
|
||||||
|
block.push_back({dimGroups[i], dimGroups[i+1] - dimGroups[i],
|
||||||
|
dimGroups[j], dimGroups[j+1] - dimGroups[j]});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return block;
|
||||||
|
}
|
||||||
|
// matchSVT with constraint and observation
|
||||||
|
// M_aff: (N, N): affinity matrix
|
||||||
|
// M_constr: =0, when (i, j) cannot be the same person
|
||||||
|
// if not consider this, set to 1(N, N)
|
||||||
|
// M_obs: =0, when (i, j) cannot be observed
|
||||||
|
// if not consider this, set to 1(N, N)
|
||||||
|
Mat matchSVT(Mat M_aff, List dimGroups, Mat M_constr, Mat M_obs, Control control)
|
||||||
|
{
|
||||||
|
bool debug = (control["debug"] > 0.);
|
||||||
|
int max_iter = int(control["maxIter"]);
|
||||||
|
float tol = control["tol"];
|
||||||
|
|
||||||
|
int N = M_aff.rows();
|
||||||
|
auto dual_blocks = getBlocksFromDimGroups(dimGroups);
|
||||||
|
// 对角线约束
|
||||||
|
for(int i=0;i<dimGroups.size() - 1;i++){
|
||||||
|
M_constr.block(dimGroups[i], dimGroups[i], dimGroups[i+1] - dimGroups[i], dimGroups[i+1]-dimGroups[i]).setZero();
|
||||||
|
}
|
||||||
|
M_constr.diagonal().setOnes();
|
||||||
|
// 将affinity乘一下constraint,保证满足约束
|
||||||
|
M_aff = (M_aff.array() * M_constr.array()).matrix();
|
||||||
|
// check一下所有区块,如果最大值和最小值差异过小的,直接认为是错误观测
|
||||||
|
for (auto block : dual_blocks)
|
||||||
|
{
|
||||||
|
Mat mat = M_aff.block(block[0], block[2], block[1], block[3]);
|
||||||
|
if(debug){
|
||||||
|
std::cout << "(" << block[0] << ", " << block[2] << "), ";
|
||||||
|
std::cout << "min: " << mat.minCoeff() << ", max: " << mat.maxCoeff() << std::endl;
|
||||||
|
}
|
||||||
|
if(mat.minCoeff() > 0.7 && block[1] > 1 && block[3] > 1){
|
||||||
|
// 如果大于0.9,说明区分度不够高啊,认为观测是虚假的
|
||||||
|
M_obs.block(block[0], block[2], block[1], block[3]).setZero();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// set the diag of M_aff to zeros
|
||||||
|
M_aff.diagonal().setConstant(0);
|
||||||
|
Mat X = M_aff;
|
||||||
|
Mat Y = Mat::Zero(N, N);
|
||||||
|
Mat Q = M_aff;
|
||||||
|
Mat W = (control["w_sparse"] - M_aff.array()).matrix();
|
||||||
|
float mu = 64;
|
||||||
|
Timer timer;
|
||||||
|
timer.tic();
|
||||||
|
for (int iter = 0; iter < max_iter; iter++)
|
||||||
|
{
|
||||||
|
Mat X0 = X;
|
||||||
|
// update Q with SVT
|
||||||
|
Q = 1.0 / mu * Y + X;
|
||||||
|
Eigen::BDCSVD<Mat> UDV(Q.bdcSvd(Eigen::ComputeThinU | Eigen::ComputeThinV));
|
||||||
|
Array Ds(Dsoft(UDV.singularValues(), control["w_rank"] / mu));
|
||||||
|
Mat Qnew(UDV.matrixU() * Ds.matrix().asDiagonal() * UDV.matrixV().adjoint());
|
||||||
|
Q = Qnew;
|
||||||
|
// update X
|
||||||
|
X = Q - (M_obs.array() * W.array() + Y.array()).matrix() / mu;
|
||||||
|
X = (X.array() * M_constr.array()).matrix();
|
||||||
|
// set the diagonal
|
||||||
|
X.diagonal().setOnes();
|
||||||
|
// 注意这个min,max
|
||||||
|
X = X.cwiseMin(1.f).cwiseMax(0.f);
|
||||||
|
#pragma omp parallel for
|
||||||
|
for(int i=0;i<dual_blocks.size();i++)
|
||||||
|
{
|
||||||
|
auto& block = dual_blocks[i];
|
||||||
|
X.block(block[0], block[2], block[1], block[3]) = myproj2dpam(X.block(block[0], block[2], block[1], block[3]), 1e-2);
|
||||||
|
}
|
||||||
|
X = (X + X.transpose().eval()) / 2;
|
||||||
|
Y = Y + mu * (X - Q);
|
||||||
|
float pRes = (X - Q).norm() / N;
|
||||||
|
float dRes = mu * (X - X0).norm() / N;
|
||||||
|
if(debug){
|
||||||
|
#ifdef _USE_OPENCV_
|
||||||
|
cv::imshow("Q", eigen2mat(Q));
|
||||||
|
cv::imshow("X", eigen2mat(X));
|
||||||
|
cv::waitKey(100);
|
||||||
|
#endif
|
||||||
|
std::cout << "Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
|
||||||
|
}
|
||||||
|
if (pRes < tol && dRes < tol)
|
||||||
|
{
|
||||||
|
std::cout << "End " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pRes > 10 * dRes)
|
||||||
|
{
|
||||||
|
mu *= 2;
|
||||||
|
}
|
||||||
|
else if (dRes > 10 * pRes)
|
||||||
|
{
|
||||||
|
mu /= 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(debug){
|
||||||
|
#ifdef _USE_OPENCV_
|
||||||
|
timer.toc("solving svt");
|
||||||
|
cv::imshow("X", eigen2mat(X));
|
||||||
|
cv::imshow("Q", eigen2mat(Q));
|
||||||
|
cv::waitKey(0);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat matchALS(Mat M_aff, List dimGroups, Mat M_constr, Mat M_obs, Control control)
|
||||||
|
{
|
||||||
|
// This function is to solve
|
||||||
|
// min <W, X> + alpha||x||_* + beta||x||_1, st. X \in C
|
||||||
|
// <beta - W, AB^T> + alpha/2||A||^2 + alpha/2||B||^2
|
||||||
|
// st AB^T = Z, Z\in \Omega
|
||||||
|
const Config cfg(control);
|
||||||
|
Mat W = (M_aff + M_aff.transpose())/2;
|
||||||
|
// set the diag of W to zeros
|
||||||
|
for(int i=0;i<W.rows();i++){
|
||||||
|
W(i, i) = 0;
|
||||||
|
}
|
||||||
|
Mat X = W;
|
||||||
|
Mat Z = W;
|
||||||
|
Mat Y = W;
|
||||||
|
Y.setZero();
|
||||||
|
int mu = 64;
|
||||||
|
int n = X.rows();
|
||||||
|
int maxRank = 0;
|
||||||
|
for(size_t i=0;i<dimGroups.size() - 1;i++){
|
||||||
|
if(dimGroups[i+1]-dimGroups[i] > maxRank){
|
||||||
|
maxRank = dimGroups[i+1]-dimGroups[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::cout << "[matchALS] set the max rank = " << maxRank << std::endl;
|
||||||
|
Mat eyeRank = Mat::Identity(maxRank, maxRank);
|
||||||
|
// initial value
|
||||||
|
Mat A = Mat::Random(n, maxRank);
|
||||||
|
Mat B;
|
||||||
|
for(int iter=0;iter<cfg.max_iter;iter++){
|
||||||
|
Mat X0 = X;
|
||||||
|
X = Z - (((Y - W).array() + cfg.w_sparse)/mu).matrix();
|
||||||
|
B = ((A.transpose() * A + cfg.w_rank/mu * eyeRank).ldlt().solve(A.transpose() * X)).transpose();
|
||||||
|
A = ((B.transpose() * B + cfg.w_rank/mu * eyeRank).ldlt().solve(B.transpose() * X.transpose())).transpose();
|
||||||
|
|
||||||
|
X = A * B.transpose();
|
||||||
|
Z = X + Y/mu;
|
||||||
|
for(int i=0;i<dimGroups.size() - 1;i++){
|
||||||
|
int start = dimGroups[i];
|
||||||
|
int end = dimGroups[i+1];
|
||||||
|
Z.block(start, start, end-start, end-start).setIdentity();
|
||||||
|
}
|
||||||
|
// 注意这个min,max
|
||||||
|
Z = Z.cwiseMin(1.f).cwiseMax(0.f);
|
||||||
|
Y = Y + mu*(X - Z);
|
||||||
|
|
||||||
|
float pRes = (X - Z).norm()/n;
|
||||||
|
float dRes = mu*(X - X0).norm()/n;
|
||||||
|
|
||||||
|
if(cfg.debug){
|
||||||
|
#ifdef _USE_OPENCV_
|
||||||
|
cv::imshow("Z", eigen2mat(Z));
|
||||||
|
cv::imshow("X", eigen2mat(X));
|
||||||
|
cv::waitKey(10);
|
||||||
|
#endif
|
||||||
|
std::cout << "Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pRes < cfg.tol && dRes < cfg.tol)
|
||||||
|
{
|
||||||
|
std::cout << "End " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(pRes > 10*dRes){
|
||||||
|
mu *= 2;
|
||||||
|
}else if(dRes > 10*pRes){
|
||||||
|
mu /= 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
X = (X + X.transpose()) / 2;
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
|
||||||
|
Vec<int> getViewsFromDim(List& dimGroups){
|
||||||
|
Vec<int> lists(dimGroups.back(), -1);
|
||||||
|
for(int i=0;i<dimGroups.size() - 1;i++){
|
||||||
|
for(int c=dimGroups[i];c<dimGroups[i+1];c++){
|
||||||
|
lists[c] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lists;
|
||||||
|
}
|
||||||
|
|
||||||
|
Vec<int> getDimsFromViews(List& views){
|
||||||
|
Vec<int> dims = {0};
|
||||||
|
int startview = 0;
|
||||||
|
for(int i=0;i<views.size();i++){
|
||||||
|
if(views[i] != startview){
|
||||||
|
dims.push_back(i);
|
||||||
|
startview = views[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dims.push_back(views.size());
|
||||||
|
return dims;
|
||||||
|
}
|
||||||
|
} // namespace match
|
335
library/pymatch/include/projfunc.hpp
Normal file
335
library/pymatch/include/projfunc.hpp
Normal file
@ -0,0 +1,335 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <vector>
|
||||||
|
#include <iostream>
|
||||||
|
#include "base.h"
|
||||||
|
|
||||||
|
namespace match
|
||||||
|
{
|
||||||
|
Mat proj2pav(Mat y);
|
||||||
|
Mat projR(Mat X);
|
||||||
|
Mat projC(Mat X);
|
||||||
|
|
||||||
|
Mat myproj2dpam(Mat Y, float tol = 1e-4, bool debug = false)
|
||||||
|
{
|
||||||
|
Mat X0 = Y;
|
||||||
|
Mat X = Y;
|
||||||
|
Mat I2 = X;
|
||||||
|
I2.setZero();
|
||||||
|
Mat X1, I1, X2;
|
||||||
|
for (int iter = 0; iter < 10; iter++)
|
||||||
|
{
|
||||||
|
X1 = projR(X0 + I2);
|
||||||
|
I1 = X1 - (X0 + I2);
|
||||||
|
X2 = projC(X0 + I1);
|
||||||
|
I2 = X2 - (X0 + I1);
|
||||||
|
float chg = (X2 - X).array().abs().sum() / (X.rows() * X.cols());
|
||||||
|
X = X2;
|
||||||
|
if (chg < tol)
|
||||||
|
{
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat projR(Mat X)
|
||||||
|
{
|
||||||
|
int n = X.cols();
|
||||||
|
// std::cout << "before projR: " << X << std::endl;
|
||||||
|
for (int i = 0; i < X.rows(); i++)
|
||||||
|
{
|
||||||
|
Mat x = proj2pav(X.block(i, 0, 1, n).transpose());
|
||||||
|
X.block(i, 0, 1, n) = x.transpose();
|
||||||
|
}
|
||||||
|
// std::cout << "after projR: " << X << std::endl;
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat projC(Mat X)
|
||||||
|
{
|
||||||
|
int n = X.rows();
|
||||||
|
// std::cout << "before projC: " << X << std::endl;
|
||||||
|
for (int j = 0; j < X.cols(); j++)
|
||||||
|
{
|
||||||
|
Mat x = proj2pav(X.block(0, j, n, 1));
|
||||||
|
X.block(0, j, n, 1) = x;
|
||||||
|
}
|
||||||
|
// std::cout << "after projC: " << X << std::endl;
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat proj2pav(Mat y)
|
||||||
|
{
|
||||||
|
y = y.cwiseMax(0.f);
|
||||||
|
Mat x = y;
|
||||||
|
x.setZero();
|
||||||
|
if (y.sum() < 1)
|
||||||
|
{
|
||||||
|
x = y;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
std::vector<float> u, sv;
|
||||||
|
for (int i = 0; i < y.rows(); i++)
|
||||||
|
{
|
||||||
|
u.push_back(y(i, 0));
|
||||||
|
}
|
||||||
|
// 排序
|
||||||
|
std::sort(u.begin(), u.end(), std::greater<float>());
|
||||||
|
float usum = 0;
|
||||||
|
for (int i = 0; i < u.size(); i++)
|
||||||
|
{
|
||||||
|
usum += u[i];
|
||||||
|
sv.push_back(usum);
|
||||||
|
}
|
||||||
|
int rho = 0;
|
||||||
|
for (int i = 0; i < u.size(); i++)
|
||||||
|
{
|
||||||
|
if (u[i] > (sv[i] - 1) / (i + 1))
|
||||||
|
{
|
||||||
|
rho = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float theta = std::max(0.f, (sv[rho] - 1) / (rho + 1));
|
||||||
|
x = (y.array() - theta).matrix();
|
||||||
|
x = x.cwiseMax(0.f);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat proj2pavC(Mat y)
|
||||||
|
{
|
||||||
|
// y: N, 1
|
||||||
|
// y[y<0] = 0
|
||||||
|
int n = y.rows();
|
||||||
|
y = y.cwiseMax(0.f);
|
||||||
|
Mat x = y;
|
||||||
|
if (y.sum() < 1)
|
||||||
|
{
|
||||||
|
x = y;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
std::vector<float> u;
|
||||||
|
for (int i = 0; i < y.rows(); i++)
|
||||||
|
{
|
||||||
|
u.push_back(y(i, 0));
|
||||||
|
}
|
||||||
|
// 排序
|
||||||
|
std::sort(u.begin(), u.end(), std::greater<float>());
|
||||||
|
float tmpsum = 0;
|
||||||
|
bool bget = false;
|
||||||
|
float tmax;
|
||||||
|
for (int ii = 0; ii < n - 1; ii++)
|
||||||
|
{
|
||||||
|
tmpsum += u[ii];
|
||||||
|
tmax = (tmpsum - 1) / (ii + 1);
|
||||||
|
if (tmax >= u[ii + 1])
|
||||||
|
{
|
||||||
|
bget = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!bget)
|
||||||
|
{
|
||||||
|
tmax = (tmpsum + u[n - 1] - 1) / n;
|
||||||
|
}
|
||||||
|
x = (y.array() - tmax).matrix();
|
||||||
|
x = x.cwiseMax(0.f);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
int proj201(Mat &z)
|
||||||
|
{
|
||||||
|
z = z.cwiseMin(1.f).cwiseMax(0.f);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat proj2kav_(Mat x0, Mat A, Mat b)
|
||||||
|
{
|
||||||
|
// to solve:
|
||||||
|
// min 1/2||x - x_0||_F^2 + ||z||_1
|
||||||
|
// s.t. Ax = b, x-z=0, x>=0, x<=1
|
||||||
|
// convert to L(x, y) = 1/2||x - x_0||_F^2 + y^T(Ax - b)
|
||||||
|
// x = (I + \rho A^T @A)^-1 @ (x_0 - A^T@y + \rho A^T@b)
|
||||||
|
// y = y + \rho *(Ax - b)
|
||||||
|
int n = x0.rows();
|
||||||
|
Mat I(n, n);
|
||||||
|
I.setIdentity();
|
||||||
|
Mat X = x0;
|
||||||
|
Mat y = b;
|
||||||
|
float rho = 2;
|
||||||
|
y.setZero();
|
||||||
|
float tol = 1e-4;
|
||||||
|
Mat Y, B, Z, c;
|
||||||
|
for (int iter = 0; iter < 100; iter++)
|
||||||
|
{
|
||||||
|
Mat X0 = X;
|
||||||
|
// (x - x_0) + A^Ty + \rho A^T(Ax + By -c)
|
||||||
|
X = (I + rho * A.transpose() * A).ldlt().solve(x0 - A.transpose() * y + rho * A.transpose() * b);
|
||||||
|
y = y + rho * (A * X - b);
|
||||||
|
|
||||||
|
Y = Y + rho * (A * X + B * Z - c);
|
||||||
|
float pRes = (A * X + B * Z - c).norm() / n;
|
||||||
|
float dRes = rho * (X - X0).norm() / n;
|
||||||
|
// std::cout << " Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), rho = " << rho << std::endl;
|
||||||
|
|
||||||
|
if (pRes < tol && dRes < tol)
|
||||||
|
break;
|
||||||
|
if (pRes > 10 * dRes)
|
||||||
|
{
|
||||||
|
rho *= 2;
|
||||||
|
}
|
||||||
|
else if (dRes > 10 * pRes)
|
||||||
|
{
|
||||||
|
rho /= 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat softthres(Mat b, float thres)
|
||||||
|
{
|
||||||
|
// TODO:vector
|
||||||
|
for (int i = 0; i < b.rows(); i++)
|
||||||
|
{
|
||||||
|
if (b(i, 0) < -thres)
|
||||||
|
{
|
||||||
|
b(i, 0) += thres;
|
||||||
|
}
|
||||||
|
else if (b(i, 0) > thres)
|
||||||
|
{
|
||||||
|
b(i, 0) -= thres;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
b(i, 0) = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
Array Dsoft(const Array &d, float penalty)
|
||||||
|
{
|
||||||
|
// inverts the singular values
|
||||||
|
// takes advantage of the fact that singular values are never negative
|
||||||
|
Array di(d.rows(), d.cols());
|
||||||
|
int maxRank = 0;
|
||||||
|
for (int j = 0; j < d.size(); ++j)
|
||||||
|
{
|
||||||
|
double penalized = d(j, 0) - penalty;
|
||||||
|
if (penalized < 0)
|
||||||
|
{
|
||||||
|
di(j, 0) = 0;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
di(j, 0) = penalized;
|
||||||
|
maxRank++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// std::cout << "max rank: " << maxRank << std::endl;
|
||||||
|
return di;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat _proj2kav(Mat x0, Mat A, Mat b, Mat weight)
|
||||||
|
{
|
||||||
|
// to solve:
|
||||||
|
// min 1/2||x - x_0||_F^2 + 1/2\lambda||Ax - b||_F^2 + \alpha||z||_1
|
||||||
|
// s.t. x=z, z \in {z| z>=0, z<=1|
|
||||||
|
// convert to L(x, y) = 1/2||x - x_0||_F^2 + 1/2\lambda||Ax - b||_F^2 + \alpha||z||_1
|
||||||
|
// + <y, x - z> + \rho/2||x - z||_F^2
|
||||||
|
// update x:
|
||||||
|
// x = (1/rho + I + lambda/rhoA^TA)^-1 @ (1/rho x_0 + lambda/rho A^Tb + y)
|
||||||
|
// update z:
|
||||||
|
// z = softthres(x + 1/rho y, lambda/rho)
|
||||||
|
// update y:
|
||||||
|
// y = y + \rho *(x - z)
|
||||||
|
int n = x0.rows();
|
||||||
|
Mat I(n, n);
|
||||||
|
I.setIdentity();
|
||||||
|
Mat X = x0;
|
||||||
|
Mat Y = X;
|
||||||
|
Mat Z = Y;
|
||||||
|
Y.setZero();
|
||||||
|
float rho = 64;
|
||||||
|
// weight
|
||||||
|
float w_init = 1;
|
||||||
|
float w_paf = 1e-1;
|
||||||
|
float w_Ax = 100;
|
||||||
|
float w_l1 = 1e-1;
|
||||||
|
float tol = 1e-4;
|
||||||
|
std::cout << "x0: " << x0 << std::endl;
|
||||||
|
std::cout << "paf: " << weight << std::endl;
|
||||||
|
for (int iter = 0; iter < 100; iter++)
|
||||||
|
{
|
||||||
|
Mat X0 = X;
|
||||||
|
// update X
|
||||||
|
X = ((rho + w_init) * I + w_Ax * A.transpose() * A).ldlt().solve(x0 + w_Ax * A.transpose() * b + rho * Z - Y + w_paf * weight);
|
||||||
|
// update Z
|
||||||
|
Z = softthres(X + 1 / rho * Y, w_l1 / rho);
|
||||||
|
// projection Z
|
||||||
|
Z = Z.cwiseMin(1.f).cwiseMax(0.f);
|
||||||
|
// update Y
|
||||||
|
Y = Y + rho * (X - Z);
|
||||||
|
// convergence
|
||||||
|
float pRes = (X - Z).norm() / n;
|
||||||
|
float dRes = rho * (X - X0).norm() / n;
|
||||||
|
std::cout << " proj2kav Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), rho = " << rho << std::endl;
|
||||||
|
std::cout << " init= " << w_init * 0.5 * (X - x0).norm() / n
|
||||||
|
<< ", equ= " << 0.5 * w_Ax * (A * X - b).norm() / n
|
||||||
|
<< ", paf=" << -w_paf * weight.transpose() * X
|
||||||
|
<< ", l1= " << 1.0 * (X.array() > 0).count() / n << std::endl;
|
||||||
|
|
||||||
|
if (pRes < tol && dRes < tol)
|
||||||
|
break;
|
||||||
|
if (pRes > 10 * dRes)
|
||||||
|
{
|
||||||
|
rho *= 2;
|
||||||
|
}
|
||||||
|
else if (dRes > 10 * pRes)
|
||||||
|
{
|
||||||
|
rho /= 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat proj2kav(Mat x0, Mat A, Mat b, Mat paf)
|
||||||
|
{
|
||||||
|
// reduce this problem
|
||||||
|
|
||||||
|
std::vector<int> indices;
|
||||||
|
// here we directly set the non-zero entries
|
||||||
|
for (int j = 0; j < A.cols(); j++)
|
||||||
|
{
|
||||||
|
if (A(A.rows() - 1, j) != 0)
|
||||||
|
{
|
||||||
|
indices.push_back(j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// just use the last row
|
||||||
|
int n = indices.size();
|
||||||
|
Mat Areduce(1, n);
|
||||||
|
Areduce.setOnes();
|
||||||
|
Mat x0reduce(n, 1), pafreduce(n, 1);
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
{
|
||||||
|
x0reduce(i, 0) = x0(indices[i], 0);
|
||||||
|
pafreduce(i, 0) = paf(indices[i], 0);
|
||||||
|
}
|
||||||
|
Mat breduce(1, 1);
|
||||||
|
breduce(0, 0) = b(b.rows() - 1, 0);
|
||||||
|
|
||||||
|
Mat xreduce = _proj2kav(x0reduce, Areduce, breduce, pafreduce);
|
||||||
|
Mat X(x0.rows(), 1);
|
||||||
|
X.setOnes();
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
{
|
||||||
|
X(indices[i], 0) = xreduce(i, 0);
|
||||||
|
}
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace match
|
34
library/pymatch/include/visualize.hpp
Normal file
34
library/pymatch/include/visualize.hpp
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
/***
|
||||||
|
* @Date: 2020-09-12 19:37:01
|
||||||
|
* @Author: Qing Shuai
|
||||||
|
* @LastEditors: Qing Shuai
|
||||||
|
* @LastEditTime: 2020-09-12 19:37:46
|
||||||
|
* @FilePath: /MatchLR/include/visualize.hpp
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#ifdef _USE_OPENCV_
|
||||||
|
#include "opencv2/opencv.hpp"
|
||||||
|
#include "opencv2/core/eigen.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace match
|
||||||
|
{
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef _USE_OPENCV_
|
||||||
|
cv::Mat eigen2mat(Mat Z){
|
||||||
|
cv::Mat showi, showd, showrgb;
|
||||||
|
auto Zmin = Z.minCoeff();
|
||||||
|
auto Zmax = Z.maxCoeff();
|
||||||
|
Z = ((Z.array() - Zmin)/(Zmax - Zmin)).matrix();
|
||||||
|
cv::eigen2cv(Z, showd);
|
||||||
|
showd.convertTo(showi, CV_8UC1, 255);
|
||||||
|
cv::applyColorMap(showi, showrgb, cv::COLORMAP_JET);
|
||||||
|
while(showrgb.rows < 600){
|
||||||
|
cv::resize(showrgb, showrgb, cv::Size(), 2, 2, cv::INTER_NEAREST);
|
||||||
|
}
|
||||||
|
return showrgb;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace match
|
8
library/pymatch/pymatchlr/__init__.py
Normal file
8
library/pymatch/pymatchlr/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
'''
|
||||||
|
* @ Date: 2020-09-12 17:58:02
|
||||||
|
* @ Author: Qing Shuai
|
||||||
|
* @ LastEditors: Qing Shuai
|
||||||
|
* @ LastEditTime: 2020-09-15 14:38:40
|
||||||
|
* @ FilePath: /MatchLR/pymatchlr/__init__.py
|
||||||
|
'''
|
||||||
|
from .pymatchlr import *
|
10
library/pymatch/python/CMakeLists.txt
Normal file
10
library/pymatch/python/CMakeLists.txt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
set(PYMODULES "pymatchlr")
|
||||||
|
|
||||||
|
foreach(module ${PYMODULES})
|
||||||
|
pybind11_add_module("${module}" "${module}.cpp")
|
||||||
|
target_link_libraries("${module}"
|
||||||
|
PRIVATE pybind11::module
|
||||||
|
PUBLIC ${OTHER_LIBS})
|
||||||
|
target_include_directories("${module}"
|
||||||
|
PUBLIC ${PUBLIC_INCLUDE})
|
||||||
|
endforeach(module ${PYMODULES})
|
0
library/pymatch/python/__init__.py
Normal file
0
library/pymatch/python/__init__.py
Normal file
33
library/pymatch/python/pymatchlr.cpp
Normal file
33
library/pymatch/python/pymatchlr.cpp
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
/***
|
||||||
|
* @Date: 2020-09-18 14:05:37
|
||||||
|
* @Author: Qing Shuai
|
||||||
|
* @LastEditors: Qing Shuai
|
||||||
|
* @LastEditTime: 2021-07-24 14:50:42
|
||||||
|
* @FilePath: /EasyMocap/library/pymatch/python/pymatchlr.cpp
|
||||||
|
*/
|
||||||
|
/*
|
||||||
|
* @Date: 2020-06-29 10:51:28
|
||||||
|
* @LastEditors: Qing Shuai
|
||||||
|
* @LastEditTime: 2020-07-12 17:11:43
|
||||||
|
* @Author: Qing Shuai
|
||||||
|
* @Mail: s_q@zju.edu.cn
|
||||||
|
*/
|
||||||
|
#include <iostream>
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/stl.h"
|
||||||
|
#include "pybind11/numpy.h"
|
||||||
|
#include "pybind11/eigen.h"
|
||||||
|
#include "matchSVT.hpp"
|
||||||
|
|
||||||
|
#define myprint(x) std::cout << #x << ": " << std::endl << x.transpose() << std::endl;
|
||||||
|
#define printshape(x) std::cout << #x << ": (" << x.rows() << ", " << x.cols() << ")" << std::endl;
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
PYBIND11_MODULE(pymatchlr, m) {
|
||||||
|
m.def("matchSVT", &match::matchSVT, "SVT for matching",
|
||||||
|
py::arg("affinity"), py::arg("dimGroups"), py::arg("constraint"), py::arg("observe"), py::arg("debug"));
|
||||||
|
m.def("matchALS", &match::matchALS, "ALS for matching",
|
||||||
|
py::arg("affinity"), py::arg("dimGroups"), py::arg("constraint"), py::arg("observe"), py::arg("debug"));
|
||||||
|
m.attr("__version__") = "0.1.0";
|
||||||
|
}
|
81
library/pymatch/setup.py
Normal file
81
library/pymatch/setup.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
'''
|
||||||
|
* @ Date: 2020-09-12 17:35:46
|
||||||
|
* @ Author: Qing Shuai
|
||||||
|
* @ LastEditors: Qing Shuai
|
||||||
|
* @ LastEditTime: 2020-09-12 17:57:36
|
||||||
|
* @ FilePath: /MatchLR/setup.py
|
||||||
|
'''
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from setuptools import setup, Extension
|
||||||
|
from setuptools.command.build_ext import build_ext
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
|
||||||
|
|
||||||
|
class CMakeExtension(Extension):
|
||||||
|
def __init__(self, name, sourcedir=''):
|
||||||
|
Extension.__init__(self, name, sources=[])
|
||||||
|
self.sourcedir = os.path.abspath(sourcedir)
|
||||||
|
|
||||||
|
|
||||||
|
class CMakeBuild(build_ext):
|
||||||
|
def run(self):
|
||||||
|
try:
|
||||||
|
out = subprocess.check_output(['cmake', '--version'])
|
||||||
|
except OSError:
|
||||||
|
raise RuntimeError("CMake must be installed to build the following extensions: " +
|
||||||
|
", ".join(e.name for e in self.extensions))
|
||||||
|
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1))
|
||||||
|
if cmake_version < '3.1.0':
|
||||||
|
raise RuntimeError("CMake >= 3.1.0 is required on Windows")
|
||||||
|
|
||||||
|
for ext in self.extensions:
|
||||||
|
self.build_extension(ext)
|
||||||
|
|
||||||
|
def build_extension(self, ext):
|
||||||
|
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
|
||||||
|
# required for auto-detection of auxiliary "native" libs
|
||||||
|
if not extdir.endswith(os.path.sep):
|
||||||
|
extdir += os.path.sep
|
||||||
|
|
||||||
|
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
|
||||||
|
'-DPYTHON_EXECUTABLE=' + sys.executable]
|
||||||
|
|
||||||
|
cfg = 'Debug' if self.debug else 'Release'
|
||||||
|
build_args = ['--config', cfg]
|
||||||
|
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)]
|
||||||
|
if sys.maxsize > 2**32:
|
||||||
|
cmake_args += ['-A', 'x64']
|
||||||
|
build_args += ['--', '/m']
|
||||||
|
else:
|
||||||
|
cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg]
|
||||||
|
build_args += ['--', '-j2']
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env['CXXFLAGS'] = '{} -DVERSION_INFO=\\"{}\\"'.format(env.get('CXXFLAGS', ''),
|
||||||
|
self.distribution.get_version())
|
||||||
|
if not os.path.exists(self.build_temp):
|
||||||
|
os.makedirs(self.build_temp)
|
||||||
|
subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env)
|
||||||
|
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='pymatchlr',
|
||||||
|
version='0.0.2',
|
||||||
|
author='Qing Shuai',
|
||||||
|
author_email='s_q@zju.edu.cn',
|
||||||
|
description='A project for low rank matching algorithm',
|
||||||
|
long_description='',
|
||||||
|
ext_modules=[CMakeExtension('pymatchlr.pymatchlr')],
|
||||||
|
packages=['pymatchlr'],
|
||||||
|
cmdclass=dict(build_ext=CMakeBuild),
|
||||||
|
zip_safe=False,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user