import type { Vector2 } from 'three';
import { Matrix4, Vector4, Vector3, Quaternion, Matrix3 } from 'three';

import type { CameraIntrinsics } from '@sb/integrations/camera/types';

import type { CartesianPose } from './CartesianPose';
import { cartesianPoseToMatrix4 } from './CartesianPose';
import type { CartesianPosition } from './CartesianPosition';
import { convertPositionToVector3 } from './CartesianPosition';
import { getPlaneNormal, type Plane } from './Plane';

export function getIntrinsicMatrix(intrinsics: CameraIntrinsics): Matrix3 {
  const projectionMatrix = new Matrix3();
  const { fx, fy, ppx: cx, ppy: cy } = intrinsics;

  projectionMatrix.set(fx, 0, cx, 0, fy, cy, 0, 0, 1);

  return projectionMatrix;
}

export function castCameraRay(
  cameraIntrinsics: CameraIntrinsics,
  cameraPose: CartesianPose,
  pixelCoordinates: Vector2,
): Vector3 {
  // Transform from pixel coordinates to camera coordinates
  const invIntrinsicMat = getIntrinsicMatrix(cameraIntrinsics).clone().invert();

  const cameraCoordsRayDirection = new Vector3(
    pixelCoordinates.x,
    pixelCoordinates.y,
    1,
  ).applyMatrix3(invIntrinsicMat);

  // Transform from camera camera coordinates to world coordinates
  const viewMat = cartesianPoseToMatrix4(cameraPose).setPosition(
    new Vector3(0, 0, 0),
  );

  const worldCoordsRayDirection = cameraCoordsRayDirection
    .clone()
    .applyMatrix4(viewMat);

  return new Vector3(
    worldCoordsRayDirection.x,
    worldCoordsRayDirection.y,
    worldCoordsRayDirection.z,
  ).normalize();
}

export function buildPlaneTransformMatrix(plane: Plane): Matrix4 {
  const { origin: position } = plane;
  const normal = getPlaneNormal(plane);
  const localNormal = new Vector3(0, 0, 1);

  const quaternion = new Quaternion();
  quaternion.setFromUnitVectors(localNormal, normal);

  const translationMatrix = new Matrix4();
  translationMatrix.makeTranslation(position.x, position.y, position.z);

  const rotationMatrix = new Matrix4();
  rotationMatrix.makeRotationFromQuaternion(quaternion);

  const transformMatrix = new Matrix4();
  transformMatrix.multiplyMatrices(translationMatrix, rotationMatrix);

  return transformMatrix;
}

/**
 * Intersect a ray with the unit plane (side lengths 1, parallel to XZ plane) given an object transformation.
 *
 * @param {Vector3} rayOrigin - The origin point of the ray in 3D space.
 * @param {Vector3} rayDirection - The normalized direction vector of the ray in 3D space.
 * @param {Matrix4} planeTransform - The transformation matrix applied to the unit plane.
 *
 * @returns {number | undefined} The distance from the ray origin to the intersection point, or undefined if no intersection exists.
 */
export function intersectRayPlane(
  rayOrigin: Vector3,
  rayDirection: Vector3,
  planeTransform: Matrix4,
): number | undefined {
  const invTransform = planeTransform.clone().invert();

  const rayDirectionVec4 = new Vector4(
    rayDirection.x,
    rayDirection.y,
    rayDirection.z,
    0,
  );

  const rayOriginVec4 = new Vector4(rayOrigin.x, rayOrigin.y, rayOrigin.z, 1);

  const transformedRayDirection = rayDirectionVec4
    .clone()
    .applyMatrix4(invTransform);

  const transformedRayOrigin = rayOriginVec4.clone().applyMatrix4(invTransform);

  if (Math.abs(transformedRayDirection.z) < 1e-6) {
    return undefined;
  }

  const distance = -transformedRayOrigin.z / transformedRayDirection.z;

  if (distance < 0) {
    return undefined;
  }

  return distance;
}

/**
 * Find the intersection point of a ray and a plane in 3D space.
 *
 * @param {CartesianPosition} rayOrigin - The origin of the ray.
 * @param {CartesianPosition} rayDirection - The normalized direction vector of the ray.
 * @param {Plane} plane - The plane to intersect with.
 *
 * @returns {Vector3 | undefined} The Cartesian coordinates of the intersection point, or undefined if no intersection exists.
 */
export function findRayPlaneIntersection(
  rayOrigin: CartesianPosition,
  rayDirection: CartesianPosition,
  plane: Plane,
): Vector3 | undefined {
  // Build transform matrix for plane
  const transform = buildPlaneTransformMatrix(plane);
  const originVec = convertPositionToVector3(rayOrigin);
  const directionVec = convertPositionToVector3(rayDirection).normalize();

  // Check for intersections
  const travelDistance = intersectRayPlane(originVec, directionVec, transform);

  if (travelDistance === undefined) {
    return undefined;
  }

  // Find the intersection point
  const intersectionPoint = originVec
    .clone()
    .addScaledVector(directionVec, travelDistance);

  return intersectionPoint;
}

/**
 * Deprojects a point in image space onto a 3D plane.
 *
 * Assumes all spatial coordinates to be in the same system.
 *
 * @param {Object} options - The options object.
 * @param {Vector2} options.pixelCoordinates - The pixel coordinates of the point in the image.
 * @param {Plane} options.plane - The plane to deproject onto.
 * @param {CartesianPose} options.cameraPose - The pose of the camera in the world coordinate system.
 * @param {CameraIntrinsics} options.cameraIntrinsics - The intrinsic parameters of the camera. Units are in pixels.
 *
 * @returns {Vector3 | undefined} - The 3D position on the plane, or undefined if no intersection is found.
 */
export function deprojectOntoPlane({
  pixelCoordinates,
  plane,
  cameraPose,
  cameraIntrinsics,
}: {
  pixelCoordinates: Vector2;
  plane: Plane;
  cameraPose: CartesianPose;
  cameraIntrinsics: CameraIntrinsics;
}): Vector3 | undefined {
  const rayDirection = castCameraRay(
    cameraIntrinsics,
    cameraPose,
    pixelCoordinates,
  );

  const rayOrigin = new Vector3(cameraPose.x, cameraPose.y, cameraPose.z);

  return findRayPlaneIntersection(rayOrigin, rayDirection, plane);
}
