import * as d3 from 'd3';
import * as d3Sankey from 'd3-sankey';
import * as R from 'ramda';
import * as RA from 'ramda-adjunct';
import React, { useCallback, useEffect, useMemo } from 'react';
import styled from 'styled-components';

import SankeyDiagramLink from 'components/lib/charts/SankeyDiagramLink';
import SankeyDiagramNode from 'components/lib/charts/SankeyDiagramNode';

import useTheme from 'lib/hooks/useTheme';
import type { Identifiable, Link, LinkWithExtraProps, NodeWithExtraProps } from 'lib/ui/sankey';
import {
  computeColorForNodes,
  computeNodePercentages,
  computeNodeSpacing,
  getDiagramMaxHeight,
} from 'lib/ui/sankey';

const DIAGRAM_MINIMUM_WIDTH_PX = 768;
const DIAGRAM_MINIMUM_HEIGHT_PX = 520;
export const DIAGRAM_DEFAULT_MARGIN_PX = 24;

const Root = styled.div<{ $marginHorizontal: number; $marginVertical: number }>`
  font-family: inherit;
  position: relative;
  margin: 0 ${({ $marginHorizontal }) => $marginHorizontal}px;

  svg {
    margin-right: ${({ $marginHorizontal }) => $marginHorizontal}px;

    .node,
    .link {
      transition: opacity 0.1s ease-out;
    }

    .dimmed {
      opacity: 0.3;
    }

    .highlighted {
      opacity: 1;
    }

    .is-clickable {
      cursor: pointer;
    }

    .node:not(.is-clickable) {
      cursor: default;
    }
  }
`;

type Props<N extends Identifiable, L extends Link, NodeT extends NodeWithExtraProps<N, L> = any> = {
  data: { nodes: N[]; links: L[] };
  width?: number;
  height?: number;
  className?: string;
  highlightOnHover?: boolean;
  margin?: number | [number, number];
  isLoading?: boolean;
  customNodesColorDeterminer?: (
    nodes: NodeT[],
    links: d3Sankey.SankeyLinkMinimal<any, any>[],
    theme: ReturnType<typeof useTheme>,
  ) => void;
  onClickNode?: (node: NodeT) => void;
  /** Formats the node value. */
  formatValue?: (node: NodeT) => string | undefined;
  /** Using the original node type because it's called before computing the other properties. */
  onAlignNode?: (nodes: d3Sankey.SankeyNode<N, any>, numColumns: number) => number;
  /** Using the original node type because it's called before computing the other properties. */
  onSortNodes?: (a: d3Sankey.SankeyNode<N, L>, b: d3Sankey.SankeyNode<N, L>) => number;
  /** Using the original node type because it's called before computing the other properties. */
  onSortLinks?: (
    a: d3Sankey.SankeyLinkMinimal<N, L>,
    b: d3Sankey.SankeyLinkMinimal<N, L>,
  ) => number;
};

const SankeyDiagram = <N extends Identifiable, L extends Link>({
  data,
  width = DIAGRAM_MINIMUM_WIDTH_PX,
  height,
  className,
  highlightOnHover,
  margin = DIAGRAM_DEFAULT_MARGIN_PX,
  customNodesColorDeterminer,
  formatValue,
  onClickNode,
  onAlignNode,
  onSortNodes,
  onSortLinks,
}: Props<N, L>) => {
  const theme = useTheme();

  const [marginVertical, marginHorizontal] = useMemo(
    () => (RA.isArray(margin) ? margin : [margin, margin]),
    [margin],
  );

  // Calculates the responsive width of the Sankey Diagram.
  // It ensures that the width is at least equal to `DIAGRAM_MINIMUM_WIDTH_PX`, and then adds twice the horizontal margin to this value.
  // This width is used to set the SVG's width, ensuring the diagram fits within its container with the specified margin.
  const responsiveWidth = useMemo(
    () => Math.max(width, DIAGRAM_MINIMUM_WIDTH_PX) - marginHorizontal * 2,
    [width, marginHorizontal],
  );

  // This property is used to vertically transform `g` SVG components in the diagram.
  // The translation value is equal to the `marginVertical` prop, which defines the vertical margin around the diagram.
  // It ensures that the diagram is properly positioned within its container, taking into account the specified margin.
  const groupTransformStyle = useMemo(
    () => ({ transform: `translateY(${marginVertical}px)` }),
    [marginVertical],
  );

  const [computedSankeyLayout, setComputedSankeyLayout] = React.useState<
    d3Sankey.SankeyGraph<NodeWithExtraProps<N, L>, LinkWithExtraProps<N, L>>
  >({
    nodes: [],
    links: [],
  });

  /**
   * This function creates a new Sankey layout using the provided nodes and links, and configures it with the specified properties.
   * It clones the nodes and links before passing them to the Sankey layout because d3-sankey mutates the data.
   *
   * It then computes the spacing and percentages for the nodes and determines their colors. If a custom color determiner
   * function is provided, it uses that. Otherwise, it uses the default color computation function.
   *
   * Finally, it updates the Sankey layout with the new values to recalculate the links and returns the updated layout.
   */
  const computeSankeyLayout = useCallback(() => {
    let sankey = d3Sankey
      .sankey<N, L>()
      .nodeId((node) => node.id)
      .nodeAlign(onAlignNode ?? d3Sankey.sankeyRight)
      .nodePadding(0)
      .nodeWidth(parseInt(theme.spacing.default, 10))
      .size([responsiveWidth, height ?? DIAGRAM_MINIMUM_HEIGHT_PX]);

    // Sort the nodes by a custom function, if provided
    if (onSortNodes) {
      sankey = sankey.nodeSort(onSortNodes);
    }

    // Sort the links by a custom function, if provided
    if (onSortLinks) {
      sankey = sankey.linkSort(onSortLinks);
    }

    // We need to clone nodes and links here because d3-sankey mutates the data
    const createdSankey = sankey({
      nodes: R.clone(data.nodes),
      links: R.clone(data.links),
    });

    computeNodeSpacing(createdSankey.nodes);

    if (customNodesColorDeterminer) {
      customNodesColorDeterminer(createdSankey.nodes, createdSankey.links, theme);
    } else {
      computeColorForNodes(createdSankey.nodes, createdSankey.links);
    }

    computeNodePercentages(createdSankey.nodes);

    // Recompute based on spacing and percentages
    const updatedSankey = sankey.update(createdSankey);

    // @ts-ignore - Types changed because of the above `compute...` functions
    setComputedSankeyLayout(updatedSankey);
  }, [
    data,
    theme,
    responsiveWidth,
    height,
    customNodesColorDeterminer,
    onAlignNode,
    onSortNodes,
    onSortLinks,
  ]);

  useEffect(() => {
    requestAnimationFrame(computeSankeyLayout);
  }, [data, responsiveWidth, computeSankeyLayout]);

  // TODO(@vanessa): It should be possible to use CSS here instead of d3
  const onNodeMouseEnter = useCallback(
    (node: NodeWithExtraProps<N, L>) => {
      if (!highlightOnHover) {
        return;
      }

      const withLinks = (node.targetLinks ?? []).concat(node.sourceLinks ?? []).filter(Boolean);
      const nodeIdsToHighlight = new Set([
        node.id,
        ...withLinks.flatMap(({ source, target }) => [target.id, source.id]),
      ]);

      // Reset any previous state
      d3.selectAll('.link, .node').classed('dimmed', true).classed('highlighted', false);

      // Highlight hovered node and its adjacent links and nodes
      nodeIdsToHighlight.forEach((nodeId) => {
        d3.selectAll(
          `.link[data-source-id='${nodeId}'], .link[data-target-id='${nodeId}'], .node[data-id='${nodeId}']`,
        )
          .classed('highlighted', true)
          .classed('dimmed', false);
      });

      // Bring the hovered node to the front so it can be hovered
      d3.select(`.node[data-id='${node.id}']`).raise();
    },
    [highlightOnHover],
  );

  const onNodeMouseLeave = useCallback(() => {
    if (!highlightOnHover) {
      return;
    }

    // Just reset all the nodes and links to their original state.
    d3.selectAll('.link, .node').classed('dimmed', false).classed('highlighted', false);
  }, [highlightOnHover]);

  // Use the maximum value between the provided height and the maximum height of the
  // computedSankeyLayout's nodes. It then adds twice the marginVertical to this value.
  // The result is memoized to optimize performance.
  const responsiveHeight = useMemo(
    () =>
      Math.max(
        height ?? getDiagramMaxHeight(computedSankeyLayout.nodes),
        DIAGRAM_MINIMUM_HEIGHT_PX,
      ) +
      marginVertical * 2,
    [computedSankeyLayout.nodes, height, marginVertical],
  );

  return (
    <Root
      className={className}
      $marginHorizontal={marginHorizontal}
      $marginVertical={marginVertical}
    >
      <svg width={responsiveWidth} height={responsiveHeight} id="monarch-sankey-diagram">
        <g style={groupTransformStyle}>
          {computedSankeyLayout.links.map((link) => (
            <SankeyDiagramLink key={link.index} link={link} />
          ))}
        </g>
        <g style={groupTransformStyle}>
          {computedSankeyLayout.nodes.map((node) => (
            <SankeyDiagramNode
              key={node.id}
              node={node}
              diagramWidth={DIAGRAM_MINIMUM_WIDTH_PX}
              onClick={onClickNode}
              onMouseEnter={onNodeMouseEnter}
              onMouseLeave={onNodeMouseLeave}
              formatValue={formatValue}
            />
          ))}
        </g>
      </svg>
    </Root>
  );
};

export default React.memo(SankeyDiagram);
