package org.mule.weave.v2.parser.phase

import org.mule.weave.v2.grammar.BinaryOpIdentifier
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.DirectivesCapableNode
import org.mule.weave.v2.parser.ast.MutableAstNode
import org.mule.weave.v2.parser.ast.annotation.AnnotationArgumentsNode
import org.mule.weave.v2.parser.ast.annotation.AnnotationNode
import org.mule.weave.v2.parser.ast.functions.DoBlockNode
import org.mule.weave.v2.parser.ast.functions.FunctionNode
import org.mule.weave.v2.parser.ast.header.HeaderNode
import org.mule.weave.v2.parser.ast.header.directives.DirectiveNode
import org.mule.weave.v2.parser.ast.header.directives.NamespaceDirective
import org.mule.weave.v2.parser.ast.header.directives.VarDirective
import org.mule.weave.v2.parser.ast.operators.BinaryOpNode
import org.mule.weave.v2.parser.ast.selectors.NullSafeNode
import org.mule.weave.v2.parser.ast.structure.DocumentNode
import org.mule.weave.v2.parser.ast.structure.NameNode
import org.mule.weave.v2.parser.ast.structure.NamespaceNode
import org.mule.weave.v2.parser.ast.structure.NumberNode
import org.mule.weave.v2.parser.ast.structure.StringNode
import org.mule.weave.v2.parser.ast.structure.UriNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode
import org.mule.weave.v2.scope.AstNavigator
import org.mule.weave.v2.scope.ScopesNavigator
import org.mule.weave.v2.scope.VariableScope

import java.util.concurrent.atomic.AtomicInteger
import scala.annotation.tailrec
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

/**
  * Runs Common Subexpression Elimination inside scopes
  *
  * https://en.wikipedia.org/wiki/Common_subexpression_elimination
  *
  * @tparam T The Type of Node
  */
class CommonSubexpressionReductionPhase[T <: AstNode]() extends CompilationPhase[ScopeGraphResult[T], ScopeGraphResult[T]] {

  private val counter = new AtomicInteger()

  private var mutated = false

  def newExpressionTree(nsn: AstNode, astNavigator: AstNavigator, scopesNavigator: ScopesNavigator, leaf: Boolean = true): Option[FactorableExpressionNode] = {
    nsn match {
      case nullSafeNode: NullSafeNode => newExpressionTree(nullSafeNode.selector, astNavigator, scopesNavigator, leaf)
      case bn @ BinaryOpNode(opId, lhs, NameNode(sn @ StringNode(value, _), maybeNs, _), _) if (isStaticNamespace(maybeNs, scopesNavigator, astNavigator)) => {
        newExpressionTree(lhs, astNavigator, scopesNavigator, leaf = false) match {
          case Some(leftNode) => {
            val uri: Option[String] = resolveNamespaceUri(maybeNs, astNavigator, scopesNavigator)
            val rightNode = ValueSelectorExpressionNode(NameSelectorExpression(value, sn.quotedBy(), uri), opId, ArrayBuffer(ReplacementNode(bn, leaf)))
            leftNode.addChild(rightNode)
            Some(rightNode)
          }
          case None => None
        }
      }
      case bn @ BinaryOpNode(opId, lhs, NumberNode(value, _), _) => {
        newExpressionTree(lhs, astNavigator, scopesNavigator, leaf = false) match {
          case Some(leftNode) => {
            val rightNode = ValueSelectorExpressionNode(NumberSelectorExpression(value), opId, ArrayBuffer(ReplacementNode(bn, leaf)))
            leftNode.addChild(rightNode)
            Some(rightNode)
          }
          case None => None
        }
      }
      case variableReferenceNode: VariableReferenceNode => {
        Some(VariableReferenceExpressionNode(variableReferenceNode, astNavigator.parentWithType(variableReferenceNode, classOf[NullSafeNode]).get))
      }
      case _ => None
    }
  }

  private def resolveNamespaceUri(maybeNs: Option[AstNode], astNavigator: AstNavigator, scopesNavigator: ScopesNavigator): Option[String] = {
    val uri = maybeNs match {
      case Some(NamespaceNode(prefix)) => {
        val prefixDecl = scopesNavigator.resolveVariable(prefix).get.referencedNode
        astNavigator.parentWithType(prefixDecl, classOf[NamespaceDirective]).map(_.uri.literalValue)
      }
      case None => None
    }
    uri
  }

  private def isStaticNamespace(maybeNamespaceNode: Option[AstNode], scopesNavigator: ScopesNavigator, astNavigator: AstNavigator): Boolean = {
    maybeNamespaceNode match {
      case Some(NamespaceNode(prefix)) => {
        val maybeReference = scopesNavigator.resolveVariable(prefix)
        maybeReference match {
          case Some(ref) if (ref.isLocalReference) => {
            astNavigator.parentWithTypeMaxLevel(ref.referencedNode, classOf[NamespaceDirective], 3).isDefined
          }
          case _ => false
        }
      }
      case None => true
      case _    => false
    }
  }

  override def doCall(input: ScopeGraphResult[T], ctx: ParsingContext): PhaseResult[ScopeGraphResult[T]] = {
    if (!ctx.shouldRunCommonSubExpressionElimination()) {
      SuccessResult(input, ctx)
    } else {
      val astNavigator: AstNavigator = input.scope.rootScope.astNavigator()
      val scopeNavigator: ScopesNavigator = input.scope
      val rootScope = scopeNavigator.rootScope
      runOnScope(rootScope, astNavigator, scopeNavigator)
      if (mutated) {
        //We rebuild the scope graph if it was mutated
        val value = new ScopeGraphPhase[T]()
        value.call(input, ctx)
      } else {
        SuccessResult(input, ctx)
      }
    }
  }

  def validReplaceExpression(parent: AstNode): Boolean = {
    parent match {
      case bo: BinaryOpNode => {
        val valid = bo.rhs match {
          case _: StringNode => true
          case _: NumberNode => true
          case nn: NameNode if (nn.keyName.isInstanceOf[StringNode]) => {
            true
          }
          case _ => false
        }
        if (valid)
          validReplaceExpression(bo.lhs)
        else
          false
      }
      case nsn: NullSafeNode =>
        validReplaceExpression(nsn.selector)
      case _: VariableReferenceNode => {
        true
      }
      case _ => false
    }
  }

  def mergeTree(v: VariableReferenceExpressionNode, acc: VariableReferenceExpressionNode): VariableReferenceExpressionNode = {

    @tailrec
    def doMergeTree(currentNode: FactorableExpressionNode, accNode: FactorableExpressionNode): FactorableExpressionNode = {
      currentNode.children.headOption match {
        case Some(child: ValueSelectorExpressionNode) => {
          val childSelector = child.selectorExpression
          val opId = child.opId
          val maybeMatchingChild = accNode.children.find({
            case ValueSelectorExpressionNode(selectorExpression, nodeOpId, _) if (selectorExpression.canBeMerged(childSelector) && nodeOpId == opId) => true
            case _ => false
          })
          maybeMatchingChild match {
            case Some(matchingChild: ValueSelectorExpressionNode) =>
              matchingChild.replacements.++=(child.replacements)
              doMergeTree(child, matchingChild)
            case _ => {
              accNode.addChild(child)
              accNode
            }
          }

        }
        case _ => accNode
      }
    }

    doMergeTree(v, acc).root()
  }

  private def runOnScope(rootScope: VariableScope, astNavigator: AstNavigator, scopeNavigator: ScopesNavigator): Unit = {
    val okScope = rootScope.astNode match {
      case _: DocumentNode => true
      case _: DoBlockNode  => true
      case node => {
        val maybeNode = astNavigator.parentOf(node)
        maybeNode match {
          case Some(_: FunctionNode) => {
            true
          }
          case _ => false
        }
      }
    }
    if (okScope) {
      val declarations = rootScope.declarations()
      val collector = new DeclarationCollector(counter, rootScope)
      declarations.foreach((decl) => {

        val nodes: Seq[NullSafeNode] = scopeNavigator
          .resolveLocalReferencedBy(decl)
          //We ignore cross module references and also namespace references
          .filter((ref) => {
            ref.isLocalReference && astNavigator.parentWithTypeMaxLevel(ref.referencedNode, classOf[NamespaceNode], 3).isEmpty
          })
          //We look for selection expression only
          .flatMap((ref) => {
            val maybeNode = astNavigator.parentWithType(ref.referencedNode, classOf[NullSafeNode])
            maybeNode match {
              case Some(value) => {
                val validSelection: Boolean = validReplaceExpression(value)
                if (!validSelection)
                  None
                else
                  Some(value)
              }
              case None => None
            }
          })

        val variableReferenceExpressionNodes = nodes
          .flatMap((n) => {
            val maybeExpressionNode: Option[FactorableExpressionNode] = newExpressionTree(n, astNavigator, scopeNavigator)
            maybeExpressionNode.map(_.root())
          })
        if (variableReferenceExpressionNodes.nonEmpty) {
          val expressionTree: FactorableExpressionNode = variableReferenceExpressionNodes
            .reduce((acc, v) => mergeTree(v, acc))

          if (expressionTree != null && expressionTree.children.nonEmpty) {
            val declarationDirectiveNode = astNavigator.parentWithTypeMaxLevel(decl, classOf[DirectiveNode], 3).orNull
            expressionTree.reduceToCommonSubExpression(null, collector, astNavigator, declarationDirectiveNode)
          }
        }
      })
      if (collector.hasDeclaration()) {
        val maybeHeaderNode = rootScope.astNode match {
          case dn: DocumentNode => Some(dn.header)
          case dn: DoBlockNode  => Some(dn.header)
          case node => {
            val maybeNode = astNavigator.parentOf(node)
            maybeNode match {
              case Some(fn: FunctionNode) => {
                val doBlockNode = DoBlockNode(HeaderNode(Seq()), fn.body)
                fn.body = doBlockNode
                Some(doBlockNode.header)
              }
              case _ => None
            }
          }
        }
        maybeHeaderNode match {
          case Some(header) => {
            collector.createDeclarations(header)
            mutated = true
          }
          case _ => throw new RuntimeException()
        }
      }
    }
    rootScope
      .children()
      .foreach((scope) => {
        runOnScope(scope, astNavigator, scopeNavigator)
      })
  }
}

trait FactorableExpressionNode {

  val children: ArrayBuffer[FactorableExpressionNode] = ArrayBuffer()

  var parent: FactorableExpressionNode = _

  def reduceToCommonSubExpression(variableReferenceNode: AstNode, collector: DeclarationCollector, astNavigator: AstNavigator, varDeclarationNode: DirectiveNode): Unit

  def addChild(expressionNode: FactorableExpressionNode): FactorableExpressionNode = {
    expressionNode.parent = this
    children.+=(expressionNode)
    this
  }

  def root(): VariableReferenceExpressionNode
}

/**
  * Collects all the variables that where extracted as common subexpressions
  *
  * @param counter The variable name counter
  */
class DeclarationCollector(counter: AtomicInteger, scope: VariableScope) {

  private val namespacesToPrefix = mutable.Map[String, String]()
  private val newNamespaceDirectives = ArrayBuffer[NamespaceDirective]()
  private val newVarDirectives = ArrayBuffer[InsertAfter]()

  def nsPrefix(uri: String): String = {
    namespacesToPrefix.getOrElseUpdate(
      uri, {
      var currentScope: Option[VariableScope] = Some(scope)
      var prefix: Option[String] = None
      while (currentScope.isDefined && prefix.isEmpty) {
        currentScope.get.astNode match {
          case dcn: DirectivesCapableNode => {
            val directives = dcn.directives.collect({
              case nd: NamespaceDirective => nd
            })
            prefix = directives
              .find(_.uri.literalValue == uri)
              .map(_.prefix.name)
          }
          case _ =>
        }
        currentScope = currentScope.get.parentScope
      }
      prefix.getOrElse(declareNamespace(uri))
    })

  }

  /**
    * Adds a new variables
    *
    * @param value The variable expression node
    * @return The name of the new variable
    */
  def addDeclaration(value: AstNode, afterNode: DirectiveNode): VarDirective = {

    val nameIdentifier = NameIdentifier(s"__fakeVariable${counter.incrementAndGet()}")
    val directive = VarDirective(nameIdentifier, value)
    directive.setAnnotations(Seq(AnnotationNode(NameIdentifier("dw::Core::Lazy"), Some(AnnotationArgumentsNode(Seq())))))
    newVarDirectives.+=(InsertAfter(directive, afterNode))
    directive
  }

  private def declareNamespace(uri: String) = {
    val prefix = s"__ns${counter.incrementAndGet()}"
    newNamespaceDirectives.+=(NamespaceDirective(NameIdentifier(prefix), UriNode(uri)))
    prefix
  }

  /**
    * Adds all the variables declaration to the HeaderNode
    *
    * @param header The header target node to where all the declarations are going to be inserted
    * @return Returns true if variables where added
    */
  def createDeclarations(header: HeaderNode): Boolean = {
    if (newVarDirectives.nonEmpty) {
      newVarDirectives.foreach((toInsert) => {
        header.addDirectiveAfter(toInsert.directiveToInsert, toInsert.afterNode)
      })

      newNamespaceDirectives.foreach((toInsert) => {
        header.addDirectiveAfter(toInsert, null)
      })
    }
    newVarDirectives.nonEmpty
  }

  def hasDeclaration(): Boolean = newVarDirectives.nonEmpty

}

case class InsertAfter(directiveToInsert: DirectiveNode, afterNode: DirectiveNode)

/**
  * The top node in a selection path
  *
  * @param variableReferenceNode The variable reference
  * @param nullSafeNode          The parent root null safe node
  */
case class VariableReferenceExpressionNode(variableReferenceNode: VariableReferenceNode, nullSafeNode: NullSafeNode) extends FactorableExpressionNode {
  def reduceToCommonSubExpression(collector: DeclarationCollector, astNavigator: AstNavigator, varDeclarationNode: DirectiveNode): Unit = {
    children.foreach(_.reduceToCommonSubExpression(variableReferenceNode, collector, astNavigator, varDeclarationNode))
  }

  override def reduceToCommonSubExpression(node: AstNode, collector: DeclarationCollector, astNavigator: AstNavigator, varDeclarationNode: DirectiveNode): Unit = {
    reduceToCommonSubExpression(collector, astNavigator, varDeclarationNode)
  }

  override def root(): VariableReferenceExpressionNode = this
}

trait SelectorExpression {

  /**
    * Returns true if this Expression and the otherSelector selects over the same thing
    *
    * @param otherSelector
    * @return
    */
  def canBeMerged(otherSelector: SelectorExpression): Boolean

  /**
    * Creates an AST Node that represents this selector Node.
    * For example for Index based it will create a Number Node and For Name selectors will create a NameNode
    *
    * @return The new AstNode
    */
  def toSelectorNode(collector: DeclarationCollector): AstNode
}

case class NameSelectorExpression(name: String, quoted: Option[Char], uri: Option[String]) extends SelectorExpression {

  override def toSelectorNode(collector: DeclarationCollector): AstNode = {
    val maybeNamespaceNode = uri.map(collector.nsPrefix).map((prefix) => NamespaceNode(NameIdentifier(prefix)))
    val stringNode = StringNode(name)
    quoted.foreach((quote) => {
      stringNode.withQuotation(quote)
    })
    NameNode(stringNode, maybeNamespaceNode)
  }

  override def canBeMerged(childSelector: SelectorExpression): Boolean = {
    childSelector match {
      case NameSelectorExpression(name, _, uri) => {
        this.name == name && this.uri.equals(uri)
      }
      case _ => false
    }
  }
}

case class NumberSelectorExpression(number: String) extends SelectorExpression {

  override def toSelectorNode(collector: DeclarationCollector): AstNode = {
    NumberNode(number)
  }

  override def canBeMerged(childSelector: SelectorExpression): Boolean = {
    childSelector match {
      case NumberSelectorExpression(number) => {
        this.number == number
      }
      case _ => false
    }
  }
}

/**
  * Represents a node in the expression tree
  *
  * @param nameIdentifier The name of the selection
  * @param replacements   The replacements are all the nodes that point to this selection path
  */
case class ValueSelectorExpressionNode(selectorExpression: SelectorExpression, opId: BinaryOpIdentifier, replacements: ArrayBuffer[ReplacementNode[BinaryOpNode]]) extends FactorableExpressionNode {

  override def reduceToCommonSubExpression(contextNode: AstNode, collector: DeclarationCollector, astNavigator: AstNavigator, contextExpression: DirectiveNode): Unit = {
    val newValueNode = NullSafeNode(BinaryOpNode(opId, contextNode, selectorExpression.toSelectorNode(collector)))
    if ((replacements.size > 1 && children.size != 1) || (replacements.exists(_.leaf) && children.nonEmpty)) {
      val newVariableDirective: VarDirective = collector.addDeclaration(newValueNode, contextExpression)
      replacements.foreach((r) => {
        val referencedNode = r.node
        astNavigator.parentOf(referencedNode) match {
          case Some(masn: MutableAstNode) => masn.update(referencedNode, VariableReferenceNode(newVariableDirective.variable.name))
          case _                          =>
        }
      })
      children.foreach((ch) => {
        ch.reduceToCommonSubExpression(VariableReferenceNode(newVariableDirective.variable.name), collector, astNavigator, newVariableDirective)
      })
    } else {
      children.foreach((ch) => {
        ch.reduceToCommonSubExpression(newValueNode, collector, astNavigator, contextExpression)
      })
    }
  }

  override def root(): VariableReferenceExpressionNode = parent.root()
}

case class ReplacementNode[T <: AstNode](node: T, leaf: Boolean)
