import { RefreshedTokens } from "../queries";
import { TokenStore } from "./stores";
import { TokenEntry, Tokens } from "./types";

export interface TokenManagerOptions {
  refreshFn: (refreshToken: Tokens["refreshToken"]) => Promise<RefreshedTokens>;
  tokenStore: TokenStore;
  onError?: (error: unknown) => void;
  onRefresh?: (newTokens: RefreshedTokens) => void;
}

export class NoTokensError extends Error {}

export class TokenManager {
  readonly refreshFn: TokenManagerOptions["refreshFn"];
  readonly store: TokenManagerOptions["tokenStore"];
  readonly onError: TokenManagerOptions["onError"];
  readonly onRefresh: TokenManagerOptions["onRefresh"];

  #promise: Promise<void>;
  #isRefreshing: boolean;
  #isInvalid: boolean;

  constructor(options: TokenManagerOptions) {
    this.refreshFn = options.refreshFn;
    this.store = options.tokenStore;
    this.onError = options.onError;
    this.onRefresh = options.onRefresh;

    this.#promise = Promise.resolve();
    this.#isRefreshing = false;
    this.#isInvalid = false;
  }

  #getEnsuredEntry() {
    const maybeEntry = this.store.get();

    if (maybeEntry === null) {
      const error = new NoTokensError();

      this.onError?.(error);
      throw error;
    }

    return maybeEntry;
  }

  #shouldRefresh(entry: TokenEntry): boolean {
    if (this.#isRefreshing) {
      // Don't refresh if currently doing so
      return false;
    }

    if (this.#isInvalid) {
      return true;
    }

    const timeSinceLastRefresh = Date.now() - entry.storedAt;

    // Only start a new refresh if tokens are expired
    return timeSinceLastRefresh >= entry.tokens.expiresIn;
  }

  async #performRefresh(entry: TokenEntry): Promise<void> {
    this.#isInvalid = false;
    this.#isRefreshing = true;

    const { tokens } = entry;

    try {
      const newTokens = await this.refreshFn(tokens.refreshToken);

      this.store.set({ ...tokens, ...newTokens });
      this.onRefresh?.(newTokens);
    } catch (e) {
      this.store.remove();
      this.onError?.(e);
    } finally {
      this.#isRefreshing = false;
    }
  }

  async fetchToken() {
    const entry = this.#getEnsuredEntry();

    if (this.#shouldRefresh(entry)) {
      // Preemptively try to refresh tokens if they're expired
      this.#promise = this.#performRefresh(entry);
    }

    // Pauses execution until fetching is complete
    await this.#promise;

    return this.#getEnsuredEntry().tokens.idToken;
  }

  invalidate(token: Tokens["idToken"]) {
    if (this.#isRefreshing) {
      return;
    }

    const entry = this.#getEnsuredEntry();
    if (entry.tokens.idToken === token) {
      // Only invalidate if the token being used by the caller matches what's
      // in the store. This avoids race conditions if, for example, a refresh
      // happens after a network request starts and the refresh finishes
      // before the network request.
      this.#isInvalid = true;
    }
  }
}
