package org.mule.weave.v2.ts.resolvers

import org.mule.weave.v2.parser.ast.structure.KeyValuePairNode
import org.mule.weave.v2.ts.Edge
import org.mule.weave.v2.ts.EdgeLabels
import org.mule.weave.v2.ts.IntersectionType
import org.mule.weave.v2.ts.KeyType
import org.mule.weave.v2.ts.KeyValuePairType
import org.mule.weave.v2.ts.NameType
import org.mule.weave.v2.ts.ObjectType
import org.mule.weave.v2.ts.ReferenceType
import org.mule.weave.v2.ts.TypeHelper
import org.mule.weave.v2.ts.TypeNode
import org.mule.weave.v2.ts.UnionType
import org.mule.weave.v2.ts.WeaveType
import org.mule.weave.v2.ts.WeaveTypeResolutionContext
import org.mule.weave.v2.ts.WeaveTypeResolver

object KeyValuePairTypeResolver extends WeaveTypeResolver {

  override def resolveReturnType(node: TypeNode, ctx: WeaveTypeResolutionContext): Option[WeaveType] = {
    val key: KeyType = calculateKeyType(node)
    val valueType = valueEdge(node).incomingType()

    //We re type so that it is always a name type
    val keyValuePairType = KeyValuePairType(key, valueType, node.astNode.asInstanceOf[KeyValuePairNode].cond.isDefined)
    node.astNode.weaveDoc.foreach((doc) => {
      keyValuePairType.withDocumentation(doc.literalValue, doc.location())
    })
    Some(keyValuePairType)
  }

  private def valueEdge(node: TypeNode) = {
    node.incomingEdges(EdgeLabels.VALUE).head
  }

  private def calculateKeyType(node: TypeNode): KeyType = {
    val keyType = keyEdge(node).incomingType()
    val key = keyType match {
      case kt: KeyType => kt
      case _           => KeyType(NameType())
    }
    key
  }

  override def resolveExpectedType(node: TypeNode, incomingExpectedType: Option[WeaveType], ctx: WeaveTypeResolutionContext): Seq[(Edge, WeaveType)] = {
    val expectedValueType: Option[KeyValuePairType] = incomingExpectedType.collect({
      case keyValuePairType: KeyValuePairType => keyValuePairType
    })
    expectedValueType
      .map((wtype) => {
        Seq((keyEdge(node), wtype.key), (valueEdge(node), wtype.value))
      })
      .getOrElse(Seq())
  }

  def selectKeyValuePair(node: TypeNode, incomingExpectedType: Option[WeaveType]): Option[KeyValuePairType] = {
    val expectedValueType = incomingExpectedType match {
      case Some(ot: ObjectType) if keyEdge(node).incomingTypeDefined() => {
        val keyType = calculateKeyType(node)
        keyType.name match {
          case NameType(Some(name)) => {
            ot.properties.find((prop) => {
              prop.key match {
                case KeyType(NameType(Some(expectedName)), _) => {
                  name.selectedBy(expectedName)
                }
                case _ => false
              }
            })
          }
          case _ => None
        }
      }
      case Some(rt: ReferenceType) => selectKeyValuePair(node, Some(rt.resolveType()))
      case Some(it: IntersectionType) => {
        val weaveType = TypeHelper.resolveIntersection(it.of)
        weaveType match {
          case _: IntersectionType => None
          case newType             => selectKeyValuePair(node, Some(newType))
        }
      }

      case Some(it: UnionType) => {
        val weaveType = TypeHelper.resolveUnion(it)
        weaveType match {
          case ut: UnionType => {
            ut.of.toStream
              .flatMap((wt) => {
                selectKeyValuePair(node, Some(wt))
              })
              .headOption
          }
          case newType => {
            selectKeyValuePair(node, Some(newType))
          }
        }
      }
      case _ => None
    }
    expectedValueType
  }

  private def keyEdge(node: TypeNode) = {
    node.incomingEdges(EdgeLabels.NAME).head
  }
}
