import type { SankeyNodeMinimal, SankeyLinkMinimal, SankeyNode } from 'd3-sankey';
import * as R from 'ramda';
import * as RA from 'ramda-adjunct';

type GenericSankeyNode = SankeyNodeMinimal<any, any> & { id: string };
type GenericSankeyLink = SankeyLinkMinimal<any, any>;

export type Identifiable = { id: string };
export type Link = { source: string; target: string; value: number };

/** This is the node type after processing. */
export type NodeWithExtraProps<N extends Identifiable, L extends Link> = Omit<
  SankeyNode<N, L>,
  'sourceLinks' | 'targetLinks'
> & {
  color: string;
  percent: number;
  label: string;
  sourceLinks: LinkWithExtraProps<N, L>[];
  targetLinks: LinkWithExtraProps<N, L>[];
  isClickable?: boolean;
  type: any;
};

/** This is the link type after processing. */
export type LinkWithExtraProps<N extends Identifiable, L extends Link> = Omit<
  SankeyLinkMinimal<N, L>,
  'source' | 'target'
> & {
  source: NodeWithExtraProps<N, L>;
  target: NodeWithExtraProps<N, L>;
};

const byDepth = R.ascend((node: GenericSankeyNode) => node.depth ?? 0);
const byY0 = R.ascend((node: GenericSankeyNode) => node.y0 ?? 0);

/**
 * Better calculate the node spacing for the Sankey diagram,
 * taking label heights into consideration.
 *
 * This mutates the nodes array in-place.
 */
export const computeNodeSpacing = <T extends GenericSankeyNode>(
  nodes: T[],
  labelHeight = 24,
  padding = 16,
) => {
  const minSpacing = labelHeight + padding;

  // Sort by depth, then by top-to-bottom
  nodes = R.sortWith([byDepth, byY0], nodes);

  // We use slice(0, -1) to exclude the last node because there's no "next node" for the last one
  nodes.slice(0, -1).forEach((currentNode, i) => {
    const nextNode = nodes[i + 1];

    // Only adjust if they're in the same depth/column
    if (currentNode.depth === nextNode.depth) {
      const overlap = (currentNode.y1 ?? 0) + minSpacing - (nextNode.y0 ?? 0);

      if (overlap > 0 && RA.isNotNil(nextNode.y0) && RA.isNotNil(nextNode.y1)) {
        // Move next node down by the overlap amount
        nextNode.y0 += overlap;
        nextNode.y1 += overlap;
      }
    }
  });
};

/** Compute the percentage of each node's value relative to the total value of its column. */
export const computeNodePercentages = <T extends GenericSankeyNode & { percent?: number }>(
  nodes: T[],
) => {
  const columnTotals = nodes.reduce(
    (acc, { depth = 0, value = 0 }) => {
      acc[depth] = (acc[depth] ?? 0) + (value ?? 0);
      return acc;
    },
    {} as Record<number, number>,
  );

  nodes.forEach((node) => {
    node.percent = (node.value ?? 0) / columnTotals[node.depth ?? 0];
  });
};

export const computeColorForNodes = <
  N extends GenericSankeyNode & { color?: string },
  L extends GenericSankeyLink,
>(
  nodes: N[],
  links: L[],
) => {
  nodes.forEach((node) => {
    if (RA.isNotNil(node.color)) {
      // If the node has a color already, don't override it
    } else {
      setNodeColorFromSource(node, links);
    }
  });
};

export const setNodeColorFromSource = <
  N extends SankeyNodeMinimal<any, any> & { color?: string },
  L extends { target?: N; source?: N },
>(
  node: N,
  links: L[],
) => {
  const sourceLink = links.find((link) => link.target === node);
  if (sourceLink) {
    node.color = sourceLink.source?.color;
  }
};

export const setColorNodeFromTarget = <
  N extends SankeyNodeMinimal<any, any> & { color?: string },
  L extends { target?: N; source?: N },
>(
  node: N,
  links: L[],
) => {
  if (node.color) {
    // If the node has a color already, don't override it
    return;
  }

  const targetLink = links.find((link) => link.source === node);
  if (targetLink) {
    node.color = targetLink.target?.color;
  }
};

export const getDiagramMaxHeight = <T extends { y1?: number }>(nodes: T[]) =>
  Math.max(...nodes.map(({ y1 }) => y1 ?? 0));
