import { SendJsonMessage } from "react-use-websocket/dist/lib/types";
import { v4 as uuidv4 } from "uuid";
import {
  SocketRequest,
  JSONType,
  SocketResponse,
  SocketResponseError,
  RequestMap,
  Version,
  SocketStatusCode,
  RestSocketResponsePayload,
} from "./protocol.types";
import { useCallback, useEffect, useRef } from "react";
import logger from "source/utils/logger";
import { ReportWebsocketEvent } from "../types/websocket.types";

export const GLOBAL_REQUEST_TIMEOUT_MS = 30000;
export const MAX_CONCURRENT_REQUESTS = 100;

export type InvokeRestFunc = <
  RequestPayload extends JSONType,
  ResponsePayload extends JSONType,
>(
  path: string,
  payload?: RequestPayload,
  signal?: AbortSignal,
  timeout?: number
) => Promise<RestSocketResponsePayload<ResponsePayload>>;

export const useSocketProtocol = (
  sendJSONMessage: SendJsonMessage,
  isWebsocketConnected: boolean
) => {
  const activeRequests = useRef<RequestMap>({});

  const clearAllRequests = useCallback(() => {
    const currentMap = activeRequests.current;
    activeRequests.current = {};
    Object.values(currentMap).forEach((activeRequest) => {
      activeRequest.reject(
        new SocketResponseError({
          path: activeRequest.path,
          request_id: activeRequest.request_id,
          status: SocketStatusCode.CLIENT_CLOSED_REQUEST,
          version: Version.v1,
        })
      );
      window.clearTimeout(activeRequest.timeoutId);
    });
  }, []);

  // Cancel all requests on disconnect or cleanup
  useEffect(() => {
    if (!isWebsocketConnected) clearAllRequests();
    return () => clearAllRequests();
  }, [isWebsocketConnected, clearAllRequests]);

  const invokeRest: InvokeRestFunc = useCallback(
    async <RequestBody extends JSONType, ResponseData extends JSONType>(
      path: string,
      data?: RequestBody,
      signal?: AbortSignal,
      timeout: number = GLOBAL_REQUEST_TIMEOUT_MS
    ): Promise<RestSocketResponsePayload<ResponseData>> => {
      // Mint a request ID, this will be used to identify requests when they come back
      const request_id = uuidv4();

      // Send request over websocket
      sendJSONMessage<SocketRequest<RequestBody>>({
        event: ReportWebsocketEvent.REST,
        payload: { version: Version.v1, request_id, path: path, data: data },
      });

      // Log for observability.
      logger.info(`Sending request for ${path}`, {
        request_id: request_id,
        path: path,
        data: data,
      });

      // Return awaitable to user
      const promise = new Promise<RestSocketResponsePayload<ResponseData>>(
        (resolve, reject) => {
          // Create a setTimeout to auto reject this promise.
          const timeoutId = window.setTimeout(() => {
            reject(
              new SocketResponseError<ResponseData>({
                request_id: request_id,
                path: path,
                status: SocketStatusCode.TIMEOUT,
                version: Version.v1,
              })
            );
            delete activeRequests.current[request_id];
          }, timeout);

          if (signal) {
            signal.addEventListener("abort", () => {
              reject(
                new SocketResponseError({
                  request_id: request_id,
                  path: path,
                  version: Version.v1,
                  status: SocketStatusCode.CLIENT_CLOSED_REQUEST,
                })
              );
              delete activeRequests.current[request_id];
            });
          }

          activeRequests.current[request_id] = {
            path,
            request_id,
            resolve,
            reject,
            signal,
            timeoutId,
            timestamp: Date.now(),
          };

          if (
            Object.keys(activeRequests.current).length > MAX_CONCURRENT_REQUESTS
          ) {
            logger.warn("Too many concurrent requests");
          }

          if (!isWebsocketConnected)
            reject(
              new SocketResponseError<ResponseData>({
                request_id: request_id,
                path: path,
                status: SocketStatusCode.SERVICE_UNAVAILIBLE,
                version: Version.v1,
              })
            );
          return;
        }
      );
      // Save awaitable in map
      return promise;
    },
    [isWebsocketConnected, sendJSONMessage]
  );

  const handleResponse = useCallback((response: SocketResponse) => {
    const request_id = response.payload.request_id;
    const activeRequest = activeRequests.current[request_id];

    if (!activeRequest) {
      logger.error("Error looking up active request.", {
        payload: response.payload,
      });
      return;
    }

    if (response.payload.status === 200) {
      activeRequest.resolve(response.payload);
    } else {
      activeRequest.reject(new SocketResponseError(response.payload));
    }
    window.clearTimeout(activeRequest.timeoutId);
    delete activeRequests.current[request_id];
  }, []);

  return { invokeRest, handleResponse };
};
