module std.compiler.pipeline

// Self-hosted compiler pipeline: source bytes โ†’ optimized TIR.
//
// Chains all 6 compiler stages:
//   1. Lex    (source โ†’ tokens)
//   2. Parse  (tokens โ†’ AST)
//   3. Check  (AST โ†’ typed AST + symbol tables)
//   4. Codegen (typed AST โ†’ TIR)
//   5. Optimize (TIR โ†’ optimized TIR)
//   6. Lower  (optimized TIR โ†’ TASM)  โ† stub, not wired yet
//
// Stage 6 is a placeholder โ€” lower needs core + warrior to target.
// compile() currently produces optimized TIR. When lower is wired,
// it will produce TASM bytes.
//
// Memory layout (state block, 12 words):
//   +0   src_base          Source bytes in RAM
//   +1   src_len           Source byte count
//   +2   out_tir_base      Output: optimized TIR base (set by compile)
//   +3   out_tir_count     Output: optimized TIR op count (set by compile)
//   +4   file_kind         0=library, 1=program
//   +5   digest_width      VM param (5)
//   +6   hash_rate         VM param (10)
//   +7   xfield_width      VM param (3)
//   +8   max_stack_depth   VM param (16)
//   +9   spill_ram_base    VM param (2^30)
//   +10  scratch_base      Start of scratch memory (~100K words)
//   +11  err_count         Total errors across all stages (set by compile)

use vm.core.field

use vm.core.convert

use vm.io.mem

use std.compiler.lexer

use std.compiler.parser

use std.compiler.typecheck

use std.compiler.codegen

use std.compiler.optimize

// =========================================================================
// State accessors
// =========================================================================

fn p_src_base(sb: Field) -> Field { mem.read(sb) }
fn p_src_len(sb: Field) -> Field { mem.read(sb + 1) }
fn p_set_tir_base(sb: Field, v: Field) { mem.write(sb + 2, v) }
fn p_set_tir_count(sb: Field, v: Field) { mem.write(sb + 3, v) }
fn p_file_kind(sb: Field) -> Field { mem.read(sb + 4) }
fn p_digest_w(sb: Field) -> Field { mem.read(sb + 5) }
fn p_hash_rate(sb: Field) -> Field { mem.read(sb + 6) }
fn p_xfield_w(sb: Field) -> Field { mem.read(sb + 7) }
fn p_max_depth(sb: Field) -> Field { mem.read(sb + 8) }
fn p_spill_ram(sb: Field) -> Field { mem.read(sb + 9) }
fn p_scratch(sb: Field) -> Field { mem.read(sb + 10) }
fn p_set_err_count(sb: Field, v: Field) { mem.write(sb + 11, v) }

// =========================================================================
// Scratch memory layout โ€” offsets from scratch_base
// =========================================================================

// Stage 1: Lexer
fn OFF_TOK_BASE() -> Field { 0 }
fn OFF_LEX_ERR() -> Field { 4000 }
fn OFF_LEX_STATE() -> Field { 4256 }

// Stage 2: Parser
fn OFF_AST_BASE() -> Field { 4280 }
fn OFF_PARSE_ERR() -> Field { 12280 }
fn OFF_PARSE_STATE() -> Field { 12536 }
fn OFF_PARSE_STACK() -> Field { 12600 }

// Stage 3: Typecheck
fn OFF_TC_STATE() -> Field { 14600 }
fn OFF_TC_TP() -> Field { 14632 }
fn OFF_TC_FN() -> Field { 16632 }
fn OFF_TC_VAR() -> Field { 18632 }
fn OFF_TC_ST() -> Field { 20632 }
fn OFF_TC_SF() -> Field { 21632 }
fn OFF_TC_CT() -> Field { 22632 }
fn OFF_TC_EV() -> Field { 23132 }
fn OFF_TC_EF() -> Field { 23632 }
fn OFF_TC_ERR() -> Field { 24132 }
fn OFF_TC_WS() -> Field { 24632 }
fn OFF_TC_RS() -> Field { 28632 }

// Stage 4: Codegen
fn OFF_CG_STATE() -> Field { 29632 }
fn OFF_CG_TIR() -> Field { 29680 }
fn OFF_CG_STACK() -> Field { 45680 }
fn OFF_CG_SPILL() -> Field { 47680 }
fn OFF_CG_WORK() -> Field { 48680 }
fn OFF_CG_RS() -> Field { 56680 }
fn OFF_CG_SIDE() -> Field { 57680 }
fn OFF_CG_SAVED() -> Field { 66680 }

// Stage 5: Optimize
fn OFF_OPT_STATE() -> Field { 74680 }
fn OFF_OPT_OUT() -> Field { 74688 }
fn OFF_OPT_STATS() -> Field { 90688 }

// =========================================================================
// Public API
// =========================================================================

pub fn compile(state_base: Field) {
    let sb: Field = state_base
    let sc: Field = p_scratch(sb)
    let mut total_errors: Field = 0

    // =====================================================================
    // Stage 1: Lex
    // =====================================================================

    let tok_base: Field = sc + OFF_TOK_BASE()
    let lex_err: Field = sc + OFF_LEX_ERR()
    let lex_state: Field = sc + OFF_LEX_STATE()

    lexer.lex(p_src_base(sb), p_src_len(sb), tok_base, lex_err, lex_state)

    let tok_count: Field = mem.read(lex_state + 1)
    let lex_errs: Field = mem.read(lex_state + 2)
    total_errors = total_errors + lex_errs

    // =====================================================================
    // Stage 2: Parse
    // =====================================================================

    let ast_base: Field = sc + OFF_AST_BASE()
    let parse_err: Field = sc + OFF_PARSE_ERR()
    let parse_state: Field = sc + OFF_PARSE_STATE()
    let parse_stack: Field = sc + OFF_PARSE_STACK()

    parser.parse(tok_base, tok_count, ast_base, parse_err, parse_state, parse_stack)

    let node_count: Field = mem.read(parse_state + 2)
    let parse_errs: Field = mem.read(parse_state + 3)
    total_errors = total_errors + parse_errs

    // =====================================================================
    // Stage 3: Typecheck
    // =====================================================================

    let tc: Field = sc + OFF_TC_STATE()

    // AST input
    mem.write(tc, ast_base)
    mem.write(tc + 1, node_count)
    // Type pool
    mem.write(tc + 2, sc + OFF_TC_TP())
    mem.write(tc + 3, 0)
    // Function table
    mem.write(tc + 4, sc + OFF_TC_FN())
    mem.write(tc + 5, 0)
    // Variable scope
    mem.write(tc + 6, sc + OFF_TC_VAR())
    mem.write(tc + 7, 0)
    // Struct table
    mem.write(tc + 8, sc + OFF_TC_ST())
    mem.write(tc + 9, 0)
    // Struct fields
    mem.write(tc + 10, sc + OFF_TC_SF())
    mem.write(tc + 11, 0)
    // Constant table
    mem.write(tc + 12, sc + OFF_TC_CT())
    mem.write(tc + 13, 0)
    // Event table
    mem.write(tc + 14, sc + OFF_TC_EV())
    mem.write(tc + 15, 0)
    // Event fields
    mem.write(tc + 16, sc + OFF_TC_EF())
    mem.write(tc + 17, 0)
    // Error output
    mem.write(tc + 18, sc + OFF_TC_ERR())
    mem.write(tc + 19, 0)
    // Work stack
    mem.write(tc + 20, sc + OFF_TC_WS())
    mem.write(tc + 21, 0)
    // Scope depth
    mem.write(tc + 22, 0)
    // VM params: digest_width, hash_rate, field_limbs, xfield_width
    mem.write(tc + 23, p_digest_w(sb))
    mem.write(tc + 24, p_hash_rate(sb))
    mem.write(tc + 25, 2)
    mem.write(tc + 26, p_xfield_w(sb))
    // In pure fn
    mem.write(tc + 27, 0)
    // Result stack
    mem.write(tc + 28, sc + OFF_TC_RS())
    mem.write(tc + 29, 0)

    typecheck.check(tc)

    let tc_errs: Field = typecheck.error_count(tc)
    total_errors = total_errors + tc_errs

    // =====================================================================
    // Stage 4: Codegen
    // =====================================================================

    let cg: Field = sc + OFF_CG_STATE()

    // AST input (from parser)
    mem.write(cg, ast_base)
    mem.write(cg + 1, node_count)
    // TIR output
    mem.write(cg + 2, sc + OFF_CG_TIR())
    mem.write(cg + 3, 0)
    // Virtual stack model
    mem.write(cg + 4, sc + OFF_CG_STACK())
    mem.write(cg + 5, 0)
    // Access counter
    mem.write(cg + 6, 0)
    // Spill table
    mem.write(cg + 7, sc + OFF_CG_SPILL())
    mem.write(cg + 8, 0)
    // Next spill address
    mem.write(cg + 9, p_spill_ram(sb))
    // Label counter
    mem.write(cg + 10, 0)
    // Function table (from typecheck +4/+5)
    mem.write(cg + 11, mem.read(tc + 4))
    mem.write(cg + 12, mem.read(tc + 5))
    // Struct table (from typecheck +8/+9)
    mem.write(cg + 13, mem.read(tc + 8))
    mem.write(cg + 14, mem.read(tc + 9))
    // Struct fields (from typecheck +10/+11)
    mem.write(cg + 15, mem.read(tc + 10))
    mem.write(cg + 16, mem.read(tc + 11))
    // Event table (from typecheck +14/+15)
    mem.write(cg + 17, mem.read(tc + 14))
    mem.write(cg + 18, mem.read(tc + 15))
    // Event fields (from typecheck +16/+17)
    mem.write(cg + 19, mem.read(tc + 16))
    mem.write(cg + 20, mem.read(tc + 17))
    // Constant table (from typecheck +12/+13)
    mem.write(cg + 21, mem.read(tc + 12))
    mem.write(cg + 22, mem.read(tc + 13))
    // Type pool (from typecheck +2/+3)
    mem.write(cg + 23, mem.read(tc + 2))
    mem.write(cg + 24, mem.read(tc + 3))
    // VM params
    mem.write(cg + 25, p_max_depth(sb))
    mem.write(cg + 26, p_spill_ram(sb))
    mem.write(cg + 27, p_digest_w(sb))
    mem.write(cg + 28, p_hash_rate(sb))
    mem.write(cg + 29, p_xfield_w(sb))
    // Work stack
    mem.write(cg + 30, sc + OFF_CG_WORK())
    mem.write(cg + 31, 0)
    // Result stack
    mem.write(cg + 32, sc + OFF_CG_RS())
    mem.write(cg + 33, 0)
    // Side-effects buffer
    mem.write(cg + 34, sc + OFF_CG_SIDE())
    mem.write(cg + 35, 0)
    // Saved states
    mem.write(cg + 36, sc + OFF_CG_SAVED())
    mem.write(cg + 37, 0)
    // Token base (for name comparison)
    mem.write(cg + 38, tok_base)
    // File kind
    mem.write(cg + 39, p_file_kind(sb))

    codegen.codegen(cg)

    let tir_base: Field = codegen.result_base(cg)
    let tir_count: Field = codegen.result_count(cg)

    // =====================================================================
    // Stage 5: Optimize
    // =====================================================================

    let opt: Field = sc + OFF_OPT_STATE()

    mem.write(opt, tir_base)
    mem.write(opt + 1, tir_count)
    mem.write(opt + 2, sc + OFF_OPT_OUT())
    mem.write(opt + 3, 0)
    mem.write(opt + 4, sc + OFF_OPT_STATS())
    mem.write(opt + 5, 0)

    optimize.optimize(opt)

    let opt_tir_base: Field = optimize.result_base(opt)
    let opt_tir_count: Field = optimize.result_count(opt)

    // =====================================================================
    // Stage 6: Lower (stub โ€” needs core + warrior)
    // =====================================================================

    // TODO: wire lower.lower() here when targeting TASM
    // For now, output is optimized TIR.

    // =====================================================================
    // Write results
    // =====================================================================

    p_set_tir_base(sb, opt_tir_base)
    p_set_tir_count(sb, opt_tir_count)
    p_set_err_count(sb, total_errors)
}

// Read back results after compile()
pub fn result_tir_base(state_base: Field) -> Field { mem.read(state_base + 2) }
pub fn result_tir_count(state_base: Field) -> Field { mem.read(state_base + 3) }
pub fn result_err_count(state_base: Field) -> Field { mem.read(state_base + 11) }

Dimensions

trident/benches/harnesses/std/compiler/pipeline.tri

Local Graph