package org.mule.weave.v2.interpreted.node.structure

import org.mule.weave.v2.core.RuntimeConfigProperties.CPU_LIMIT_CHECK_FREQUENCY
import org.mule.weave.v2.interpreted.ExecutionContext
import org.mule.weave.v2.interpreted.Frame
import org.mule.weave.v2.interpreted.node.ValueNode
import org.mule.weave.v2.parser.location.LocationCapable
import org.mule.weave.v2.parser.location.UnknownLocation
import org.mule.weave.v2.model
import org.mule.weave.v2.model.service.FrequencyBasedCpuLimitService
import org.mule.weave.v2.model.structure.KeyValuePair
import org.mule.weave.v2.model.structure.ObjectSeq
import org.mule.weave.v2.model.types.ArrayType
import org.mule.weave.v2.model.types.KeyValuePairType
import org.mule.weave.v2.model.types.NullType
import org.mule.weave.v2.model.types.ObjectType
import org.mule.weave.v2.model.values.ObjectValue
import org.mule.weave.v2.model.values.Value
import org.mule.weave.v2.runtime.exception.InvalidObjectExpansionException

import java.util.concurrent.atomic.AtomicInteger
import scala.collection.AbstractIterator
import scala.collection.mutable.ArrayBuffer

class DynamicObjectNode(elements: Seq[ValueNode[_]]) extends ValueNode[ObjectSeq] {
  override def doExecute(implicit ctx: ExecutionContext): Value[ObjectSeq] = {
    val filteredByCondition: Seq[ValueNode[_]] = elements
      .filter((node) => {
        node match {
          case ccn: ConditionalCapableNode => ccn.condition
          case _                           => true
        }
      })
    val iterator = filteredByCondition.toIterator
    val value: DynamicKeyValuePairIterator = new DynamicKeyValuePairIterator(iterator, ctx.executionStack().activeFrame(), ctx, this)
    ObjectValue(value, this)
  }

  override def productElement(n: Int): Any = elements.apply(n)

  override def productArity: Int = elements.length
}

class LiteralObjectNode(elements: Array[KeyValuePairNode]) extends ValueNode[ObjectSeq] {

  var value: Value[ObjectSeq] = _

  override def productElement(n: Int): Any = elements.apply(n)

  override def productArity: Int = elements.length

  override protected def doExecute(implicit ctx: ExecutionContext): Value[ObjectSeq] = {
    if (value == null) {
      synchronized {
        if (value == null) {
          val result = new Array[KeyValuePair](elements.length)
          var i = 0
          while (i < elements.length) {
            result.update(i, elements(i).toKeyValuePair)
            i = i + 1
          }
          value = ObjectValue(ObjectSeq(result, materialized = true), this)
        }
      }
    }
    value
  }
}

class ObjectNode(elements: Array[KeyValuePairNode]) extends ValueNode[ObjectSeq] {

  override protected def doExecute(implicit ctx: ExecutionContext): Value[ObjectSeq] = {
    val result = new Array[KeyValuePair](elements.length)
    var i = 0
    while (i < elements.length) {
      result.update(i, elements(i).toKeyValuePair)
      i = i + 1
    }
    ObjectValue(ObjectSeq(result, materialized = false), this)
  }

  override def productElement(n: Int): Any = elements.apply(n)

  override def productArity: Int = elements.length
}

class FilteredObjectNode(elements: Array[KeyValuePairNode]) extends ValueNode[ObjectSeq] {

  override protected def doExecute(implicit ctx: ExecutionContext): Value[ObjectSeq] = {
    val keyValuePairs = new ArrayBuffer[KeyValuePair]()
    var i = 0
    while (i < elements.length) {
      val valuePairNode = elements(i)
      if (valuePairNode.condition) {
        keyValuePairs.+=(valuePairNode.toKeyValuePair)
      }
      i = i + 1
    }
    ObjectValue(ObjectSeq(keyValuePairs), this)
  }

  override def productElement(n: Int): Any = elements.apply(n)

  override def productArity: Int = elements.length
}

class HeadTailObjectNode(val headKey: ValueNode[_], headValue: ValueNode[_], val tail: ValueNode[_]) extends ValueNode[ObjectSeq] {

  override def productElement(n: Int): Any = n match {
    case 0 => headKey
    case 1 => headValue
    case 2 => tail
  }

  override def productArity: Int = 2

  override def doExecute(implicit ctx: ExecutionContext): Value[ObjectSeq] = {
    val frame = ctx.executionStack().activeFrame()
    val valueProducer: () => Iterator[KeyValuePair] = () => {
      ctx.runInFrame(frame, ObjectType.coerce(tail.execute, this).evaluate.toIterator())
    }
    val kvpValue = KeyValuePairNode(headKey, headValue).execute.evaluate
    val headTailObjectSeq = new HeadTailKeyValuePairIterator(kvpValue, valueProducer)
    ObjectValue(headTailObjectSeq, this)
  }
}

//tail is a value producer
class HeadTailKeyValuePairIterator(var head: KeyValuePair, var tail: () => Iterator[KeyValuePair]) extends Iterator[KeyValuePair] {

  def consumeTail(): Unit = {
    val tailArraySeq = tail()
    tailArraySeq match {
      case ht: HeadTailKeyValuePairIterator if ht.hasNext =>
        head = ht.head
        tail = ht.tail
      case seq: Iterator[KeyValuePair] if seq.nonEmpty =>
        head = seq.next()
        tail = () => seq
      case _ =>
    }
  }

  override def hasNext: Boolean = {
    if (head == null) {
      consumeTail()
    }
    head != null
  }

  override def next(): KeyValuePair = {
    val result = head
    head = null
    result
  }
}

trait FlattenableIterator[+A] {
  def flatten(): Iterator[A]
}

class MappedDynamicKeyValuePairIterator[+B](iterator: DynamicKeyValuePairIterator, mapper: KeyValuePair => B) extends AbstractIterator[B] with FlattenableIterator[B] {
  def hasNext = iterator.hasNext
  def next() = mapper(iterator.next())

  def flatten(): Iterator[B] =
    iterator.flatten().map(key => mapper.apply(key))
}

class DynamicKeyValuePairIterator(filtered: Iterator[ValueNode[_]], contextFrame: Frame, ctx: ExecutionContext, locationCapable: LocationCapable) extends Iterator[KeyValuePair] with FlattenableIterator[KeyValuePair] {
  private val cpuLimitService = new FrequencyBasedCpuLimitService(ctx.serviceManager.cpuLimitService, CPU_LIMIT_CHECK_FREQUENCY)
  val count = new AtomicInteger()
  var currentIterator: Iterator[KeyValuePair] = Iterator.empty

  override def map[B](f: KeyValuePair => B): Iterator[B] = new MappedDynamicKeyValuePairIterator[B](this, f)

  def flatten() = allIterators()

  override def hasNext: Boolean = {
    if (!currentIterator.hasNext) {
      cpuLimitService.check(UnknownLocation)
      currentIterator = loadIterator()
    }
    currentIterator.hasNext
  }

  def allIterators(): Iterator[KeyValuePair] = {
    val filteredLoaded = filtered.flatMap((v) => loadNext(v))
    currentIterator ++ filteredLoaded
  }

  def loadNext(expression: ValueNode[_]): Iterator[KeyValuePair] = {
    var currentIterator: Iterator[KeyValuePair] = Iterator.empty;
    ctx.runInFrame(
      contextFrame, {
      val value: Value[_] = expression.execute(ctx)
      value match {
        case v: Value[_] if KeyValuePairType.accepts(v)(ctx) => {
          currentIterator = Iterator(v.evaluate(ctx).asInstanceOf[KeyValuePair])
        }
        case v: Value[_] if ObjectType.accepts(v)(ctx) => {
          val objectSeq: model.types.ObjectType.T = ObjectType.coerce(v, locationCapable)(ctx).evaluate(ctx)
          val iterator: Iterator[KeyValuePair] = objectSeq.toIterator()(ctx)
          currentIterator = iterator match {
            //If it is a nested FlattenableIterator (DynamicKeyValuePairIterator) unroll it up so that we remove stack calls`
            case flattenIterator: FlattenableIterator[KeyValuePair] =>
              flattenIterator.flatten()
            case _ => iterator
          }
        }
        case v: Value[_] if ArrayType.accepts(v)(ctx) => {
          val arr = ArrayType.coerce(v, locationCapable)(ctx).evaluate(ctx)
          currentIterator = arr
            .toIterator()
            .flatMap((v: Value[_]) => {
              val objectSeq = ObjectType.coerce(v, locationCapable)(ctx).evaluate(ctx)
              objectSeq.toIterator()(ctx)
            })
        }
        case v: Value[_] if NullType.accepts(v)(ctx) => {
          currentIterator = Iterator()
        }
        case v =>
          throw InvalidObjectExpansionException(expression.location(), v)(ctx)
      }
    })
    currentIterator
  }

  private def loadIterator(): Iterator[KeyValuePair] = {
    var currentIterator: Iterator[KeyValuePair] = Iterator.empty;
    while (filtered.hasNext && !currentIterator.hasNext) {
      val expression: ValueNode[_] = filtered.next()
      currentIterator = loadNext(expression)
    }
    currentIterator
  }

  override def next(): KeyValuePair = {
    if (hasNext) {
      currentIterator.next()
    } else {
      Iterator.empty.next()
    }
  }
}

object ObjectNode {
  def apply(elements: Seq[ValueNode[_]]): ValueNode[ObjectSeq] = {
    val containsExpression: Boolean = elements.exists((value) => !value.isInstanceOf[KeyValuePairNode])
    if (containsExpression) {
      new DynamicObjectNode(elements)
    } else {
      val needsFilter: Boolean = elements.exists((element) => element.asInstanceOf[KeyValuePairNode].cond.isDefined)
      if (needsFilter) {
        new FilteredObjectNode(elements.map(_.asInstanceOf[KeyValuePairNode]).toArray)
      } else {
        new ObjectNode(elements.map(_.asInstanceOf[KeyValuePairNode]).toArray)
      }
    }
  }
}
