add match func
This commit is contained in:
parent
4919cdc417
commit
a66a138314
49
library/pymatch/CMakeLists.txt
Normal file
49
library/pymatch/CMakeLists.txt
Normal file
@ -0,0 +1,49 @@
|
||||
cmake_minimum_required(VERSION 3.12.0)
|
||||
project(MATCHLR VERSION 0.1.0)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "-pthread -O3 -fPIC")
|
||||
|
||||
add_definitions(-DPROJECT_SOURCE_DIR="${PROJECT_SOURCE_DIR}")
|
||||
|
||||
set(Eigen3_DIR "${PROJECT_SOURCE_DIR}/../../3rdparty/eigen-3.3.7/share/eigen3/cmake")
|
||||
set(PYBIND11_DIR "${PROJECT_SOURCE_DIR}/../../3rdparty/pybind11")
|
||||
|
||||
find_package(Eigen3 3.3.7 REQUIRED)
|
||||
message(STATUS "Eigen3 path is ${EIGEN3_INCLUDE_DIR}")
|
||||
|
||||
set(USE_OPENCV 0)
|
||||
if(USE_OPENCV)
|
||||
# OpenCV
|
||||
find_package(OpenCV REQUIRED)
|
||||
add_definitions(-D_USE_OPENCV_)
|
||||
endif(USE_OPENCV)
|
||||
|
||||
IF (WIN32)
|
||||
MESSAGE(STATUS "I don't test on Windows")
|
||||
ELSEIF (APPLE)
|
||||
MESSAGE(STATUS "Not use openmp")
|
||||
ELSEIF (UNIX)
|
||||
find_package(OpenMP REQUIRED)
|
||||
set(OTHER_LIBS ${OTHER_LIBS} OpenMP::OpenMP_CXX)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
|
||||
ENDIF ()
|
||||
|
||||
set(PUBLIC_INCLUDE ${PROJECT_SOURCE_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${EIGEN3_INCLUDE_DIR}
|
||||
)
|
||||
|
||||
set(OTHER_SRCS "")
|
||||
|
||||
|
||||
if(USE_OPENCV)
|
||||
set(PUBLIC_INCLUDE ${PUBLIC_INCLUDE} ${OpenCV_INCLUDE_DIRS})
|
||||
set(OTHER_LIBS ${OTHER_LIBS} ${OpenCV_LIBRARIES})
|
||||
endif(USE_OPENCV)
|
||||
|
||||
# add_subdirectory(test)
|
||||
add_subdirectory(${PYBIND11_DIR} "PYBIND11_out")
|
||||
add_subdirectory(python)
|
||||
# add_subdirectory(src)
|
63
library/pymatch/include/Timer.hpp
Normal file
63
library/pymatch/include/Timer.hpp
Normal file
@ -0,0 +1,63 @@
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <chrono>
|
||||
|
||||
class Timer
|
||||
{
|
||||
public:
|
||||
std::string name_;
|
||||
Timer(std::string name):name_(name){};
|
||||
Timer(){};
|
||||
void start();
|
||||
void tic();
|
||||
void toc();
|
||||
void toc(std::string things);
|
||||
void end();
|
||||
double now();
|
||||
|
||||
|
||||
private:
|
||||
std::chrono::steady_clock::time_point t_s_; //start time ponit
|
||||
std::chrono::steady_clock::time_point t_tic_; //tic time ponit
|
||||
std::chrono::steady_clock::time_point t_toc_; //toc time ponit
|
||||
std::chrono::steady_clock::time_point t_e_; //stop time point
|
||||
};
|
||||
|
||||
|
||||
void Timer::tic()
|
||||
{
|
||||
t_tic_ = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
void Timer::toc()
|
||||
{
|
||||
t_toc_ = std::chrono::steady_clock::now();
|
||||
auto tt = std::chrono::duration_cast<std::chrono::duration<double>>(t_toc_ - t_tic_).count();
|
||||
std::cout << "Time spend: " << tt << " seconds" << std::endl;
|
||||
}
|
||||
|
||||
void Timer::toc(std::string things)
|
||||
{
|
||||
t_toc_ = std::chrono::steady_clock::now();
|
||||
auto tt = std::chrono::duration_cast<std::chrono::duration<double>>(t_toc_ - t_tic_).count();
|
||||
std::cout << "Time spend: " << tt << " seconds when doing "<<things << std::endl;
|
||||
}
|
||||
|
||||
void Timer::start()
|
||||
{
|
||||
t_s_ = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
void Timer::end()
|
||||
{
|
||||
t_e_ = std::chrono::steady_clock::now();
|
||||
auto tt = std::chrono::duration_cast<std::chrono::duration<double>>(t_e_ - t_s_).count();
|
||||
std::cout << "< "<<this->name_<<" > Time total spend: " << tt << " seconds" << std::endl;
|
||||
}
|
||||
|
||||
double Timer::now()
|
||||
{
|
||||
t_e_ = std::chrono::steady_clock::now();
|
||||
auto tt = std::chrono::duration_cast<std::chrono::duration<double>>(t_e_ - t_s_).count();
|
||||
return tt;
|
||||
}
|
41
library/pymatch/include/base.h
Normal file
41
library/pymatch/include/base.h
Normal file
@ -0,0 +1,41 @@
|
||||
/***
|
||||
* @Date: 2020-09-19 16:10:21
|
||||
* @Author: Qing Shuai
|
||||
* @LastEditors: Qing Shuai
|
||||
* @LastEditTime: 2020-09-24 21:09:58
|
||||
* @FilePath: /MatchLR/include/match/base.h
|
||||
*/
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include "Eigen/Dense"
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
|
||||
namespace match
|
||||
{
|
||||
typedef float Type;
|
||||
typedef Eigen::Matrix<Type, -1, -1> Mat;
|
||||
typedef Eigen::Array<Type, -1, -1> Array;
|
||||
template <typename T>
|
||||
using Vec=std::vector<T>;
|
||||
typedef std::vector<int> List;
|
||||
typedef std::vector<List> ListList;
|
||||
struct MatchInfo
|
||||
{
|
||||
int maxIter = 100;
|
||||
float alpha = 200;
|
||||
float beta = 0.1;
|
||||
float tol = 1e-3;
|
||||
float w_sparse = 0.1;
|
||||
float w_rank = 50;
|
||||
};
|
||||
typedef std::unordered_map<std::string, float> Control;
|
||||
|
||||
void print(Vec<int>& lists, std::string name){
|
||||
std::cout << name << ": [";
|
||||
for(auto i:lists){
|
||||
std::cout << i << ", ";
|
||||
}
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
} // namespace match
|
246
library/pymatch/include/matchSVT.hpp
Normal file
246
library/pymatch/include/matchSVT.hpp
Normal file
@ -0,0 +1,246 @@
|
||||
/***
|
||||
* @Date: 2020-09-12 19:01:56
|
||||
* @Author: Qing Shuai
|
||||
* @LastEditors: Qing Shuai
|
||||
* @LastEditTime: 2022-07-29 22:38:40
|
||||
* @FilePath: /EasyMocapPublic/library/pymatch/include/matchSVT.hpp
|
||||
*/
|
||||
#pragma once
|
||||
#include "base.h"
|
||||
#include "projfunc.hpp"
|
||||
#include "Timer.hpp"
|
||||
#include "visualize.hpp"
|
||||
|
||||
namespace match
|
||||
{
|
||||
struct Config
|
||||
{
|
||||
bool debug;
|
||||
int max_iter;
|
||||
float tol;
|
||||
float w_rank;
|
||||
float w_sparse;
|
||||
Config(Control& control){
|
||||
debug = (control["debug"] > 0.);
|
||||
max_iter = int(control["maxIter"]);
|
||||
tol = control["tol"];
|
||||
w_rank = control["w_rank"];
|
||||
w_sparse = control["w_sparse"];
|
||||
}
|
||||
};
|
||||
|
||||
// dimGroups: [0, nF1, nF1 + nF2, ...]
|
||||
ListList getBlocksFromDimGroups(const List& dimGroups){
|
||||
ListList block;
|
||||
for(int i=0;i<dimGroups.size() - 1;i++){
|
||||
// 这个视角没有找到人的情况
|
||||
if(dimGroups[i] == dimGroups[i+1])continue;
|
||||
for(int j=0;j<dimGroups.size()-1;j++){
|
||||
if(i==j)continue;
|
||||
if(dimGroups[j] == dimGroups[j+1])continue;
|
||||
block.push_back({dimGroups[i], dimGroups[i+1] - dimGroups[i],
|
||||
dimGroups[j], dimGroups[j+1] - dimGroups[j]});
|
||||
}
|
||||
}
|
||||
return block;
|
||||
}
|
||||
// matchSVT with constraint and observation
|
||||
// M_aff: (N, N): affinity matrix
|
||||
// M_constr: =0, when (i, j) cannot be the same person
|
||||
// if not consider this, set to 1(N, N)
|
||||
// M_obs: =0, when (i, j) cannot be observed
|
||||
// if not consider this, set to 1(N, N)
|
||||
Mat matchSVT(Mat M_aff, List dimGroups, Mat M_constr, Mat M_obs, Control control)
|
||||
{
|
||||
bool debug = (control["debug"] > 0.);
|
||||
int max_iter = int(control["maxIter"]);
|
||||
float tol = control["tol"];
|
||||
|
||||
int N = M_aff.rows();
|
||||
auto dual_blocks = getBlocksFromDimGroups(dimGroups);
|
||||
// 对角线约束
|
||||
for(int i=0;i<dimGroups.size() - 1;i++){
|
||||
M_constr.block(dimGroups[i], dimGroups[i], dimGroups[i+1] - dimGroups[i], dimGroups[i+1]-dimGroups[i]).setZero();
|
||||
}
|
||||
M_constr.diagonal().setOnes();
|
||||
// 将affinity乘一下constraint,保证满足约束
|
||||
M_aff = (M_aff.array() * M_constr.array()).matrix();
|
||||
// check一下所有区块,如果最大值和最小值差异过小的,直接认为是错误观测
|
||||
for (auto block : dual_blocks)
|
||||
{
|
||||
Mat mat = M_aff.block(block[0], block[2], block[1], block[3]);
|
||||
if(debug){
|
||||
std::cout << "(" << block[0] << ", " << block[2] << "), ";
|
||||
std::cout << "min: " << mat.minCoeff() << ", max: " << mat.maxCoeff() << std::endl;
|
||||
}
|
||||
if(mat.minCoeff() > 0.7 && block[1] > 1 && block[3] > 1){
|
||||
// 如果大于0.9,说明区分度不够高啊,认为观测是虚假的
|
||||
M_obs.block(block[0], block[2], block[1], block[3]).setZero();
|
||||
}
|
||||
}
|
||||
// set the diag of M_aff to zeros
|
||||
M_aff.diagonal().setConstant(0);
|
||||
Mat X = M_aff;
|
||||
Mat Y = Mat::Zero(N, N);
|
||||
Mat Q = M_aff;
|
||||
Mat W = (control["w_sparse"] - M_aff.array()).matrix();
|
||||
float mu = 64;
|
||||
Timer timer;
|
||||
timer.tic();
|
||||
for (int iter = 0; iter < max_iter; iter++)
|
||||
{
|
||||
Mat X0 = X;
|
||||
// update Q with SVT
|
||||
Q = 1.0 / mu * Y + X;
|
||||
Eigen::BDCSVD<Mat> UDV(Q.bdcSvd(Eigen::ComputeThinU | Eigen::ComputeThinV));
|
||||
Array Ds(Dsoft(UDV.singularValues(), control["w_rank"] / mu));
|
||||
Mat Qnew(UDV.matrixU() * Ds.matrix().asDiagonal() * UDV.matrixV().adjoint());
|
||||
Q = Qnew;
|
||||
// update X
|
||||
X = Q - (M_obs.array() * W.array() + Y.array()).matrix() / mu;
|
||||
X = (X.array() * M_constr.array()).matrix();
|
||||
// set the diagonal
|
||||
X.diagonal().setOnes();
|
||||
// 注意这个min,max
|
||||
X = X.cwiseMin(1.f).cwiseMax(0.f);
|
||||
#pragma omp parallel for
|
||||
for(int i=0;i<dual_blocks.size();i++)
|
||||
{
|
||||
auto& block = dual_blocks[i];
|
||||
X.block(block[0], block[2], block[1], block[3]) = myproj2dpam(X.block(block[0], block[2], block[1], block[3]), 1e-2);
|
||||
}
|
||||
X = (X + X.transpose().eval()) / 2;
|
||||
Y = Y + mu * (X - Q);
|
||||
float pRes = (X - Q).norm() / N;
|
||||
float dRes = mu * (X - X0).norm() / N;
|
||||
if(debug){
|
||||
#ifdef _USE_OPENCV_
|
||||
cv::imshow("Q", eigen2mat(Q));
|
||||
cv::imshow("X", eigen2mat(X));
|
||||
cv::waitKey(100);
|
||||
#endif
|
||||
std::cout << "Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
|
||||
}
|
||||
if (pRes < tol && dRes < tol)
|
||||
{
|
||||
std::cout << "End " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
|
||||
break;
|
||||
}
|
||||
|
||||
if (pRes > 10 * dRes)
|
||||
{
|
||||
mu *= 2;
|
||||
}
|
||||
else if (dRes > 10 * pRes)
|
||||
{
|
||||
mu /= 2;
|
||||
}
|
||||
}
|
||||
if(debug){
|
||||
#ifdef _USE_OPENCV_
|
||||
timer.toc("solving svt");
|
||||
cv::imshow("X", eigen2mat(X));
|
||||
cv::imshow("Q", eigen2mat(Q));
|
||||
cv::waitKey(0);
|
||||
#endif
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
Mat matchALS(Mat M_aff, List dimGroups, Mat M_constr, Mat M_obs, Control control)
|
||||
{
|
||||
// This function is to solve
|
||||
// min <W, X> + alpha||x||_* + beta||x||_1, st. X \in C
|
||||
// <beta - W, AB^T> + alpha/2||A||^2 + alpha/2||B||^2
|
||||
// st AB^T = Z, Z\in \Omega
|
||||
const Config cfg(control);
|
||||
Mat W = (M_aff + M_aff.transpose())/2;
|
||||
// set the diag of W to zeros
|
||||
for(int i=0;i<W.rows();i++){
|
||||
W(i, i) = 0;
|
||||
}
|
||||
Mat X = W;
|
||||
Mat Z = W;
|
||||
Mat Y = W;
|
||||
Y.setZero();
|
||||
int mu = 64;
|
||||
int n = X.rows();
|
||||
int maxRank = 0;
|
||||
for(size_t i=0;i<dimGroups.size() - 1;i++){
|
||||
if(dimGroups[i+1]-dimGroups[i] > maxRank){
|
||||
maxRank = dimGroups[i+1]-dimGroups[i];
|
||||
}
|
||||
}
|
||||
std::cout << "[matchALS] set the max rank = " << maxRank << std::endl;
|
||||
Mat eyeRank = Mat::Identity(maxRank, maxRank);
|
||||
// initial value
|
||||
Mat A = Mat::Random(n, maxRank);
|
||||
Mat B;
|
||||
for(int iter=0;iter<cfg.max_iter;iter++){
|
||||
Mat X0 = X;
|
||||
X = Z - (((Y - W).array() + cfg.w_sparse)/mu).matrix();
|
||||
B = ((A.transpose() * A + cfg.w_rank/mu * eyeRank).ldlt().solve(A.transpose() * X)).transpose();
|
||||
A = ((B.transpose() * B + cfg.w_rank/mu * eyeRank).ldlt().solve(B.transpose() * X.transpose())).transpose();
|
||||
|
||||
X = A * B.transpose();
|
||||
Z = X + Y/mu;
|
||||
for(int i=0;i<dimGroups.size() - 1;i++){
|
||||
int start = dimGroups[i];
|
||||
int end = dimGroups[i+1];
|
||||
Z.block(start, start, end-start, end-start).setIdentity();
|
||||
}
|
||||
// 注意这个min,max
|
||||
Z = Z.cwiseMin(1.f).cwiseMax(0.f);
|
||||
Y = Y + mu*(X - Z);
|
||||
|
||||
float pRes = (X - Z).norm()/n;
|
||||
float dRes = mu*(X - X0).norm()/n;
|
||||
|
||||
if(cfg.debug){
|
||||
#ifdef _USE_OPENCV_
|
||||
cv::imshow("Z", eigen2mat(Z));
|
||||
cv::imshow("X", eigen2mat(X));
|
||||
cv::waitKey(10);
|
||||
#endif
|
||||
std::cout << "Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
|
||||
}
|
||||
|
||||
if (pRes < cfg.tol && dRes < cfg.tol)
|
||||
{
|
||||
std::cout << "End " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
|
||||
break;
|
||||
}
|
||||
|
||||
if(pRes > 10*dRes){
|
||||
mu *= 2;
|
||||
}else if(dRes > 10*pRes){
|
||||
mu /= 2;
|
||||
}
|
||||
}
|
||||
X = (X + X.transpose()) / 2;
|
||||
return X;
|
||||
}
|
||||
|
||||
Vec<int> getViewsFromDim(List& dimGroups){
|
||||
Vec<int> lists(dimGroups.back(), -1);
|
||||
for(int i=0;i<dimGroups.size() - 1;i++){
|
||||
for(int c=dimGroups[i];c<dimGroups[i+1];c++){
|
||||
lists[c] = i;
|
||||
}
|
||||
}
|
||||
return lists;
|
||||
}
|
||||
|
||||
Vec<int> getDimsFromViews(List& views){
|
||||
Vec<int> dims = {0};
|
||||
int startview = 0;
|
||||
for(int i=0;i<views.size();i++){
|
||||
if(views[i] != startview){
|
||||
dims.push_back(i);
|
||||
startview = views[i];
|
||||
}
|
||||
}
|
||||
dims.push_back(views.size());
|
||||
return dims;
|
||||
}
|
||||
} // namespace match
|
335
library/pymatch/include/projfunc.hpp
Normal file
335
library/pymatch/include/projfunc.hpp
Normal file
@ -0,0 +1,335 @@
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include "base.h"
|
||||
|
||||
namespace match
|
||||
{
|
||||
Mat proj2pav(Mat y);
|
||||
Mat projR(Mat X);
|
||||
Mat projC(Mat X);
|
||||
|
||||
Mat myproj2dpam(Mat Y, float tol = 1e-4, bool debug = false)
|
||||
{
|
||||
Mat X0 = Y;
|
||||
Mat X = Y;
|
||||
Mat I2 = X;
|
||||
I2.setZero();
|
||||
Mat X1, I1, X2;
|
||||
for (int iter = 0; iter < 10; iter++)
|
||||
{
|
||||
X1 = projR(X0 + I2);
|
||||
I1 = X1 - (X0 + I2);
|
||||
X2 = projC(X0 + I1);
|
||||
I2 = X2 - (X0 + I1);
|
||||
float chg = (X2 - X).array().abs().sum() / (X.rows() * X.cols());
|
||||
X = X2;
|
||||
if (chg < tol)
|
||||
{
|
||||
return X;
|
||||
}
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
Mat projR(Mat X)
|
||||
{
|
||||
int n = X.cols();
|
||||
// std::cout << "before projR: " << X << std::endl;
|
||||
for (int i = 0; i < X.rows(); i++)
|
||||
{
|
||||
Mat x = proj2pav(X.block(i, 0, 1, n).transpose());
|
||||
X.block(i, 0, 1, n) = x.transpose();
|
||||
}
|
||||
// std::cout << "after projR: " << X << std::endl;
|
||||
return X;
|
||||
}
|
||||
|
||||
Mat projC(Mat X)
|
||||
{
|
||||
int n = X.rows();
|
||||
// std::cout << "before projC: " << X << std::endl;
|
||||
for (int j = 0; j < X.cols(); j++)
|
||||
{
|
||||
Mat x = proj2pav(X.block(0, j, n, 1));
|
||||
X.block(0, j, n, 1) = x;
|
||||
}
|
||||
// std::cout << "after projC: " << X << std::endl;
|
||||
return X;
|
||||
}
|
||||
|
||||
Mat proj2pav(Mat y)
|
||||
{
|
||||
y = y.cwiseMax(0.f);
|
||||
Mat x = y;
|
||||
x.setZero();
|
||||
if (y.sum() < 1)
|
||||
{
|
||||
x = y;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<float> u, sv;
|
||||
for (int i = 0; i < y.rows(); i++)
|
||||
{
|
||||
u.push_back(y(i, 0));
|
||||
}
|
||||
// 排序
|
||||
std::sort(u.begin(), u.end(), std::greater<float>());
|
||||
float usum = 0;
|
||||
for (int i = 0; i < u.size(); i++)
|
||||
{
|
||||
usum += u[i];
|
||||
sv.push_back(usum);
|
||||
}
|
||||
int rho = 0;
|
||||
for (int i = 0; i < u.size(); i++)
|
||||
{
|
||||
if (u[i] > (sv[i] - 1) / (i + 1))
|
||||
{
|
||||
rho = i;
|
||||
}
|
||||
}
|
||||
float theta = std::max(0.f, (sv[rho] - 1) / (rho + 1));
|
||||
x = (y.array() - theta).matrix();
|
||||
x = x.cwiseMax(0.f);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
Mat proj2pavC(Mat y)
|
||||
{
|
||||
// y: N, 1
|
||||
// y[y<0] = 0
|
||||
int n = y.rows();
|
||||
y = y.cwiseMax(0.f);
|
||||
Mat x = y;
|
||||
if (y.sum() < 1)
|
||||
{
|
||||
x = y;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<float> u;
|
||||
for (int i = 0; i < y.rows(); i++)
|
||||
{
|
||||
u.push_back(y(i, 0));
|
||||
}
|
||||
// 排序
|
||||
std::sort(u.begin(), u.end(), std::greater<float>());
|
||||
float tmpsum = 0;
|
||||
bool bget = false;
|
||||
float tmax;
|
||||
for (int ii = 0; ii < n - 1; ii++)
|
||||
{
|
||||
tmpsum += u[ii];
|
||||
tmax = (tmpsum - 1) / (ii + 1);
|
||||
if (tmax >= u[ii + 1])
|
||||
{
|
||||
bget = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!bget)
|
||||
{
|
||||
tmax = (tmpsum + u[n - 1] - 1) / n;
|
||||
}
|
||||
x = (y.array() - tmax).matrix();
|
||||
x = x.cwiseMax(0.f);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
int proj201(Mat &z)
|
||||
{
|
||||
z = z.cwiseMin(1.f).cwiseMax(0.f);
|
||||
return 0;
|
||||
}
|
||||
|
||||
Mat proj2kav_(Mat x0, Mat A, Mat b)
|
||||
{
|
||||
// to solve:
|
||||
// min 1/2||x - x_0||_F^2 + ||z||_1
|
||||
// s.t. Ax = b, x-z=0, x>=0, x<=1
|
||||
// convert to L(x, y) = 1/2||x - x_0||_F^2 + y^T(Ax - b)
|
||||
// x = (I + \rho A^T @A)^-1 @ (x_0 - A^T@y + \rho A^T@b)
|
||||
// y = y + \rho *(Ax - b)
|
||||
int n = x0.rows();
|
||||
Mat I(n, n);
|
||||
I.setIdentity();
|
||||
Mat X = x0;
|
||||
Mat y = b;
|
||||
float rho = 2;
|
||||
y.setZero();
|
||||
float tol = 1e-4;
|
||||
Mat Y, B, Z, c;
|
||||
for (int iter = 0; iter < 100; iter++)
|
||||
{
|
||||
Mat X0 = X;
|
||||
// (x - x_0) + A^Ty + \rho A^T(Ax + By -c)
|
||||
X = (I + rho * A.transpose() * A).ldlt().solve(x0 - A.transpose() * y + rho * A.transpose() * b);
|
||||
y = y + rho * (A * X - b);
|
||||
|
||||
Y = Y + rho * (A * X + B * Z - c);
|
||||
float pRes = (A * X + B * Z - c).norm() / n;
|
||||
float dRes = rho * (X - X0).norm() / n;
|
||||
// std::cout << " Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), rho = " << rho << std::endl;
|
||||
|
||||
if (pRes < tol && dRes < tol)
|
||||
break;
|
||||
if (pRes > 10 * dRes)
|
||||
{
|
||||
rho *= 2;
|
||||
}
|
||||
else if (dRes > 10 * pRes)
|
||||
{
|
||||
rho /= 2;
|
||||
}
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
Mat softthres(Mat b, float thres)
|
||||
{
|
||||
// TODO:vector
|
||||
for (int i = 0; i < b.rows(); i++)
|
||||
{
|
||||
if (b(i, 0) < -thres)
|
||||
{
|
||||
b(i, 0) += thres;
|
||||
}
|
||||
else if (b(i, 0) > thres)
|
||||
{
|
||||
b(i, 0) -= thres;
|
||||
}
|
||||
else
|
||||
{
|
||||
b(i, 0) = 0;
|
||||
}
|
||||
}
|
||||
return b;
|
||||
}
|
||||
|
||||
Array Dsoft(const Array &d, float penalty)
|
||||
{
|
||||
// inverts the singular values
|
||||
// takes advantage of the fact that singular values are never negative
|
||||
Array di(d.rows(), d.cols());
|
||||
int maxRank = 0;
|
||||
for (int j = 0; j < d.size(); ++j)
|
||||
{
|
||||
double penalized = d(j, 0) - penalty;
|
||||
if (penalized < 0)
|
||||
{
|
||||
di(j, 0) = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
di(j, 0) = penalized;
|
||||
maxRank++;
|
||||
}
|
||||
}
|
||||
// std::cout << "max rank: " << maxRank << std::endl;
|
||||
return di;
|
||||
}
|
||||
|
||||
Mat _proj2kav(Mat x0, Mat A, Mat b, Mat weight)
|
||||
{
|
||||
// to solve:
|
||||
// min 1/2||x - x_0||_F^2 + 1/2\lambda||Ax - b||_F^2 + \alpha||z||_1
|
||||
// s.t. x=z, z \in {z| z>=0, z<=1|
|
||||
// convert to L(x, y) = 1/2||x - x_0||_F^2 + 1/2\lambda||Ax - b||_F^2 + \alpha||z||_1
|
||||
// + <y, x - z> + \rho/2||x - z||_F^2
|
||||
// update x:
|
||||
// x = (1/rho + I + lambda/rhoA^TA)^-1 @ (1/rho x_0 + lambda/rho A^Tb + y)
|
||||
// update z:
|
||||
// z = softthres(x + 1/rho y, lambda/rho)
|
||||
// update y:
|
||||
// y = y + \rho *(x - z)
|
||||
int n = x0.rows();
|
||||
Mat I(n, n);
|
||||
I.setIdentity();
|
||||
Mat X = x0;
|
||||
Mat Y = X;
|
||||
Mat Z = Y;
|
||||
Y.setZero();
|
||||
float rho = 64;
|
||||
// weight
|
||||
float w_init = 1;
|
||||
float w_paf = 1e-1;
|
||||
float w_Ax = 100;
|
||||
float w_l1 = 1e-1;
|
||||
float tol = 1e-4;
|
||||
std::cout << "x0: " << x0 << std::endl;
|
||||
std::cout << "paf: " << weight << std::endl;
|
||||
for (int iter = 0; iter < 100; iter++)
|
||||
{
|
||||
Mat X0 = X;
|
||||
// update X
|
||||
X = ((rho + w_init) * I + w_Ax * A.transpose() * A).ldlt().solve(x0 + w_Ax * A.transpose() * b + rho * Z - Y + w_paf * weight);
|
||||
// update Z
|
||||
Z = softthres(X + 1 / rho * Y, w_l1 / rho);
|
||||
// projection Z
|
||||
Z = Z.cwiseMin(1.f).cwiseMax(0.f);
|
||||
// update Y
|
||||
Y = Y + rho * (X - Z);
|
||||
// convergence
|
||||
float pRes = (X - Z).norm() / n;
|
||||
float dRes = rho * (X - X0).norm() / n;
|
||||
std::cout << " proj2kav Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), rho = " << rho << std::endl;
|
||||
std::cout << " init= " << w_init * 0.5 * (X - x0).norm() / n
|
||||
<< ", equ= " << 0.5 * w_Ax * (A * X - b).norm() / n
|
||||
<< ", paf=" << -w_paf * weight.transpose() * X
|
||||
<< ", l1= " << 1.0 * (X.array() > 0).count() / n << std::endl;
|
||||
|
||||
if (pRes < tol && dRes < tol)
|
||||
break;
|
||||
if (pRes > 10 * dRes)
|
||||
{
|
||||
rho *= 2;
|
||||
}
|
||||
else if (dRes > 10 * pRes)
|
||||
{
|
||||
rho /= 2;
|
||||
}
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
Mat proj2kav(Mat x0, Mat A, Mat b, Mat paf)
|
||||
{
|
||||
// reduce this problem
|
||||
|
||||
std::vector<int> indices;
|
||||
// here we directly set the non-zero entries
|
||||
for (int j = 0; j < A.cols(); j++)
|
||||
{
|
||||
if (A(A.rows() - 1, j) != 0)
|
||||
{
|
||||
indices.push_back(j);
|
||||
}
|
||||
}
|
||||
// just use the last row
|
||||
int n = indices.size();
|
||||
Mat Areduce(1, n);
|
||||
Areduce.setOnes();
|
||||
Mat x0reduce(n, 1), pafreduce(n, 1);
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
x0reduce(i, 0) = x0(indices[i], 0);
|
||||
pafreduce(i, 0) = paf(indices[i], 0);
|
||||
}
|
||||
Mat breduce(1, 1);
|
||||
breduce(0, 0) = b(b.rows() - 1, 0);
|
||||
|
||||
Mat xreduce = _proj2kav(x0reduce, Areduce, breduce, pafreduce);
|
||||
Mat X(x0.rows(), 1);
|
||||
X.setOnes();
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
X(indices[i], 0) = xreduce(i, 0);
|
||||
}
|
||||
return X;
|
||||
}
|
||||
|
||||
} // namespace match
|
34
library/pymatch/include/visualize.hpp
Normal file
34
library/pymatch/include/visualize.hpp
Normal file
@ -0,0 +1,34 @@
|
||||
/***
|
||||
* @Date: 2020-09-12 19:37:01
|
||||
* @Author: Qing Shuai
|
||||
* @LastEditors: Qing Shuai
|
||||
* @LastEditTime: 2020-09-12 19:37:46
|
||||
* @FilePath: /MatchLR/include/visualize.hpp
|
||||
*/
|
||||
#pragma once
|
||||
#ifdef _USE_OPENCV_
|
||||
#include "opencv2/opencv.hpp"
|
||||
#include "opencv2/core/eigen.hpp"
|
||||
#endif
|
||||
|
||||
namespace match
|
||||
{
|
||||
|
||||
|
||||
#ifdef _USE_OPENCV_
|
||||
cv::Mat eigen2mat(Mat Z){
|
||||
cv::Mat showi, showd, showrgb;
|
||||
auto Zmin = Z.minCoeff();
|
||||
auto Zmax = Z.maxCoeff();
|
||||
Z = ((Z.array() - Zmin)/(Zmax - Zmin)).matrix();
|
||||
cv::eigen2cv(Z, showd);
|
||||
showd.convertTo(showi, CV_8UC1, 255);
|
||||
cv::applyColorMap(showi, showrgb, cv::COLORMAP_JET);
|
||||
while(showrgb.rows < 600){
|
||||
cv::resize(showrgb, showrgb, cv::Size(), 2, 2, cv::INTER_NEAREST);
|
||||
}
|
||||
return showrgb;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace match
|
8
library/pymatch/pymatchlr/__init__.py
Normal file
8
library/pymatch/pymatchlr/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
'''
|
||||
* @ Date: 2020-09-12 17:58:02
|
||||
* @ Author: Qing Shuai
|
||||
* @ LastEditors: Qing Shuai
|
||||
* @ LastEditTime: 2020-09-15 14:38:40
|
||||
* @ FilePath: /MatchLR/pymatchlr/__init__.py
|
||||
'''
|
||||
from .pymatchlr import *
|
10
library/pymatch/python/CMakeLists.txt
Normal file
10
library/pymatch/python/CMakeLists.txt
Normal file
@ -0,0 +1,10 @@
|
||||
set(PYMODULES "pymatchlr")
|
||||
|
||||
foreach(module ${PYMODULES})
|
||||
pybind11_add_module("${module}" "${module}.cpp")
|
||||
target_link_libraries("${module}"
|
||||
PRIVATE pybind11::module
|
||||
PUBLIC ${OTHER_LIBS})
|
||||
target_include_directories("${module}"
|
||||
PUBLIC ${PUBLIC_INCLUDE})
|
||||
endforeach(module ${PYMODULES})
|
0
library/pymatch/python/__init__.py
Normal file
0
library/pymatch/python/__init__.py
Normal file
33
library/pymatch/python/pymatchlr.cpp
Normal file
33
library/pymatch/python/pymatchlr.cpp
Normal file
@ -0,0 +1,33 @@
|
||||
/***
|
||||
* @Date: 2020-09-18 14:05:37
|
||||
* @Author: Qing Shuai
|
||||
* @LastEditors: Qing Shuai
|
||||
* @LastEditTime: 2021-07-24 14:50:42
|
||||
* @FilePath: /EasyMocap/library/pymatch/python/pymatchlr.cpp
|
||||
*/
|
||||
/*
|
||||
* @Date: 2020-06-29 10:51:28
|
||||
* @LastEditors: Qing Shuai
|
||||
* @LastEditTime: 2020-07-12 17:11:43
|
||||
* @Author: Qing Shuai
|
||||
* @Mail: s_q@zju.edu.cn
|
||||
*/
|
||||
#include <iostream>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/numpy.h"
|
||||
#include "pybind11/eigen.h"
|
||||
#include "matchSVT.hpp"
|
||||
|
||||
#define myprint(x) std::cout << #x << ": " << std::endl << x.transpose() << std::endl;
|
||||
#define printshape(x) std::cout << #x << ": (" << x.rows() << ", " << x.cols() << ")" << std::endl;
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(pymatchlr, m) {
|
||||
m.def("matchSVT", &match::matchSVT, "SVT for matching",
|
||||
py::arg("affinity"), py::arg("dimGroups"), py::arg("constraint"), py::arg("observe"), py::arg("debug"));
|
||||
m.def("matchALS", &match::matchALS, "ALS for matching",
|
||||
py::arg("affinity"), py::arg("dimGroups"), py::arg("constraint"), py::arg("observe"), py::arg("debug"));
|
||||
m.attr("__version__") = "0.1.0";
|
||||
}
|
81
library/pymatch/setup.py
Normal file
81
library/pymatch/setup.py
Normal file
@ -0,0 +1,81 @@
|
||||
'''
|
||||
* @ Date: 2020-09-12 17:35:46
|
||||
* @ Author: Qing Shuai
|
||||
* @ LastEditors: Qing Shuai
|
||||
* @ LastEditTime: 2020-09-12 17:57:36
|
||||
* @ FilePath: /MatchLR/setup.py
|
||||
'''
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
from setuptools import setup, Extension
|
||||
from setuptools.command.build_ext import build_ext
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
|
||||
class CMakeExtension(Extension):
|
||||
def __init__(self, name, sourcedir=''):
|
||||
Extension.__init__(self, name, sources=[])
|
||||
self.sourcedir = os.path.abspath(sourcedir)
|
||||
|
||||
|
||||
class CMakeBuild(build_ext):
|
||||
def run(self):
|
||||
try:
|
||||
out = subprocess.check_output(['cmake', '--version'])
|
||||
except OSError:
|
||||
raise RuntimeError("CMake must be installed to build the following extensions: " +
|
||||
", ".join(e.name for e in self.extensions))
|
||||
|
||||
if platform.system() == "Windows":
|
||||
cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1))
|
||||
if cmake_version < '3.1.0':
|
||||
raise RuntimeError("CMake >= 3.1.0 is required on Windows")
|
||||
|
||||
for ext in self.extensions:
|
||||
self.build_extension(ext)
|
||||
|
||||
def build_extension(self, ext):
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
|
||||
# required for auto-detection of auxiliary "native" libs
|
||||
if not extdir.endswith(os.path.sep):
|
||||
extdir += os.path.sep
|
||||
|
||||
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
|
||||
'-DPYTHON_EXECUTABLE=' + sys.executable]
|
||||
|
||||
cfg = 'Debug' if self.debug else 'Release'
|
||||
build_args = ['--config', cfg]
|
||||
|
||||
if platform.system() == "Windows":
|
||||
cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)]
|
||||
if sys.maxsize > 2**32:
|
||||
cmake_args += ['-A', 'x64']
|
||||
build_args += ['--', '/m']
|
||||
else:
|
||||
cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg]
|
||||
build_args += ['--', '-j2']
|
||||
|
||||
env = os.environ.copy()
|
||||
env['CXXFLAGS'] = '{} -DVERSION_INFO=\\"{}\\"'.format(env.get('CXXFLAGS', ''),
|
||||
self.distribution.get_version())
|
||||
if not os.path.exists(self.build_temp):
|
||||
os.makedirs(self.build_temp)
|
||||
subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env)
|
||||
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
|
||||
|
||||
setup(
|
||||
name='pymatchlr',
|
||||
version='0.0.2',
|
||||
author='Qing Shuai',
|
||||
author_email='s_q@zju.edu.cn',
|
||||
description='A project for low rank matching algorithm',
|
||||
long_description='',
|
||||
ext_modules=[CMakeExtension('pymatchlr.pymatchlr')],
|
||||
packages=['pymatchlr'],
|
||||
cmdclass=dict(build_ext=CMakeBuild),
|
||||
zip_safe=False,
|
||||
)
|
Loading…
Reference in New Issue
Block a user