package org.mule.weave.v2.scope

import org.mule.weave.v2.parser.annotation.InfixNotationFunctionCallAnnotation
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.AstNodeHelper
import org.mule.weave.v2.parser.ast.CommentNode
import org.mule.weave.v2.parser.ast.VirtualAstNode
import org.mule.weave.v2.parser.ast.functions.{ FunctionCallNode, FunctionCallParametersNode }
import org.mule.weave.v2.parser.ast.header.directives.ImportDirective
import org.mule.weave.v2.parser.ast.module.ModuleNode
import org.mule.weave.v2.parser.ast.structure.DocumentNode
import org.mule.weave.v2.parser.ast.structure.NameNode
import org.mule.weave.v2.parser.ast.structure.StringNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode
import org.mule.weave.v2.parser.location.UnknownPosition
import org.mule.weave.v2.utils.IdentityHashMap

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

class AstNavigator(val documentNode: AstNode) {

  private lazy val _childParent: IdentityHashMap[AstNode, AstNode] = loadChildParentRelationShip(documentNode)
  private lazy val _nodesByType: mutable.Map[Class[_], ArrayBuffer[AstNode]] = loadAllWithType(documentNode, new mutable.HashMap())

  def loadAllWithType(node: AstNode, into: mutable.HashMap[Class[_], ArrayBuffer[AstNode]]): mutable.Map[Class[_], ArrayBuffer[AstNode]] = {
    def doCollect(rootNode: AstNode): Unit = {
      val children = new mutable.Stack[AstNode]()
      children.push(rootNode)
      into.getOrElseUpdate(node.getClass, new ArrayBuffer[AstNode]()).+=(node)
      while (children.nonEmpty) {
        val childNodes = children.pop().children()
        var i = 0
        while (i < childNodes.length) {
          val child = childNodes(i)
          into.getOrElseUpdate(child.getClass, new ArrayBuffer[AstNode]()).+=(child)
          children.push(child)
          i = i + 1
        }
      }
    }
    doCollect(node)
    into
  }

  def allWithType[T <: AstNode](value: Class[T]): Seq[T] = {
    _nodesByType.getOrElse(value, Seq.empty).asInstanceOf[Seq[T]]
  }

  /**
    * Returns true if the child is a descendant of the node. This means that is child or child of any descendant child
    *
    * @param parent The parent
    * @param child  The child to test
    * @return true if descendant
    */
  def isDescendantOf(parent: AstNode, child: AstNode): Boolean = {
    if (parent eq child) {
      true
    } else {
      parent.children().exists(isDescendantOf(_, child))
    }
  }

  /**
    * Returns true if the parent is of the specified type
    */
  def isParentOfType(astNode: AstNode, parentType: Class[_ <: AstNode]): Boolean = {
    parentOf(astNode).exists((parent) => parentType.isInstance(parent))
  }

  def granParentOf(node: AstNode): Option[AstNode] = {
    parentOf(node).flatMap(parentOf)
  }

  def granGranParentOf(node: AstNode): Option[AstNode] = {
    parentOf(node).flatMap(parentOf).flatMap(parentOf)
  }

  /**
    * Returns the list of import directives
    */
  def importDirectives(): Seq[ImportDirective] = {
    AstNodeHelper.collectChildrenWith(documentNode, classOf[ImportDirective])
  }

  private def loadChildParentRelationShip(astNode: AstNode): IdentityHashMap[AstNode, AstNode] = {
    def loadChildIn(node: AstNode, list: IdentityHashMap[AstNode, AstNode]): IdentityHashMap[AstNode, AstNode] = {
      val childNodes = node.children()
      var i = 0
      while (i < childNodes.length) {
        val child = childNodes(i)
        list.put(child, node)
        loadChildIn(child, list)
        i = i + 1
      }
      list
    }

    loadChildIn(astNode, IdentityHashMap())
  }

  private def childParent: IdentityHashMap[AstNode, AstNode] = {
    _childParent
  }

  /**
    * Returns the parent of the specified node if any
    *
    * @param node The node
    * @return The parent if any
    */
  def parentOf(node: AstNode): Option[AstNode] = {
    node match {
      case _: DocumentNode             => None
      case _: ModuleNode               => None
      case _ if (node eq documentNode) => None
      case _ => {
        childParent.get(node) match {
          case Some(parent) => {
            Some(parent)
          }
          case None => {
            childParent.get(node)
          }
        }
      }
    }
  }

  def isChildOf(node: AstNode, parentType: Class[_]): Boolean = {
    parentOf(node).exists((parent) => parentType.isAssignableFrom(parent.getClass))
  }

  def isChildOfAny(node: AstNode, parentTypes: Class[_]*): Boolean = {
    parentOf(node).exists((parent) => parentTypes.exists(_.isAssignableFrom(parent.getClass)))
  }

  final def nodeWith[T <: AstNode](node: AstNode, parentType: Class[T]): Option[T] = {
    if (parentType.isAssignableFrom(node.getClass)) {
      Some(parentType.cast(node))
    } else {
      parentWithType(node, parentType)
    }
  }

  @scala.annotation.tailrec
  final def parentWithType[T <: AstNode](node: AstNode, parentType: Class[T]): Option[T] = {
    val maybeParent = parentOf(node)
    maybeParent match {
      case Some(parent) => {
        if (parentType.isAssignableFrom(parent.getClass)) {
          Some(parentType.cast(parent))
        } else {
          parentWithType(parent, parentType)
        }
      }
      case None => None
    }
  }

  @scala.annotation.tailrec
  final def parentWithTypeMaxLevel[T <: AstNode](node: AstNode, parentType: Class[T], maxLevels: Int): Option[T] = {
    if (maxLevels == 0) {
      None
    } else {
      val maybeParent = parentOf(node)
      maybeParent match {
        case Some(parent) => {
          if (parentType.isAssignableFrom(parent.getClass)) {
            Some(parentType.cast(parent))
          } else {
            parentWithTypeMaxLevel(parent, parentType, maxLevels - 1)
          }
        }
        case None => None
      }
    }
  }

  @scala.annotation.tailrec
  final def parentWithTypeUntil[T <: AstNode](node: AstNode, parentType: Class[T], untilParent: AstNode): Option[T] = {
    val maybeParent = parentOf(node)
    maybeParent match {
      case Some(parent) if parent eq untilParent => None
      case Some(parent) => {
        if (parentType.isAssignableFrom(parent.getClass)) {
          Some(parentType.cast(parent))
        } else {
          parentWithTypeUntil(parent, parentType, untilParent)
        }
      }
      case None => None
    }
  }

  private def nodeContains(astNode: AstNode, index: Int): Boolean = {
    val location = astNode.location()
    if (location.startPosition != UnknownPosition && location.endPosition != UnknownPosition) {
      index >= location.startPosition.index && index <= location.endPosition.index
    } else {
      false
    }
  }

  /**
    * Returns the node at the specified range
    *
    * @param startIndex The start Index
    * @param endIndex   The end index
    * @return The element at that range
    */
  def nodeAt(startIndex: Int, endIndex: Int): Option[AstNode] = {

    def matchesRange(x: AstNode): Boolean = {
      x.location().startPosition.index == startIndex && x.location().endPosition.index == endIndex
    }

    def containsRange(x: AstNode): Boolean = {
      nodeContains(x, startIndex) && nodeContains(x, endIndex)
    }

    @scala.annotation.tailrec
    def refineNode(container: AstNode): Option[AstNode] = {
      if (matchesRange(container)) {
        Some(container)
      } else if (nodeContains(container, startIndex) && nodeContains(container, endIndex)) {
        val maybeChild = container
          .children()
          .find((x) => {
            containsRange(x)
          })
        maybeChild match {
          case Some(child) => {
            refineNode(child)
          }
          case None => Some(container)
        }
      } else {
        None
      }

    }

    refineNode(documentNode)
  }

  /**
    * Returns the most user intersting node at a given cursor position
    *
    * @param index Cursor index
    * @return The astNode
    */
  def nodeAtCursor(index: Int): Option[AstNode] = {
    val maybeNode = nodeAt(index)
    refineNode(maybeNode)
  }

  @scala.annotation.tailrec
  private def refineNode(maybeNode: Option[AstNode]): Option[AstNode] = {
    maybeNode match {
      case Some(ni: NameIdentifier) => {
        refineNode(parentOf(ni))
      }
      case Some(st: StringNode) if parentOf(st).exists(_.isInstanceOf[NameNode]) => {
        refineNode(parentOf(st))
      }
      case Some(vr: VariableReferenceNode) if parentOf(vr).exists(_.isInstanceOf[FunctionCallNode]) => {
        refineNode(parentOf(vr))
      }
      case Some(nn: NameNode) => refineNode(parentOf(nn))
      case _                  => maybeNode
    }
  }

  /**
    * Returns the node at the location with the specified class if specified
    *
    * @param index    The index location
    * @param nodeType The node type to use as filter
    * @return
    */
  def nodeAt(index: Int, nodeType: Option[Class[_]] = None): Option[AstNode] = {
    def ifMatchesType(container: AstNode) = {
      nodeType match {
        case Some(classType) => {
          if (classType.isAssignableFrom(container.getClass)) {
            Some(container)
          } else {
            None
          }
        }
        case None => Some(container)
      }
    }

    def searchInComments(nodes: Seq[AstNode]): Option[CommentNode] = {
      val maybeComment = nodes
        .toStream
        .flatMap((x) => {
          x.comments.find(cn => {
            nodeContains(cn, index)
          })
        }).headOption
      maybeComment
    }

    def refineNode(container: AstNode): Option[AstNode] = {
      val nodes = container
        .children()
        .flatMap({
          case vn: VirtualAstNode => vn.children()
          case n                  => Seq(n)
        })
      val maybeChild = nodes
        .find((x) => {
          nodeContains(x, index)
        })
      maybeChild match {
        /*
         * FunctionCallParametersNode position is equal to the position of FunctionCallNode when infix.
         * That means that any index inside the function call will match against the arguments, making it impossible to
         * refine into the type parameter application. We manually check if it matches the parameters, if not
         * we check type parameters
         */
        case Some(fcn @ FunctionCallNode(_, _, typeParameters, _)) if fcn.annotation(classOf[InfixNotationFunctionCallAnnotation]).isDefined && typeParameters.isDefined => {
          refineNode(fcn) match {
            case Some(FunctionCallParametersNode(_)) => {
              refineNode(typeParameters.get)
            }
            case result => {
              result
            }
          }
        }
        case Some(child) => {
          refineNode(child) match {
            case None => {
              val maybeComment = searchInComments(nodes)
              maybeComment.orElse(
                ifMatchesType(container))
            }
            case result => {
              result
            }
          }
        }
        case None => {
          val maybeComment = searchInComments(nodes)
          maybeComment.orElse(
            ifMatchesType(container))
        }
      }
    }

    refineNode(documentNode)
  }

}

object AstNavigator {
  def apply(rootNode: AstNode): AstNavigator = new AstNavigator(rootNode)
}
