import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import KDTree

def best_fit_transform(A, B):
    assert A.shape == B.shape
    m = A.shape[1]
    centroid_A = np.mean(A, axis=0)
    centroid_B = np.mean(B, axis=0)
    AA = A - centroid_A
    BB = B - centroid_B
    H = np.dot(AA.T, BB)
    U, S, Vt = np.linalg.svd(H)
    R = np.dot(Vt.T, U.T)
    if np.linalg.det(R) < 0:
        Vt[m-1, :] *= -1
        R = np.dot(Vt.T, U.T)
    t = centroid_B - np.dot(R, centroid_A)
    T = np.identity(m + 1)
    T[:m, :m] = R
    T[:m, m] = t
    return T, R, t

def icp(source_points, target_points, init_pose=None, max_iterations=50, tolerance=1e-4):
    m = source_points.shape[1]
    src = np.ones((m + 1, source_points.shape[0]))
    src[:m, :] = source_points.T
    if init_pose is not None:
        src = np.dot(init_pose, src)
    tree = KDTree(target_points)
    prev_error = 0
    T_cumulative = np.eye(m + 1)
    for i in range(max_iterations):
        current_src = src[:m, :].T
        distances, indices = tree.query(current_src)
        matched_tgt = target_points[indices]
        T_step, _, _ = best_fit_transform(current_src, matched_tgt)
        src = np.dot(T_step, src)
        mean_error = np.mean(distances)
        if abs(prev_error - mean_error) < tolerance:
            break
        prev_error = mean_error
        T_cumulative = np.dot(T_step, T_cumulative)
    return T_cumulative

def test(points_A, points_B):
    # Optional: Densify points for better ICP (sample along edges)
    def densify_polygon(points, samples_per_edge=10):
        points = np.asarray(points)
        n = len(points)
        t = np.linspace(0, 1, samples_per_edge, endpoint=False)
        p2 = np.roll(points, -1, axis=0)
        diffs = p2 - points
        dense = points[:, None, :] + t[None, :, None] * diffs[:, None, :]
        return dense.reshape(-1, 2)

    dense_A = densify_polygon(points_A, 20)  # More samples for target
    dense_B = densify_polygon(points_B, 20)

    # Run ICP (align B to A)
    T = icp(dense_B, dense_A)

    # Apply transformation to original B points
    points_B_hom = np.ones((3, len(points_B)))
    points_B_hom[:2, :] = points_B.T
    aligned_B_hom = np.dot(T, points_B_hom)
    aligned_B = aligned_B_hom[:2, :].T

    # Plot
    # fig, ax = plt.subplots(figsize=(8, 8))
    # all_points = np.vstack([points_A, aligned_B])
    # y_min, y_max = np.min(all_points[:, 1]), np.max(all_points[:, 1])
    # ax.set_ylim(y_max, y_min)

    # points_A_closed = np.vstack([points_A, points_A[0]])
    # aligned_B_closed = np.vstack([aligned_B, aligned_B[0]])
    # ax.plot(points_A_closed[:, 0], points_A_closed[:, 1], color='blue', linewidth=2, label='A (real points)')
    # ax.plot(aligned_B_closed[:, 0], aligned_B_closed[:, 1], color='red', linewidth=2, label='B (pdf points)')
    # ax.set_aspect('equal')
    # ax.legend()
    # plt.savefig("icp_aligned.png", dpi=300)

    aligned_B_cv = np.array(aligned_B, dtype=np.int32).reshape((len(aligned_B), 2))
    return aligned_B_cv