import { TreeGraphNode } from 'types/backend/response/TreeGraph';
import React, { FunctionComponent, useContext } from 'react';
import { mean } from 'd3-array';
import { useAvgKeywordEmbeddingColor } from 'hooks/useAvgKeywordEmbeddingColor';
import { rgb } from 'd3-color';
import { scaleLinear } from '@visx/scale';
import { interpolateLab } from 'd3-interpolate';
import NodeHighlightContext from 'App/NodeHighlightContext';
import TreeDisplayStyleContext from 'App/TreeDisplayStyleContext';
import { TreeDisplayStyle } from 'types/common/TreeDisplayStyle';

const nodeProbabilityColorScale = scaleLinear<string>().domain([0, 1]).range(['#6c757d', '#007bff']);

const keywordColorScale = (keywordScore: number, meanEmbeddingColor: string) =>
    interpolateLab('#fff', meanEmbeddingColor)(keywordScore);

interface Props {
    x: number;
    y: number;
    width: number;
    height: number;
    treeGraphNode: TreeGraphNode;
}

const NodeBox: FunctionComponent<Props> = (props) => {
    const { treeDisplayStyle } = useContext(TreeDisplayStyleContext);

    return (
        <NodeBoxWithTreeStyle
            {...props}
            nodeProbabilityColoring={treeDisplayStyle.nodeProbabilityColoring}
            nodeEmbeddingColoring={treeDisplayStyle.nodeEmbeddingColoring}
        />
    );
};

type NodeBoxWithTreeStyleProps = Props & {
    nodeProbabilityColoring: TreeDisplayStyle['nodeProbabilityColoring'];
    nodeEmbeddingColoring: TreeDisplayStyle['nodeEmbeddingColoring'];
};

const NodeBoxWithTreeStyle: FunctionComponent<NodeBoxWithTreeStyleProps> = ({
    x,
    y,
    width,
    height,
    treeGraphNode,
    nodeEmbeddingColoring,
    nodeProbabilityColoring,
}) => {
    const { subTreeGraph } = useContext(NodeHighlightContext);

    const subTreeGraphNode = subTreeGraph?.nodes.find((n) => n.id === treeGraphNode.id);

    const keywords = subTreeGraphNode ? subTreeGraphNode.keywords : treeGraphNode.keywords;

    // Color by the mean of the extracted keywords, if there are any. Otherwise, return null.
    const meanKeywordScore = keywords.length > 0 ? mean(keywords.map((kw) => kw.score)) : null;

    // Compute the average node embedding and get its color.
    const meanEmbeddingColor = useAvgKeywordEmbeddingColor(keywords, rgb(0, 0, 0, 0), [keywords]).formatRgb();

    // Create fill color as a mixture of keyword importance and node embeddings
    //   - 2D-projected node embedding determines color
    //   - keyword importance determines mix between embedding color and white (=> brightness)
    const fillColor =
        nodeEmbeddingColoring && meanKeywordScore ? keywordColorScale(meanKeywordScore, meanEmbeddingColor) : '#fff';

    const strokeColor =
        treeGraphNode.parentId === undefined
            ? 'var(--bs-gray-800)'
            : nodeProbabilityColoring
            ? nodeProbabilityColorScale(treeGraphNode.nodeProbability)
            : meanEmbeddingColor;

    return (
        <rect
            rx={5}
            x={x}
            y={y}
            width={width}
            height={height}
            style={{
                stroke: strokeColor,
                fill: fillColor,
                strokeWidth: subTreeGraphNode ? 3 : 2,
                filter: 'drop-shadow( 2px 2px 2px var(--bs-gray))',
            }}
        />
    );
};

export default NodeBox;
