import {
  $getNodeByKey,
  $getRoot,
  $isElementNode,
  ElementNode,
  LexicalEditor,
  LexicalNode,
} from 'lexical'

import { Point } from './point'
import { Rect } from './rect'

export const Downward = 1
export const Upward = -1
export const Indeterminate = 0
const SPACE = 32

function _getElementNodeKeys(element: ElementNode, result: Set<string>, filterNode: FilterNode) {
  for (const child of element.getChildren()) {
    // Add the children keys first so we can select the deepest element.
    if (filterNode(child)) {
      result.add(child.getKey())
    }
    if ($isElementNode(child)) {
      _getElementNodeKeys(child, result, filterNode)
    }
  }
}

function getElementNodeKeys(editor: LexicalEditor, filterNode: FilterNode): string[] {
  const result = new Set<string>()
  return editor.getEditorState().read(() => {
    _getElementNodeKeys($getRoot(), result, filterNode)
    return Array.from(result)
  })
}

function getCurrentIndex(prevIndex: number, keysLength: number): number {
  if (keysLength === 0) return Infinity
  if (prevIndex >= 0 && prevIndex < keysLength) return prevIndex
  return Math.floor(keysLength / 2)
}

type FilterNode = (node: LexicalNode) => boolean
export interface GetResizableElementOpts {
  filterNode: FilterNode
  prevIndex: number
  offset?: number
}
export function getResizableElement(
  editor: LexicalEditor,
  event: MouseEvent,
  { filterNode, offset = 0, ...opts }: GetResizableElementOpts,
) {
  const elementNodeKeys = getElementNodeKeys(editor, filterNode)

  let blockElement: HTMLElement | null = null
  let blockNode: ElementNode | null = null

  editor.getEditorState().read(() => {
    let index = getCurrentIndex(opts.prevIndex, elementNodeKeys.length)
    let direction = Indeterminate

    while (index >= 0 && index < elementNodeKeys.length) {
      const key = elementNodeKeys[index]
      const element = editor.getElementByKey(key)
      if (element === null) break
      const point = new Point(event.x, event.y)
      const domRect = Rect.fromDOM(element)
      const { marginTop, marginBottom, marginLeft, marginRight } = window.getComputedStyle(element)

      const rect = domRect.generateNewRect({
        bottom: domRect.bottom + parseFloat(marginBottom),
        left: domRect.left - parseFloat(marginLeft) - SPACE - offset,
        right: domRect.right + parseFloat(marginRight) + SPACE + offset,
        top: domRect.top - parseFloat(marginTop),
      })

      const {
        result,
        reason: { isOnTopSide, isOnBottomSide, isOnLeftSide, isOnRightSide },
      } = rect.contains(point)

      if (result) {
        blockElement = element
        blockNode = $getNodeByKey(key) as ElementNode
        opts.prevIndex = index
        break
      }

      if (direction === Indeterminate) {
        if (isOnTopSide || isOnLeftSide) {
          direction = Upward
        } else if (isOnBottomSide || isOnRightSide) {
          direction = Downward
        } else {
          direction = Infinity
        }
      }

      index += direction
    }
  })

  return blockElement ? { blockElement, blockNode } : null
}
