camera_calibrate/check_calibrate.py
2024-12-05 21:27:45 +08:00

232 lines
9.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2 as cv
import numpy as np
import os.path as osp
import argparse
from calib_tools import read_json, DataPath
from calib_tools import read_img_paths
from calibrate_extri import findChessboardCorners
# ////////////////////////////////////////////////////////////////////////////////////////////////////////////
# detect_chessboard
# 检测棋盘格上的所有2d角点存储到dict中key是相机名value是2d角点的列表
def detect_chessboard(check_img_path, pattern):
imgPaths = read_img_paths(check_img_path)
if not imgPaths:
print("No images found!")
return {}
data = {}
for imgPath in imgPaths:
camname = osp.splitext(osp.basename(imgPath))[0]
keypoints2d = findChessboardCorners(imgPath, pattern, False)
if keypoints2d is not None:
data[camname] = keypoints2d.tolist()
else:
print(f"Failed to find chessboard corners in image: {imgPath}")
return data
# //////////////////////////////////////////////////////////////////////////////////////////////////////////////
def read_cameras(intri_path, extri_path):
cameras = {}
intri = read_json(intri_path)
extri = read_json(extri_path)
for key in intri:
cameras[key] = {
'intri': intri[key],
'extri': extri[key]
}
return cameras
def plot_line(img, pt1, pt2, lw, col):
cv.line(img, (int(pt1[0] + 0.5), int(pt1[1] + 0.5)), (int(pt2[0] + 0.5), int(pt2[1] + 0.5)),
col, lw)
def plot_cross(img, x, y, col, width=-1, lw=-1):
if lw == -1:
lw = max(1, int(round(img.shape[0] / 1000)))
width = lw * 5
cv.line(img, (int(x - width), int(y)), (int(x + width), int(y)), col, lw)
cv.line(img, (int(x), int(y - width)), (int(x), int(y + width)), col, lw)
def plot_points2d(img, points2d, lines, lw=-1, col=(0, 255, 0), putText=True, style='+'):
# Draw 2D points on the image
if points2d.shape[1] == 2:
points2d = np.hstack([points2d, np.ones((points2d.shape[0], 1))])
if lw == -1:
lw = img.shape[0] // 200
for i, (x, y, v) in enumerate(points2d):
if v < 0.01:
continue
c = col
if '+' in style:
plot_cross(img, x, y, width=10, col=c, lw=lw * 2)
if 'o' in style:
cv.circle(img, (int(x), int(y)), 10, c, lw * 2)
cv.circle(img, (int(x), int(y)), lw, c, lw)
if putText:
c = col[::-1]
font_scale = img.shape[0] / 1000
cv.putText(img, '{}'.format(i), (int(x), int(y)), cv.FONT_HERSHEY_SIMPLEX, font_scale, c, 2)
for i, j in lines:
if points2d[i][2] < 0.01 or points2d[j][2] < 0.01:
continue
plot_line(img, points2d[i], points2d[j], max(1, lw // 2), col)
# 对一批关键点进行三角化
def batch_triangulate(keypoints_, Pall, min_view=2):
""" triangulate the keypoints of whole body
Args:
keypoints_ (nViews, nJoints, 3): 2D detections
Pall (nViews, 3, 4) | (nViews, nJoints, 3, 4): projection matrix of each view
min_view (int, optional): min view for visible points. Defaults to 2.
Returns:
keypoints3d: (nJoints, 4)
"""
# keypoints: (nViews, nJoints, 3)
# Pall: (nViews, 3, 4)
# A: (nJoints, nViewsx2, 4), x: (nJoints, 4, 1); b: (nJoints, nViewsx2, 1)
# 计算关键点的可见性,提取有效的关键点
v = (keypoints_[:, :, -1] > 0).sum(axis=0) # 每个关键点在多少个视角中被检测到
valid_joint = np.where(v >= min_view)[0] # 至少被 min_view 个视角捕获的点的索引
keypoints = keypoints_[:, valid_joint] # 筛选有效的关键点
conf3d = keypoints[:, :, -1].sum(axis=0) / v[valid_joint]
# P2: P矩阵的最后一行(1, nViews, 1, 4)
if len(Pall.shape) == 3:
P0 = Pall[None, :, 0, :]
P1 = Pall[None, :, 1, :]
P2 = Pall[None, :, 2, :]
else:
P0 = Pall[:, :, 0, :].swapaxes(0, 1)
P1 = Pall[:, :, 1, :].swapaxes(0, 1)
P2 = Pall[:, :, 2, :].swapaxes(0, 1)
# uP2: x坐标乘上P2: (nJoints, nViews, 1, 4)
uP2 = keypoints[:, :, 0].T[:, :, None] * P2
vP2 = keypoints[:, :, 1].T[:, :, None] * P2
conf = keypoints[:, :, 2].T[:, :, None]
Au = conf * (uP2 - P0)
Av = conf * (vP2 - P1)
A = np.hstack([Au, Av])
u, s, v = np.linalg.svd(A)
X = v[:, -1, :]
X = X / X[:, 3:]
# out: (nJoints, 4)
result = np.zeros((keypoints_.shape[1], 4))
result[valid_joint, :3] = X[:, :3]
result[valid_joint, 3] = conf3d # * (conf[..., 0].sum(axis=-1)>min_view)
return result
def reprojectN3(kpts3d, Pall):
# kpts3d: (N, 3) 或 (N, 4)
# Pall: (nViews, 3, 4) 投影矩阵r|t
nViews = len(Pall)
# 在xyz坐标后面添加一个1转换为齐次坐标
kp3d = np.hstack((kpts3d[:, :3], np.ones((kpts3d.shape[0], 1)))) # 转换为齐次坐标 (N, 4)
kp2ds = []
for nv in range(nViews):
kp2d = Pall[nv] @ kp3d.T # 投影到 2D (3, N)
kp2d[:2, :] /= kp2d[2:, :] # 归一化齐次坐标
kp2ds.append(kp2d.T[None, :, :]) # 添加视角维度 (1, N, 3)
kp2ds = np.vstack(kp2ds) # 拼接所有视角 (nViews, N, 3)
if kpts3d.shape[-1] == 4:
kp2ds[..., -1] = kp2ds[..., -1] * (kpts3d[None, :, -1] > 0.) # 保留置信度信息
return kp2ds
# 输入内外参图像上标注的2d点数据
# 将输入的2d点进行三角化得到3d点然后投影到图像上与标注的2d点进行比较计算重投影误差
# 输出:重投影误差,平均误差,最大误差
# 每个相机视角一张图像,计算重投影误差
def check_match(pattern):
# 读取内参和外参
cameras = read_cameras(DataPath.intri_json_path, DataPath.extri_json_path)
# 格式:{"cam1": {[x1, y1conf], [x2, y2conf], ...}, "cam2": {[x1, y1conf], [x2, y2conf], ...]}, ...},每个相机对应一张图片
kpts2d = detect_chessboard(DataPath.check_data, pattern)
# 去畸变
for cam in cameras:
K = np.array(cameras[cam]['intri']['K'])
dist = np.array(cameras[cam]['intri']['dist'])
points2d = np.array(kpts2d[cam])[:, :2]
# 将 points2d 数组在第一个轴上扩展一个维度,使其形状从 (N, 2) 变为 (N, 1, 2)
points2d_undistorted = cv.undistortPoints(np.expand_dims(points2d, axis=1), K, dist)
# 将去畸变后的点与原始数据中的置信度信息水平拼接,形成新的数组。
kpts2d[cam] = np.hstack((points2d_undistorted.squeeze(), np.array(kpts2d[cam])[:, 2:]))
# 三角化
# Prepare projection matrices (Pall)
Pall = []
keypoints = []
for cam in cameras:
K = np.array(cameras[cam]['intri']['K'])
R = np.array(cameras[cam]['extri']['R'])
T = np.array(cameras[cam]['extri']['T']).reshape(3, 1)
P = K @ np.hstack((R, T))
Pall.append(P)
keypoints.append(kpts2d[cam])
Pall = np.array(Pall)
keypoints = np.array(keypoints)
# Triangulate 3D points
keypoints3d = batch_triangulate(keypoints, Pall)
# Calculate reprojection error and plot results
# 计算重投影误差并绘制结果
reprojection_errors = []
for i, cam in enumerate(cameras):
P = Pall[i]
kpts2d_proj = keypoints3d[:, :3] @ P[:, :3].T + P[:, 3] # 将三维关键点投影到二维平面
kpts2d_proj /= kpts2d_proj[:, 2:3] # Normalize by z
# Compare with original 2D keypoints
kpts2d_actual = keypoints[i, :, :2]
kpts2d_error = np.linalg.norm(kpts2d_proj[:, :2] - kpts2d_actual, axis=1)
reprojection_errors.append(kpts2d_error)
# Plot reprojection results
# img = np.zeros((480, 640, 3), dtype=np.uint8) # Placeholder for the actual image
# 换成原图
img_path = osp.join(DataPath.check_data, f"{cam}.jpg")
img = cv.imread(img_path)
plot_points2d(img, np.hstack((kpts2d_proj[:, :2], np.ones((kpts2d_proj.shape[0], 1)))), [], col=(0, 255, 0))
mean_error_per_image = np.mean(kpts2d_error)
font_scale = img.shape[0] / 1000
cv.putText(img, f'Mean Error: {mean_error_per_image:.2f}', (50, 50), cv.FONT_HERSHEY_SIMPLEX, font_scale,
(0, 0, 255), 2)
cv.imwrite(osp.join(DataPath.check_vis, f"{cam}.jpg"), img)
# Combine errors for statistics
reprojection_errors = np.hstack(reprojection_errors)
mean_error = np.mean(reprojection_errors)
max_error = np.max(reprojection_errors)
return {
'mean_error': mean_error,
'max_error': max_error
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Check camera calibration")
parser.add_argument("--pattern", type=str, default="11,8",
help="Chessboard pattern size (columns, rows), e.g., '11,8'")
args = parser.parse_args()
pattern = tuple(map(int, args.pattern.split(','))) # Convert pattern string to tuple
result = check_match(pattern)
print(f"Mean Reprojection Error: {result['mean_error']:.2f}")
print(f"Max Reprojection Error: {result['max_error']:.2f}")