import * as d3 from 'd3';
import groupBy from 'lodash/groupBy';
import partition from 'lodash/partition';

import { Class } from '~/api/entities/entity-constants';
import { OptionalU } from '~/declarations/standard';
import { ExploreArticleDirection } from '~/declarations/toggle-api-enums';
import { ChartPOI, TSPoint } from '~/shared/components/chart/Chart';

import { MappedEntity } from '../../hooks/use-entities';

const numSort = (a: number, b: number) => a - b;

const iconGroupStats = (
  date: number,
  icons: ChartPOI[],
  bisect: d3.Bisector<TSPoint<number, number>, number>,
  ts: Array<TSPoint<number, number>>,
  poiSize: number,
  xScale: d3.ScaleLinear<number, number>,
  yScale: d3.ScaleLinear<number, number>
) => {
  const midX = xScale(date);
  const valRange = [midX - poiSize / 2, midX + poiSize / 2].map(xScale.invert);
  const chartBounds = valRange.map(date => bisect.right(ts, date));
  const [bearish, bullish] = partition(
    icons,
    i => i.direction === ExploreArticleDirection.Bearish
  ).map(x => x.length);

  const [minY, maxY] = chartBounds
    .map(idx => ts[Math.min(idx, ts.length - 1)].value)
    .sort(numSort)
    .map(yScale);

  const yRange = yScale.range().slice().sort(numSort);

  const newDomain = yScale.domain().slice();

  const bullishIconHeight = 10 + bullish * poiSize;
  if (maxY + bullishIconHeight > yRange[1]) {
    const desired = Math.min(bullishIconHeight / (yRange[1] - yRange[0]), 0.5);
    const actual = yScale.invert(maxY);
    newDomain[0] = (actual - desired * newDomain[1]) / (1 - desired);
  }

  const bearishHeight = 10 + bearish * poiSize;
  if (minY - bearishHeight < yRange[0]) {
    const desired = Math.max(bullishIconHeight / (yRange[1] - yRange[0]), 0.5);
    const actual = yScale.invert(minY);
    newDomain[1] = (actual - desired * newDomain[0]) / (1 - desired);
  }

  return newDomain;
};

export const extendYAxis = (
  yScale: d3.ScaleLinear<number, number>,
  xScale: d3.ScaleLinear<number, number>,
  ts: Array<TSPoint<number, number>>,
  pois: ChartPOI[],
  poiSize: number
): d3.ScaleLinear<number, number> => {
  const grouped = groupBy(pois, p => p.index);

  const bisect = d3.bisector((ts: TSPoint<number, number>) => ts.index);

  const groupStats = Object.keys(grouped).map((date: string) => {
    return iconGroupStats(
      Number(date),
      grouped[date],
      bisect,
      ts,
      poiSize,
      xScale,
      yScale
    );
  });
  const newDomain = groupStats.reduce((range, stat) => {
    return [Math.min(range[0], stat[0]), Math.max(range[1], stat[1])];
  }, yScale.domain());

  return d3.scaleLinear().range(yScale.range()).domain(newDomain);
};

export const entityMinY = (
  entity: OptionalU<MappedEntity>,
  isPriceSnake = true
) => {
  return isPriceSnake &&
    (entity?.class === Class.ClassEtf || entity?.class === Class.ClassStock)
    ? 0
    : -Infinity;
};

type SelectionOrTransition<
  GElement extends d3.BaseType,
  Datum,
  PElement extends d3.BaseType,
  PDatum,
  AElement extends d3.BaseType
> =
  | d3.Selection<GElement, Datum, PElement, PDatum>
  | d3.Transition<GElement, Datum, AElement, unknown>;

export const maybeAnimated = <
  GElement extends d3.BaseType,
  Datum,
  PElement extends d3.BaseType,
  PDatum,
  AElement extends d3.BaseType
>(
  selection: d3.Selection<GElement, Datum, PElement, PDatum>,
  transition?: d3.Transition<AElement, unknown, null, undefined>
) => {
  type SelOrTransition = SelectionOrTransition<
    GElement,
    Datum,
    PElement,
    PDatum,
    AElement
  >;

  return (fn: (selection: SelOrTransition) => void) => {
    let transitionOrSelection: SelOrTransition = selection;
    try {
      if (transition) {
        transitionOrSelection = transition.selectAll(() => selection.nodes());
      }
    } catch (e) {}

    try {
      transitionOrSelection?.call(fn);
    } catch (e) {
      selection.call(fn);
    }
  };
};

interface GetRightAxisTicksProps {
  yScale: d3.ScaleLinear<number, number>;
  tsScale?: number;
  currency?: string;
  tickCount?: number;
}

export const getRightAxisTicks = ({
  yScale,
  tsScale = 1,
  currency,
  tickCount = 12,
}: GetRightAxisTicksProps) => {
  const domain = yScale.domain().map(x => x / tsScale);
  const range = yScale.range();
  const paddingTopPercent = currency ? 32 / (range[0] - range[1]) : 0;
  const axisHeight = Math.abs(domain[1] - domain[0]);
  const offsetTop = axisHeight * (paddingTopPercent + 0.05);
  //last tick should be at least 5% offset from top
  const ticks = yScale
    .copy()
    .domain(domain)
    .ticks(tickCount)
    .filter(t => t < domain[1] - offsetTop);

  const decimals = d3.precisionFixed(ticks[1] - ticks[0]);

  const format = (n: number) =>
    n.toLocaleString(window.navigator.language, {
      minimumFractionDigits: decimals,
    });

  return {
    ticks,
    labels: ticks.map(t => ({
      index: t * tsScale,
      label: format(t),
    })),
  };
};
