package org.mule.weave.v2.utils

import org.mule.weave.v2.parser.ast.variables.NameIdentifier

import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.locks.ReentrantReadWriteLock
import scala.annotation.tailrec
import scala.collection.mutable

class JVMReadWriteLock extends ReadWriteLock {

  private val parentLock: ReentrantReadWriteLock = new ReentrantReadWriteLock()
  private val readLock: ReentrantReadWriteLock.ReadLock = parentLock.readLock()
  private val writeLock: ReentrantReadWriteLock.WriteLock = parentLock.writeLock()

  override def readLock[T](provider: => T): T = {
    readLock.lock()
    try {
      provider
    } finally {
      readLock.unlock()
    }
  }

  override def writeLock[T](provider: => T): T = {
    writeLock.lock()
    try {
      provider
    } finally {
      writeLock.unlock()
    }
  }
}

/**
  * In the case of the JVM Lock as it is also being used inside turtles (sfdc-core-dataweave) and turtles doesn't guarantee
  * finally block to be executed. This is why we are using `synchronized` statement instead of ReentrantLock.
  */
class JVMLock extends Lock {
  private val locks = new ConcurrentHashMap[NameIdentifier, LockInformation]()
  private val stacks = new ConcurrentHashMap[Thread, mutable.ArrayBuffer[NameIdentifier]]()

  override def lock[T](nameIdentifier: NameIdentifier, provider: => T): T = {
    val lockInformation = locks.computeIfAbsent(nameIdentifier, _ => new LockInformation)
    var identifiers: mutable.ArrayBuffer[NameIdentifier] = null
    try {
      var isCircular: Boolean = false
      // We need to make sure that we add a new identifier into the stack and check for circular
      // In the same transaction so that we don't suffer from concurrency issues
      this.synchronized({
        isCircular = isCircularDep(nameIdentifier)
        identifiers = stacks.computeIfAbsent(Thread.currentThread(), _ => new mutable.ArrayBuffer[NameIdentifier]())
        identifiers.+=(nameIdentifier)
      })
      if (isCircular) {
        provider
      } else {
        lockInformation.synchronized({
          provider
        })
      }
    } finally {
      // We need to synchronize removal with isCircularDep, if not we can generate an exception
      // on `locking.getValue.last`.
      // The locked entry is present so it passes the check but before getting
      // the value it is removed from stacks by a thread finalizing compilation
      this.synchronized({
        identifiers.remove(identifiers.length - 1)
        if (identifiers.isEmpty) {
          stacks.remove(Thread.currentThread())
        }
      })
    }
  }

  // A -> B -> C
  @tailrec
  private def isCircularDep(nameIdentifier: NameIdentifier): Boolean = {
    val maybeLocking = stacks
      .entrySet()
      .stream()
      .filter(entry => {
        val index = entry.getValue.indexOf(nameIdentifier)
        index >= 0 && index != entry.getValue.size - 1
      })
      .findFirst()

    if (maybeLocking.isPresent) {
      val locking = maybeLocking.get()
      if (locking.getKey eq Thread.currentThread()) {
        true
      } else {
        isCircularDep(locking.getValue.last)
      }
    } else {
      false
    }
  }

  private case class LockInformation()
}

/**
  * Factory for creating locks
  */
object LockFactory {
  def createReadWriteLock(): ReadWriteLock = new JVMReadWriteLock
  def createLock(): Lock = new JVMLock
}
