import { useState, useCallback, useEffect, useRef } from "react";
import * as THREE from "three";
import { throttle } from "lodash";

const useScalpelMode = (
  sceneRef,
  cameraRef,
  containerRef,
  modelRef,
  controlsRef,
  selectedModels,
  undoStack,
  undoIndex
) => {
  const [isScalpelMode, setIsScalpelMode] = useState(false);
  const [isDrawing, setIsDrawing] = useState(false);
  const [drawPoints, setDrawPoints] = useState([]);
  const [processing, setProcessing] = useState(false);
  const [cutHeight, setCutHeight] = useState(10);

  const MIN_DISTANCE = 0.2; // Minimum distance between consecutive points

  // Refs for throttled functions
  const throttledMouseMove = useRef(null);
  const throttledTouchMove = useRef(null);

  // --------------------------
  // Utility functions
  // --------------------------

  // Get the (normalized) direction in which the camera is looking
  const getCameraNormalVector = useCallback(() => {
    const cameraDirection = new THREE.Vector3();
    cameraRef.current.getWorldDirection(cameraDirection);
    return cameraDirection.normalize();
  }, [cameraRef]);

  const normalizeVector = useCallback((vector) => vector.normalize(), []);

  // Given a set of base points (drawn by the user), project them “up” (by a given height)
  // from the vertex point along the provided normal.
  const projectBaseToVertexZ = useCallback(
    (basePoints, vertexPoint, normal, height) => {
      normal = normalizeVector(normal.clone());
      const d = vertexPoint.dot(normal);
      const projectedPoints = basePoints.map((point) => {
        const pointVector = new THREE.Vector3(point.x, point.y, point.z).add(
          normal.clone().multiplyScalar(height)
        );
        const projection = pointVector.add(
          normal.clone().multiplyScalar(d - pointVector.dot(normal))
        );
        return projection;
      });
      return projectedPoints;
    },
    [normalizeVector]
  );

  // Simple point-in-polygon test (in 2D)
  const pointInPolygon2D = useCallback((px, py, poly) => {
    let inside = false;
    for (let i = 0, j = poly.length - 1; i < poly.length; j = i++) {
      const xi = poly[i].x,
        yi = poly[i].y;
      const xj = poly[j].x,
        yj = poly[j].y;
      const intersect =
        // eslint-disable-next-line no-mixed-operators
        yi > py !== yj > py && px < ((xj - xi) * (py - yi)) / (yj - yi) + xi;
      if (intersect) inside = !inside;
    }
    return inside;
  }, []);

  // Get all vertices from a geometry’s position attribute
  const getVerticesFromBufferGeometry = useCallback((geometry) => {
    const positionAttribute = geometry.attributes.position;
    const vertices = [];
    for (let i = 0; i < positionAttribute.count; i++) {
      const vertex = new THREE.Vector3();
      vertex.fromBufferAttribute(positionAttribute, i);
      vertices.push(vertex);
    }
    return vertices;
  }, []);

  // For a given geometry, return the indices of all vertices that lie inside
  // the (2D-projected) drawn shape.
  const getVerticesWithinField = useCallback(
    (geometry, basePoints, normal, height) => {
      const verticesWithinField = [];
      const vertices = getVerticesFromBufferGeometry(geometry);

      vertices.forEach((vertex, index) => {
        const projectedBase = projectBaseToVertexZ(
          basePoints,
          vertex,
          normal,
          height
        );
        const projectedBase2D = projectedBase.map(
          (p) => new THREE.Vector2(p.x, p.y)
        );
        if (pointInPolygon2D(vertex.x, vertex.y, projectedBase2D)) {
          verticesWithinField.push(index);
        }
      });

      return verticesWithinField;
    },
    [getVerticesFromBufferGeometry, projectBaseToVertexZ, pointInPolygon2D]
  );

  // --------------------------
  // NEW: Helper Function
  // --------------------------
  // This function removes triangles (each defined by 9 consecutive numbers) from the geometry
  // if any of its vertices are within the removal set. At the same time, it collects the removed
  // triangles into a list for undo purposes.
  const removeTrianglesAndCollect = useCallback((geometry, vertexIndices) => {
    const positions = geometry.attributes.position.array;
    const normals = geometry.attributes.normal.array;
    const newVertices = [];
    const newNormals = [];
    const deletedTrianglesList = [];

    const vertexIndicesSet = new Set(vertexIndices);

    const numTriangles = positions.length / 9;
    for (let j = 0; j < numTriangles; j++) {
      const start = j * 9;
      // For non-indexed geometry, the logical vertex indices for triangle j are:
      const a = 3 * j;
      const b = 3 * j + 1;
      const c = 3 * j + 2;

      if (
        vertexIndicesSet.has(a) ||
        vertexIndicesSet.has(b) ||
        vertexIndicesSet.has(c)
      ) {
        // Save the triangle’s data for undo:
        deletedTrianglesList.push({
          vertices: positions.slice(start, start + 9),
          normals: normals.slice(start, start + 9),
          indices: [a, b, c],
        });
      } else {
        // Otherwise, include the triangle in the new geometry.
        newVertices.push(
          positions[start],
          positions[start + 1],
          positions[start + 2],
          positions[start + 3],
          positions[start + 4],
          positions[start + 5],
          positions[start + 6],
          positions[start + 7],
          positions[start + 8]
        );
        newNormals.push(
          normals[start],
          normals[start + 1],
          normals[start + 2],
          normals[start + 3],
          normals[start + 4],
          normals[start + 5],
          normals[start + 6],
          normals[start + 7],
          normals[start + 8]
        );
      }
    }

    const newGeometry = new THREE.BufferGeometry();
    newGeometry.setAttribute(
      "position",
      new THREE.Float32BufferAttribute(newVertices, 3)
    );
    newGeometry.setAttribute(
      "normal",
      new THREE.Float32BufferAttribute(newNormals, 3)
    );

    return { newGeometry, deletedTrianglesList };
  }, []);

  // Instead of a simple cut that returns a new geometry, we now call our helper.
  const cutGeometryWithShape = useCallback(
    (geometry, basePoints, height, normal) => {
      // Always use the provided normal.
      const verticesWithinField = getVerticesWithinField(
        geometry,
        basePoints,
        normal,
        height
      );
      return removeTrianglesAndCollect(geometry, verticesWithinField);
    },
    [getVerticesWithinField, removeTrianglesAndCollect]
  );

  // --------------------------
  // Modified "Apply Cut" Function
  // --------------------------
  // This function applies the scalpel cut to each selected model and builds an undo object
  // in the form:
  // {
  //   type: "scalpel",
  //   geometries: { [object.name]: deletedTrianglesList, ... }
  // }
  const modifyGeometryAndAddToScene = useCallback(() => {
    const modelGroup = modelRef.current;
    const scene = sceneRef.current;
    if (!modelGroup || !scene) {
      setProcessing(false);
      console.error("Model or scene not found.");
      return;
    }
    if (!drawPoints || drawPoints.length < 3) {
      setProcessing(false);
      console.error(
        "Invalid drawPoints. Ensure there are at least 3 points to form a shape."
      );
      return;
    }
    // Ensure the drawn shape is closed.
    const points = drawPoints[0].equals(drawPoints[drawPoints.length - 1])
      ? drawPoints
      : [...drawPoints, drawPoints[0]];

    const normal = getCameraNormalVector();
    // Prepare a single scalpel undo object for all selected models.
    const scalpelUndo = {
      type: "scalpel",
      // Instead of a map per mesh, we now store a single drawn shape and
      // a list of mesh names that were affected.
      meshes: [], // List of affected mesh names.
      drawnShape: points, // The closed shape used for the cut.
      cutHeight: cutHeight, // The cut height parameter.
      normal: normal, // The normal vector used for the cut.
      // (Optional) A mapping of deleted triangles for each mesh for undo:
      deletedTriangles: {},
    };

    // Loop through each selected model (by index)
    selectedModels.forEach((modelIndex) => {
      const model = modelGroup.children[modelIndex];
      if (!model) return;
      const geometry = model.geometry;
      // Apply the cut operation to the geometry.
      const { newGeometry, deletedTrianglesList } = cutGeometryWithShape(
        geometry,
        points,
        cutHeight,
        normal
      );

      // If the cut operation removed any triangles, record the mesh name and its undo data.
      if (deletedTrianglesList.length > 0) {
        scalpelUndo.meshes.push(model.name);
        scalpelUndo.deletedTriangles[model.name] = deletedTrianglesList;
      }

      // Replace the model's geometry.
      model.geometry.dispose();
      model.geometry = newGeometry;
    });

    // Clear the drawn shape and remove the temporary drawing line.
    setDrawPoints([]);
    if (scene.userData.currentDrawingLine) {
      scene.remove(scene.userData.currentDrawingLine);
      scene.userData.currentDrawingLine.geometry.dispose();
      scene.userData.currentDrawingLine.material.dispose();
      scene.userData.currentDrawingLine = null;
    }

    // Push the undo object onto the undo stack (clearing any redo history beyond the current index).
    undoStack.current = undoStack.current.slice(0, undoIndex.current);
    undoStack.current.push(scalpelUndo);
    undoIndex.current += 1;

    setProcessing(false);
  }, [
    drawPoints,
    modelRef,
    sceneRef,
    selectedModels,
    cutHeight,
    cutGeometryWithShape,
    undoStack,
    undoIndex,
    setDrawPoints,
    setProcessing,
    getCameraNormalVector,
  ]);

  // --------------------------
  // Event Handlers & Throttling
  // --------------------------

  useEffect(() => {
    throttledMouseMove.current = throttle((event) => {
      if (!isDrawing || event.target.closest(".ui-outside-scene")) return;

      const rect = containerRef.current.getBoundingClientRect();
      const x = ((event.clientX - rect.left) / rect.width) * 2 - 1;
      const y = -((event.clientY - rect.top) / rect.height) * 2 + 1;
      const raycaster = new THREE.Raycaster();
      raycaster.setFromCamera({ x, y }, cameraRef.current);
      const visibleModels = modelRef.current.children.filter(
        (child) => child.visible
      );
      if (visibleModels.length > 0) {
        const intersects = raycaster.intersectObjects(visibleModels, true);
        if (intersects.length > 0) {
          const intersectPoint = intersects[0].point;
          if (
            drawPoints.length === 0 ||
            drawPoints[drawPoints.length - 1].distanceTo(intersectPoint) >
            MIN_DISTANCE
          ) {
            setDrawPoints((prevPoints) => [...prevPoints, intersectPoint]);
            // Draw the temporary line:
            const points = [...drawPoints, intersectPoint];
            const curve = new THREE.CatmullRomCurve3(points);
            const tubeGeometry = new THREE.TubeGeometry(
              curve,
              64,
              0.5,
              8,
              true
            );
            const tubeMaterial = new THREE.MeshBasicMaterial({
              color: 0xff0000,
            });
            const line = new THREE.Mesh(tubeGeometry, tubeMaterial);
            if (sceneRef.current.userData.currentDrawingLine) {
              sceneRef.current.remove(
                sceneRef.current.userData.currentDrawingLine
              );
            }
            sceneRef.current.add(line);
            sceneRef.current.userData.currentDrawingLine = line;
          }
        }
      }
    }, 100);

    throttledTouchMove.current = throttle((event) => {
      if (!isDrawing || event.target.closest(".ui-outside-scene")) return;
      event.preventDefault();
      const touch = event.touches[0];
      const rect = containerRef.current.getBoundingClientRect();
      const x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
      const y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
      const raycaster = new THREE.Raycaster();
      raycaster.setFromCamera({ x, y }, cameraRef.current);
      const visibleModels = modelRef.current.children.filter(
        (child) => child.visible
      );
      if (visibleModels.length > 0) {
        const intersects = raycaster.intersectObjects(visibleModels, true);
        if (intersects.length > 0) {
          const intersectPoint = intersects[0].point;
          if (
            drawPoints.length === 0 ||
            drawPoints[drawPoints.length - 1].distanceTo(intersectPoint) >
            MIN_DISTANCE
          ) {
            setDrawPoints((prevPoints) => [...prevPoints, intersectPoint]);
            const points = [...drawPoints, intersectPoint];
            const curve = new THREE.CatmullRomCurve3(points);
            const tubeGeometry = new THREE.TubeGeometry(
              curve,
              64,
              0.1,
              8,
              true
            );
            const tubeMaterial = new THREE.MeshBasicMaterial({
              color: 0xff0000,
            });
            const line = new THREE.Mesh(tubeGeometry, tubeMaterial);
            if (sceneRef.current.userData.currentDrawingLine) {
              sceneRef.current.remove(
                sceneRef.current.userData.currentDrawingLine
              );
            }
            sceneRef.current.add(line);
            sceneRef.current.userData.currentDrawingLine = line;
          }
        }
      }
    }, 100);
  }, [isDrawing, cameraRef, containerRef, modelRef, sceneRef, drawPoints]);

  const handleMouseMove = useCallback((event) => {
    if (throttledMouseMove.current) {
      throttledMouseMove.current(event);
    }
  }, []);

  const handleTouchMove = useCallback((event) => {
    if (throttledTouchMove.current) {
      throttledTouchMove.current(event);
    }
  }, []);

  const handleMouseDown = useCallback(
    (event) => {
      if (!isScalpelMode || event.target.closest(".ui-outside-scene")) return;
      if (event.button !== 0) return;
      const rect = containerRef.current.getBoundingClientRect();
      const x = ((event.clientX - rect.left) / rect.width) * 2 - 1;
      const y = -((event.clientY - rect.top) / rect.height) * 2 + 1;
      const raycaster = new THREE.Raycaster();
      raycaster.setFromCamera({ x, y }, cameraRef.current);
      const visibleModels = modelRef.current.children.filter(
        (child) => child.visible
      );
      if (visibleModels.length > 0) {
        const intersects = raycaster.intersectObjects(visibleModels, true);
        if (intersects.length > 0) {
          const intersectPoint = intersects[0].point;
          setDrawPoints([intersectPoint]);
          setIsDrawing(true);
          // Disable rotation while drawing.
          controlsRef.current.noRotate = true;
        }
      }
    },
    [isScalpelMode, cameraRef, containerRef, modelRef, controlsRef]
  );

  const handleMouseUp = useCallback(
    (event) => {
      if (event.target.closest(".ui-outside-scene")) return;
      if (!isDrawing) return;
      setIsDrawing(false);
      controlsRef.current.noRotate = false;
    },
    [isDrawing, controlsRef]
  );

  const handleTouchStart = useCallback(
    (event) => {
      if (event.target.closest(".ui-outside-scene")) return;
      event.preventDefault();
      const touch = event.touches[0];
      const rect = containerRef.current.getBoundingClientRect();
      const x = ((touch.clientX - rect.left) / rect.width) * 2 - 1;
      const y = -((touch.clientY - rect.top) / rect.height) * 2 + 1;
      const raycaster = new THREE.Raycaster();
      raycaster.setFromCamera({ x, y }, cameraRef.current);
      const visibleModels = modelRef.current.children.filter(
        (child) => child.visible
      );
      if (visibleModels.length > 0) {
        const intersects = raycaster.intersectObjects(visibleModels, true);
        if (intersects.length > 0) {
          const intersectPoint = intersects[0].point;
          setDrawPoints([intersectPoint]);
          setIsDrawing(true);
          controlsRef.current.noRotate = true;
        }
      }
    },
    [containerRef, cameraRef, modelRef, controlsRef]
  );

  const handleTouchEnd = useCallback(
    (event) => {
      if (event.target.closest(".ui-outside-scene")) return;
      event.preventDefault();
      if (!isDrawing) return;
      setIsDrawing(false);
      controlsRef.current.noRotate = false;
    },
    [isDrawing, controlsRef]
  );

  const clearDrawPoints = useCallback(() => {
    setDrawPoints([]);
    if (sceneRef.current.userData.currentDrawingLine) {
      sceneRef.current.remove(sceneRef.current.userData.currentDrawingLine);
      sceneRef.current.userData.currentDrawingLine.geometry.dispose();
      sceneRef.current.userData.currentDrawingLine.material.dispose();
      sceneRef.current.userData.currentDrawingLine = null;
    }
  }, [sceneRef]);

  // Undo function for scalpel mode
  // It expects a scalpelUndo object with the structure:
  // { type: "scalpel", geometries: { [objectName]: deletedTrianglesList, ... } }
  const undoScalpel = useCallback(
    (scalpelUndo) => {
      // Iterate over each affected mesh name.
      scalpelUndo.meshes.forEach((meshName) => {
        // Get the deletedTrianglesList from the new undo object structure.
        const deletedTrianglesList = scalpelUndo.deletedTriangles[meshName];

        // Ensure that deletedTrianglesList is an array.
        if (!Array.isArray(deletedTrianglesList)) {
          console.error(
            `Expected deletedTrianglesList to be an array for mesh ${meshName}, but got:`,
            deletedTrianglesList
          );
          return;
        }

        // Find the mesh from your modelRef by its name.
        const mesh = modelRef.current.children.find(
          (child) => child.name === meshName
        );
        if (!mesh) {
          console.error(`Mesh ${meshName} not found for undoScalpel`);
          return;
        }

        // Get the current geometry's positions and normals as arrays.
        const geometry = mesh.geometry;
        const positions = Array.from(geometry.attributes.position.array);
        const normals = Array.from(geometry.attributes.normal.array);

        // Append each deleted triangle's vertices and normals.
        deletedTrianglesList.forEach(({ vertices, normals: triNormals }) => {
          positions.push(...vertices);
          normals.push(...triNormals);
        });

        // Rebuild the geometry with the restored triangles.
        const newGeometry = new THREE.BufferGeometry();
        newGeometry.setAttribute(
          "position",
          new THREE.Float32BufferAttribute(positions, 3)
        );
        newGeometry.setAttribute(
          "normal",
          new THREE.Float32BufferAttribute(normals, 3)
        );

        // Dispose of the old geometry and update the mesh.
        mesh.geometry.dispose();
        mesh.geometry = newGeometry;
      });
    },
    [modelRef]
  );

  const performScalpelAutoMultiple = useCallback(
    (scalpelData) => {
      // scalpelData is expected to have the structure:
      // {
      //   meshes: [meshName1, meshName2, ...],
      //   drawnShape: [/* the closed set of drawPoints */],
      //   cutHeight: number,
      //   normal: THREE.Vector3 (the normal vector used for the cut)
      // }
      const { meshes, drawnShape, cutHeight, normal } = scalpelData;
      if (!meshes || !drawnShape || cutHeight === undefined || !normal) {
        console.error(
          "Invalid scalpelData. Ensure that meshes, drawnShape, cutHeight, and normal are provided."
        );
        return;
      }

      meshes.forEach((meshName) => {
        // Find the mesh by its name.
        const mesh = modelRef.current.children.find(
          (child) => child.name === meshName
        );
        if (!mesh) {
          console.error(`Mesh ${meshName} not found.`);
          return;
        }
        const normalVector = new THREE.Vector3(normal.x, normal.y, normal.z);

        // Reapply the scalpel operation using the common drawnShape, cutHeight, and normal.
        const { newGeometry } = cutGeometryWithShape(
          mesh.geometry,
          drawnShape,
          cutHeight,
          normalVector
        );

        // Replace the mesh's geometry.
        mesh.geometry.dispose();
        mesh.geometry = newGeometry;

        console.log(`Scalpel operation reapplied on mesh: ${meshName}`);
      });
    },
    [modelRef, cutGeometryWithShape]
  );

  // --------------------------
  // Register Event Listeners
  // --------------------------
  useEffect(() => {
    if (isScalpelMode) {
      const currentContainer = containerRef.current;
      currentContainer.addEventListener("mousedown", handleMouseDown);
      currentContainer.addEventListener("mousemove", handleMouseMove);
      currentContainer.addEventListener("mouseup", handleMouseUp);
      currentContainer.addEventListener("touchstart", handleTouchStart, {
        passive: false,
      });
      currentContainer.addEventListener("touchmove", handleTouchMove, {
        passive: false,
      });
      currentContainer.addEventListener("touchend", handleTouchEnd);

      return () => {
        currentContainer.removeEventListener("mousedown", handleMouseDown);
        currentContainer.removeEventListener("mousemove", handleMouseMove);
        currentContainer.removeEventListener("mouseup", handleMouseUp);
        currentContainer.removeEventListener("touchstart", handleTouchStart);
        currentContainer.removeEventListener("touchmove", handleTouchMove);
        currentContainer.removeEventListener("touchend", handleTouchEnd);
        throttledMouseMove.current?.cancel();
        throttledTouchMove.current?.cancel();
      };
    }
  }, [
    isScalpelMode,
    handleMouseDown,
    handleMouseMove,
    handleMouseUp,
    handleTouchMove,
    handleTouchStart,
    handleTouchEnd,
    containerRef,
  ]);

  return {
    isScalpelMode,
    setIsScalpelMode,
    setDrawPoints,
    processing,
    setProcessing,
    cutHeight,
    setCutHeight,
    modifyGeometryAndAddToScene,
    clearDrawPoints,
    performScalpelAutoMultiple,
    undoScalpel,
  };
};

export default useScalpelMode;
