import React, {
  createContext,
  useCallback,
  useContext,
  useEffect,
  useRef,
  useState,
} from "react";
import { WEBSOCKET_URL } from "source/envConstants";
import { setLeaseId } from "source/redux/matrix";
import { useSelector } from "react-redux";
import logger from "source/utils/logger";
import { getUser } from "source/redux/user";
import { PartialPayload } from "./types";
import { useQueryMatrixMetadata } from "source/api/matrix/useQueryReports";
import {
  MatrixChatWebsocketEvent,
  MCAskToolIncomingResponsePayload,
  MCGetToolIncomingResponsePayload,
  ReportShouldClearPayload,
  ReportUserMessagePayload,
  ReportSaveReportMessagePayload,
  ReportWebsocketEvent,
  ReportWebsocketResponse,
  ReportToggleReadOnlyPayload,
  ToolReadyPayload,
} from "../types/websocket.types";
import { setAlert } from "source/redux/ui";
import { omit } from "lodash";
import { v4 as uuidv4 } from "uuid";
import { ReportsGridContext } from "./ReportsGridContext";
import { useRenameReport } from "source/hooks/matrix/useRenameReport";
import useWebSocket, { ReadyState } from "react-use-websocket";
import { getCurrentOrg } from "source/redux/organization";
import { getAuthToken, setAuthToken } from "source/auth/localStorage";
import fetchAccessToken from "source/auth/fetchAccessToken";
import {
  handleCachedCellsEvent,
  handleGatherCompleteEvent,
  handleLoadingCellsEvent,
  handleProgressUpdate,
  handleToolCompleteEvent,
  handleToolReadyEvent,
  initializeLeaseId,
} from "source/thunks/matrix/websocket";
import {
  handleLoadHistoryMessage,
  handleMCAskToolEvent,
  handleMCCitationInfoEvent,
  handleMCColumnConversionEvent,
  handleMCFastBuildDocIds,
  handleMCGetToolEvent,
  handleMCMessage,
  handleMCSizeError,
} from "source/thunks/matrix-chat/websocket";
import { useAppDispatch } from "source/redux";
import { InvokeRestFunc, useSocketProtocol } from "./useSocketProtocol";
import { isAnswerToolType } from "../../../utils/matrix/tools";
import { handleApplyMatrixChatGridFilter } from "source/thunks/matrix-chat/matrixChat";
import { useLogPosthog } from "../../../hooks/tracking/usePosthogTracking";
import { useActiveMatrixId } from "source/hooks/matrix/useActiveMatrixId";

const baseURL = `${WEBSOCKET_URL}/sheets/ws`;

export type ReportPayloadContext =
  | ReportUserMessagePayload
  | ReportToggleReadOnlyPayload
  | ReportShouldClearPayload
  | PartialPayload
  | ReportSaveReportMessagePayload;

type ResetWebsocketParams = {
  keepSessionId?: boolean;
  keepLeaseId?: boolean;
};

type Contextable = {
  isConnected: boolean;
  resetWebsocket: (params: ResetWebsocketParams) => void;
  sendMessage: (payload: ReportPayloadContext, event: string) => void;
  invokeRest: InvokeRestFunc;
};

export const ReportsWebsocketContext = createContext<Contextable>(
  {} as Contextable
);

// Known websocket codes
const CLOSE_NORMAL_CODE = 1000;
const CLOSE_INVALID_AUTH_CODE = 4003;
const CLOSE_REFRESH_CONNECTION = 4444;

// Websocket Options
const RETRY_TIMES = 10;
const CODES_TO_STOP_RETRYING = [CLOSE_NORMAL_CODE, CLOSE_INVALID_AUTH_CODE];

type Props = { children: React.ReactNode };

const ReportsWebsocketProvider = ({ children }: Props) => {
  const matrixId = useActiveMatrixId();

  const { gridApi, refreshServerSide } = useContext(ReportsGridContext);

  // Hack to allow reconnection: https://github.com/robtaussig/react-use-websocket/issues/217#issuecomment-1937461277
  const [shouldConnectOverride, setShouldConnectOverride] = useState(true);

  // On the initial connection to the websocket (brand new page load)
  // we want to fetch history, this flag determines whether this is an initial page load
  const shouldFetchHistory = useRef<boolean | undefined>(true);
  const gridApiRef = useRef(gridApi);
  const connectionLostRef = useRef(false);
  const sessionId = useRef<string | undefined>(undefined);
  const refreshServerSideRef = useRef(refreshServerSide);

  // For tracking time to cell
  const requestTimerMap = useRef<{
    [p: string]: { start: Date; cellIds: Set<string> };
  }>({});

  const dispatch = useAppDispatch();

  const orgId = useSelector(getCurrentOrg)?.id;
  const userEmail = useSelector(getUser)?.email;

  const { generateReportTitle } = useRenameReport();

  const metadataResponse = useQueryMatrixMetadata(matrixId ?? "");

  const shouldConnect =
    shouldConnectOverride &&
    !!matrixId &&
    metadataResponse.isSuccess &&
    !metadataResponse.data.expired_at;

  const { logPosthog } = useLogPosthog();

  useEffect(() => {
    gridApiRef.current = gridApi;
  }, [gridApi]);

  // Every time the page navs (matrix ID changes), we want to fetch history on first websocket connection
  useEffect(() => {
    shouldFetchHistory.current = true;
  }, [matrixId]);

  useEffect(() => {
    refreshServerSideRef.current = refreshServerSide;
  }, [refreshServerSide]);

  useEffect(() => {
    logger.info("Resetting requestTimerMap due to page change");
    requestTimerMap.current = {};
  }, [matrixId]);

  const monitorSentRequest = useCallback((requestId?: string) => {
    const timerMap = requestTimerMap.current;

    if (!requestId) {
      return;
    }

    logger.info("Capturing request_id for request", {
      request_id: requestId,
    });

    requestTimerMap.current = {
      ...timerMap,
      [requestId]: {
        start: new Date(),
        cellIds: new Set<string>(),
      },
    };

    // if there are more than 100 request ids in the map, remove the oldest one
    if (Object.keys(timerMap).length > 100) {
      // get oldest key of oldest date value in map
      const oldestKey = Object.entries(requestTimerMap.current)
        .map<[string, Date]>(([key, value]) => [key, value.start])
        .reduce((a, b) => (a[1] < b[1] ? a : b), ["", new Date()])[0];
      // delete the oldest key from the map
      requestTimerMap.current = omit(requestTimerMap.current, oldestKey);
    }
  }, []);

  const monitorRecievedRequest = useCallback((payload?: ToolReadyPayload) => {
    const timerMap = requestTimerMap.current;
    const requestId = payload?.request_id;
    const now = new Date();

    if (!requestId) {
      return;
    }

    const requestSentData = timerMap[requestId];

    if (!requestSentData) {
      return;
    }

    const cells = payload?.cells;
    if (!cells) {
      return;
    }

    const timeDiff = (now.getTime() - requestSentData.start.getTime()) / 1000;

    const newMatchingCells = cells
      .filter((cell) => isAnswerToolType(cell.tool))
      .filter((cell) => !requestSentData.cellIds.has(cell.id));

    if (!requestSentData.cellIds.size && newMatchingCells.length) {
      logger.info("First response to tool request received", {
        requestId: requestId,
        startTime: requestSentData.start,
        responseTime: now,
        timeElapsed: timeDiff,
        metricVersion: 2,
      });
    }

    newMatchingCells.forEach(() =>
      logger.info("Cell response received", {
        requestId: requestId,
        startTime: requestSentData.start,
        responseTime: now,
        timeElapsed: timeDiff,
        metricVersion: 2,
      })
    );

    requestTimerMap.current[requestId] = {
      start: requestSentData.start,
      cellIds: new Set([
        ...Array.from(requestSentData.cellIds),
        ...newMatchingCells.map((cell) => cell.id),
      ]),
    };
  }, []);

  const onWebsocketMessage = useCallback(
    (message: MessageEvent<ReportWebsocketResponse>) => {
      const responseData = JSON.parse(
        message.data as any
      ) as ReportWebsocketResponse;
      switch (responseData.event) {
        case ReportWebsocketEvent.GATHER_COMPLETE:
          dispatch(
            handleGatherCompleteEvent(responseData.payload, gridApiRef.current)
          );
          break;
        case ReportWebsocketEvent.REST:
          handleRestResponse(responseData);
          break;
        case ReportWebsocketEvent.TOOL_COMPLETE:
          dispatch(
            handleToolCompleteEvent(
              responseData.payload,
              gridApiRef.current,
              generateReportTitle,
              refreshServerSideRef.current
            )
          );
          break;
        case ReportWebsocketEvent.TOOL_READY:
          dispatch(
            handleToolReadyEvent(
              responseData.payload,
              gridApiRef.current,
              logPosthog
            )
          );
          monitorRecievedRequest(responseData.payload);
          break;
        case ReportWebsocketEvent.PROGRESS_UPDATE:
          dispatch(handleProgressUpdate(responseData.payload));
          break;
        case ReportWebsocketEvent.LOADING_CELLS:
          dispatch(
            handleLoadingCellsEvent(responseData.payload, gridApiRef.current)
          );
          break;
        case ReportWebsocketEvent.CACHED_CELLS:
          dispatch(
            handleCachedCellsEvent(responseData.payload, gridApiRef.current)
          );
          break;
        case ReportWebsocketEvent.PRE_COMPUTE_COMPLETE:
          break;
        case MatrixChatWebsocketEvent.MC_MESSAGE:
          dispatch(handleMCMessage(responseData.payload));
          break;
        case MatrixChatWebsocketEvent.MC_ASK_TOOL_INCOMING:
          monitorSentRequest(
            (responseData.payload as MCAskToolIncomingResponsePayload)
              ?.request_id
          );
          dispatch(
            handleMCAskToolEvent(responseData.payload, gridApiRef.current)
          );
          break;
        case MatrixChatWebsocketEvent.MC_COL_TYPE_CONVERSION:
          dispatch(handleMCColumnConversionEvent(responseData.payload));
          break;
        case MatrixChatWebsocketEvent.MC_GET_TOOL_INCOMING:
          monitorSentRequest(
            (responseData.payload as MCGetToolIncomingResponsePayload)
              ?.request_id
          );
          dispatch(
            handleMCGetToolEvent(
              responseData.payload,
              matrixId,
              refreshServerSideRef.current
            )
          );
          break;
        case MatrixChatWebsocketEvent.MC_CITATION_INFO:
          dispatch(handleMCCitationInfoEvent(responseData.payload));
          break;
        case MatrixChatWebsocketEvent.MC_SIZE_ERROR:
          dispatch(handleMCSizeError());
          break;
        case MatrixChatWebsocketEvent.MC_FAST_BUILD_DOC_IDS:
          dispatch(
            handleMCFastBuildDocIds(
              responseData.payload,
              matrixId,
              refreshServerSideRef.current
            )
          );
          break;
        case MatrixChatWebsocketEvent.MC_COL_FILTER: {
          dispatch(
            handleApplyMatrixChatGridFilter({
              gridApi: gridApiRef.current,
              payload: responseData.payload,
            })
          );
          break;
        }
        case MatrixChatWebsocketEvent.MC_LOAD_CHAT_HISTORY:
          dispatch(handleLoadHistoryMessage(responseData.payload));
          break;
        default:
          logger.error("Invalid websocket event type received", {
            response: responseData,
          });
          return;
      }
    },
    [
      matrixId,
      dispatch,
      logPosthog,
      refreshServerSide,
      generateReportTitle,
      monitorSentRequest,
      monitorRecievedRequest,
    ]
  );
  // We have specifically NOT included handleRestResponse as a dep
  // handleRestResponse has NO stateful variables, all ref based, so we are ok here

  useEffect(() => {
    // If session id changes, reset the websocket
    if (
      metadataResponse.isSuccess &&
      metadataResponse.data?.sheet_session_id !== sessionId.current
    ) {
      sessionId.current = metadataResponse.data?.sheet_session_id ?? uuidv4();
    }
  }, [metadataResponse.isSuccess, metadataResponse.data?.sheet_session_id]);

  const { readyState, sendJsonMessage, getWebSocket } = useWebSocket(
    `${baseURL}?sheet_id=${matrixId}`,
    {
      // Retry Config
      retryOnError: true,
      reconnectAttempts: RETRY_TIMES,
      // Exponential backoff
      reconnectInterval: (attemptNumber) => Math.pow(attemptNumber, 2) * 1000,
      // Only reconnect on irregular codes
      shouldReconnect: (ev) => {
        return !CODES_TO_STOP_RETRYING.includes(ev.code);
      },
      // Once the reconnect attempts have failed
      onReconnectStop: () => {
        logger.error(
          `Websocket client failed to reconnect after ${RETRY_TIMES} attempts`,
          {
            session_id: sessionId.current,
            user: userEmail,
            org: orgId,
          }
        );
        // Let the user know that the websocket is disconnected and to refresh
        dispatch(setAlert({ alert: "websocketDisconnected" }));

        // Track that we lost connection
        connectionLostRef.current = true;
        setShouldConnectOverride(false);
      },
      // Unfortunately, we need to do this inline function to access
      // the websocket variables in this closure.
      onOpen: async () => {
        // Grab existing sessionId or create a new one.
        const newSessionId = sessionId.current ?? uuidv4();

        // Initialize a new lease id if one isn't set
        const newLeaseId = dispatch(initializeLeaseId());

        // Get or fetch access token
        let accessToken = getAuthToken();
        if (!accessToken) {
          const response = await fetchAccessToken();
          accessToken = response.accessToken;
          setAuthToken(accessToken);
        }

        // Send auth handshake
        const additionalInfo = shouldFetchHistory.current
          ? { fetchHistory: true }
          : {};
        sendJsonMessage({
          session_id: newSessionId,
          lease_id: newLeaseId,
          access_token: accessToken,
          ...additionalInfo,
        });
        sessionId.current = newSessionId;

        // Initial connection is established, no reason to fetch history again
        shouldFetchHistory.current = false;

        logger.info("Opened new websocket session", {
          session_id: newSessionId,
          user: userEmail,
          org: orgId,
        });

        // If we previously lost connection then clear the disconnected websocket alert and refresh the grid
        if (connectionLostRef.current === true) {
          connectionLostRef.current = false;

          dispatch(
            setAlert({
              alert: null,
            })
          );

          refreshServerSideRef.current({ purge: true });
        }

        // TODO: Ammar is this needed?
        // if (!activeReportRef.current || !metadataResponse.isSuccess) return;
        // Fetch Report
        // fetchReport({
        //   reportId: activeReportRef.current?.id,
        //   role: metadataResponse.data.role,
        //   buildStatus: metadataResponse.data.build_status,
        //   matrixName: metadataResponse.data.name,
        // });
      },
      // Message handler
      onMessage: onWebsocketMessage,
      // Dont set message state, we dont use it
      filter: () => false,
      onError(event) {
        logger.error("Socket encountered error, closing socket", {
          session_id: sessionId.current,
          user: userEmail,
          org: orgId,
          event,
        });
      },
      onClose: (event: CloseEvent) => {
        if (CODES_TO_STOP_RETRYING.includes(event.code)) {
          // Not going to reconnect
          dispatch(setAlert({ alert: "websocketDisconnected" }));
        }
        // Websocket connection is closed
        logger.info("Socket closed.", {
          session_id: sessionId.current,
          user: userEmail,
          org: orgId,
          event,
        });
      },
    },
    shouldConnect
  );

  // Flag for determining if the websocket is connected or not.
  const isConnected = readyState === ReadyState.OPEN;
  const { handleResponse: handleRestResponse, invokeRest } = useSocketProtocol(
    sendJsonMessage,
    isConnected
  );

  const sendMatrixMessage = useCallback(
    (payload: ReportPayloadContext, event?: string) => {
      // Sheets takes event and payload, chat only takes the payload
      const formattedPayload = event ? { event, payload } : payload;

      logger.info("Sending message to socket", {
        event: event ?? null,
        resource: "client",
        user: userEmail,
        org: orgId,
        sessionId: sessionId.current,
      });

      sendJsonMessage(formattedPayload);
    },
    [orgId, userEmail, sendJsonMessage]
  );

  /**
   * Message interceptor to store request_ids for time to cell tracking.
   * @param payload
   * @param event
   */
  const sendMessageInterceptor = useCallback(
    (payload: ReportPayloadContext, event: string) => {
      monitorSentRequest((payload as PartialPayload)?.meta?.request_id);
      sendMatrixMessage(payload, event);
    },
    [monitorSentRequest, sendMatrixMessage]
  );

  /**
   * Grab the existing websocket and close normal.
   * @param keepSessionId Determines whether to clear the sessionId or not.
   * @param keepLeaseId Determines whether to clear the leaseId or not.
   */
  const resetWebsocket = useCallback(
    ({ keepSessionId, keepLeaseId }: ResetWebsocketParams) => {
      if (!keepLeaseId) dispatch(setLeaseId(null));
      if (!keepSessionId) {
        sessionId.current = undefined;
      }
      getWebSocket()?.close(CLOSE_REFRESH_CONNECTION);
    },
    [dispatch, getWebSocket]
  );

  // If the websocket is closed (i.e. inactive browser tab) and the page is refocused, attempt to reconnect
  useEffect(() => {
    if (window.addEventListener && readyState === ReadyState.CLOSED) {
      const listener = () => {
        setShouldConnectOverride(true);

        logger.info(
          "Matrix page re-focused after websocket was disconnected. Attempting to reconnect..."
        );
      };

      // Listen to visibilitychange
      window.addEventListener("visibilitychange", listener, false);

      return () => {
        // Be sure to unsubscribe if a new handler is set
        window.removeEventListener("visibilitychange", listener);
      };
    }
  }, [readyState]);

  return (
    <ReportsWebsocketContext.Provider
      value={{
        isConnected,
        resetWebsocket,
        sendMessage: sendMessageInterceptor,
        invokeRest,
      }}
    >
      {children}
    </ReportsWebsocketContext.Provider>
  );
};

export default ReportsWebsocketProvider;
