import { useCallback, useMemo, useRef } from 'react'

import isEqual from 'fast-deep-equal'
import { useSyncExternalStore } from 'use-sync-external-store/shim'
import * as Y from 'yjs'

const USER_TRANSACTION_ORIGIN = 'user'
const USER_SKIP_HISTORY_TRANSACTION_ORIGIN = 'user_skip_history'
const INTERNAL_TRANSACTION_ORIGIN = 'internal'

/**
 * Abstraction layer over Yjs, so we don't repeat handling common functionality like undo/redo.
 * @param initialData
 */
export function useYjsState<T extends Record<string, unknown>>(initialData: T) {
    type InternalType = Y.Map<T>

    const initialDataRef = useRef(initialData)

    const undoManagerRef = useRef<Y.UndoManager>()

    const setData = useCallback((doc: Y.Doc, data: T) => {
        const yData = toYType(data) as InternalType

        doc.transact(() => {
            initialDataRef.current = data
            doc.getMap<InternalType>().set('data', yData)
        }, INTERNAL_TRANSACTION_ORIGIN)

        const yDataToWatch = doc.getMap<InternalType>().get('data')!
        undoManagerRef.current = new Y.UndoManager(yDataToWatch, {
            trackedOrigins: new Set([USER_TRANSACTION_ORIGIN]),
        })

        return yData
    }, [])

    const yDoc = useMemo(() => {
        const yDoc = new Y.Doc({ gc: true })

        const initialData = initialDataRef.current
        setData(yDoc, initialData)

        return yDoc
    }, [setData])

    const prevDataRef = useRef<T>()
    const data: T = useSyncExternalStore<T>(
        (callback) => {
            const yRoot = yDoc.getMap<InternalType>()

            yRoot.observeDeep(callback)
            return () => yRoot.unobserveDeep(callback)
        },
        () => {
            // Keep the previous data if it hasn't changed,
            // to prevent unnecessary re-renders.
            const yData = yDoc.getMap<InternalType>().get('data')!

            const data = yData.toJSON() as T
            if (isEqual(prevDataRef.current, data)) {
                return prevDataRef.current!
            } else {
                prevDataRef.current = data
                return prevDataRef.current!
            }
        },
        () => {
            const yData = yDoc.getMap<InternalType>().get('data')!
            return yData.toJSON() as T
        }
    )

    const applyTransaction = useCallback(
        (
            tr: (data: Y.Map<any>) => void,
            options: {
                skipHistory?: boolean
            } = {}
        ) => {
            const { skipHistory } = options

            const trOrigin = skipHistory
                ? USER_SKIP_HISTORY_TRANSACTION_ORIGIN
                : USER_TRANSACTION_ORIGIN

            yDoc.transact(() => {
                const yData = yDoc.getMap<InternalType>()
                tr(yData.get('data') as any)
            }, trOrigin)
        },
        [yDoc]
    )

    const replaceValue = useCallback(
        (data: T) => {
            return setData(yDoc, data)
        },
        [setData, yDoc]
    )

    const undo = useCallback(() => {
        const undoManager = undoManagerRef.current
        undoManager?.undo()
    }, [])

    const redo = useCallback(() => {
        const undoManager = undoManagerRef.current
        undoManager?.redo()
    }, [])

    return useMemo(
        () => ({
            yDoc,
            data,
            applyTransaction,
            replaceValue,
            undo,
            redo,
        }),
        [applyTransaction, data, redo, replaceValue, undo, yDoc]
    )
}

export function toYType<Value = unknown>(val: Value): Y.AbstractType<any> {
    if (Array.isArray(val)) {
        const yArray = new Y.Array()

        const yValues = val.map((v) => toYType(v))
        yArray.push(yValues)

        return yArray
    } else if (Object.prototype.toString.call(val) === '[object Object]') {
        const yMap = new Y.Map()

        Object.entries(val as Object).forEach(([key, v]) => {
            yMap.set(key, toYType(v))
        })

        return yMap
    }

    return val as Y.AbstractType<any>
}
