package org.mule.weave.v2.parser.phase

import org.mule.weave.v2.grammar.literals.TypeLiteral
import org.mule.weave.v2.parser.ArrayFunctionInjectionNotPossible
import org.mule.weave.v2.parser.FunctionInjectionNotPossible
import org.mule.weave.v2.parser.annotation.InjectedNodeAnnotation
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.AstNodeHelper
import org.mule.weave.v2.parser.ast.functions.DoBlockNode
import org.mule.weave.v2.parser.ast.functions.FunctionCallNode
import org.mule.weave.v2.parser.ast.functions.FunctionNode
import org.mule.weave.v2.parser.ast.functions.FunctionParameter
import org.mule.weave.v2.parser.ast.functions.FunctionParameters
import org.mule.weave.v2.parser.ast.functions.OverloadedFunctionNode
import org.mule.weave.v2.parser.ast.functions.UsingNode
import org.mule.weave.v2.parser.ast.header.directives.FunctionDirectiveNode
import org.mule.weave.v2.parser.ast.header.directives.TypeDirective
import org.mule.weave.v2.parser.ast.structure.ArrayNode
import org.mule.weave.v2.parser.ast.types.FunctionTypeNode
import org.mule.weave.v2.parser.ast.types.TypeReferenceNode
import org.mule.weave.v2.parser.ast.types.WeaveTypeNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode
import org.mule.weave.v2.parser.location.WeaveLocation
import org.mule.weave.v2.scope.Reference
import org.mule.weave.v2.scope.ScopesNavigator

class ImplicitFunctionTransformer[R <: AstNode, T <: AstNodeResultAware[R] with DummyScopeNavigatorAware]() extends CompilationPhase[T, T] {

  override def doCall(source: T, ctx: ParsingContext): PhaseResult[T] = {
    val navigator: ScopesNavigator = source.dummyScopesNavigator
    val functionCallNodes: Seq[FunctionCallNode] = AstNodeHelper.collectChildren(source.astNode, {
      case node: FunctionCallNode => AstNodeHelper.isInfixFunctionCall(node) && node.args.args.size > 1
      case _                      => false
    }).asInstanceOf[Seq[FunctionCallNode]]

    functionCallNodes.foreach((functionCallNode) => {
      val reference: Option[Reference] = functionCallNode.function match {
        case vr: VariableReferenceNode => navigator.resolveVariable(vr.variable)
        case _                         => None
      }
      reference match {
        case Some(x) => {
          val parents = x.scope.astNavigator().parentOf(x.referencedNode)
          parents
            .foreach({
              case fdn: FunctionDirectiveNode => {
                injectFunctionOnInfixCallTo(navigator, functionCallNode, fdn, ctx)
              }
              case _ =>
            })
        }
        case None =>
      }
    })
    SuccessResult(source, ctx)
  }

  def getArrayNodeExpression(astNode: AstNode): Option[ArrayNode] = {
    astNode match {
      case arrayNode: ArrayNode   => Some(arrayNode)
      case usingNode: UsingNode   => getArrayNodeExpression(usingNode.expr)
      case blockNode: DoBlockNode => getArrayNodeExpression(blockNode.body)
      case _                      => None
    }
  }

  private def injectFunctionOnInfixCallTo(navigator: ScopesNavigator, functionCallNode: FunctionCallNode, functionDirectiveNode: FunctionDirectiveNode, ctx: ParsingContext): Unit = {
    val secondArgument: AstNode = functionCallNode.args.args(1)
    functionDirectiveNode.literal match {
      case fn: FunctionNode if fn.params.paramList.size > 1 => {
        val secondParamType = fn.params.paramList(1).wtype
        secondArgument match {
          case array: ArrayNode if isArrayOfFunctionType(navigator, secondParamType) =>
            injectArrayFunctions(collectArrayFunctionType(navigator, secondParamType).get, array)
          case _ if isFunctionType(navigator, secondParamType) =>
            injectFunction(functionCallNode, collectFunctionType(navigator, secondParamType).get, secondArgument)
          case _ =>
        }

      }
      case ofn: OverloadedFunctionNode => {
        injectInOverloadedFunction(navigator, functionCallNode, functionDirectiveNode, ctx, secondArgument, ofn)
      }
      case _ =>
    }
  }

  private def injectInOverloadedFunction(navigator: ScopesNavigator, functionCallNode: FunctionCallNode, functionDirectiveNode: FunctionDirectiveNode, ctx: ParsingContext, secondArgument: AstNode, ofn: OverloadedFunctionNode) = {
    //We keep only with the ones that are binary as this is an infix call
    val binaryFunctions = ofn.functions.filter((function) => function.params.paramList.size > 1)
    if (binaryFunctions.nonEmpty) {

      val isArrayOfFunctionTypeKind = binaryFunctions.exists((functionNode) => {
        isArrayOfFunctionType(navigator, functionNode.params.paramList(1).wtype)
      })

      val isSimpleFunctionTypeKind = binaryFunctions.exists((functionNode) => {
        isFunctionType(navigator, functionNode.params.paramList(1).wtype)
      })

      if (isArrayOfFunctionTypeKind || isSimpleFunctionTypeKind) {
        secondArgument match {
          case astNode if isArrayNodeExpression(astNode) && isArrayOfFunctionTypeKind => {
            val arrayNode = getArrayNodeExpression(astNode).get
            val functionArrayType = binaryFunctions.filter((function) => {
              val paramType = function.params.paramList(1).wtype
              paramType match {
                case Some(tr: TypeReferenceNode) => {
                  tr.variable.name == TypeLiteral.ARRAY_TYPE_NAME && !isArrayOfFunctionType(navigator, Some(tr))
                }
                case _ => false
              }
            })
            if (functionArrayType.isEmpty) {
              val functionTypes = binaryFunctions.flatMap((function) => collectArrayFunctionType(navigator, function.params.paramList(1).wtype))
              injectArrayFunctions(functionTypes.head, arrayNode)
            } else {
              ctx.messageCollector.warning(ArrayFunctionInjectionNotPossible(functionDirectiveNode.variable.name), secondArgument.location())
            }
          }
          case _ => {
            val functionWithoutFunctionTypeParam = binaryFunctions.filter((function) => {
              val paramType = function.params.paramList(1).wtype
              !(isFunctionType(navigator, paramType) || isArrayOfFunctionType(navigator, paramType))
            })

            if (functionWithoutFunctionTypeParam.isEmpty) {
              if (isSimpleFunctionTypeKind) {
                //Check cardinality of the function type
                val functionParamTypes = binaryFunctions.flatMap((function) => collectFunctionType(navigator, function.params.paramList(1).wtype))
                val functionWithCardinality = functionParamTypes.map((function) => {
                  (function.args.size, function)
                })
                val sortByCardinality = functionWithCardinality.sortBy(_._1)
                val functionType = sortByCardinality.head._2
                val secondArgument: AstNode = functionCallNode.args.args(1)
                injectFunction(functionCallNode, functionType, secondArgument)
              }
            } else {
              ctx.messageCollector.warning(FunctionInjectionNotPossible(functionDirectiveNode.variable.name), secondArgument.location())
            }
          }
        }
      }
    }
  }

  private def isArrayNodeExpression(astNode: AstNode) = {
    getArrayNodeExpression(astNode).isDefined
  }

  private def isFunctionType(navigator: ScopesNavigator, wtype: Option[WeaveTypeNode]): Boolean = {
    wtype match {
      case Some(_: FunctionTypeNode) => true
      case Some(trn: TypeReferenceNode) => {
        navigator.resolveVariable(trn.variable) match {
          case Some(reference) => {
            val astNavigator = reference.scope.rootScope().astNavigator()
            astNavigator.parentOf(reference.referencedNode) match {
              case Some(tdn: TypeDirective) => isFunctionType(navigator, Some(tdn.typeExpression))
              case _                        => false
            }
          }
          case None => false
        }
      }
      case _ => false
    }
  }

  private def collectFunctionType(navigator: ScopesNavigator, wtype: Option[WeaveTypeNode]): Option[FunctionTypeNode] = {
    wtype match {
      case Some(ft: FunctionTypeNode) => Some(ft)
      case Some(trn: TypeReferenceNode) => {
        navigator.resolveVariable(trn.variable) match {
          case Some(reference) => {
            val astNavigator = reference.scope.rootScope().astNavigator()
            astNavigator.parentOf(reference.referencedNode) match {
              case Some(tdn: TypeDirective) => collectFunctionType(navigator, Some(tdn.typeExpression))
              case _                        => None
            }
          }
          case None => None
        }
      }
      case _ => None
    }
  }

  private def collectArrayFunctionType(navigator: ScopesNavigator, wtype: Option[WeaveTypeNode]): Option[FunctionTypeNode] = {
    wtype match {
      case Some(trn: TypeReferenceNode) => {
        if (isArrayOfFunction(navigator, trn)) {
          collectFunctionType(navigator, trn.typeArguments.map(_.head))
        } else {
          navigator.resolveVariable(trn.variable) match {
            case Some(reference) => {
              val astNavigator = reference.scope.rootScope().astNavigator()
              astNavigator.parentOf(reference.referencedNode) match {
                case Some(tdn: TypeDirective) => collectArrayFunctionType(navigator, Some(tdn.typeExpression))
                case _                        => None
              }
            }
            case None => None
          }
        }
      }
      case _ => None
    }
  }

  private def isArrayOfFunctionType(navigator: ScopesNavigator, wtype: Option[WeaveTypeNode]): Boolean = {
    wtype match {
      case Some(trn: TypeReferenceNode) => {
        if (isArrayOfFunction(navigator, trn)) {
          true
        } else {
          navigator.resolveVariable(trn.variable) match {
            case Some(reference) => {
              val astNavigator = reference.scope.rootScope().astNavigator()
              astNavigator.parentOf(reference.referencedNode) match {
                case Some(tdn: TypeDirective) => isArrayOfFunctionType(navigator, Some(tdn.typeExpression))
                case _                        => false
              }
            }
            case None => false
          }
        }
      }
      case _ => false
    }
  }

  private def isArrayOfFunction(navigator: ScopesNavigator, tr: TypeReferenceNode) = {
    val isArray = tr.variable.name == TypeLiteral.ARRAY_TYPE_NAME
    val isFunctionArray = isArray && tr.typeArguments.nonEmpty && isFunctionType(navigator, tr.typeArguments.map(_.head))
    isFunctionArray
  }

  private def injectFunction(functionCallNode: FunctionCallNode, functionType: FunctionTypeNode, argument: AstNode): Unit = {
    argument match {
      case fn: FunctionNode => {
        val paramList = fn.params.paramList
        if (paramList.size < functionType.args.size) {
          //We inject missing function parameters
          val parameters = functionType.args.splitAt(fn.params.paramList.size)._2.zipWithIndex.map((parameterWithIndex) => {
            val nameIdentifier = NameIdentifier("$" * (parameterWithIndex._2 + 1 + fn.params.paramList.size))
            val location: WeaveLocation = fn.location()
            nameIdentifier._location = Some(WeaveLocation(location.startPosition, location.startPosition, location.resourceName))
            val parameter: FunctionParameter = FunctionParameter(nameIdentifier)
            parameter.annotate(InjectedNodeAnnotation())
          })
          fn.params.paramList = paramList ++ parameters
        }
      }
      case _ => {
        val parameters: Seq[FunctionParameter] = functionType.args.zipWithIndex.map((parameterWithIndex) => {
          val nameIdentifier = NameIdentifier("$" * (parameterWithIndex._2 + 1))
          val location: WeaveLocation = functionCallNode.location()
          nameIdentifier._location = Some(WeaveLocation(location.startPosition, location.startPosition, location.resourceName))
          val parameter: FunctionParameter = FunctionParameter(nameIdentifier)
          parameter.annotate(InjectedNodeAnnotation())
        })
        val injectedFunctionNode: FunctionNode = FunctionNode(FunctionParameters(parameters), argument)
        injectedFunctionNode._location = Some(argument.location())
        injectedFunctionNode.annotate(InjectedNodeAnnotation())
        functionCallNode.args.args = functionCallNode.args.args.updated(1, injectedFunctionNode)
      }
    }
  }

  private def injectArrayFunctions(functionType: FunctionTypeNode, argument: ArrayNode): Unit = {
    argument.elements = argument.elements.map {
      case (fn: FunctionNode) => {
        val paramList = fn.params.paramList
        if (paramList.size < functionType.args.size) {
          //We inject missing function parameters
          val parameters = functionType.args.splitAt(fn.params.paramList.size)._2.zipWithIndex.map((parameterWithIndex) => {
            val nameIdentifier = NameIdentifier("$" * (parameterWithIndex._2 + 1) + fn.params.paramList.size)
            val location = fn.location()
            nameIdentifier._location = Some(WeaveLocation(location.startPosition, location.startPosition, location.resourceName))
            FunctionParameter(nameIdentifier)
          })
          fn.params.paramList = paramList ++ parameters
        }
        fn
      }
      case element =>
        val parameters: Seq[FunctionParameter] = functionType.args.zipWithIndex.map((parameterWithIndex) => {
          val nameIdentifier = NameIdentifier("$" * (parameterWithIndex._2 + 1))
          val location = element.location()
          nameIdentifier._location = Some(WeaveLocation(location.startPosition, location.startPosition, location.resourceName))
          FunctionParameter(nameIdentifier)
        })
        val injectedFunctionNode: FunctionNode = FunctionNode(FunctionParameters(parameters), element)
        injectedFunctionNode._location = Some(element.location())
        injectedFunctionNode.annotate(InjectedNodeAnnotation())
        injectedFunctionNode
    }
  }
}

