import { flextree } from 'd3-flextree';
import { stratify } from 'd3-hierarchy';
import { useEffect } from 'react';
import { Node, Edge, ReactFlowState, useStore, useReactFlow } from 'reactflow';
import { useTreeStore } from '@/context/tree-context';
import { NODE_PAGINATION_HEIGHT } from '@/components/node-pagination';
import { DEFAULT_GRID_GUTTER } from '@/components/group-node';

const layout = flextree<Node>({
  spacing: 4,
});

const nodeCountSelector = (state: ReactFlowState) => state.nodeInternals.size;
const nodesInitializedSelector = (state: ReactFlowState) =>
  Array.from(state.nodeInternals.values())
    .filter(
      (node) =>
        node.type === 'object' ||
        node.type === 'image' ||
        node.type === 'group' ||
        node.type === 'table'
    )
    .every((node) => node.width && node.height);

export function layoutNodes(nodes: Node[], edgesPositioning: Edge[]) {
  const sizedNodes = nodes
    .filter((node) => !node.parentNode)
    .map((node) => {
      let height = Number(node.style?.height) || 0;

      // Take into account the height of the pagination component for table and group nodes
      // 20 -> spacing between nodes
      if (node.type === 'table' || node.type === 'group') {
        height += 2 * NODE_PAGINATION_HEIGHT + 20;
      }

      return {
        ...node,
        size: [height, node.style?.width ? Number(node.style.width) + 100 : 0],
      };
    });

  const hierarchy = stratify<Node>()
    .id((d) => d.id)
    .parentId(
      (d) => edgesPositioning.find((e: Edge) => e.target === d.id)?.source
    )(sizedNodes);

  const root = layout(hierarchy);

  return nodes.map((node) => {
    if (node.parentNode || !node.width || !node.height) {
      return node;
    }

    const { x, y } = root.find((d) => d.id === node.id) || {
      x: node.position.x,
      y: node.position.y,
    };

    const position = { x: y, y: x - node.height / 2 };

    return {
      ...node,
      position,
    };
  });
}

export function useAutoLayout() {
  const nodeCount = useStore(nodeCountSelector);
  const nodesInitialized = useStore(nodesInitializedSelector);
  const { getNodes, getEdges, setNodes, setEdges } = useReactFlow();
  const edgesPositioning = useTreeStore((state) => state.edgesPositioning);

  useEffect(() => {
    // only run the layout if there are nodes and they have been initialized with their dimensions
    if (!nodeCount || !nodesInitialized) {
      return;
    }

    const nodes: Node[] = getNodes();

    setNodes(layoutNodes(nodes, edgesPositioning));
  }, [nodeCount, nodesInitialized, getNodes, getEdges, setNodes, setEdges]);
}
