Created
May 19, 2025 15:58
-
-
Save raven38/357053f50b0fa596e8f9ac34ef74eaa6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.optim as optim | |
import numpy as np | |
import cv2 # Ground truth生成や比較でのRodrigues変換に一部使用 | |
from scipy.spatial.transform import Rotation as R_scipy # Ground truthの回転生成用 | |
# PyTorchによるRodriguesの公式(回転ベクトルから回転行列へ) | |
def rodrigues_pytorch(rvecs_batch): | |
""" | |
Rodriguesの公式を用いて、回転ベクトルのバッチを回転行列のバッチに変換します。 | |
Args: | |
rvecs_batch (torch.Tensor): 回転ベクトルのバッチ (B, 3)。 | |
Returns: | |
torch.Tensor: 回転行列のバッチ (B, 3, 3)。 | |
""" | |
theta = torch.norm(rvecs_batch, dim=1, keepdim=True) # (B, 1) | |
# ゼロ回転の場合の除算エラーを避けるため、微小値を加算 | |
# thetaが0に近い場合、sin(theta)/theta -> 1, (1-cos(theta))/theta^2 -> 0.5 となる | |
# しかし、ここではthetaが0ならkも0になるので、直接計算しても問題ない場合が多い | |
k = rvecs_batch / (theta + 1e-9) # 回転軸 (B, 3) | |
K_cross = torch.zeros((rvecs_batch.shape[0], 3, 3), device=rvecs_batch.device, dtype=rvecs_batch.dtype) | |
K_cross[:, 0, 1] = -k[:, 2] | |
K_cross[:, 0, 2] = k[:, 1] | |
K_cross[:, 1, 0] = k[:, 2] | |
K_cross[:, 1, 2] = -k[:, 0] | |
K_cross[:, 2, 0] = -k[:, 1] | |
K_cross[:, 2, 1] = k[:, 0] | |
I_batch = torch.eye(3, device=rvecs_batch.device, dtype=rvecs_batch.dtype).unsqueeze(0).repeat(rvecs_batch.shape[0], 1, 1) | |
cos_theta = torch.cos(theta).unsqueeze(2) # (B, 1, 1) | |
sin_theta = torch.sin(theta).unsqueeze(2) # (B, 1, 1) | |
# R = I + sin(theta)*K + (1-cos(theta))*K^2 | |
R_matrices = I_batch + sin_theta * K_cross + (1 - cos_theta) * torch.bmm(K_cross, K_cross) | |
return R_matrices | |
# PyTorchによる3D点の2D投影 | |
def project_points_pytorch(points_3d, focals, rvecs, tvecs, cx, cy, cam_indices, pt_indices): | |
""" | |
PyTorchを使用して3D点を2D画像平面に投影します。 | |
Args: | |
points_3d (torch.Tensor): 3D点の座標 (N_points, 3)。 | |
focals (torch.Tensor): 各カメラの焦点距離 (N_cameras,). | |
rvecs (torch.Tensor): 各カメラの回転ベクトル (N_cameras, 3)。 | |
tvecs (torch.Tensor): 各カメラの並進ベクトル (N_cameras, 3)。 | |
cx (torch.Tensor): 主点のx座標 (スカラー)。 | |
cy (torch.Tensor): 主点のy座標 (スカラー)。 | |
cam_indices (torch.Tensor): 各観測に対応するカメラインデックス (N_observations,). | |
pt_indices (torch.Tensor): 各観測に対応する3D点インデックス (N_observations,). | |
Returns: | |
torch.Tensor: 投影された2D点の座標 (N_observations, 2)。 | |
""" | |
# 観測ごとのパラメータを取得 | |
f_obs = focals[cam_indices] # (N_obs,) | |
rvec_obs = rvecs[cam_indices] # (N_obs, 3) | |
tvec_obs = tvecs[cam_indices] # (N_obs, 3) | |
p3d_obs = points_3d[pt_indices] # (N_obs, 3) | |
# 回転ベクトルを回転行列に変換 | |
R_obs = rodrigues_pytorch(rvec_obs) # (N_obs, 3, 3) | |
# カメラ座標系への変換: P_cam = R * P_world + t | |
# P_world を (N_obs, 3, 1) に変形してバッチ行列積を計算 | |
P_cam_rotated = torch.bmm(R_obs, p3d_obs.unsqueeze(2)).squeeze(2) # (N_obs, 3) | |
P_cam = P_cam_rotated + tvec_obs # (N_obs, 3) | |
# Xc, Yc, Zc 成分の取得 | |
Xc = P_cam[:, 0] | |
Yc = P_cam[:, 1] | |
Zc = P_cam[:, 2] | |
# Zcが0または負の場合の除算エラーや不適切な投影を避ける | |
# 非常に小さい正の値をZcに加えるか、Zc > eps のマスクをかける | |
# ここでは単純化のため、そのまま計算を進めるが、実用上は注意が必要 | |
Zc_safe = torch.where(torch.abs(Zc) < 1e-6, torch.ones_like(Zc) * 1e-6 * torch.sign(Zc), Zc) | |
# Zc_safe = Zc.clone() | |
# Zc_safe[torch.abs(Zc_safe) < 1e-5] = 1e-5 # Avoid division by zero | |
# ピンホールカメラモデルによる投影 | |
u_proj = f_obs * (Xc / Zc_safe) + cx | |
v_proj = f_obs * (Yc / Zc_safe) + cy | |
projected_2d = torch.stack((u_proj, v_proj), dim=1) # (N_obs, 2) | |
return projected_2d | |
# テストケースのセットアップ | |
def setup_bundle_adjustment_test_case(device): | |
print("🧪 テストケースをセットアップ中 (バンドル調整用)...") | |
# 3D構造の真値 | |
points_3d_gt_np = np.array([ | |
[0,0,0], [1,0,0], [1,1,0], [0,1,0], [0.5,0.5,0], # Plane 1 (z=0) | |
[0,0,1], [1,0,1], [1,1,1], [0,1,1], [0.5,0.5,1], # Plane 2 (z=1) | |
[0.5, -0.5, 0.5], [0.5, 1.5, 0.5] # Additional points | |
], dtype=np.float64) | |
n_points = points_3d_gt_np.shape[0] | |
# カメラパラメータの真値 | |
n_cameras = 3 | |
img_width, img_height = 1280, 720 | |
cx_gt, cy_gt = img_width / 2.0, img_height / 2.0 | |
true_focals_np = np.array([1000.0, 1050.0, 950.0], dtype=np.float64) | |
rvec1_gt_np = R_scipy.from_euler('xyz', [0.05, -0.05, 0.02], degrees=False).as_rotvec() | |
tvec1_gt_np = np.array([-0.1, 0.05, 2.5], dtype=np.float64) | |
rvec2_gt_np = R_scipy.from_euler('xyz', [0.0, 0.1, -0.08], degrees=False).as_rotvec() | |
tvec2_gt_np = np.array([1.5, -0.1, 3.0], dtype=np.float64) | |
rvec3_gt_np = R_scipy.from_euler('xyz', [-0.05, -0.15, 0.1], degrees=False).as_rotvec() | |
tvec3_gt_np = np.array([0.5, 0.8, 2.0], dtype=np.float64) | |
true_rvecs_np = np.array([rvec1_gt_np, rvec2_gt_np, rvec3_gt_np]) | |
true_tvecs_np = np.array([tvec1_gt_np, tvec2_gt_np, tvec3_gt_np]) | |
# 2D観測点の生成 (ノイズ付加) | |
observations_list = [] | |
camera_indices_list = [] | |
point_indices_list = [] | |
visibility_threshold = 0.8 # 各点が各カメラで見える確率 | |
for cam_idx in range(n_cameras): | |
f_cam = true_focals_np[cam_idx] | |
rvec_cam = true_rvecs_np[cam_idx] | |
tvec_cam = true_tvecs_np[cam_idx] | |
K_cam = np.array([[f_cam, 0, cx_gt], [0, f_cam, cy_gt], [0, 0, 1]]) | |
projected_np, _ = cv2.projectPoints(points_3d_gt_np, rvec_cam, tvec_cam, K_cam, None) | |
projected_np = projected_np.reshape(-1, 2) | |
for pt_idx in range(n_points): | |
# 全点が全カメラで見えるわけではない状況をシミュレート | |
if np.random.rand() < visibility_threshold: | |
# 点がカメラの前方にあるか簡易チェック (Zc > 0) | |
pt_world_hom = np.append(points_3d_gt_np[pt_idx], 1) | |
R_mat_cam, _ = cv2.Rodrigues(rvec_cam) | |
P_ext = np.hstack((R_mat_cam, tvec_cam.reshape(3,1))) | |
pt_cam = P_ext @ pt_world_hom | |
if pt_cam[2] > 0: # Z座標が正である(カメラの前方) | |
observations_list.append(projected_np[pt_idx]) | |
camera_indices_list.append(cam_idx) | |
point_indices_list.append(pt_idx) | |
observations_gt_np = np.array(observations_list, dtype=np.float64) | |
noise_sigma = 0.75 # ピクセル単位の標準偏差 | |
observations_noisy_np = observations_gt_np + np.random.normal(scale=noise_sigma, size=observations_gt_np.shape) | |
camera_indices_np = np.array(camera_indices_list) | |
point_indices_np = np.array(point_indices_list) | |
print(f"📊 {len(observations_noisy_np)}点の2D観測点を生成 ({n_cameras}台のカメラ, {n_points}個の3D点, 可視性考慮)") | |
# パラメータの初期値を準備 (PyTorch Tensor) | |
# 3D点 (真値に大きなノイズ) | |
initial_points_3d_np = points_3d_gt_np + np.random.normal(0, 0.3, points_3d_gt_np.shape) | |
points_3d_param = torch.tensor(initial_points_3d_np, device=device, requires_grad=True, dtype=torch.float64) | |
# カメラパラメータ (真値にノイズ) | |
initial_focals_np = true_focals_np + np.random.normal(0, 70, n_cameras) | |
initial_rvecs_np = true_rvecs_np + np.random.normal(0, 0.15, true_rvecs_np.shape) # radian | |
initial_tvecs_np = true_tvecs_np + np.random.normal(0, 0.3, true_tvecs_np.shape) | |
focals_param = torch.tensor(initial_focals_np, device=device, requires_grad=True, dtype=torch.float64) | |
rvecs_param = torch.tensor(initial_rvecs_np, device=device, requires_grad=True, dtype=torch.float64) | |
tvecs_param = torch.tensor(initial_tvecs_np, device=device, requires_grad=True, dtype=torch.float64) | |
# 固定値 (cx, cy) と観測データもTensorに変換 | |
cx = torch.tensor(cx_gt, device=device, dtype=torch.float64) | |
cy = torch.tensor(cy_gt, device=device, dtype=torch.float64) | |
observations_2d_torch = torch.tensor(observations_noisy_np, device=device, dtype=torch.float64) | |
camera_indices_torch = torch.tensor(camera_indices_np, device=device, dtype=torch.long) | |
point_indices_torch = torch.tensor(point_indices_np, device=device, dtype=torch.long) | |
# 真値もTensorとして保持 (比較用) | |
true_points_3d_torch = torch.tensor(points_3d_gt_np, device=device, dtype=torch.float64) | |
true_focals_torch = torch.tensor(true_focals_np, device=device, dtype=torch.float64) | |
true_rvecs_torch = torch.tensor(true_rvecs_np, device=device, dtype=torch.float64) | |
true_tvecs_torch = torch.tensor(true_tvecs_np, device=device, dtype=torch.float64) | |
return (points_3d_param, focals_param, rvecs_param, tvecs_param, | |
observations_2d_torch, camera_indices_torch, point_indices_torch, cx, cy, | |
true_points_3d_torch, true_focals_torch, true_rvecs_torch, true_tvecs_torch) | |
# メイン処理 | |
def main_bundle_adjustment_adam(): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# バンドル調整は数値的安定性が重要なので float64 (double) を推奨 | |
torch.set_default_dtype(torch.float64) | |
print(f"デバイス: {device}") | |
(points_3d_param, focals_param, rvecs_param, tvecs_param, | |
observations_2d, camera_indices, point_indices, cx, cy, | |
true_points_3d, true_focals, true_rvecs, true_tvecs) = setup_bundle_adjustment_test_case(device) | |
# 最適化対象のパラメータリスト | |
params_to_optimize = [points_3d_param, focals_param, rvecs_param, tvecs_param] | |
# Adamオプティマイザ | |
# 学習率は調整が非常に重要。パラメータ群ごとに異なる学習率を設定することも有効 | |
# 例: optimizer = optim.Adam([ | |
# {'params': points_3d_param, 'lr': 1e-3}, | |
# {'params': focals_param, 'lr': 1e-2}, | |
# {'params': rvecs_param, 'lr': 1e-3}, | |
# {'params': tvecs_param, 'lr': 1e-3} | |
# ]) | |
optimizer = optim.Adam(params_to_optimize, lr=5e-3) # 5e-3 to 1e-2 range often works for BA | |
num_epochs = 3000 # イテレーション回数 (問題の複雑さによる) | |
print("\n--- 真のパラメータ (一部) ---") | |
print(f"3D点0 (真値): {true_points_3d[0].cpu().numpy()}") | |
print(f"焦点距離0 (真値): {true_focals[0].item():.2f}") | |
print("\n--- 初期パラメータ (一部) ---") | |
print(f"3D点0 (初期値): {points_3d_param[0].detach().cpu().numpy()}") | |
print(f"焦点距離0 (初期値): {focals_param[0].item():.2f}") | |
print("\n⚙️ Adamでバンドル調整を実行中...") | |
for epoch in range(num_epochs): | |
optimizer.zero_grad() # 勾配をリセット | |
# 2D投影点を計算 | |
projected_2d_current = project_points_pytorch( | |
points_3d_param, focals_param, rvecs_param, tvecs_param, | |
cx, cy, camera_indices, point_indices | |
) | |
# 損失: 再投影誤差の二乗和 | |
reprojection_errors = projected_2d_current - observations_2d | |
loss = torch.sum(reprojection_errors**2) # or torch.mean(reprojection_errors**2) | |
loss.backward() # 勾配計算 | |
optimizer.step() # パラメータ更新 | |
if (epoch + 1) % 200 == 0: | |
with torch.no_grad(): | |
# 平均再投影誤差 (ピクセル単位のユークリッド距離) | |
current_pixel_errors = torch.norm(reprojection_errors, dim=1) | |
mean_reprojection_error_px = torch.mean(current_pixel_errors) | |
print(f"エポック [{epoch+1}/{num_epochs}], 損失: {loss.item():.4f}, 平均再投影誤差: {mean_reprojection_error_px.item():.4f} ピクセル") | |
print("最適化完了。") | |
# --- 結果の表示 --- | |
# 注意: バンドル調整の結果はスケール不定性やグローバルな剛体変換の不定性を持つため、 | |
# 真値との直接比較はアラインメント処理なしでは難しい場合があります。 | |
# ここでは単純な比較に留めます。 | |
print("\n--- 最適化後のパラメータ (一部) vs 真のパラメータ ---") | |
print(f"3D点0: 最適値={points_3d_param[0].detach().cpu().numpy()}, 真値={true_points_3d[0].cpu().numpy()}") | |
print(f" 差分ノルム: {torch.norm(points_3d_param[0] - true_points_3d[0]).item():.3f}") | |
# 全3D点の平均誤差 (アラインメントなし) | |
avg_point_diff_no_align = torch.mean(torch.norm(points_3d_param - true_points_3d, dim=1)) | |
print(f"3D点の平均L2差 (アラインメントなし): {avg_point_diff_no_align.item():.3f}") | |
for i in range(focals_param.shape[0]): # カメラごと | |
print(f"\nカメラ {i}:") | |
print(f" 焦点距離: 最適値={focals_param[i].item():.2f} (真値={true_focals[i].item():.2f})") | |
rvec_opt_np = rvecs_param[i].detach().cpu().numpy() | |
rvec_true_np = true_rvecs[i].cpu().numpy() | |
print(f" 回転ベクトル (rvec): 最適値={np.round(rvec_opt_np, 3)}, 真値={np.round(rvec_true_np, 3)}") | |
try: | |
r_mat_opt, _ = cv2.Rodrigues(rvec_opt_np) | |
r_mat_true, _ = cv2.Rodrigues(rvec_true_np) | |
diff_rot = R_scipy.from_matrix(r_mat_opt @ r_mat_true.T) # 回転差 | |
print(f" 回転差 (角度): {np.rad2deg(diff_rot.magnitude()):.3f}°") | |
except Exception as e: | |
print(f" 回転差の計算エラー: {e}") | |
tvec_opt_np = tvecs_param[i].detach().cpu().numpy() | |
tvec_true_np = true_tvecs[i].cpu().numpy() | |
print(f" 並進ベクトル (tvec): 最適値={np.round(tvec_opt_np, 3)}, 真値={np.round(tvec_true_np, 3)}") | |
print(f" 並進差 (L2ノルム): {np.linalg.norm(tvec_opt_np - tvec_true_np):.3f}") | |
with torch.no_grad(): | |
final_projected_2d = project_points_pytorch( | |
points_3d_param, focals_param, rvecs_param, tvecs_param, | |
cx, cy, camera_indices, point_indices | |
) | |
final_reprojection_errors = final_projected_2d - observations_2d | |
final_pixel_errors = torch.norm(final_reprojection_errors, dim=1) | |
mean_final_reprojection_error_px = torch.mean(final_pixel_errors) | |
print(f"\n📈 最適化後の最終平均再投影誤差: {mean_final_reprojection_error_px.item():.4f} ピクセル") | |
if __name__ == '__main__': | |
main_bundle_adjustment_adam() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment