program optimize_bench

// Benchmark: lex โ†’ parse โ†’ typecheck โ†’ codegen โ†’ optimize.
// Reads source bytes from public input, runs 5 stages,
// asserts optimized TIR op count and error count match expected values.
use vm.io.io

use vm.io.mem

use vm.core.convert

use std.compiler.lexer

use std.compiler.parser

use std.compiler.typecheck

use std.compiler.codegen

use std.compiler.optimize

fn main() {
    // Read layout parameters from public input
    let src_base: Field = io.read()
    let src_len: Field = io.read()
    let tok_base: Field = io.read()
    let lex_err_base: Field = io.read()
    let lex_state_base: Field = io.read()
    let ast_base: Field = io.read()
    let parse_err_base: Field = io.read()
    let parse_state_base: Field = io.read()
    let parse_stack_base: Field = io.read()
    let tc_state_base: Field = io.read()
    let cg_state_base: Field = io.read()
    let opt_state_base: Field = io.read()
    let expected_tok_count: Field = io.read()
    let expected_opt_tir_count: Field = io.read()
    let expected_err_count: Field = io.read()

    // Load source bytes from public input into RAM
    let src_len_u32: U32 = convert.as_u32(src_len)
    for i in 0..src_len_u32 bounded 2048 {
        let idx: Field = convert.as_field(i)
        let byte: Field = io.read()
        mem.write(src_base + idx, byte)
    }

    // Stage 1: Lex
    lexer.lex(src_base, src_len, tok_base, lex_err_base, lex_state_base)
    let tok_count: Field = mem.read(lex_state_base + 1)
    let lex_err_count: Field = mem.read(lex_state_base + 2)
    assert(tok_count == expected_tok_count)
    assert(lex_err_count == 0)

    // Stage 2: Parse
    parser.parse(tok_base, tok_count, ast_base, parse_err_base, parse_state_base, parse_stack_base)
    let node_count: Field = mem.read(parse_state_base + 2)
    let parse_err_count: Field = mem.read(parse_state_base + 3)
    assert(parse_err_count == 0)

    // Stage 3: Typecheck
    mem.write(tc_state_base, ast_base)
    mem.write(tc_state_base + 1, node_count)
    mem.write(tc_state_base + 2, tc_state_base + 32)
    mem.write(tc_state_base + 3, 0)
    mem.write(tc_state_base + 4, tc_state_base + 2032)
    mem.write(tc_state_base + 5, 0)
    mem.write(tc_state_base + 6, tc_state_base + 4032)
    mem.write(tc_state_base + 7, 0)
    mem.write(tc_state_base + 8, tc_state_base + 6032)
    mem.write(tc_state_base + 9, 0)
    mem.write(tc_state_base + 10, tc_state_base + 7032)
    mem.write(tc_state_base + 11, 0)
    mem.write(tc_state_base + 12, tc_state_base + 8032)
    mem.write(tc_state_base + 13, 0)
    mem.write(tc_state_base + 14, tc_state_base + 8532)
    mem.write(tc_state_base + 15, 0)
    mem.write(tc_state_base + 16, tc_state_base + 9032)
    mem.write(tc_state_base + 17, 0)
    mem.write(tc_state_base + 18, tc_state_base + 9532)
    mem.write(tc_state_base + 19, 0)
    mem.write(tc_state_base + 20, tc_state_base + 10032)
    mem.write(tc_state_base + 21, 0)
    mem.write(tc_state_base + 22, 0)
    mem.write(tc_state_base + 23, 5)
    mem.write(tc_state_base + 24, 10)
    mem.write(tc_state_base + 25, 2)
    mem.write(tc_state_base + 26, 3)
    mem.write(tc_state_base + 27, 0)
    mem.write(tc_state_base + 28, tc_state_base + 14032)
    mem.write(tc_state_base + 29, 0)

    typecheck.check(tc_state_base)
    let tc_err_count: Field = typecheck.error_count(tc_state_base)
    assert(tc_err_count == expected_err_count)

    // Stage 4: Codegen
    mem.write(cg_state_base, ast_base)
    mem.write(cg_state_base + 1, node_count)
    mem.write(cg_state_base + 2, cg_state_base + 48)
    mem.write(cg_state_base + 3, 0)
    mem.write(cg_state_base + 4, cg_state_base + 16048)
    mem.write(cg_state_base + 5, 0)
    mem.write(cg_state_base + 6, 0)
    mem.write(cg_state_base + 7, cg_state_base + 18048)
    mem.write(cg_state_base + 8, 0)
    mem.write(cg_state_base + 9, 1073741824)
    mem.write(cg_state_base + 10, 0)
    mem.write(cg_state_base + 11, mem.read(tc_state_base + 4))
    mem.write(cg_state_base + 12, mem.read(tc_state_base + 5))
    mem.write(cg_state_base + 13, mem.read(tc_state_base + 8))
    mem.write(cg_state_base + 14, mem.read(tc_state_base + 9))
    mem.write(cg_state_base + 15, mem.read(tc_state_base + 10))
    mem.write(cg_state_base + 16, mem.read(tc_state_base + 11))
    mem.write(cg_state_base + 17, mem.read(tc_state_base + 14))
    mem.write(cg_state_base + 18, mem.read(tc_state_base + 15))
    mem.write(cg_state_base + 19, mem.read(tc_state_base + 16))
    mem.write(cg_state_base + 20, mem.read(tc_state_base + 17))
    mem.write(cg_state_base + 21, mem.read(tc_state_base + 12))
    mem.write(cg_state_base + 22, mem.read(tc_state_base + 13))
    mem.write(cg_state_base + 23, mem.read(tc_state_base + 2))
    mem.write(cg_state_base + 24, mem.read(tc_state_base + 3))
    mem.write(cg_state_base + 25, 16)
    mem.write(cg_state_base + 26, 1073741824)
    mem.write(cg_state_base + 27, 5)
    mem.write(cg_state_base + 28, 10)
    mem.write(cg_state_base + 29, 3)
    mem.write(cg_state_base + 30, cg_state_base + 19048)
    mem.write(cg_state_base + 31, 0)
    mem.write(cg_state_base + 32, cg_state_base + 27048)
    mem.write(cg_state_base + 33, 0)
    mem.write(cg_state_base + 34, cg_state_base + 28048)
    mem.write(cg_state_base + 35, 0)
    mem.write(cg_state_base + 36, cg_state_base + 37048)
    mem.write(cg_state_base + 37, 0)
    mem.write(cg_state_base + 38, tok_base)
    mem.write(cg_state_base + 39, 1)

    codegen.codegen(cg_state_base)

    let tir_base: Field = codegen.result_base(cg_state_base)
    let tir_count: Field = codegen.result_count(cg_state_base)

    // Stage 5: Optimize
    mem.write(opt_state_base, tir_base)
    mem.write(opt_state_base + 1, tir_count)
    mem.write(opt_state_base + 2, opt_state_base + 8)
    mem.write(opt_state_base + 3, 0)
    mem.write(opt_state_base + 4, opt_state_base + 16008)
    mem.write(opt_state_base + 5, 0)

    optimize.optimize(opt_state_base)

    let opt_tir_count: Field = optimize.result_count(opt_state_base)
    assert(opt_tir_count == expected_opt_tir_count)

    // Output counts as proof of work
    io.write(tok_count)
    io.write(opt_tir_count)
    io.write(tc_err_count)
}

Dimensions

trident/std/compiler/optimize.tri

Local Graph