import React from 'react';

import { TableProps } from '@material-ui/core';

import '@tanstack/react-table';
import {
  ExpandedState,
  RowSelectionState,
  VisibilityState,
  getCoreRowModel,
  getExpandedRowModel,
  getGroupedRowModel,
  getSortedRowModel,
  useReactTable,
} from '@tanstack/react-table';
import _ from 'lodash';

import { ResultVisualizationType } from './Visualizations';
import { ResultColumnType, defaultColumns } from './columns';
import {
  MetricResultTableGrouping,
  ResultData,
  TableMetaOptions,
} from './types';

type MetricResultTableProps = Omit<TableProps, 'children' | 'results'> &
  Partial<TableMetaOptions> & {
    data: ResultData[];
    relativeBounds?: { maxUpper: number; minLower: number };
    grouping?: MetricResultTableGrouping;
    visibleColumns?: string[];
    columns?: ResultColumnType[];
    visualizationType?: ResultVisualizationType;
  };

function getVisibilityState(
  visibleColumns: string[],
  columns: ResultColumnType[],
): VisibilityState {
  // All columns visible by default
  if (visibleColumns.length === 0) {
    return {};
  }

  // TODO: inferred type is missing some keys
  function getIds(column: any) {
    const ids = [];
    if (column.columns) {
      ids.push(...column.columns.flatMap(getIds));
    } else {
      ids.push(column.id || column.accessorKey);
    }
    return ids;
  }
  const hiddenColumns = columns
    .flatMap(getIds)
    .filter(id => !visibleColumns.includes(id))
    .map(c => [c, false]);

  return Object.fromEntries(hiddenColumns);
}

export const getRelativeBounds = (bounds: number[]) => {
  const maxUpper = Math.max(...bounds, 0);
  const minLower = -maxUpper;

  return { maxUpper, minLower };
};

export const useResultTable = ({
  data,
  relativeBounds,
  grouping = [],
  visibleColumns = [],
  visualizationType = 'none',
  columns = defaultColumns,
  ...metaOptions
}: MetricResultTableProps) => {
  const columnVisibility = React.useMemo(() => {
    return getVisibilityState(
      visibleColumns.filter(c =>
        visualizationType === 'none' ? c !== 'Graph' : true,
      ),
      columns,
    );
  }, [visibleColumns, visualizationType]);

  const [expanded, setExpanded] = React.useState<ExpandedState>({});
  const selectionInitialized = React.useRef<boolean>(false);
  const [rowSelection, setRowSelection] = React.useState<RowSelectionState>({});
  const sorting = React.useMemo(() => {
    const newSorting = [];
    if (columns.some(c => c.id === 'timeLabel')) {
      newSorting.push({ id: 'timeLabel', desc: true });
    }
    if (columns.some(c => c.id === 'segment')) {
      newSorting.push({ id: 'segment', desc: false });
    }
    return newSorting;
  }, [columns]);

  const table = useReactTable({
    data,
    columns,
    state: {
      grouping,
      columnVisibility,
      expanded,
      rowSelection,
      sorting,
    },
    meta: {
      ...relativeBounds,
      ...metaOptions,
    },
    enableRowSelection: true,
    onExpandedChange: setExpanded,
    getExpandedRowModel: getExpandedRowModel(),
    getGroupedRowModel: getGroupedRowModel(),
    getSortedRowModel: getSortedRowModel(),
    getCoreRowModel: getCoreRowModel(),
    onRowSelectionChange: setRowSelection,
  });

  React.useEffect(() => {
    // If a visualization is enabled, show all rows by default, otherwise reset
    const initialized = selectionInitialized.current;
    if (visualizationType === 'none') {
      if (initialized) {
        setRowSelection({});
        selectionInitialized.current = false;
      }
    } else {
      if (!initialized) {
        // Used instead of "table.toggleAllRowsSelected" to properly select grouped rows
        table.getGroupedRowModel().rows.forEach(row => row.toggleSelected());
        selectionInitialized.current = true;
      }
    }
  }, [rowSelection, table, visualizationType]);

  React.useEffect(() => {
    // Reset selection if grouping changes
    if (selectionInitialized.current) {
      setRowSelection({});
      selectionInitialized.current = false;
    }
  }, [grouping]);

  return table;
};

export function createBlankResultData(overrides: Partial<ResultData> = {}) {
  return {
    metric: '',
    comparison: '',
    requiredSampleSize: 0,
    currentSampleSize: 0,
    baselineGroupId: '',
    comparedGroupId: '',
    timeLabel: '',
    ...overrides,
  };
}
