EasyMocap/easymocap/affinity/matchSVT.py
2021-06-28 10:38:36 +08:00

63 lines
1.8 KiB
Python

'''
@ Date: 2021-06-04 20:47:38
@ Author: Qing Shuai
@ LastEditors: Qing Shuai
@ LastEditTime: 2021-06-15 17:30:16
@ FilePath: /EasyMocap/easymocap/affinity/matchSVT.py
'''
import numpy as np
def matchSVT(M_aff, dimGroups, M_constr=None, M_obs=None, control={}):
max_iter = control['maxIter']
w_rank = control['w_rank']
tol = control['tol']
X = M_aff.copy()
N = X.shape[0]
index_diag = np.arange(N)
X[index_diag, index_diag] = 0.
if M_constr is None:
M_constr = np.ones_like(M_aff)
for i in range(len(dimGroups) - 1):
M_constr[dimGroups[i]:dimGroups[i+1], dimGroups[i]:dimGroups[i+1]] = 0
M_constr[index_diag, index_diag] = 1
X = (X + X.T)/2
Y = np.zeros((N, N))
mu = 64
W = control['w_sparse'] - X
for iter_ in range(max_iter):
X0 = X.copy()
# update Q with SVT
Q = 1.0/mu * Y + X
U, s, VT = np.linalg.svd(Q)
diagS = s - w_rank/mu
diagS[diagS<0] = 0
Q = U @ np.diag(diagS) @ VT
# update X
X = Q - (W + Y)/mu
# project X
for i in range(len(dimGroups)-1):
ind1, ind2 = dimGroups[i], dimGroups[i + 1]
X[ind1:ind2, ind1:ind2] = 0
X[index_diag, index_diag] = 1.
X[X < 0] = 0
X[X > 1] = 1
X = X * M_constr
if False:
pass
X = (X + X.T)/2
# update Y
Y = Y + mu * (X - Q)
pRes = np.linalg.norm(X - Q)/N
dRes = mu * np.linalg.norm(X - X0)/N
if control['log']:print('[Match] {}, Res = ({:.4f}, {:.4f}), mu = {}'.format(iter_, pRes, dRes, mu))
if pRes < tol and dRes < tol:
break
if pRes > 10 * dRes:
mu = 2 * mu
elif dRes > 10 * pRes:
mu = mu / 2
return X