232 lines
9.0 KiB
Python
232 lines
9.0 KiB
Python
import cv2 as cv
|
||
import numpy as np
|
||
import os.path as osp
|
||
import argparse
|
||
|
||
from calib_tools import read_json, DataPath
|
||
from calib_tools import read_img_paths
|
||
|
||
from calibrate_extri import findChessboardCorners
|
||
|
||
|
||
# ////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
# detect_chessboard
|
||
# 检测棋盘格上的所有2d角点,存储到dict中,key是相机名,value是2d角点的列表
|
||
def detect_chessboard(check_img_path, pattern):
|
||
imgPaths = read_img_paths(check_img_path)
|
||
if not imgPaths:
|
||
print("No images found!")
|
||
return {}
|
||
|
||
data = {}
|
||
for imgPath in imgPaths:
|
||
camname = osp.splitext(osp.basename(imgPath))[0]
|
||
keypoints2d = findChessboardCorners(imgPath, pattern, False)
|
||
if keypoints2d is not None:
|
||
data[camname] = keypoints2d.tolist()
|
||
else:
|
||
print(f"Failed to find chessboard corners in image: {imgPath}")
|
||
return data
|
||
|
||
|
||
# //////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
def read_cameras(intri_path, extri_path):
|
||
cameras = {}
|
||
intri = read_json(intri_path)
|
||
extri = read_json(extri_path)
|
||
for key in intri:
|
||
cameras[key] = {
|
||
'intri': intri[key],
|
||
'extri': extri[key]
|
||
}
|
||
return cameras
|
||
|
||
|
||
|
||
def plot_line(img, pt1, pt2, lw, col):
|
||
cv.line(img, (int(pt1[0] + 0.5), int(pt1[1] + 0.5)), (int(pt2[0] + 0.5), int(pt2[1] + 0.5)),
|
||
col, lw)
|
||
|
||
|
||
def plot_cross(img, x, y, col, width=-1, lw=-1):
|
||
if lw == -1:
|
||
lw = max(1, int(round(img.shape[0] / 1000)))
|
||
width = lw * 5
|
||
cv.line(img, (int(x - width), int(y)), (int(x + width), int(y)), col, lw)
|
||
cv.line(img, (int(x), int(y - width)), (int(x), int(y + width)), col, lw)
|
||
|
||
|
||
def plot_points2d(img, points2d, lines, lw=-1, col=(0, 255, 0), putText=True, style='+'):
|
||
# Draw 2D points on the image
|
||
if points2d.shape[1] == 2:
|
||
points2d = np.hstack([points2d, np.ones((points2d.shape[0], 1))])
|
||
if lw == -1:
|
||
lw = img.shape[0] // 200
|
||
for i, (x, y, v) in enumerate(points2d):
|
||
if v < 0.01:
|
||
continue
|
||
c = col
|
||
if '+' in style:
|
||
plot_cross(img, x, y, width=10, col=c, lw=lw * 2)
|
||
if 'o' in style:
|
||
cv.circle(img, (int(x), int(y)), 10, c, lw * 2)
|
||
cv.circle(img, (int(x), int(y)), lw, c, lw)
|
||
if putText:
|
||
c = col[::-1]
|
||
font_scale = img.shape[0] / 1000
|
||
cv.putText(img, '{}'.format(i), (int(x), int(y)), cv.FONT_HERSHEY_SIMPLEX, font_scale, c, 2)
|
||
for i, j in lines:
|
||
if points2d[i][2] < 0.01 or points2d[j][2] < 0.01:
|
||
continue
|
||
plot_line(img, points2d[i], points2d[j], max(1, lw // 2), col)
|
||
|
||
|
||
# 对一批关键点进行三角化
|
||
def batch_triangulate(keypoints_, Pall, min_view=2):
|
||
""" triangulate the keypoints of whole body
|
||
|
||
Args:
|
||
keypoints_ (nViews, nJoints, 3): 2D detections
|
||
Pall (nViews, 3, 4) | (nViews, nJoints, 3, 4): projection matrix of each view
|
||
min_view (int, optional): min view for visible points. Defaults to 2.
|
||
|
||
Returns:
|
||
keypoints3d: (nJoints, 4)
|
||
"""
|
||
# keypoints: (nViews, nJoints, 3)
|
||
# Pall: (nViews, 3, 4)
|
||
# A: (nJoints, nViewsx2, 4), x: (nJoints, 4, 1); b: (nJoints, nViewsx2, 1)
|
||
# 计算关键点的可见性,提取有效的关键点
|
||
v = (keypoints_[:, :, -1] > 0).sum(axis=0) # 每个关键点在多少个视角中被检测到
|
||
valid_joint = np.where(v >= min_view)[0] # 至少被 min_view 个视角捕获的点的索引
|
||
keypoints = keypoints_[:, valid_joint] # 筛选有效的关键点
|
||
conf3d = keypoints[:, :, -1].sum(axis=0) / v[valid_joint]
|
||
# P2: P矩阵的最后一行:(1, nViews, 1, 4)
|
||
if len(Pall.shape) == 3:
|
||
P0 = Pall[None, :, 0, :]
|
||
P1 = Pall[None, :, 1, :]
|
||
P2 = Pall[None, :, 2, :]
|
||
else:
|
||
P0 = Pall[:, :, 0, :].swapaxes(0, 1)
|
||
P1 = Pall[:, :, 1, :].swapaxes(0, 1)
|
||
P2 = Pall[:, :, 2, :].swapaxes(0, 1)
|
||
# uP2: x坐标乘上P2: (nJoints, nViews, 1, 4)
|
||
uP2 = keypoints[:, :, 0].T[:, :, None] * P2
|
||
vP2 = keypoints[:, :, 1].T[:, :, None] * P2
|
||
conf = keypoints[:, :, 2].T[:, :, None]
|
||
Au = conf * (uP2 - P0)
|
||
Av = conf * (vP2 - P1)
|
||
A = np.hstack([Au, Av])
|
||
u, s, v = np.linalg.svd(A)
|
||
X = v[:, -1, :]
|
||
X = X / X[:, 3:]
|
||
# out: (nJoints, 4)
|
||
result = np.zeros((keypoints_.shape[1], 4))
|
||
result[valid_joint, :3] = X[:, :3]
|
||
result[valid_joint, 3] = conf3d # * (conf[..., 0].sum(axis=-1)>min_view)
|
||
return result
|
||
|
||
|
||
def reprojectN3(kpts3d, Pall):
|
||
# kpts3d: (N, 3) 或 (N, 4)
|
||
# Pall: (nViews, 3, 4) ,投影矩阵r|t
|
||
nViews = len(Pall)
|
||
# 在xyz坐标后面添加一个1,转换为齐次坐标
|
||
kp3d = np.hstack((kpts3d[:, :3], np.ones((kpts3d.shape[0], 1)))) # 转换为齐次坐标 (N, 4)
|
||
kp2ds = []
|
||
for nv in range(nViews):
|
||
kp2d = Pall[nv] @ kp3d.T # 投影到 2D (3, N)
|
||
kp2d[:2, :] /= kp2d[2:, :] # 归一化齐次坐标
|
||
kp2ds.append(kp2d.T[None, :, :]) # 添加视角维度 (1, N, 3)
|
||
kp2ds = np.vstack(kp2ds) # 拼接所有视角 (nViews, N, 3)
|
||
if kpts3d.shape[-1] == 4:
|
||
kp2ds[..., -1] = kp2ds[..., -1] * (kpts3d[None, :, -1] > 0.) # 保留置信度信息
|
||
return kp2ds
|
||
|
||
|
||
# 输入:内外参,图像上标注的2d点数据
|
||
# 将输入的2d点进行三角化,得到3d点,然后投影到图像上,与标注的2d点进行比较,计算重投影误差
|
||
# 输出:重投影误差,平均误差,最大误差
|
||
# 每个相机视角一张图像,计算重投影误差
|
||
def check_match(pattern):
|
||
# 读取内参和外参
|
||
cameras = read_cameras(DataPath.intri_json_path, DataPath.extri_json_path)
|
||
# 格式:{"cam1": {[x1, y1,conf], [x2, y2,conf], ...}, "cam2": {[x1, y1,conf], [x2, y2,conf], ...]}, ...},每个相机对应一张图片
|
||
kpts2d = detect_chessboard(DataPath.check_data, pattern)
|
||
|
||
# 去畸变
|
||
for cam in cameras:
|
||
K = np.array(cameras[cam]['intri']['K'])
|
||
dist = np.array(cameras[cam]['intri']['dist'])
|
||
points2d = np.array(kpts2d[cam])[:, :2]
|
||
# 将 points2d 数组在第一个轴上扩展一个维度,使其形状从 (N, 2) 变为 (N, 1, 2)
|
||
points2d_undistorted = cv.undistortPoints(np.expand_dims(points2d, axis=1), K, dist)
|
||
# 将去畸变后的点与原始数据中的置信度信息水平拼接,形成新的数组。
|
||
kpts2d[cam] = np.hstack((points2d_undistorted.squeeze(), np.array(kpts2d[cam])[:, 2:]))
|
||
|
||
# 三角化
|
||
# Prepare projection matrices (Pall)
|
||
Pall = []
|
||
keypoints = []
|
||
for cam in cameras:
|
||
K = np.array(cameras[cam]['intri']['K'])
|
||
R = np.array(cameras[cam]['extri']['R'])
|
||
T = np.array(cameras[cam]['extri']['T']).reshape(3, 1)
|
||
P = K @ np.hstack((R, T))
|
||
Pall.append(P)
|
||
keypoints.append(kpts2d[cam])
|
||
|
||
Pall = np.array(Pall)
|
||
keypoints = np.array(keypoints)
|
||
|
||
# Triangulate 3D points
|
||
keypoints3d = batch_triangulate(keypoints, Pall)
|
||
|
||
# Calculate reprojection error and plot results
|
||
# 计算重投影误差并绘制结果
|
||
reprojection_errors = []
|
||
for i, cam in enumerate(cameras):
|
||
P = Pall[i]
|
||
kpts2d_proj = keypoints3d[:, :3] @ P[:, :3].T + P[:, 3] # 将三维关键点投影到二维平面
|
||
kpts2d_proj /= kpts2d_proj[:, 2:3] # Normalize by z
|
||
|
||
# Compare with original 2D keypoints
|
||
kpts2d_actual = keypoints[i, :, :2]
|
||
kpts2d_error = np.linalg.norm(kpts2d_proj[:, :2] - kpts2d_actual, axis=1)
|
||
reprojection_errors.append(kpts2d_error)
|
||
|
||
# Plot reprojection results
|
||
# img = np.zeros((480, 640, 3), dtype=np.uint8) # Placeholder for the actual image
|
||
# 换成原图
|
||
img_path = osp.join(DataPath.check_data, f"{cam}.jpg")
|
||
img = cv.imread(img_path)
|
||
plot_points2d(img, np.hstack((kpts2d_proj[:, :2], np.ones((kpts2d_proj.shape[0], 1)))), [], col=(0, 255, 0))
|
||
mean_error_per_image = np.mean(kpts2d_error)
|
||
font_scale = img.shape[0] / 1000
|
||
cv.putText(img, f'Mean Error: {mean_error_per_image:.2f}', (50, 50), cv.FONT_HERSHEY_SIMPLEX, font_scale,
|
||
(0, 0, 255), 2)
|
||
cv.imwrite(osp.join(DataPath.check_vis, f"{cam}.jpg"), img)
|
||
|
||
# Combine errors for statistics
|
||
reprojection_errors = np.hstack(reprojection_errors)
|
||
mean_error = np.mean(reprojection_errors)
|
||
max_error = np.max(reprojection_errors)
|
||
|
||
return {
|
||
'mean_error': mean_error,
|
||
'max_error': max_error
|
||
}
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="Check camera calibration")
|
||
parser.add_argument("--pattern", type=str, default="11,8",
|
||
help="Chessboard pattern size (columns, rows), e.g., '11,8'")
|
||
args = parser.parse_args()
|
||
|
||
pattern = tuple(map(int, args.pattern.split(','))) # Convert pattern string to tuple
|
||
result = check_match(pattern)
|
||
|
||
print(f"Mean Reprojection Error: {result['mean_error']:.2f}")
|
||
print(f"Max Reprojection Error: {result['max_error']:.2f}")
|