#!/usr/bin/env python3
"""
Calculate the origin of a laser scan using normal vectors.
Improved version with better normal orientation handling and debugging.
"""

import numpy as np
import argparse
from pathlib import Path


def load_point_cloud(filepath):
    """Load point cloud from text file with format: X Y Z R G B Nx Ny Nz"""
    print(f"Loading point cloud from {filepath}...")
    data = np.loadtxt(filepath)
    
    if data.shape[1] != 9:
        raise ValueError(f"Expected 9 columns (X Y Z R G B Nx Ny Nz), got {data.shape[1]}")
    
    points = data[:, 0:3]
    colors = data[:, 3:6]
    normals = data[:, 6:9]
    
    print(f"Loaded {len(points)} points")
    return points, colors, normals


def normalize_vectors(vectors):
    """Normalize vectors to unit length."""
    norms = np.linalg.norm(vectors, axis=1, keepdims=True)
    norms[norms == 0] = 1
    return vectors / norms


def find_ray_intersection_point(points, normals, max_pairs=5000):
    """
    Find the point that minimizes distance to rays cast from points along normals.
    Uses pairs of rays to find intersection points, then averages them.
    
    This is more robust than assuming all normals point exactly at the origin.
    """
    print("Finding ray intersections...")
    
    # Use random subset of points for efficiency
    if len(points) > max_pairs:
        indices = np.random.choice(len(points), max_pairs, replace=False)
        points_subset = points[indices]
        normals_subset = normals[indices]
    else:
        points_subset = points
        normals_subset = normals
    
    normals_subset = normalize_vectors(normals_subset)
    
    # Find closest point between pairs of rays
    # Ray 1: P1 + t1 * N1
    # Ray 2: P2 + t2 * N2
    
    # Sample pairs
    n_samples = min(1000, len(points_subset) // 2)
    intersections = []
    
    for _ in range(n_samples):
        i, j = np.random.choice(len(points_subset), 2, replace=False)
        
        P1, N1 = points_subset[i], normals_subset[i]
        P2, N2 = points_subset[j], normals_subset[j]
        
        # Find closest point between two 3D lines
        # Based on: http://paulbourke.net/geometry/pointlineplane/
        
        w0 = P1 - P2
        a = np.dot(N1, N1)
        b = np.dot(N1, N2)
        c = np.dot(N2, N2)
        d = np.dot(N1, w0)
        e = np.dot(N2, w0)
        
        denom = a * c - b * b
        
        if abs(denom) < 1e-10:  # Lines are parallel
            continue
        
        t1 = (b * e - c * d) / denom
        t2 = (a * e - b * d) / denom
        
        # Points on each ray
        point1 = P1 + t1 * N1
        point2 = P2 + t2 * N2
        
        # Midpoint between closest points
        intersection = (point1 + point2) / 2
        
        # Only keep if t1 and t2 are reasonable (origin should be "behind" the points)
        # Positive t means ray goes in direction of normal
        # We want negative t if normals point away from origin
        intersections.append(intersection)
    
    if len(intersections) == 0:
        return None
    
    intersections = np.array(intersections)
    
    # Remove outliers using median absolute deviation
    median = np.median(intersections, axis=0)
    mad = np.median(np.abs(intersections - median), axis=0)
    
    # Keep points within 3 MAD of median
    threshold = 3 * np.maximum(mad, 0.1)  # At least 0.1 units
    inliers = np.all(np.abs(intersections - median) < threshold, axis=1)
    
    if np.sum(inliers) < 10:
        print(f"Warning: Only {np.sum(inliers)} inliers found")
        return median
    
    origin = np.mean(intersections[inliers], axis=0)
    print(f"Used {np.sum(inliers)}/{len(intersections)} intersection points")
    
    return origin


def calculate_origin_with_orientation_test(points, normals, max_points=10000):
    """
    Test both normal orientations (toward and away from origin) and pick the best.
    """
    print("\nTesting normal orientations...")
    
    # Use subset for efficiency
    if len(points) > max_points:
        indices = np.random.choice(len(points), max_points, replace=False)
        points = points[indices]
        normals = normals[indices]
    
    normals = normalize_vectors(normals)
    
    # Compute approximate center of point cloud as initial guess
    center = np.mean(points, axis=0)
    print(f"Point cloud center: [{center[0]:.3f}, {center[1]:.3f}, {center[2]:.3f}]")
    
    # Test both orientations
    origins = []
    scores = []
    
    for flip in [1, -1]:
        test_normals = normals * flip
        
        # Try ray intersection method
        origin = find_ray_intersection_point(points, test_normals)
        
        if origin is None:
            continue
        
        # Score: how well do rays from origin through points align with normals?
        vectors_from_origin = points - origin
        directions = normalize_vectors(vectors_from_origin)
        
        # Alignment with original normals
        alignment = np.abs(np.sum(directions * normals, axis=1))
        score = np.mean(alignment)
        
        origins.append(origin)
        scores.append(score)
        
        flip_str = "as-is" if flip == 1 else "flipped"
        print(f"  Normals {flip_str}: origin=[{origin[0]:.3f}, {origin[1]:.3f}, {origin[2]:.3f}], score={score:.3f}")
    
    if len(origins) == 0:
        print("Warning: Could not find origin, using point cloud center")
        return center
    
    # Pick best orientation
    best_idx = np.argmax(scores)
    best_origin = origins[best_idx]
    
    print(f"\nBest orientation: {'as-is' if best_idx == 0 else 'flipped'}")
    
    return best_origin


def calculate_origin_weighted_centroid(points, normals, expected_origin=None, max_points=10000):
    """
    Calculate origin by finding the weighted centroid where normals converge.
    Each point votes for locations along its normal ray.
    """
    print("\nUsing weighted centroid method...")
    
    if len(points) > max_points:
        indices = np.random.choice(len(points), max_points, replace=False)
        points = points[indices]
        normals = normals[indices]
    
    normals = normalize_vectors(normals)
    
    # If we have an expected origin, use it to determine normal orientation
    if expected_origin is not None:
        expected_origin = np.array(expected_origin)
        vectors_to_expected = expected_origin - points
        dot_products = np.sum(vectors_to_expected * normals, axis=1)
        
        # If most normals point away from expected origin, flip them
        if np.mean(dot_products) < 0:
            print("Flipping normals based on expected origin")
            normals = -normals
    
    # Sample points along each normal ray
    # Try multiple distances
    distances = np.linspace(-10, 10, 20)  # Sample from -10 to +10 meters along normal
    
    votes = []
    
    for d in distances:
        candidate_origins = points + d * normals
        votes.append(candidate_origins)
    
    votes = np.vstack(votes)
    
    # Find densest cluster using mean shift concept
    # Start from median
    origin = np.median(votes, axis=0)
    
    # Iterative refinement
    for iteration in range(10):
        # Weight points by inverse distance
        distances_to_origin = np.linalg.norm(votes - origin, axis=1)
        weights = 1.0 / (distances_to_origin + 0.1)
        weights /= np.sum(weights)
        
        # Update origin
        new_origin = np.sum(votes * weights.reshape(-1, 1), axis=0)
        
        if np.linalg.norm(new_origin - origin) < 0.01:
            break
        
        origin = new_origin
    
    print(f"Converged after {iteration + 1} iterations")
    
    return origin


def analyze_results(points, normals, origin, expected_origin=None):
    """Analyze the quality of the origin estimation."""
    print("\n" + "="*60)
    print("RESULTS")
    print("="*60)
    print(f"Estimated scan origin: [{origin[0]:.3f}, {origin[1]:.3f}, {origin[2]:.3f}]")
    
    if expected_origin is not None:
        expected_origin = np.array(expected_origin)
        error = np.linalg.norm(origin - expected_origin)
        print(f"Expected origin:       [{expected_origin[0]:.3f}, {expected_origin[1]:.3f}, {expected_origin[2]:.3f}]")
        print(f"Error:                 {error:.3f} meters")
    
    # Calculate distances from origin to points
    vectors_to_points = points - origin
    distances = np.linalg.norm(vectors_to_points, axis=1)
    
    print(f"\nDistance statistics (origin to points):")
    print(f"  Mean:   {np.mean(distances):.3f} m")
    print(f"  Median: {np.median(distances):.3f} m")
    print(f"  Min:    {np.min(distances):.3f} m")
    print(f"  Max:    {np.max(distances):.3f} m")
    print(f"  Std:    {np.std(distances):.3f} m")
    
    # Check alignment between normals and vectors from origin
    normals_normalized = normalize_vectors(normals)
    directions = normalize_vectors(vectors_to_points)
    
    # Dot product
    dot_products = np.sum(normals_normalized * directions, axis=1)
    
    print(f"\nNormal alignment (dot product with direction from origin):")
    print(f"  Mean:   {np.mean(dot_products):.3f}")
    print(f"  Median: {np.median(dot_products):.3f}")
    
    pointing_away = np.sum(dot_products > 0.5)
    pointing_towards = np.sum(dot_products < -0.5)
    perpendicular = np.sum(np.abs(dot_products) < 0.5)
    
    total = len(points)
    print(f"\n  Pointing away from origin:    {pointing_away:6d} ({100*pointing_away/total:5.1f}%)")
    print(f"  Pointing towards origin:      {pointing_towards:6d} ({100*pointing_towards/total:5.1f}%)")
    print(f"  Perpendicular to view:        {perpendicular:6d} ({100*perpendicular/total:5.1f}%)")


def main():
    parser = argparse.ArgumentParser(
        description="Calculate laser scan origin from point cloud with normals",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  python calculate_scan_origin.py scan.txt
  python calculate_scan_origin.py scan.txt --expected -56 -9 1
  python calculate_scan_origin.py scan.txt --method weighted --expected -56 -9 1
        """
    )
    
    parser.add_argument("input_file", type=str, help="Path to input text file (X Y Z R G B Nx Ny Nz)")
    parser.add_argument("--method", type=str, choices=["ray_intersection", "weighted"], 
                        default="ray_intersection",
                        help="Method to use (default: ray_intersection)")
    parser.add_argument("--expected", type=float, nargs=3, metavar=("X", "Y", "Z"),
                        help="Expected origin coordinates for comparison and normal orientation")
    parser.add_argument("--max-points", type=int, default=10000,
                        help="Maximum number of points to use (default: 10000)")
    
    args = parser.parse_args()
    
    # Load point cloud
    points, colors, normals = load_point_cloud(args.input_file)
    
    # Calculate origin
    if args.method == "ray_intersection":
        origin = calculate_origin_with_orientation_test(points, normals, args.max_points)
    else:  # weighted
        origin = calculate_origin_weighted_centroid(points, normals, args.expected, args.max_points)
    
    # Analyze results
    analyze_results(points, normals, origin, args.expected)
    
    print("\n" + "="*60)


if __name__ == "__main__":
    main()
