package org.mule.weave.v2.interpreted.transform

import org.mule.weave.v2.interpreted.marker.LazyVarDirectiveAnnotation
import org.mule.weave.v2.interpreted.node.NameSlot
import org.mule.weave.v2.interpreted.node.ValueNode
import org.mule.weave.v2.interpreted.node.expressions.InlineTailRecFunctionCall
import org.mule.weave.v2.interpreted.node.expressions.TailRecFunctionBodyNode
import org.mule.weave.v2.interpreted.node.pattern.{ PatternMatcherNode => XPatternMatcherNode }
import org.mule.weave.v2.interpreted.node.structure
import org.mule.weave.v2.interpreted.node.structure.header.directives
import org.mule.weave.v2.interpreted.node.structure.header.directives.Directive
import org.mule.weave.v2.interpreted.node.structure.header.directives.ExecutionDirectiveNode
import org.mule.weave.v2.interpreted.node.structure.header.directives.FunctionDirective
import org.mule.weave.v2.interpreted.node.structure.header.directives.{ VarDirective => XVarDirective }
import org.mule.weave.v2.interpreted.node.structure.{ DoBlockNode => XDoBlockNode }
import org.mule.weave.v2.interpreted.node.{ DefaultNode => XDefaultNode }
import org.mule.weave.v2.interpreted.node.{ IfNode => XIfNode }
import org.mule.weave.v2.interpreted.node.{ UnlessNode => XUnlessNode }
import org.mule.weave.v2.interpreted.node.{ UsingNode => XUsingNode }
import org.mule.weave.v2.model.values.math.Number
import org.mule.weave.v2.parser
import org.mule.weave.v2.parser.annotation.TailRecFunctionAnnotation
import org.mule.weave.v2.parser.annotation.TailRecFunctionCallAnnotation
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.WeaveLocationCapable
import org.mule.weave.v2.parser.ast.annotation.AnnotationNode
import org.mule.weave.v2.parser.ast.conditional.DefaultNode
import org.mule.weave.v2.parser.ast.conditional.IfNode
import org.mule.weave.v2.parser.ast.conditional.UnlessNode
import org.mule.weave.v2.parser.ast.functions
import org.mule.weave.v2.parser.ast.functions.DoBlockNode
import org.mule.weave.v2.parser.ast.functions.FunctionCallNode
import org.mule.weave.v2.parser.ast.functions.FunctionParameter
import org.mule.weave.v2.parser.ast.functions.FunctionParameters
import org.mule.weave.v2.parser.ast.functions.OverloadedFunctionNode
import org.mule.weave.v2.parser.ast.functions.UsingNode
import org.mule.weave.v2.parser.ast.functions.{ FunctionNode => AstFunctionNode }
import org.mule.weave.v2.parser.ast.header.directives._
import org.mule.weave.v2.parser.ast.patterns.PatternExpressionNode
import org.mule.weave.v2.parser.ast.patterns.PatternMatcherNode
import org.mule.weave.v2.parser.ast.structure.UriNode
import org.mule.weave.v2.parser.ast.types.TypeParametersListNode
import org.mule.weave.v2.parser.ast.types.WeaveTypeNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.runtime.exception.CompilationExecutionException

trait EngineDirectiveTransformations extends AstTransformation with EnginePatternTransformations {
  def transformFormat(format: FormatExpression): directives.FormatExpression = {
    format match {
      case ContentType(mime) => new directives.ContentType(mime)
      case DataFormatId(id)  => new directives.DataFormatId(id)
    }
  }

  def transformContentType(mime: Option[ContentType]) = mime.map(m => new directives.ContentType(m.mime))

  def transformNamespaceDirective(prefix: NameIdentifier, uri: UriNode): directives.NamespaceDirective = {
    new directives.NamespaceDirective(transform(prefix), transform(uri), !transformingModule)
  }

  def transformVersionMajor(v: String): directives.VersionMajor = new directives.VersionMajor(Number(v))

  def transformAstDirectiveNode(options: Option[Seq[DirectiveOption]]): ExecutionDirectiveNode = new ExecutionDirectiveNode(transformOptionSeq(options))

  def transformDirectiveOption(name: DirectiveOptionName, value: AstNode): directives.DirectiveOption = {
    new directives.DirectiveOption(structure.StringNode(name.name), transform(value))
  }

  def transformVersionDirective(major: VersionMajor, minor: VersionMinor): directives.VersionDirective =
    new directives.VersionDirective(transform(major), transform(minor))

  def transformOutputDirective(id: Option[DataFormatId], mime: Option[ContentType], options: Option[Seq[DirectiveOption]]): directives.OutputDirective =
    new directives.OutputDirective(transformContentType(mime), transformOptionSeq(options), id.map(_.id))

  def transformVersionMinor(v: String): directives.VersionMinor = new directives.VersionMinor(Number(v))

  def transformVarDirective(variable: NameIdentifier, value: AstNode, codeAnnotations: Seq[AnnotationNode]): Directive = {
    if (value.isInstanceOf[AstFunctionNode]) {
      transformFunctionDirective(variable, value, codeAnnotations, value)
    } else {
      val lazyInit = variable.isAnnotatedWith(classOf[LazyVarDirectiveAnnotation])
      if (lazyInit) {
        new directives.LazyVarDirective(transform(variable), transform(value), needsMaterialization(variable))
      } else {
        new directives.VarDirective(transform(variable), transform(value), needsMaterialization(variable))
      }
    }
  }

  def transformFunctionNode(args: Seq[functions.FunctionParameter], body: AstNode, returnType: Option[WeaveTypeNode], typeParametersListNode: Option[TypeParametersListNode], name: Option[String] = None): ValueNode[_]

  def createFunctionNode(args: Seq[FunctionParameter], body: AstNode, returnType: Option[WeaveTypeNode], name: Option[String], bodyValue: ValueNode[_]): ValueNode[_]

  def transformOverloadedFunctionNode(ofn: OverloadedFunctionNode, name: Option[String] = None): ValueNode[_]

  private def isRecursiveCallExpression(functionIdentifier: NameIdentifier): Boolean = {
    functionIdentifier.isAnnotatedWith(classOf[TailRecFunctionAnnotation])
  }

  def transformFunctionDirective(functionName: NameIdentifier, node: AstNode, codeAnnotations: Seq[AnnotationNode], location: WeaveLocationCapable): FunctionDirective = {
    val functionValue: ValueNode[_] = node match {
      case parser.ast.functions.FunctionNode(params, body, returnType, typeParameters) => {
        //Detect tail recursion to do tail recursion elimination
        if (isRecursiveCallExpression(functionName)) {
          val recursiveCallBody = new TailRecFunctionBodyNode(transformTailRecursionElimination(functionName, body, params))
          recursiveCallBody._location = Some(location.location())
          createFunctionNode(params.paramList, body, returnType, Some(functionName.name), recursiveCallBody)
        } else {
          val functionNode = transformFunctionNode(params.paramList, body, returnType, typeParameters, Some(functionName.name))
          applyInterceptor(functionNode, functionName, codeAnnotations)
        }
      }
      case ofn: parser.ast.functions.OverloadedFunctionNode => {
        transformOverloadedFunctionNode(ofn, Some(functionName.name))
      }
    }

    new FunctionDirective(transform(functionName), functionValue)

  }

  def applyInterceptor(functionValue: ValueNode[_], functionName: NameIdentifier, codeAnnotations: Seq[AnnotationNode]): ValueNode[_]

  def transformReference(reference: NameIdentifier): Option[NameSlot]

  private def transformTailRecursionElimination(functionName: NameIdentifier, body: AstNode, params: FunctionParameters): ValueNode[Any] = {
    val result = body match {
      case ifnode: IfNode => {
        val ifExpression: ValueNode[Any] = if (isRecursiveCallExpression(functionName)) {
          transformRecursiveCall(functionName, ifnode.ifExpr, params)
        } else {
          transform(ifnode.ifExpr)
        }
        val elseExpression: ValueNode[Any] = if (isRecursiveCallExpression(functionName)) {
          transformRecursiveCall(functionName, ifnode.elseExpr, params)
        } else {
          transform(ifnode.elseExpr)
        }
        new XIfNode(ifExpression, transform(ifnode.condition), elseExpression)
      }
      case defaultNode: DefaultNode => {
        val rhs: ValueNode[Any] = if (isRecursiveCallExpression(functionName)) {
          transformRecursiveCall(functionName, defaultNode.rhs, params)
        } else {
          transform(defaultNode.rhs)
        }
        new XDefaultNode(transform(defaultNode.lhs), rhs)
      }
      case unless: UnlessNode => {
        val ifExpression: ValueNode[Any] = if (isRecursiveCallExpression(functionName)) {
          transformRecursiveCall(functionName, unless.ifExpr, params)
        } else {
          transform(unless.ifExpr)
        }

        val elseExpression: ValueNode[Any] = if (isRecursiveCallExpression(functionName)) {
          transformRecursiveCall(functionName, unless.elseExpr, params)
        } else {
          transform(unless.elseExpr)
        }
        new XUnlessNode(ifExpression, transform(unless.condition), elseExpression)
      }
      case doBlock: DoBlockNode => {
        new XDoBlockNode(transform(doBlock.header), transformTailRecursionElimination(functionName, doBlock.body, params))
      }
      case usingNode: UsingNode => {
        val varDirectives = usingNode.assignments.assignmentSeq.map((variable) => {
          val engineVariable: NameSlot = transform(variable.name)
          val engineAstNode: ValueNode[_] = transform(variable.value)
          new directives.VarDirective(engineVariable, engineAstNode, needsMaterialization(variable.name))
        })
        XUsingNode(varDirectives, transformTailRecursionElimination(functionName, usingNode.expr, params))
      }
      case patternMatcher: PatternMatcherNode => {
        val patterns = patternMatcher.patterns.patterns.map((pattern) => {
          transformPattern(functionName, pattern, params)
        })
        new XPatternMatcherNode(transform(patternMatcher.lhs), patterns.toArray)
      }
      case fcn: FunctionCallNode if fcn.isAnnotatedWith(classOf[TailRecFunctionCallAnnotation]) =>
        transformRecursiveCall(functionName, fcn, params)
      case _ => {
        transform(body)
      }
    }
    result._location = Some(body.location())
    result
  }

  private def transformPattern(functionName: NameIdentifier, body: PatternExpressionNode, params: FunctionParameters) = {
    body match {
      case parser.ast.patterns.ExpressionPatternNode(pattern, name, function) => {
        transformExpressionPatternNode(pattern, name, function, transformTailRecursionElimination(functionName, function, params))
      }
      case parser.ast.patterns.LiteralPatternNode(pattern, name, function) => {
        transformLiteralPatternNode(pattern, name, function, transformTailRecursionElimination(functionName, function, params))
      }
      case parser.ast.patterns.RegexPatternNode(pattern, name, function) => {
        transformRegexPatternNode(pattern, name, function, transformTailRecursionElimination(functionName, function, params))
      }
      case parser.ast.patterns.DefaultPatternNode(value, name) => {
        transformDefaultPatternNode(name, value, transformTailRecursionElimination(functionName, value, params))
      }
      case tpn: parser.ast.patterns.TypePatternNode =>
        transformTypePatternNode(tpn, transformTailRecursionElimination(functionName, tpn.onMatch, params))
      case parser.ast.patterns.EmptyArrayPatternNode(function) => {
        transformEmptyArrayNode(function, transformTailRecursionElimination(functionName, function, params))
      }
      case parser.ast.patterns.DeconstructArrayPatternNode(head, tail, function) => {
        transformDeconstructArrayNode(head, tail, function, transformTailRecursionElimination(functionName, function, params))
      }
      case parser.ast.patterns.EmptyObjectPatternNode(function) => {
        transformEmptyObjectNode(function, transformTailRecursionElimination(functionName, function, params))
      }
      case parser.ast.patterns.DeconstructObjectPatternNode(headKey, headValue, tail, function) => {
        transformDeconstructObjectNode(headKey, headValue, tail, function, transformTailRecursionElimination(functionName, function, params))
      }
    }
  }

  private def transformRecursiveCall(variable: NameIdentifier, node: AstNode, params: FunctionParameters): ValueNode[Any] = {
    node match {
      case functionCallNode: FunctionCallNode if functionCallNode.isAnnotatedWith(classOf[TailRecFunctionCallAnnotation]) => {
        var args = functionCallNode.args.args
        val paramsSeq = params.paramList
        //This way we can simply support default values by injecting them
        if (args.size < paramsSeq.size) {
          if (params.paramList.last.defaultValue.isDefined) {
            args = args ++ params.paramList.takeRight(paramsSeq.size - args.size).flatMap(_.defaultValue)
          } else {
            args = params.paramList.take(paramsSeq.size - args.size).flatMap(_.defaultValue) ++ args
          }
        }

        if (args.size != paramsSeq.size) {
          throw new CompilationExecutionException(functionCallNode.location(), s"Not enough arguments: `${args.size}`, function: `${variable.name}` requires: `${paramsSeq.size}`.")
        }

        val statements: Seq[XVarDirective] = paramsSeq.zipWithIndex.map((param) => {
          val paramName = param._1.variable
          new XVarDirective(transform(paramName), transform(args(param._2)), needsMaterialization(paramName))
        })

        new InlineTailRecFunctionCall(statements.toArray)
      }
      case _ => transformTailRecursionElimination(variable, node, params)
    }
  }
}
