import { GridPaginationModel } from "@mui/x-data-grid-premium";
import { CustomDatasetList } from "data/models";
import { useEffect, useState } from "react";
import {
  QueryClient,
  UseQueryResult,
  useQuery,
  useQueryClient,
} from "react-query";

type QueryFn<T> = (limit: number, offset: number) => Promise<T>;

function usePaginatedTableQuery<T>(
  queryFn: QueryFn<T>,
  queryKey: string,
  userKey: string,
  filterKey: object,
  queryFnOpts: object = {}
) {
  const [paginationModel, setPaginationModel] = useState({
    page: 0,
    pageSize: 10,
  });

  const offset = paginationModel.page * paginationModel.pageSize;

  const baseQueryId = [queryKey, userKey, filterKey];
  const thisPageQueryId = [...baseQueryId, paginationModel.pageSize, offset];
  const queryClient = useQueryClient();
  const searchQuery: UseQueryResult<CustomDatasetList> = useQuery(
    thisPageQueryId,
    () => queryFn(paginationModel.pageSize, offset),
    queryFnOpts
  );

  const prefetch = (
    queryClient: QueryClient,
    limit: number,
    offset: number
  ) => {
    queryClient.prefetchQuery(
      [...baseQueryId, limit, offset],
      () => queryFn(limit, offset),
      {
        staleTime: Infinity,
      }
    );
  };

  const handleChangePaginationModel = (newModel: GridPaginationModel) => {
    const { page, pageSize } = newModel;

    if (pageSize !== paginationModel.pageSize) {
      const recalculatedPage = Math.floor(offset / pageSize);
      setPaginationModel({ page: recalculatedPage, pageSize });
    } else {
      setPaginationModel(newModel);
    }
  };

  useEffect(() => {
    // If thge user changes the filters, reset the page back to 0
    setPaginationModel({
      page: 0,
      pageSize: paginationModel.pageSize,
    });
  }, [filterKey]);

  useEffect(() => {
    // Gets the next 2 offsets and prefetches them so we have 2 additional pages cached at any time
    if (!searchQuery.data) {
      return;
    }
    const nextUrl = searchQuery.data.links?.next;
    if (!nextUrl) {
      return;
    }
    const paramsString = nextUrl.split("?")[1];
    const urlParams = new URLSearchParams(paramsString);
    const nextOffset = parseInt(urlParams.get("offset") ?? "0");
    const nextNextOffset = nextOffset + paginationModel.pageSize;
    prefetch(queryClient, paginationModel.pageSize, nextOffset);
    prefetch(queryClient, paginationModel.pageSize, nextNextOffset);
  }, [searchQuery.data]);

  return {
    paginationModel,
    handleChangePaginationModel,
    searchQuery,
    queryKey: thisPageQueryId,
  };
}

export default usePaginatedTableQuery;
