import React, { HTMLAttributes, memo, useCallback, useContext, useState } from 'react';

import { merge } from '../utils';

export interface WizardState<T> {
  next: (data?: Partial<T>) => Partial<T> | Promise<Partial<T>>;
  previous: (data?: Partial<T>) => Partial<T> | Promise<Partial<T>>;
  setStep: (step: number) => void;
  step: number;
  maxSteps: number;
  state: Partial<T>;
}

export function createWizard<T>() {
  const WizardContext = React.createContext<WizardState<T>>({
    step: 0,
    next: () => ({}),
    previous: () => ({}),
    state: {},
    setStep: () => {},
    maxSteps: 0,
  });

  const useWizard = () => {
    return useContext(WizardContext);
  };

  type StepProps = {
    onNext?: (result?: Partial<T>) => Partial<T> | undefined | Promise<Partial<T>>;
    onPrevious?: (result?: Partial<T>) => Partial<T> | undefined | Promise<Partial<T>>;
    children: (wizard: WizardState<T>) => React.ReactNode;
  } & Omit<HTMLAttributes<HTMLDivElement>, 'children'>;

  const Step = memo(({ children, onNext, onPrevious, className, ...rest }: StepProps) => {
    const wizard = useContext(WizardContext);

    const onNextWrapper = useCallback(
      async (data?: Partial<T>) => {
        const newState = await wizard?.next(data);
        return onNext?.(newState) || newState;
      },
      [onNext, wizard]
    );

    const onPreviousWrapper = useCallback(
      async (data?: Partial<T>) => {
        const newState = await wizard?.previous(data);
        return onPrevious?.(newState) || newState;
      },
      [onPrevious, wizard]
    );

    const wrappedWizard = { ...wizard, next: onNextWrapper, previous: onPreviousWrapper };

    return (
      <WizardContext.Provider value={wrappedWizard}>
        <div className={merge('flex flex-col gap-4 bg-white', className)} {...rest}>
          {children(wrappedWizard)}
        </div>
      </WizardContext.Provider>
    );
  });

  type WizardProps = { children: React.ReactNode[] } & Omit<HTMLAttributes<HTMLDivElement>, 'children'>;

  const Wizard = ({ children, className, ...rest }: WizardProps) => {
    const [step, setStep] = useState(0);
    const [state, setState] = useState<Partial<T>>({});

    const next = useCallback(
      (data: Partial<T> | undefined) => {
        setStep(Math.min(step + 1, children?.length));
        const newState = { ...state, ...data };
        setState(newState);
        return newState;
      },
      [children.length, state, step]
    );

    const previous = useCallback(
      (data: Partial<T> | undefined) => {
        setStep(Math.max(step - 1, 0));
        const newState = { ...state, ...data };
        setState(newState);
        return newState;
      },
      [state, step]
    );

    const childrenArray = React.Children.toArray(children);

    const maxSteps = childrenArray.filter((child) => React.isValidElement(child) && child.type === Step).length;

    let count = 0;
    const filteredChildren = childrenArray.filter((child) => {
      if (!React.isValidElement(child)) return false;

      if (child.type === Step) {
        const show = step === count;
        count += 1;
        return show;
      }

      return true;
    });

    return (
      <WizardContext.Provider value={{ step, maxSteps, next, previous, setStep, state }}>
        <div className={merge('flex flex-col rounded-lg bg-white p-6 shadow', className)} {...rest}>
          {filteredChildren}
        </div>
      </WizardContext.Provider>
    );
  };

  Wizard.Step = Step;

  return { useWizard, Wizard };
}
