EasyMocap/library/pymatch/include/matchSVT.hpp

247 lines
8.6 KiB
C++
Raw Normal View History

2023-07-10 22:10:55 +08:00
/***
* @Date: 2020-09-12 19:01:56
* @Author: Qing Shuai
* @LastEditors: Qing Shuai
* @LastEditTime: 2022-07-29 22:38:40
* @FilePath: /EasyMocapPublic/library/pymatch/include/matchSVT.hpp
*/
#pragma once
#include "base.h"
#include "projfunc.hpp"
#include "Timer.hpp"
#include "visualize.hpp"
namespace match
{
struct Config
{
bool debug;
int max_iter;
float tol;
float w_rank;
float w_sparse;
Config(Control& control){
debug = (control["debug"] > 0.);
max_iter = int(control["maxIter"]);
tol = control["tol"];
w_rank = control["w_rank"];
w_sparse = control["w_sparse"];
}
};
// dimGroups: [0, nF1, nF1 + nF2, ...]
ListList getBlocksFromDimGroups(const List& dimGroups){
ListList block;
for(int i=0;i<dimGroups.size() - 1;i++){
// 这个视角没有找到人的情况
if(dimGroups[i] == dimGroups[i+1])continue;
for(int j=0;j<dimGroups.size()-1;j++){
if(i==j)continue;
if(dimGroups[j] == dimGroups[j+1])continue;
block.push_back({dimGroups[i], dimGroups[i+1] - dimGroups[i],
dimGroups[j], dimGroups[j+1] - dimGroups[j]});
}
}
return block;
}
// matchSVT with constraint and observation
// M_aff: (N, N): affinity matrix
// M_constr: =0, when (i, j) cannot be the same person
// if not consider this, set to 1(N, N)
// M_obs: =0, when (i, j) cannot be observed
// if not consider this, set to 1(N, N)
Mat matchSVT(Mat M_aff, List dimGroups, Mat M_constr, Mat M_obs, Control control)
{
bool debug = (control["debug"] > 0.);
int max_iter = int(control["maxIter"]);
float tol = control["tol"];
int N = M_aff.rows();
auto dual_blocks = getBlocksFromDimGroups(dimGroups);
// 对角线约束
for(int i=0;i<dimGroups.size() - 1;i++){
M_constr.block(dimGroups[i], dimGroups[i], dimGroups[i+1] - dimGroups[i], dimGroups[i+1]-dimGroups[i]).setZero();
}
M_constr.diagonal().setOnes();
// 将affinity乘一下constraint保证满足约束
M_aff = (M_aff.array() * M_constr.array()).matrix();
// check一下所有区块如果最大值和最小值差异过小的直接认为是错误观测
for (auto block : dual_blocks)
{
Mat mat = M_aff.block(block[0], block[2], block[1], block[3]);
if(debug){
std::cout << "(" << block[0] << ", " << block[2] << "), ";
std::cout << "min: " << mat.minCoeff() << ", max: " << mat.maxCoeff() << std::endl;
}
if(mat.minCoeff() > 0.7 && block[1] > 1 && block[3] > 1){
// 如果大于0.9,说明区分度不够高啊,认为观测是虚假的
M_obs.block(block[0], block[2], block[1], block[3]).setZero();
}
}
// set the diag of M_aff to zeros
M_aff.diagonal().setConstant(0);
Mat X = M_aff;
Mat Y = Mat::Zero(N, N);
Mat Q = M_aff;
Mat W = (control["w_sparse"] - M_aff.array()).matrix();
float mu = 64;
Timer timer;
timer.tic();
for (int iter = 0; iter < max_iter; iter++)
{
Mat X0 = X;
// update Q with SVT
Q = 1.0 / mu * Y + X;
Eigen::BDCSVD<Mat> UDV(Q.bdcSvd(Eigen::ComputeThinU | Eigen::ComputeThinV));
Array Ds(Dsoft(UDV.singularValues(), control["w_rank"] / mu));
Mat Qnew(UDV.matrixU() * Ds.matrix().asDiagonal() * UDV.matrixV().adjoint());
Q = Qnew;
// update X
X = Q - (M_obs.array() * W.array() + Y.array()).matrix() / mu;
X = (X.array() * M_constr.array()).matrix();
// set the diagonal
X.diagonal().setOnes();
// 注意这个min,max
X = X.cwiseMin(1.f).cwiseMax(0.f);
#pragma omp parallel for
for(int i=0;i<dual_blocks.size();i++)
{
auto& block = dual_blocks[i];
X.block(block[0], block[2], block[1], block[3]) = myproj2dpam(X.block(block[0], block[2], block[1], block[3]), 1e-2);
}
X = (X + X.transpose().eval()) / 2;
Y = Y + mu * (X - Q);
float pRes = (X - Q).norm() / N;
float dRes = mu * (X - X0).norm() / N;
if(debug){
#ifdef _USE_OPENCV_
cv::imshow("Q", eigen2mat(Q));
cv::imshow("X", eigen2mat(X));
cv::waitKey(100);
#endif
std::cout << "Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
}
if (pRes < tol && dRes < tol)
{
std::cout << "End " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
break;
}
if (pRes > 10 * dRes)
{
mu *= 2;
}
else if (dRes > 10 * pRes)
{
mu /= 2;
}
}
if(debug){
#ifdef _USE_OPENCV_
timer.toc("solving svt");
cv::imshow("X", eigen2mat(X));
cv::imshow("Q", eigen2mat(Q));
cv::waitKey(0);
#endif
}
return X;
}
Mat matchALS(Mat M_aff, List dimGroups, Mat M_constr, Mat M_obs, Control control)
{
// This function is to solve
// min <W, X> + alpha||x||_* + beta||x||_1, st. X \in C
// <beta - W, AB^T> + alpha/2||A||^2 + alpha/2||B||^2
// st AB^T = Z, Z\in \Omega
const Config cfg(control);
Mat W = (M_aff + M_aff.transpose())/2;
// set the diag of W to zeros
for(int i=0;i<W.rows();i++){
W(i, i) = 0;
}
Mat X = W;
Mat Z = W;
Mat Y = W;
Y.setZero();
int mu = 64;
int n = X.rows();
int maxRank = 0;
for(size_t i=0;i<dimGroups.size() - 1;i++){
if(dimGroups[i+1]-dimGroups[i] > maxRank){
maxRank = dimGroups[i+1]-dimGroups[i];
}
}
std::cout << "[matchALS] set the max rank = " << maxRank << std::endl;
Mat eyeRank = Mat::Identity(maxRank, maxRank);
// initial value
Mat A = Mat::Random(n, maxRank);
Mat B;
for(int iter=0;iter<cfg.max_iter;iter++){
Mat X0 = X;
X = Z - (((Y - W).array() + cfg.w_sparse)/mu).matrix();
B = ((A.transpose() * A + cfg.w_rank/mu * eyeRank).ldlt().solve(A.transpose() * X)).transpose();
A = ((B.transpose() * B + cfg.w_rank/mu * eyeRank).ldlt().solve(B.transpose() * X.transpose())).transpose();
X = A * B.transpose();
Z = X + Y/mu;
for(int i=0;i<dimGroups.size() - 1;i++){
int start = dimGroups[i];
int end = dimGroups[i+1];
Z.block(start, start, end-start, end-start).setIdentity();
}
// 注意这个min,max
Z = Z.cwiseMin(1.f).cwiseMax(0.f);
Y = Y + mu*(X - Z);
float pRes = (X - Z).norm()/n;
float dRes = mu*(X - X0).norm()/n;
if(cfg.debug){
#ifdef _USE_OPENCV_
cv::imshow("Z", eigen2mat(Z));
cv::imshow("X", eigen2mat(X));
cv::waitKey(10);
#endif
std::cout << "Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
}
if (pRes < cfg.tol && dRes < cfg.tol)
{
std::cout << "End " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
break;
}
if(pRes > 10*dRes){
mu *= 2;
}else if(dRes > 10*pRes){
mu /= 2;
}
}
X = (X + X.transpose()) / 2;
return X;
}
Vec<int> getViewsFromDim(List& dimGroups){
Vec<int> lists(dimGroups.back(), -1);
for(int i=0;i<dimGroups.size() - 1;i++){
for(int c=dimGroups[i];c<dimGroups[i+1];c++){
lists[c] = i;
}
}
return lists;
}
Vec<int> getDimsFromViews(List& views){
Vec<int> dims = {0};
int startview = 0;
for(int i=0;i<views.size();i++){
if(views[i] != startview){
dims.push_back(i);
startview = views[i];
}
}
dims.push_back(views.size());
return dims;
}
} // namespace match