package org.mule.weave.v2.ts

import org.mule.weave.v2.scope.Reference

trait VariableConstraint

case class IsConstraint(typ: WeaveType) extends VariableConstraint

case class NotConstraint(typ: WeaveType) extends VariableConstraint

/**
  * A conditional constraint, that will be applied only if the condition
  * returns true when applied to the referenced value and its type.
  *
  * @param const The constraint to apply
  * @param cond The condition to check
  */
case class ConditionalConstraint(const: VariableConstraint, cond: (Reference, WeaveType, WeaveTypeResolutionContext) => Boolean) extends VariableConstraint

class VariableConstraints private (protected val posConstraints: VariableConstraints.ConstraintsMap, protected val negConstraints: VariableConstraints.ConstraintsMap) {
  def enhancePositive(ref: Reference, orig: WeaveType, ctx: WeaveTypeResolutionContext): WeaveType = {
    enhanceWithConstraints(ref, orig, posConstraints, ctx)
  }

  def enhanceNegative(ref: Reference, orig: WeaveType, ctx: WeaveTypeResolutionContext): WeaveType = {
    enhanceWithConstraints(ref, orig, negConstraints, ctx)
  }

  private def enhanceWithConstraints(ref: Reference, orig: WeaveType, constraints: VariableConstraints.ConstraintsMap, ctx: WeaveTypeResolutionContext): WeaveType = {
    constraints match {
      case Some(value) => enhanceWithConstraints(ref, orig, value, ctx)
      case None        => orig
    }
  }

  private def enhanceWithConstraints(ref: Reference, orig: WeaveType, constraints: Map[Reference, Seq[VariableConstraint]], ctx: WeaveTypeResolutionContext): WeaveType = {
    constraints get ref match {
      case Some(constr) =>
        // Validate conditional constraints
        val validConstraints = constr.flatMap(validateConstraint(ref, orig, _, ctx))

        // Extract Is and Not contraints
        val simpleConstraints = validConstraints.collect({
          case IsConstraint(typ) => typ
        })
        val subConstraints = validConstraints.collect({
          case NotConstraint(typ) => typ
        })

        // Intersect Is constraints
        val inter = TypeHelper.resolveAlgebraicIntersection(orig +: simpleConstraints)

        // Subtract Not constraints
        val subtracted = subConstraints.foldLeft(inter)((acc, sub) => TypeHelper.subtractType(acc, sub, ctx))

        val result = TypeHelper.simplifyIntersections(TypeHelper.simplifyUnions(subtracted))
        WeaveTypeCloneHelper.copyAdditionalTypeInformation(orig, result)

        result
      case None => orig
    }
  }

  private def validateConstraint(ref: Reference, orig: WeaveType, constraint: VariableConstraint, ctx: WeaveTypeResolutionContext): Option[VariableConstraint] = {
    constraint match {
      case ConditionalConstraint(const, cond) =>
        if (cond(ref, orig, ctx)) {
          validateConstraint(ref, orig, const, ctx)
        } else {
          None
        }
      case x: IsConstraint  => Some(x)
      case x: NotConstraint => Some(x)
    }

  }

  def positiveConstrains(): Map[Reference, Seq[VariableConstraint]] = {
    posConstraints.getOrElse(Map())
  }

  def negativeConstrains(): Map[Reference, Seq[VariableConstraint]] = {
    negConstraints.getOrElse(Map())
  }

  /**
    * Combine two VariableConstrains to get the conjunction of their constraints
    *
    * @param other
    * @return
    */
  def conjunction(other: VariableConstraints): VariableConstraints = {
    val positive = VariableConstraints.mergeConstraints(posConstraints, other.posConstraints)
    val negative = VariableConstraints.unifyConstraint(negConstraints, other.negConstraints)
    VariableConstraints(positive, negative)
  }

  /**
    * Combine two VariableConstrains to get the disjunction of their constraints
    *
    * @param other
    * @return
    */
  def disjunction(other: VariableConstraints): VariableConstraints = {
    val positive = VariableConstraints.unifyConstraint(posConstraints, other.posConstraints)
    val negative = VariableConstraints.mergeConstraints(negConstraints, other.negConstraints)
    VariableConstraints(positive, negative)
  }

  /**
    * Negate the constraints of this VariableConstraints
    *
    * @return
    */
  def negate(): VariableConstraints = {
    VariableConstraints(negConstraints, posConstraints)
  }

  /**
    * Combine all the constraints, just getting the union of all constraints
    *
    * @param other
    * @return
    */
  def combine(other: VariableConstraints): VariableConstraints = {
    val positive = VariableConstraints.mergeConstraints(posConstraints, other.posConstraints)
    val negative = VariableConstraints.mergeConstraints(negConstraints, other.negConstraints)
    VariableConstraints(positive, negative)
  }
}

object VariableConstraints {
  type ConstraintsMap = Option[Map[Reference, Seq[VariableConstraint]]]

  def emptyConstraints(): VariableConstraints = new VariableConstraints(None, None)

  def onlyPositive(pos: Map[Reference, Seq[VariableConstraint]]): VariableConstraints = apply(Some(pos), None)

  def onlyNegative(neg: Map[Reference, Seq[VariableConstraint]]): VariableConstraints = apply(None, Some(neg))

  def apply(pos: ConstraintsMap = None, neg: ConstraintsMap = None): VariableConstraints = new VariableConstraints(pos, neg)

  private def mergeConstraints(x: ConstraintsMap, y: ConstraintsMap): ConstraintsMap = {
    (x, y) match {
      case (Some(vx), Some(vy)) =>
        val combined = (vx.toSeq ++ vy.toSeq).groupBy(_._1)

        Some(combined.mapValues(_.flatMap(_._2)))

      case (Some(vx), None) => Some(vx)
      case (None, Some(vy)) => Some(vy)
      case (None, None)     => None
    }
  }

  /**
    * Get the union of two ConstraintsMap
    *
    * Only join them if they only apply on the same Reference,
    * and if they are all IsConstraint or all NotConstraint
    *
    * @param x
    * @param y
    * @return
    */
  private def unifyConstraint(x: ConstraintsMap, y: ConstraintsMap): ConstraintsMap = {
    (x, y) match {
      case (Some(vx), Some(vy)) if vx.keys.size == 1 && vx.keys == vy.keys =>
        val ref = vx.keys.head
        val xconstr = vx.values.head
        val yconstr = vy.values.head

        val allIsConstr = (xconstr ++ yconstr).forall({
          case IsConstraint(_) => true
          case _               => false
        })

        val allNotConstr = (xconstr ++ yconstr).forall({
          case NotConstraint(_) => true
          case _                => false
        })

        if (allIsConstr) {
          val xtypes = IntersectionType(xconstr.collect({
            case IsConstraint(typ) => typ
          }))
          val ytypes = IntersectionType(yconstr.collect({
            case IsConstraint(typ) => typ
          }))

          Some(Map(ref -> Seq(IsConstraint(UnionType(Seq(xtypes, ytypes))))))
        } else if (allNotConstr) {
          val xtypes = UnionType(xconstr.collect({
            case NotConstraint(typ) => typ
          }))
          val ytypes = UnionType(yconstr.collect({
            case NotConstraint(typ) => typ
          }))

          Some(Map(ref -> Seq(NotConstraint(IntersectionType(Seq(xtypes, ytypes))))))
        } else {
          None
        }

      case _ => None
    }
  }

}
