import { useTheme, useMediaQuery, Breakpoint } from "@mui/material"

export enum ScreenSize {
  xs = 0,
  sm = 1,
  md = 2,
  lg = 3,
  xl = 5,
  xxl = 6,
}

export const useScreenSize = (breakPoint: Breakpoint) => {
  const theme = useTheme()
  return useMediaQuery(theme.breakpoints.up(breakPoint))
}

type ScreenSizeValue = string | number | boolean | undefined
export type ScreenSizeProps<T extends ScreenSizeValue> = Partial<Record<Breakpoint, T>> | T

export const useScreenSizeValue = <T extends ScreenSizeValue>(
  value: ScreenSizeProps<T> | undefined
): T | undefined => {
  const screenSize = useCurrentBreakpoint()

  return getScreenSizeValue(value, screenSize)
}

export function getScreenSizeValue<T extends ScreenSizeValue>(
  value: ScreenSizeProps<T> | undefined,
  screenSize: Breakpoint
): T | undefined {
  if (!value) return undefined

  if (typeof value === "object") {
    const screenSizeValue = value[screenSize]
    // if(screenSizeValue === undefined) return the closest value, prioritising the smaller breakpoints. else return the default value
    if (screenSizeValue !== undefined) {
      return screenSizeValue
    }

    let closestValue: T | undefined = undefined
    for (const keyStr in value) {
      const key = keyStr as Breakpoint
      if (ScreenSize[key] < ScreenSize[screenSize]) {
        closestValue = value[key]
      }
    }

    return closestValue ?? undefined
  }

  return value
}

export function useGetScreenSizeValueFunction<T extends ScreenSizeValue>(): (
  value: ScreenSizeProps<T> | undefined
) => T | undefined {
  const screenSize = useCurrentBreakpoint()

  return (value: ScreenSizeProps<T> | undefined) => getScreenSizeValue(value, screenSize)
}

function useCurrentBreakpoint(): Breakpoint {
  const theme = useTheme()
  const matchesXS = useMediaQuery(theme.breakpoints.only("xs"))
  const matchesSM = useMediaQuery(theme.breakpoints.only("sm"))
  const matchesMD = useMediaQuery(theme.breakpoints.only("md"))
  const matchesLG = useMediaQuery(theme.breakpoints.only("lg"))
  const matchesXL = useMediaQuery(theme.breakpoints.only("xl"))

  if (matchesXS) return "xs"
  if (matchesSM) return "sm"
  if (matchesMD) return "md"
  if (matchesLG) return "lg"
  if (matchesXL) return "xl"

  return "xl"
}
