package org.mule.weave.v2.runtime.core.functions

import org.mule.weave.v2.core.exception.NotEnoughArgumentsException
import org.mule.weave.v2.core.exception.TooManyArgumentsException
import org.mule.weave.v2.core.exception.UnexpectedFunctionCallTypesException
import org.mule.weave.v2.interpreted.node.FunctionDispatchingHelper
import org.mule.weave.v2.interpreted.node.FunctionDispatchingHelper.expandArguments
import org.mule.weave.v2.interpreted.node.FunctionDispatchingHelper.findMatchingFunctionWithCoercion
import org.mule.weave.v2.interpreted.node.FunctionDispatchingHelper.sortByParameterTypeWeight
import org.mule.weave.v2.interpreted.node.structure.FunctionFilter
import org.mule.weave.v2.model.EvaluationContext
import org.mule.weave.v2.model.types.FunctionType
import org.mule.weave.v2.model.types.Type
import org.mule.weave.v2.model.types.UnionType
import org.mule.weave.v2.model.values.FunctionParameter
import org.mule.weave.v2.model.values.FunctionValue
import org.mule.weave.v2.model.values.Value
import org.mule.weave.v2.parser.ast.WeaveLocationCapable
import org.mule.weave.v2.parser.location._

import scala.collection.mutable

class OverloadedFunctionValue(val overloadedFunctions: Array[_ <: FunctionValue], val functionFilters: Array[FunctionFilter], override val name: Option[String] = None, valueLocation: WeaveLocation, cacheable: Boolean) extends FunctionValue with WeaveLocationCapable {

  override def isOverloaded: Boolean = true

  override def dispatchCanBeCached: Boolean = cacheable

  /**
    * Returns the max amount of parameters that this function can be invoked
    */
  override lazy val maxParams: Int = overloadedFunctions.map(_.maxParams).max

  override lazy val minParams: Int = overloadedFunctions.map(_.minParams).min

  override val paramsTypesRequiresMaterialize: Boolean = overloadedFunctions.exists(_.paramsTypesRequiresMaterialize)

  override def valueType(implicit ctx: EvaluationContext): Type = {
    if (_type == null) {
      _type = new FunctionType(Some(functionParamTypes), returnType, Some(overloads.map(_.valueType)))
    }
    _type
  }

  override def overloads(implicit ctx: EvaluationContext): Array[_ <: FunctionValue] = {
    if (functionFilters.length > 0) {
      val result = new mutable.ArrayBuffer[FunctionValue]()
      var i = 0
      while (i < overloadedFunctions.length) {
        val functionValue = overloadedFunctions(i)
        if (functionFilters(i).accept(functionValue)) {
          result.+=(functionValue)
        }
        i = i + 1
      }
      result.toArray
    } else {
      overloadedFunctions
    }
  }

  override def call(args: Array[Value[_]])(implicit ctx: EvaluationContext): Value[_] = {

    //Validate arity first
    if (args.length > maxParams) {
      val parameterValues: Seq[FunctionParameter] = parameters
      throw new TooManyArgumentsException(location(), args.length, parameterValues)
    } else if (args.length < minParams) {
      val parameterValues: Seq[FunctionParameter] = parameters
      throw new NotEnoughArgumentsException(location(), args.length, parameterValues)
    }

    val maybeMatArguments = if (paramsTypesRequiresMaterialize) {
      FunctionDispatchingHelper.materializeOverloadedFunctionArgs(overloads, args)
    } else {
      args
    }
    val functionToCall = FunctionDispatchingHelper.findMatchingFunction(maybeMatArguments, overloads)
    functionToCall match {
      case Some((_, dispatchFunction)) =>
        val expandedArguments = expandArguments(maybeMatArguments, dispatchFunction)
        dispatchFunction.call(expandedArguments)
      case None => {
        val materializedValues: Array[Value[Any]] = maybeMatArguments.map(_.materialize)
        val argTypes: Array[Type] = materializedValues.map(_.valueType)
        val sortedOperators: Array[FunctionValue] = sortByParameterTypeWeight(overloads, argTypes)
        val functionToCallWithCoercion = findMatchingFunctionWithCoercion(materializedValues, sortedOperators, this)
        functionToCallWithCoercion match {
          case Some((dispatchFunction, coercedArguments, _)) => {
            val functionValue = sortedOperators.apply(dispatchFunction)
            functionValue.call(expandArguments(coercedArguments, functionValue))
          }
          case None =>
            throw new UnexpectedFunctionCallTypesException(location(), this.label, maybeMatArguments, sortedOperators.map(_.parameterTypes))
        }
      }
    }
  }

  override lazy val parameters: Array[FunctionParameter] = {
    val paramsSize = overloadedFunctions.map(_.parameters.length).max
    val unionTypes = for (i <- 0 until paramsSize) yield {
      val ithParamTypes: Seq[Type] = overloadedFunctions.flatMap((fun) => if (fun.parameters.length > i) Some(fun.parameters.apply(i).wtype) else None)
      UnionType(ithParamTypes)
    }
    val params = unionTypes.zipWithIndex.map {
      case (unionType, i) => FunctionParameter(s"arg$i", unionType)
    }
    params.toArray
  }

  override def location(): WeaveLocation = valueLocation

}

object OverloadedFunctionValue {

  def createValue(functions: Array[_ <: FunctionValue], functionFilters: Array[FunctionFilter], name: Option[String] = None, location: WeaveLocation, cacheabe: Boolean): FunctionValue = {
    new OverloadedFunctionValue(functions, functionFilters, name, location, cacheabe)
  }
}
