package org.mule.weave.v2.parser.phase

import org.mule.weave.v2.grammar.AdditionOpId
import org.mule.weave.v2.grammar.AsOpId
import org.mule.weave.v2.grammar.AttributeValueSelectorOpId
import org.mule.weave.v2.grammar.DescendantsSelectorOpId
import org.mule.weave.v2.grammar.DivisionOpId
import org.mule.weave.v2.grammar.DynamicSelectorOpId
import org.mule.weave.v2.grammar.EqOpId
import org.mule.weave.v2.grammar.FilterSelectorOpId
import org.mule.weave.v2.grammar.GreaterOrEqualThanOpId
import org.mule.weave.v2.grammar.IsOpId
import org.mule.weave.v2.grammar.LeftShiftOpId
import org.mule.weave.v2.grammar.LessOrEqualThanOpId
import org.mule.weave.v2.grammar.LessThanOpId
import org.mule.weave.v2.grammar.MultiAttributeValueSelectorOpId
import org.mule.weave.v2.grammar.MultiValueSelectorOpId
import org.mule.weave.v2.grammar.MultiplicationOpId
import org.mule.weave.v2.grammar.NotEqOpId
import org.mule.weave.v2.grammar.ObjectKeyValueSelectorOpId
import org.mule.weave.v2.grammar.RightShiftOpId
import org.mule.weave.v2.grammar.SchemaValueSelectorOpId
import org.mule.weave.v2.grammar.SubtractionOpId
import org.mule.weave.v2.grammar.ValueSelectorOpId
import org.mule.weave.v2.parser.ComplexType
import org.mule.weave.v2.parser.DynamicAccess
import org.mule.weave.v2.parser.DynamicFunctionsCanNotBeChecked
import org.mule.weave.v2.parser.ExpressionPatternForceMaterialize
import org.mule.weave.v2.parser.MessageCollector
import org.mule.weave.v2.parser.NativeFunctionParameterNotAnnotated
import org.mule.weave.v2.parser.NegativeIndexAccess
import org.mule.weave.v2.parser.NonStreamableParameter
import org.mule.weave.v2.parser.NotStreamableOperator
import org.mule.weave.v2.parser.RevertSelection
import org.mule.weave.v2.parser.UsingUpsert
import org.mule.weave.v2.parser.annotation.MaterializeVariableAnnotation
import org.mule.weave.v2.parser.annotation.StreamingCapableVariableAnnotation
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.AstNodeHelper
import org.mule.weave.v2.parser.ast.functions.FunctionCallNode
import org.mule.weave.v2.parser.ast.functions.FunctionCallParametersNode
import org.mule.weave.v2.parser.ast.functions.FunctionNode
import org.mule.weave.v2.parser.ast.functions.OverloadedFunctionNode
import org.mule.weave.v2.parser.ast.functions.UsingNode
import org.mule.weave.v2.parser.ast.functions.UsingVariableAssignment
import org.mule.weave.v2.parser.ast.header.directives.FunctionDirectiveNode
import org.mule.weave.v2.parser.ast.header.directives.InputDirective
import org.mule.weave.v2.parser.ast.header.directives.VarDirective
import org.mule.weave.v2.parser.ast.operators.BinaryOpNode
import org.mule.weave.v2.parser.ast.operators.UnaryOpNode
import org.mule.weave.v2.parser.ast.patterns.DeconstructArrayPatternNode
import org.mule.weave.v2.parser.ast.patterns.DeconstructObjectPatternNode
import org.mule.weave.v2.parser.ast.patterns.DefaultPatternNode
import org.mule.weave.v2.parser.ast.patterns.EmptyArrayPatternNode
import org.mule.weave.v2.parser.ast.patterns.EmptyObjectPatternNode
import org.mule.weave.v2.parser.ast.patterns.PatternMatcherNode
import org.mule.weave.v2.parser.ast.patterns.TypePatternNode
import org.mule.weave.v2.parser.ast.structure.DocumentNode
import org.mule.weave.v2.parser.ast.structure.NumberNode
import org.mule.weave.v2.parser.ast.types.WeaveTypeNode
import org.mule.weave.v2.parser.ast.updates.UpdateNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode
import org.mule.weave.v2.scope.AstNavigator
import org.mule.weave.v2.scope.Reference
import org.mule.weave.v2.scope.ScopesNavigator
import org.mule.weave.v2.ts.ScopeGraphTypeReferenceResolver
import org.mule.weave.v2.ts.TypeHelper
import org.mule.weave.v2.ts.WeaveType

import scala.collection.mutable
import scala.util.Try

object StreamingCapableVariableMarkerPhase {
  val STREAM_ANNOTATION = "StreamCapable"
}

/**
  * This phase will mark all the variables if of functions and inputs if they need they can be Streamed or not.
  * It requires the `org.mule.weave.v2.parser.phase.MaterializeVariableMarkerPhase` to be previously executed
  */
class StreamingCapableVariableMarkerPhase[R <: AstNode, T <: AstNodeResultAware[R] with ScopeNavigatorResultAware]() extends CompilationPhase[T, T] {

  override def doCall(source: T, context: ParsingContext): PhaseResult[T] = {
    val functionCallStack = new FunctionCallStack()
    val astNavigator: AstNavigator = source.scope.astNavigator()
    source.astNode match {
      case dn: DocumentNode => {
        dn.header.directives
          .collect({ case id: InputDirective => id })
          .foreach((id) => {
            val materializeAnnotation: MaterializeVariableAnnotation = id.variable.annotation(classOf[MaterializeVariableAnnotation]).get
            if (materializeAnnotation.needMaterialize) {
              id.variable.annotate(new StreamingCapableVariableAnnotation(false, materializeAnnotation.reasons))
            } else {
              val collector = new MessageCollector
              val canStream = canStreamVariable(id.variable, dn, astNavigator, source.scope, collector, context, functionCallStack)
              id.variable.annotate(new StreamingCapableVariableAnnotation(canStream, collector.errorMessages))
            }
          })
      }
      case _ =>
    }

    val nodes: Seq[FunctionNode] = astNavigator.allWithType(classOf[FunctionNode])
    //Mark Stream or not
    nodes.foreach((fn) => {
      if (!AstNodeHelper.isNativeCall(fn.body, source.scope)) {
        markWithStreamingFunctionNode(fn, astNavigator, source.scope, context, functionCallStack)
      } else {
        fn.params.paramList.foreach((fp) => {
          val variable: NameIdentifier = fp.variable
          if (fp.codeAnnotations.exists((an) => an.name.name.equals(StreamingCapableVariableMarkerPhase.STREAM_ANNOTATION))) {
            variable.annotate(new StreamingCapableVariableAnnotation(true))
          } else {
            variable.annotate(new StreamingCapableVariableAnnotation(false, Seq((variable.location(), NativeFunctionParameterNotAnnotated()))))
          }
        })
      }
    })

    SuccessResult(source, context)
  }

  private def markWithStreamingFunctionNode(fn: FunctionNode, astNavigator: AstNavigator, scopeNaviagtor: ScopesNavigator, ctx: ParsingContext, fcs: FunctionCallStack): Unit = {
    if (!fcs.alreadyProcessed(fn)) {
      fcs.push(fn)
      fn.params.paramList.foreach((fp) => {
        val variable: NameIdentifier = fp.variable
        val annotation = variable.annotation(classOf[MaterializeVariableAnnotation]).get
        if (annotation.needMaterialize) {
          variable.annotate(new StreamingCapableVariableAnnotation(false, annotation.reasons))
        } else {
          val collector: MessageCollector = new MessageCollector
          val isStreameable: Boolean = canStreamVariable(variable, fn.body, astNavigator, scopeNaviagtor, collector, ctx, fcs)
          variable.annotate(new StreamingCapableVariableAnnotation(isStreameable, collector.errorMessages))
        }
      })
      fcs.pop()
    }
  }

  private def canStreamVariable(variable: NameIdentifier, scope: AstNode, astNavigator: AstNavigator, scopeNaviagtor: ScopesNavigator, messageCollector: MessageCollector, ctx: ParsingContext, fcs: FunctionCallStack): Boolean = {
    val references: Seq[Reference] =
      scopeNaviagtor
        .scopeOf(variable)
        .map(_.resolveLocalReferenceTo(variable))
        .getOrElse(Seq())

    references.forall((ref) => {
      canStream(ref.referencedNode, scope, astNavigator, scopeNaviagtor, messageCollector, ctx, fcs)
    })
  }

  /**
    * This methods will validate if a given method expression can be stream. The way it does is goes from the reference to its parent to see where the variable is going to be used.
    */
  private def canStream(node: AstNode, rootBody: AstNode, astNavigator: AstNavigator, scopeNaviagtor: ScopesNavigator, messageCollector: MessageCollector, ctx: ParsingContext, fcs: FunctionCallStack): Boolean = {
    val parent: AstNode = astNavigator.parentOf(node).get
    val doCanStream = parent match {
      case fcpn: FunctionCallParametersNode => {
        val argIndex: Int = fcpn.args.indexOf(node)
        astNavigator.parentOf(fcpn) match {
          case Some(fcn: FunctionCallNode) => {
            handleFunctionCallNode(fcn, argIndex, scopeNaviagtor, messageCollector, ctx, fcs)
          }
          case _ => false
        }
      }
      case pmn: PatternMatcherNode if (pmn.lhs eq node) => {
        // If we are branching on a reference then validate the patterns. Otherwise we are coming from an expression node and we should not handle this case
        isPatternStream(pmn, rootBody, astNavigator, scopeNaviagtor, messageCollector, ctx, fcs)
      }
      case up: UpdateNode => {
        val usesUpsert = up.matchers.expressions.find(_.forceCreate)
        if (usesUpsert.isDefined) {
          messageCollector.error(UsingUpsert(), usesUpsert.get.location())
          false
        } else {
          true
        }

      }
      case vd: VarDirective => {
        val needsMaterialize = vd.variable.annotation(classOf[MaterializeVariableAnnotation]).exists(_.needMaterialize)
        !needsMaterialize &&
          canStreamVariable(vd.variable, astNavigator.granParentOf(vd).get, astNavigator, scopeNaviagtor, messageCollector, ctx, fcs)
      }
      case vd: UsingVariableAssignment => {
        val needsMaterialize = vd.name.annotation(classOf[MaterializeVariableAnnotation]).exists(_.needMaterialize)
        !needsMaterialize &&
          canStreamVariable(vd.name, astNavigator.parentWithType(vd, classOf[UsingNode]).get, astNavigator, scopeNaviagtor, messageCollector, ctx, fcs)
      }
      case _: FunctionNode => false //If it being referenced inside a function node then we can not stream
      case bon: BinaryOpNode => {
        bon.opId match {
          case AsOpId | IsOpId => {
            val typeReferenceResolver = new ScopeGraphTypeReferenceResolver(scopeNaviagtor)
            val needsMaterialize = typeNeedsMaterialize(bon.rhs, typeReferenceResolver)
            if (needsMaterialize) {
              messageCollector.error(ComplexType(), bon.rhs.location())
            }
            !needsMaterialize
          }
          case DynamicSelectorOpId if (bon.lhs eq node) => {
            bon.rhs match {
              case NumberNode(literalValue, _) => {
                val positiveIndex = isPositiveNumber(literalValue)
                if (!positiveIndex) {
                  messageCollector.error(NegativeIndexAccess(literalValue), bon.location())
                }
                positiveIndex
              }
              case fcn @ FunctionCallNode(VariableReferenceNode(to, _), FunctionCallParametersNode(Seq(NumberNode(leftNumber, _), NumberNode(rightNumber, _))), _, _) => {
                val maybeReference = scopeNaviagtor.resolveVariable(to)
                maybeReference match {
                  case Some(value) => {
                    if (maybeReference.flatMap(_.moduleSource).getOrElse(ctx.nameIdentifier).equals(NameIdentifier.CORE_MODULE) && value.referencedNode.name.equals("to")) {
                      if (isPositiveNumber(leftNumber) && isPositiveNumber(rightNumber)) {
                        val sequentialAccess = java.lang.Long.parseLong(leftNumber) <= java.lang.Long.parseLong(rightNumber)
                        if (!sequentialAccess) {
                          messageCollector.error(RevertSelection(), fcn.location())
                        }
                        sequentialAccess
                      } else {
                        if (!isPositiveNumber(leftNumber)) {
                          messageCollector.error(NegativeIndexAccess(leftNumber), fcn.location())
                        } else {
                          messageCollector.error(NegativeIndexAccess(rightNumber), fcn.location())
                        }
                        false
                      }
                    } else {
                      messageCollector.error(DynamicAccess(), fcn.location())
                      false
                    }
                  }
                  case None => {
                    messageCollector.error(DynamicAccess(), fcn.location())
                    false
                  }
                }
              }
              case _ => {
                messageCollector.error(DynamicAccess(), bon.location())
                false
              }
            }
          }
          case AdditionOpId | GreaterOrEqualThanOpId | AttributeValueSelectorOpId | ValueSelectorOpId | MultiAttributeValueSelectorOpId | MultiValueSelectorOpId | SubtractionOpId | DivisionOpId | EqOpId | LeftShiftOpId | RightShiftOpId | MultiplicationOpId | NotEqOpId | SchemaValueSelectorOpId |
            ObjectKeyValueSelectorOpId | FilterSelectorOpId | LessThanOpId | LessOrEqualThanOpId =>
            true
          case op => {
            messageCollector.error(NotStreamableOperator(op.name), bon.location())
            false
          }
        }
      }
      case un: UnaryOpNode => {
        un.opId match {
          case DescendantsSelectorOpId => {
            messageCollector.error(NotStreamableOperator("Descendant Selector"), un.location())
            false
          }
          case _ => true
        }
      }
      //Before we keep querying parent we should check if we reach the top of where to keep bubeling.
      case _ => {
        true
      }
    }
    doCanStream && ((parent eq rootBody) || canStream(parent, rootBody, astNavigator, scopeNaviagtor, messageCollector, ctx, fcs))
  }

  private def isPositiveNumber(literalValue: String) = {
    Try(java.lang.Long.parseLong(literalValue) >= 0).getOrElse(false)
  }

  /**
    * This methods handles the streaming of a pattern matcher. So we are validating if matching on a given variable can be stream.
    */
  private def isPatternStream(pmn: PatternMatcherNode, functionBody: AstNode, astNavigator: AstNavigator, scopeNaviagtor: ScopesNavigator, messageCollector: MessageCollector, ctx: ParsingContext, fcs: FunctionCallStack): Boolean = {
    def canStreamVariableInScope(name: NameIdentifier, onMatch: AstNode) = {
      val needsMaterialization = name.annotation(classOf[MaterializeVariableAnnotation]).forall(_.needMaterialize)
      if (!needsMaterialization) {
        canStreamVariable(name, onMatch, astNavigator, scopeNaviagtor, messageCollector, ctx, fcs)
      } else {
        true
      }
    }
    //Check if any pattern needs materialization
    !pmn.patterns.patterns
      .exists({
        case tpn: TypePatternNode => {
          val typeReferenceResolver = new ScopeGraphTypeReferenceResolver(scopeNaviagtor)
          val needsMaterialize = typeNeedsMaterialize(tpn.pattern, typeReferenceResolver)
          if (needsMaterialize) {
            messageCollector.error(ComplexType(), tpn.location())
            true
          } else {
            !canStreamVariableInScope(tpn.name, tpn.onMatch)
          }
        }
        case dpn: DefaultPatternNode => {
          !canStreamVariableInScope(dpn.name, dpn.onMatch)
        }
        case dapn: DeconstructArrayPatternNode => {
          !canStreamVariableInScope(dapn.tail, dapn.onMatch)
        }
        case dopn: DeconstructObjectPatternNode => {
          !canStreamVariableInScope(dopn.tail, dopn.onMatch)
        }
        case _: EmptyObjectPatternNode => false
        case _: EmptyArrayPatternNode  => false
        case pn => {
          messageCollector.error(ExpressionPatternForceMaterialize(), pn.location())
          true
        }

      })
  }

  private def handleFunctionCallNode(fcn: FunctionCallNode, argIndex: Int, scopeNaviagtor: ScopesNavigator, messageCollector: MessageCollector, ctx: ParsingContext, fcs: FunctionCallStack) = {
    fcn.function match {
      case vrn: VariableReferenceNode => {
        val maybeReference = scopeNaviagtor.resolveVariable(vrn.variable)
        maybeReference match {
          case Some(ref) => {
            //We can only validate if it is a static linked function
            val astNavigator = ref.scope.astNavigator()
            astNavigator.parentOf(ref.referencedNode) match {
              // Check if function is recursive
              case Some(fdn: FunctionDirectiveNode) => {
                val functionNode = fdn.literal
                canStreamFunctionNode(functionNode, argIndex, fcn, scopeNaviagtor, messageCollector, ctx, fcs, astNavigator)
              }
              case _ => {
                messageCollector.error(DynamicFunctionsCanNotBeChecked(), vrn.location())
                false
              }
            }
          }
          case _ => false
        }
      }
      case _ => false
    }
  }

  private def canStreamFunctionNode(functionNode: AstNode, argIndex: Int, fcn: FunctionCallNode, scopeNavigator: ScopesNavigator, messageCollector: MessageCollector, ctx: ParsingContext, fcs: FunctionCallStack, astNavigator: AstNavigator): Boolean = {
    functionNode match {
      case fn: FunctionNode if (fcs.isRecursive(fn)) => true
      case fn: FunctionNode => {
        val doCanStream = canStreamArgument(argIndex, fn, astNavigator, scopeNavigator, ctx, fcs)
        messageCollector.error(NonStreamableParameter(), fcn.args.args(argIndex).location())
        doCanStream
      }
      case ofd: OverloadedFunctionNode => {
        val doCanStream = ofd.functions.forall((fd) => {
          canStreamFunctionNode(fd, argIndex, fcn, scopeNavigator, messageCollector, ctx, fcs, astNavigator)
        })
        messageCollector.error(NonStreamableParameter(), fcn.args.args(argIndex).location())
        doCanStream
      }
      case _ => false
    }
  }

  private def canStreamArgument(argIndex: Int, fn: FunctionNode, astNavigator: AstNavigator, scopeNaviagtor: ScopesNavigator, ctx: ParsingContext, fcs: FunctionCallStack) = {
    val paramList = fn.params.paramList
    if (paramList.size > argIndex) {
      val streamAnnotation = paramList(argIndex).variable.annotation(classOf[StreamingCapableVariableAnnotation])
      if (streamAnnotation.isDefined) {
        streamAnnotation.get.canStream
      } else {
        markWithStreamingFunctionNode(fn, astNavigator, scopeNaviagtor, ctx, fcs)
        paramList(argIndex).variable.annotation(classOf[StreamingCapableVariableAnnotation]).exists(_.canStream)
      }
    } else {
      true
    }
  }

  private def typeNeedsMaterialize(typeNode: AstNode, typeReferenceResolver: ScopeGraphTypeReferenceResolver): Boolean = {
    //Only if it a complex object needs to be consumed else is ok
    typeNode match {
      case weaveTypeNode: WeaveTypeNode => {
        TypeHelper.requiredMaterialize(WeaveType(weaveTypeNode, typeReferenceResolver))
      }
      case _ => true
    }
  }

}

class FunctionCallStack {
  private val stack: mutable.Stack[FunctionNode] = mutable.Stack[FunctionNode]()
  private val processed = mutable.ArrayBuffer[FunctionNode]()

  def isRecursive(fn: AstNode): Boolean = {
    stack.exists((sfn) => sfn eq fn)
  }

  def push(fn: FunctionNode): Unit = {
    stack.push(fn)
  }

  def pop(): Unit = {
    processed.+=(stack.pop())
  }

  def alreadyProcessed(fn: FunctionNode): Boolean = {
    processed.exists((pfn) => pfn eq fn)
  }
}
