package org.mule.weave.v2.runtime

import org.mule.weave.v2.interpreted.InterpreterMappingCompilerPhase
import org.mule.weave.v2.interpreted.InterpreterModuleCompilerPhase
import org.mule.weave.v2.interpreted.InterpreterPreCompilerPhase
import org.mule.weave.v2.interpreted.RuntimeModuleNodeCompiler
import org.mule.weave.v2.interpreted.transform.phase.ConstantFoldingPhase
import org.mule.weave.v2.parser.MappingParser
import org.mule.weave.v2.parser.ModuleParser
import org.mule.weave.v2.parser.ast.module.ModuleNode
import org.mule.weave.v2.parser.ast.structure.DocumentNode
import org.mule.weave.v2.parser.phase.CommonSubexpressionReductionPhase
import org.mule.weave.v2.parser.phase.CompilationPhase
import org.mule.weave.v2.parser.phase.LoggingContextInjectionPhase
import org.mule.weave.v2.parser.phase.ParsingContentInput
import org.mule.weave.v2.parser.phase.ParsingContext
import org.mule.weave.v2.parser.phase.PhaseResult
import org.mule.weave.v2.parser.phase.ScopeGraphResult
import org.mule.weave.v2.sdk.WeaveResource

object WeaveCompiler {

  private def fullCompilerPhase(moduleNodeLoader: RuntimeModuleNodeCompiler): CompilationPhase[ParsingContentInput, CompilationResult[DocumentNode]] = {
    MappingParser.typeCheckPhase()
      .chainWith(new CommonSubexpressionReductionPhase[DocumentNode]())
      .chainWith(new ConstantFoldingPhase[DocumentNode]())
      .chainWith(new LoggingContextInjectionPhase[DocumentNode]())
      .chainWith(RuntimeDocumentNodeCompiler.compilerPhase(moduleNodeLoader))
  }

  private def fullModuleCompilerPhase(moduleNodeLoader: RuntimeModuleNodeCompiler): CompilationPhase[ParsingContentInput, CompilationResult[ModuleNode]] = {
    ModuleParser.typeCheckPhase()
      .chainWith(new CommonSubexpressionReductionPhase[ModuleNode]())
      .chainWith(new ConstantFoldingPhase[ModuleNode]())
      .chainWith(new LoggingContextInjectionPhase[ModuleNode]())
      .chainWith(RuntimeModuleCompiler.compilerPhase(moduleNodeLoader))
  }

  private def noCheckCompilerPhase(moduleNodeLoader: RuntimeModuleNodeCompiler): CompilationPhase[ParsingContentInput, CompilationResult[DocumentNode]] = {
    MappingParser.scopePhase()
      .chainWith(new CommonSubexpressionReductionPhase[DocumentNode]())
      .chainWith(new ConstantFoldingPhase[DocumentNode]())
      .chainWith(new LoggingContextInjectionPhase[DocumentNode]())
      .chainWith(RuntimeDocumentNodeCompiler.compilerPhase(moduleNodeLoader))
  }

  def compileWithNoCheck(input: WeaveResource, parsingContext: ParsingContext, moduleNodeLoader: RuntimeModuleNodeCompiler): PhaseResult[CompilationResult[DocumentNode]] = {
    MappingParser
      .parse(noCheckCompilerPhase(moduleNodeLoader), input, parsingContext)
  }

  def compileWithNoCheck(input: WeaveResource, parsingContext: ParsingContext): PhaseResult[CompilationResult[DocumentNode]] = {
    MappingParser.parse(noCheckCompilerPhase(RuntimeModuleNodeCompiler()), input, parsingContext)
  }

  def compile(input: WeaveResource, parsingContext: ParsingContext): PhaseResult[CompilationResult[DocumentNode]] = {
    MappingParser.parse(fullCompilerPhase(RuntimeModuleNodeCompiler()), input, parsingContext)
  }

  def compile(input: WeaveResource, parsingContext: ParsingContext, moduleNodeLoader: RuntimeModuleNodeCompiler): PhaseResult[CompilationResult[DocumentNode]] = {
    MappingParser.parse(fullCompilerPhase(moduleNodeLoader), input, parsingContext)
  }

  def compileModule(input: WeaveResource, parsingContext: ParsingContext, moduleNodeLoader: RuntimeModuleNodeCompiler): PhaseResult[CompilationResult[ModuleNode]] = {
    ModuleParser.parse(fullModuleCompilerPhase(moduleNodeLoader), input, parsingContext)
  }

  def runtimeModuleCompilation(moduleNode: ScopeGraphResult[ModuleNode], parsingContext: ParsingContext, moduleNodeLoader: RuntimeModuleNodeCompiler): PhaseResult[CompilationResult[ModuleNode]] = {
    RuntimeModuleCompiler
      .compilerPhase(moduleNodeLoader)
      .call(moduleNode, parsingContext)
  }

  def runtimeCompilation(documentNode: ScopeGraphResult[DocumentNode], parsingContext: ParsingContext, moduleNodeLoader: RuntimeModuleNodeCompiler): PhaseResult[CompilationResult[DocumentNode]] = {
    RuntimeDocumentNodeCompiler
      .compilerPhase(moduleNodeLoader)
      .call(documentNode, parsingContext)
  }
}

object RuntimeModuleCompiler {

  def compilerPhase(moduleNodeLoader: RuntimeModuleNodeCompiler): CompilationPhase[ScopeGraphResult[ModuleNode], CompilationResult[ModuleNode]] = {
    new LoggingContextInjectionPhase[ModuleNode]()
      .chainWith(new InterpreterPreCompilerPhase[ModuleNode, ScopeGraphResult[ModuleNode]]())
      .chainWith(new InterpreterModuleCompilerPhase(moduleNodeLoader))
  }
}

object RuntimeDocumentNodeCompiler {

  def compilerPhase(moduleNodeLoader: RuntimeModuleNodeCompiler): CompilationPhase[ScopeGraphResult[DocumentNode], CompilationResult[DocumentNode]] = {
    new LoggingContextInjectionPhase[DocumentNode]()
      .chainWith(new InterpreterPreCompilerPhase[DocumentNode, ScopeGraphResult[DocumentNode]]())
      .chainWith(new InterpreterMappingCompilerPhase(moduleNodeLoader))
  }
}
