import { useCallback, useContext } from 'react'
import { Node, Edge } from '@xyflow/react'
import { FlowContext } from './context'
import { NodeData, OutputField } from './model'

// Custom hook for node operations (change, delete, labeling)
export const useNodeOperations = (
  nodes: Node[],
  setNodes: Function,
  setEdges: Function
) => {
  const handleNodeChange = useCallback(
    (id: string, newData: any) => {
      setNodes((nds: Node[]) =>
        nds.map((node) =>
          node.id === id
            ? { ...node, data: { ...node.data, ...newData } }
            : node
        )
      )
    },
    [setNodes]
  )

  const handleNodeDelete = useCallback(
    (id: string) => {
      setNodes((nds: Node[]) => nds.filter((node) => node.id !== id))
      setEdges((eds: Edge[]) =>
        eds.filter((edge) => edge.source !== id && edge.target !== id)
      )
    },
    [setNodes, setEdges]
  )

  return { handleNodeChange, handleNodeDelete }
}

interface NodeOutput {
  node: Node
  branch: string | null
  output: OutputField[]
}

interface GroupedNode {
  node: Node
  branches: { name: string | null; output: OutputField[] }[]
}

export const useWorkflowFunctions = () => {
  const context = useContext(FlowContext)
  const strNodes = JSON.stringify(context?.nodes || {})
  const strEdges = JSON.stringify(context?.edges || {})

  const getNode = useCallback(
    (nodeId: string) => {
      if (!context?.nodes) return null
      return context.nodes.find((node: any) => node.id === nodeId)
    },
    [strNodes]
  )

  const getNodeOutput = useCallback(
    (nodeId: string, branchName: string | null = null): NodeOutput | null => {
      const node = getNode(nodeId)
      if (!node?.data) return null
      const nodeData = node.data.data as NodeData
      return {
        node,
        branch: branchName,
        output: branchName
          ? nodeData.branchOutput[branchName]?.output
          : nodeData.output,
      }
    },
    []
  )

  const getParentNodes = useCallback(
    (
      nodeId: string,
      ignoredNodeTypes: string[] = [],
      nextParents: boolean = false
    ) => {
      if (!context?.nodes || !context?.edges) return []
      const parentEdges = context.edges.filter((edge) => edge.target === nodeId)
      const parentNodesOutput = parentEdges
        .map((edge) => getNodeOutput(edge.source, edge.sourceHandle))
        .filter((output): output is NodeOutput => !!output)
      const validNodesOutput = parentNodesOutput.filter(
        (output) => !ignoredNodeTypes.includes(output.node.type ?? '')
      )
      if (nextParents) {
        validNodesOutput.forEach((node) => {
          const parentNodes = getParentNodes(
            node.node.id,
            ignoredNodeTypes,
            nextParents
          )
          validNodesOutput.push(...parentNodes)
        })
      }
      const ignoredNodesOutput = parentNodesOutput.filter((output) =>
        ignoredNodeTypes.includes(output.node.type ?? '')
      )
      ignoredNodesOutput.forEach((node) => {
        const parentNodes = getParentNodes(node.node.id, ignoredNodeTypes)
        validNodesOutput.push(...parentNodes)
      })
      const uniqueNodesOutput: NodeOutput[] = []
      if (nextParents) {
        validNodesOutput.forEach((output) => {
          if (
            !uniqueNodesOutput.find((item) => item.node.id === output.node.id)
          ) {
            uniqueNodesOutput.push(output)
          }
        })
      }
      return uniqueNodesOutput
    },
    [strNodes, strEdges, getNode]
  )

  const getOutputField = useCallback(
    (nodes: ReturnType<typeof getParentNodes>, outputName: string) => {
      for (const node of nodes) {
        const output = node?.output.find((output) => output.name === outputName)
        if (output) {
          return output
        }
      }
    },
    []
  )

  const groupNodes = useCallback((nodes: NodeOutput[]) => {
    return nodes
      .sort((a, b) =>
        `${a.node?.id},${a.branch}`.localeCompare(`${b.node?.id},${b.branch}`)
      )
      .reduce((acc, cur) => {
        if (acc.length === 0) {
          return [
            {
              node: cur.node,
              branches: [{ name: cur.branch, output: cur.output }],
            },
          ]
        }
        const data = acc[acc.length - 1]
        if (data.node.id === cur.node.id) {
          data.branches.push({ name: cur.branch, output: cur.output })
        } else {
          acc.push({
            node: cur.node,
            branches: [{ name: cur.branch, output: cur.output }],
          })
        }
        return acc
      }, [] as GroupedNode[])
  }, [])

  return {
    getNode,
    getNodeOutput,
    getParentNodes,
    getOutputField,
    groupNodes,
  }
}
