package org.mule.weave.v2.ts.resolvers

import org.mule.weave.v2.ts.Edge
import org.mule.weave.v2.ts.EdgeLabels
import org.mule.weave.v2.ts.FunctionType
import org.mule.weave.v2.ts.NothingType
import org.mule.weave.v2.ts.ReferenceType
import org.mule.weave.v2.ts.TypeHelper
import org.mule.weave.v2.ts.TypeNode
import org.mule.weave.v2.ts.TypeType
import org.mule.weave.v2.ts.WeaveType
import org.mule.weave.v2.ts.WeaveTypeResolutionContext
import org.mule.weave.v2.ts.WeaveTypeResolver

object TypePatternTypeResolver extends WeaveTypeResolver {

  override def supportsPartialResolution() = true

  override def resolveReturnType(node: TypeNode, ctx: WeaveTypeResolutionContext): Option[WeaveType] = {
    val caseExpression: Edge = node.incomingEdges(EdgeLabels.CASE_EXPRESSION).head
    val patternExpression: Edge = node.incomingEdges(EdgeLabels.PATTERN_EXPRESSION).head
    val matchExpression: Edge = node.incomingEdges(EdgeLabels.MATCH_EXPRESSION).head
    if (caseExpression.mayBeIncomingType().isDefined && matchExpression.mayBeIncomingType().isDefined) {
      val functionType: FunctionType = matchExpression.incomingType().asInstanceOf[FunctionType]
      val caseType: WeaveType = caseExpression.incomingType()
      resolveReturnType(node, ctx, patternExpression, functionType, caseType)
    } else {
      None
    }
  }

  private def resolveReturnType(node: TypeNode, ctx: WeaveTypeResolutionContext, patternExpression: Edge, functionType: FunctionType, caseType: WeaveType): Option[WeaveType] = {
    caseType match {
      case TypeType(patternToMatchType) => {
        val incomingExpressionTypePresent = patternExpression.mayBeIncomingType().nonEmpty
        val activatedBranch =
          if (incomingExpressionTypePresent) {
            val matchingExprType: WeaveType = patternExpression.incomingType()
            !TypeHelper.areDisjointTypes(matchingExprType, patternToMatchType)
          } else {
            true
          }

        if (activatedBranch) {
          //We need to do the algebraic intersection
          val matchingType: WeaveType =
            if (incomingExpressionTypePresent) {
              val matchingExprType = patternExpression.incomingType()
              // Remove type parameters
              val matchingExprTypeWithoutTypeParameters = TypeHelper.removeTypeParameters(matchingExprType)
              TypeHelper.resolveAlgebraicIntersection(Seq(matchingExprTypeWithoutTypeParameters, patternToMatchType))
            } else {
              patternToMatchType
            }
          val maybeType: Option[WeaveType] = FunctionCallNodeResolver.resolveReturnType(functionType, Seq(matchingType), Seq(), node, ctx)
          maybeType
        } else {
          //Even if they are disjoint we should evaluate the branch to validate that the right hand side is still valid
          FunctionCallNodeResolver.resolveReturnType(functionType, Seq(patternToMatchType), Seq(), node, ctx)
          Some(NothingType())
        }
      }
      case rt: ReferenceType => resolveReturnType(node, ctx, patternExpression, functionType, rt.resolveType())
      case _                 => None
    }
  }

}
