import {
  RangeSelection,
  ElementNode,
  $getSelection,
  $isRangeSelection,
  $isTextNode,
  $isElementNode,
  $isRootNode,
  LexicalNode,
  GridSelection,
  NodeSelection,
} from 'lexical'

import { getElementNodesInSelection } from './getElementNodesInSelection'

export function getSelectedElement(selection: RangeSelection): ElementNode | null {
  const elements = getElementNodesInSelection(selection)
  return elements.size === 1 ? elements.values().next().value : null
}

export function $getSelectedBlockElement(
  selection: RangeSelection | NodeSelection | GridSelection | null = $getSelection(),
): ElementNode | null {
  if ($isRangeSelection(selection)) {
    const anchor = selection.anchor.getNode()
    let parent: LexicalNode | null = anchor
    while (
      (parent && !$isElementNode(parent) && !$isRootNode(parent)) ||
      $isTextNode(parent) ||
      (parent && parent.isInline && parent.isInline())
    ) {
      parent = parent.getParent()
    }
    return parent
  } else {
    return null
  }
}
