247 lines
8.6 KiB
C++
247 lines
8.6 KiB
C++
/***
|
||
* @Date: 2020-09-12 19:01:56
|
||
* @Author: Qing Shuai
|
||
* @LastEditors: Qing Shuai
|
||
* @LastEditTime: 2022-07-29 22:38:40
|
||
* @FilePath: /EasyMocapPublic/library/pymatch/include/matchSVT.hpp
|
||
*/
|
||
#pragma once
|
||
#include "base.h"
|
||
#include "projfunc.hpp"
|
||
#include "Timer.hpp"
|
||
#include "visualize.hpp"
|
||
|
||
namespace match
|
||
{
|
||
struct Config
|
||
{
|
||
bool debug;
|
||
int max_iter;
|
||
float tol;
|
||
float w_rank;
|
||
float w_sparse;
|
||
Config(Control& control){
|
||
debug = (control["debug"] > 0.);
|
||
max_iter = int(control["maxIter"]);
|
||
tol = control["tol"];
|
||
w_rank = control["w_rank"];
|
||
w_sparse = control["w_sparse"];
|
||
}
|
||
};
|
||
|
||
// dimGroups: [0, nF1, nF1 + nF2, ...]
|
||
ListList getBlocksFromDimGroups(const List& dimGroups){
|
||
ListList block;
|
||
for(int i=0;i<dimGroups.size() - 1;i++){
|
||
// 这个视角没有找到人的情况
|
||
if(dimGroups[i] == dimGroups[i+1])continue;
|
||
for(int j=0;j<dimGroups.size()-1;j++){
|
||
if(i==j)continue;
|
||
if(dimGroups[j] == dimGroups[j+1])continue;
|
||
block.push_back({dimGroups[i], dimGroups[i+1] - dimGroups[i],
|
||
dimGroups[j], dimGroups[j+1] - dimGroups[j]});
|
||
}
|
||
}
|
||
return block;
|
||
}
|
||
// matchSVT with constraint and observation
|
||
// M_aff: (N, N): affinity matrix
|
||
// M_constr: =0, when (i, j) cannot be the same person
|
||
// if not consider this, set to 1(N, N)
|
||
// M_obs: =0, when (i, j) cannot be observed
|
||
// if not consider this, set to 1(N, N)
|
||
Mat matchSVT(Mat M_aff, List dimGroups, Mat M_constr, Mat M_obs, Control control)
|
||
{
|
||
bool debug = (control["debug"] > 0.);
|
||
int max_iter = int(control["maxIter"]);
|
||
float tol = control["tol"];
|
||
|
||
int N = M_aff.rows();
|
||
auto dual_blocks = getBlocksFromDimGroups(dimGroups);
|
||
// 对角线约束
|
||
for(int i=0;i<dimGroups.size() - 1;i++){
|
||
M_constr.block(dimGroups[i], dimGroups[i], dimGroups[i+1] - dimGroups[i], dimGroups[i+1]-dimGroups[i]).setZero();
|
||
}
|
||
M_constr.diagonal().setOnes();
|
||
// 将affinity乘一下constraint,保证满足约束
|
||
M_aff = (M_aff.array() * M_constr.array()).matrix();
|
||
// check一下所有区块,如果最大值和最小值差异过小的,直接认为是错误观测
|
||
for (auto block : dual_blocks)
|
||
{
|
||
Mat mat = M_aff.block(block[0], block[2], block[1], block[3]);
|
||
if(debug){
|
||
std::cout << "(" << block[0] << ", " << block[2] << "), ";
|
||
std::cout << "min: " << mat.minCoeff() << ", max: " << mat.maxCoeff() << std::endl;
|
||
}
|
||
if(mat.minCoeff() > 0.7 && block[1] > 1 && block[3] > 1){
|
||
// 如果大于0.9,说明区分度不够高啊,认为观测是虚假的
|
||
M_obs.block(block[0], block[2], block[1], block[3]).setZero();
|
||
}
|
||
}
|
||
// set the diag of M_aff to zeros
|
||
M_aff.diagonal().setConstant(0);
|
||
Mat X = M_aff;
|
||
Mat Y = Mat::Zero(N, N);
|
||
Mat Q = M_aff;
|
||
Mat W = (control["w_sparse"] - M_aff.array()).matrix();
|
||
float mu = 64;
|
||
Timer timer;
|
||
timer.tic();
|
||
for (int iter = 0; iter < max_iter; iter++)
|
||
{
|
||
Mat X0 = X;
|
||
// update Q with SVT
|
||
Q = 1.0 / mu * Y + X;
|
||
Eigen::BDCSVD<Mat> UDV(Q.bdcSvd(Eigen::ComputeThinU | Eigen::ComputeThinV));
|
||
Array Ds(Dsoft(UDV.singularValues(), control["w_rank"] / mu));
|
||
Mat Qnew(UDV.matrixU() * Ds.matrix().asDiagonal() * UDV.matrixV().adjoint());
|
||
Q = Qnew;
|
||
// update X
|
||
X = Q - (M_obs.array() * W.array() + Y.array()).matrix() / mu;
|
||
X = (X.array() * M_constr.array()).matrix();
|
||
// set the diagonal
|
||
X.diagonal().setOnes();
|
||
// 注意这个min,max
|
||
X = X.cwiseMin(1.f).cwiseMax(0.f);
|
||
#pragma omp parallel for
|
||
for(int i=0;i<dual_blocks.size();i++)
|
||
{
|
||
auto& block = dual_blocks[i];
|
||
X.block(block[0], block[2], block[1], block[3]) = myproj2dpam(X.block(block[0], block[2], block[1], block[3]), 1e-2);
|
||
}
|
||
X = (X + X.transpose().eval()) / 2;
|
||
Y = Y + mu * (X - Q);
|
||
float pRes = (X - Q).norm() / N;
|
||
float dRes = mu * (X - X0).norm() / N;
|
||
if(debug){
|
||
#ifdef _USE_OPENCV_
|
||
cv::imshow("Q", eigen2mat(Q));
|
||
cv::imshow("X", eigen2mat(X));
|
||
cv::waitKey(100);
|
||
#endif
|
||
std::cout << "Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
|
||
}
|
||
if (pRes < tol && dRes < tol)
|
||
{
|
||
std::cout << "End " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
|
||
break;
|
||
}
|
||
|
||
if (pRes > 10 * dRes)
|
||
{
|
||
mu *= 2;
|
||
}
|
||
else if (dRes > 10 * pRes)
|
||
{
|
||
mu /= 2;
|
||
}
|
||
}
|
||
if(debug){
|
||
#ifdef _USE_OPENCV_
|
||
timer.toc("solving svt");
|
||
cv::imshow("X", eigen2mat(X));
|
||
cv::imshow("Q", eigen2mat(Q));
|
||
cv::waitKey(0);
|
||
#endif
|
||
}
|
||
return X;
|
||
}
|
||
|
||
Mat matchALS(Mat M_aff, List dimGroups, Mat M_constr, Mat M_obs, Control control)
|
||
{
|
||
// This function is to solve
|
||
// min <W, X> + alpha||x||_* + beta||x||_1, st. X \in C
|
||
// <beta - W, AB^T> + alpha/2||A||^2 + alpha/2||B||^2
|
||
// st AB^T = Z, Z\in \Omega
|
||
const Config cfg(control);
|
||
Mat W = (M_aff + M_aff.transpose())/2;
|
||
// set the diag of W to zeros
|
||
for(int i=0;i<W.rows();i++){
|
||
W(i, i) = 0;
|
||
}
|
||
Mat X = W;
|
||
Mat Z = W;
|
||
Mat Y = W;
|
||
Y.setZero();
|
||
int mu = 64;
|
||
int n = X.rows();
|
||
int maxRank = 0;
|
||
for(size_t i=0;i<dimGroups.size() - 1;i++){
|
||
if(dimGroups[i+1]-dimGroups[i] > maxRank){
|
||
maxRank = dimGroups[i+1]-dimGroups[i];
|
||
}
|
||
}
|
||
std::cout << "[matchALS] set the max rank = " << maxRank << std::endl;
|
||
Mat eyeRank = Mat::Identity(maxRank, maxRank);
|
||
// initial value
|
||
Mat A = Mat::Random(n, maxRank);
|
||
Mat B;
|
||
for(int iter=0;iter<cfg.max_iter;iter++){
|
||
Mat X0 = X;
|
||
X = Z - (((Y - W).array() + cfg.w_sparse)/mu).matrix();
|
||
B = ((A.transpose() * A + cfg.w_rank/mu * eyeRank).ldlt().solve(A.transpose() * X)).transpose();
|
||
A = ((B.transpose() * B + cfg.w_rank/mu * eyeRank).ldlt().solve(B.transpose() * X.transpose())).transpose();
|
||
|
||
X = A * B.transpose();
|
||
Z = X + Y/mu;
|
||
for(int i=0;i<dimGroups.size() - 1;i++){
|
||
int start = dimGroups[i];
|
||
int end = dimGroups[i+1];
|
||
Z.block(start, start, end-start, end-start).setIdentity();
|
||
}
|
||
// 注意这个min,max
|
||
Z = Z.cwiseMin(1.f).cwiseMax(0.f);
|
||
Y = Y + mu*(X - Z);
|
||
|
||
float pRes = (X - Z).norm()/n;
|
||
float dRes = mu*(X - X0).norm()/n;
|
||
|
||
if(cfg.debug){
|
||
#ifdef _USE_OPENCV_
|
||
cv::imshow("Z", eigen2mat(Z));
|
||
cv::imshow("X", eigen2mat(X));
|
||
cv::waitKey(10);
|
||
#endif
|
||
std::cout << "Iter " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
|
||
}
|
||
|
||
if (pRes < cfg.tol && dRes < cfg.tol)
|
||
{
|
||
std::cout << "End " << iter << ", Res = (" << pRes << ", " << dRes << "), mu = " << mu << std::endl;
|
||
break;
|
||
}
|
||
|
||
if(pRes > 10*dRes){
|
||
mu *= 2;
|
||
}else if(dRes > 10*pRes){
|
||
mu /= 2;
|
||
}
|
||
}
|
||
X = (X + X.transpose()) / 2;
|
||
return X;
|
||
}
|
||
|
||
Vec<int> getViewsFromDim(List& dimGroups){
|
||
Vec<int> lists(dimGroups.back(), -1);
|
||
for(int i=0;i<dimGroups.size() - 1;i++){
|
||
for(int c=dimGroups[i];c<dimGroups[i+1];c++){
|
||
lists[c] = i;
|
||
}
|
||
}
|
||
return lists;
|
||
}
|
||
|
||
Vec<int> getDimsFromViews(List& views){
|
||
Vec<int> dims = {0};
|
||
int startview = 0;
|
||
for(int i=0;i<views.size();i++){
|
||
if(views[i] != startview){
|
||
dims.push_back(i);
|
||
startview = views[i];
|
||
}
|
||
}
|
||
dims.push_back(views.size());
|
||
return dims;
|
||
}
|
||
} // namespace match
|