import { $wrapSelectionInMarkNode, MarkNode, SerializedMarkNode } from '@lexical/mark'
import {
  $applyNodeReplacement,
  $isElementNode,
  DOMConversionFn,
  DOMConversionMap,
  EditorConfig,
  ElementNode,
  LexicalNode,
  RangeSelection,
} from 'lexical'

import { DraggableNode } from './types'

const convertStyledMarkNodeElement: DOMConversionFn<HTMLElement> = (
  element,
  parent,
  preformatted,
) => {
  const classNames = [...element.classList]
  return {
    node: $createStyledMarkNode(classNames.map(cls => `cls-${cls}`)),
  }
}

export interface SerializedStyledMarkNode extends SerializedMarkNode {}

// @ts-ignore importDOM is typed as null on MarkNode.
export class StyledMarkNode extends MarkNode implements DraggableNode {
  static getType(): string {
    return 'mark'
  }

  static clone(node: StyledMarkNode): StyledMarkNode {
    return new StyledMarkNode(Array.from(node.__ids), node.__key)
  }

  static importDOM(): DOMConversionMap {
    return {
      mark(node: Node) {
        return {
          conversion: convertStyledMarkNodeElement,
          priority: 3,
        }
      },
    }
  }

  static importJSON(serializedNode: SerializedMarkNode): StyledMarkNode {
    const node = $createStyledMarkNode(serializedNode.ids)
    node.setFormat(serializedNode.format)
    node.setIndent(serializedNode.indent)
    node.setDirection(serializedNode.direction)
    return node
  }

  exportJSON(): SerializedMarkNode {
    return super.exportJSON()
  }

  canDrag() {
    return false
  }

  insertNewAfter(selection: RangeSelection, restoreSelection = true): null | ElementNode {
    const element = this.getParentOrThrow().insertNewAfter(selection, restoreSelection)
    if ($isElementNode(element)) {
      const markNode = $createStyledMarkNode(this.__ids)
      element.append(markNode)
      return markNode
    }
    return null
  }

  createDOM(config: EditorConfig): HTMLElement {
    const dom = super.createDOM(config)
    const classIds = this.getClasses()
    for (const cls of classIds) {
      dom.classList.add(cls)
    }

    return dom
  }

  updateDOM(prevNode: StyledMarkNode, element: HTMLElement, config: EditorConfig): boolean {
    super.updateDOM(prevNode, element, config)
    const classIds = this.getClasses()
    while (element.classList.length) {
      element.classList.remove(element.classList[0])
    }
    for (const cls of classIds) {
      element.classList.add(cls)
    }

    return false
  }

  getClasses() {
    return this.getIDs()
      .filter(id => id.startsWith('cls-'))
      .map(id => id.replace('cls-', ''))
  }

  deleteClass(cls: string) {
    return this.deleteID(`cls-${cls}`)
  }

  addClass(cls: string) {
    return this.addID(`cls-${cls}`)
  }
}

export function $createStyledMarkNode(ids?: string[]): StyledMarkNode {
  return $applyNodeReplacement(new StyledMarkNode(ids || []))
}

export function $isStyledMarkNode(node: LexicalNode | null): node is StyledMarkNode {
  return node instanceof StyledMarkNode
}

export function $wrapSelectionInStyledMarkNode(
  selection: RangeSelection,
  isBackward: boolean,
  cls: string,
) {
  return $wrapSelectionInMarkNode(selection, isBackward, `cls-${cls}`, ids => {
    const node = $createStyledMarkNode(ids)
    return node
  })
}

export function $getAllMarkNodes(selection: RangeSelection): StyledMarkNode[] {
  const node = selection.anchor.getNode()
  const matchingParents: StyledMarkNode[] = []
  let currentParent = node.getParent()
  while (currentParent) {
    if ($isStyledMarkNode(currentParent)) {
      matchingParents.push(currentParent)
    }
    currentParent = currentParent.getParent()
  }

  const childrenNodes = selection.getNodes()
  const matchingChildren = childrenNodes.filter(node => $isStyledMarkNode(node))

  return [...matchingParents, ...matchingChildren] as StyledMarkNode[]
}
