module std.compiler.optimize
// Self-hosted TIR peephole optimizer.
//
// Runs pattern-based rewrites on a flat TIR array (stride 4) to reduce
// instruction count. Applied between TIR building and lowering.
//
// Eight peephole passes run in a fixed-point loop:
// 1. merge_hints โ consecutive Hint(a)+Hint(b) โ Hint(a+b)
// 2. merge_pops โ consecutive Pop(a)+Pop(b) โ Pop(a+b)
// 3. eliminate_nops โ remove Swap(0), Pop(0)
// 4. dead_spills โ dead spill/reload elimination
// 5. dup_pop_nops โ Dup(0)+Pop(1) โ nothing
// 6. double_swaps โ Swap(N)+Swap(N) โ nothing
// 7. swap_pop_chains โ dup+swap+pop round-trips
// 8. epilogue_cleanup โ swap(D)+pop(1) chains โ compact
// + optimize_nested โ recurse into structural op bodies
//
// Memory layout (state block):
// +0 tir_base Input TIR ops (stride 4)
// +1 tir_count Number of input ops
// +2 out_base Output buffer (stride 4)
// +3 out_count Output op count
// +4 addr_stats_base Address stats for dead spill (stride 3: addr, writes, reads)
// +5 addr_stats_count
use vm.core.field
use vm.core.convert
use vm.io.mem
use std.compiler.lower
// =========================================================================
// State accessors
// =========================================================================
fn s_in_base(sb: Field) -> Field { mem.read(sb) }
fn s_in_count(sb: Field) -> Field { mem.read(sb + 1) }
fn s_set_in_count(sb: Field, v: Field) { mem.write(sb + 1, v) }
fn s_out_base(sb: Field) -> Field { mem.read(sb + 2) }
fn s_out_count(sb: Field) -> Field { mem.read(sb + 3) }
fn s_set_out_count(sb: Field, v: Field) { mem.write(sb + 3, v) }
fn s_addr_base(sb: Field) -> Field { mem.read(sb + 4) }
fn s_addr_count(sb: Field) -> Field { mem.read(sb + 5) }
fn s_set_addr_count(sb: Field, v: Field) { mem.write(sb + 5, v) }
// TIR op accessors (stride 4 from a given base)
fn op_at(base: Field, idx: Field) -> Field { mem.read(base + idx * 4) }
fn a0_at(base: Field, idx: Field) -> Field { mem.read(base + idx * 4 + 1) }
fn a1_at(base: Field, idx: Field) -> Field { mem.read(base + idx * 4 + 2) }
fn a2_at(base: Field, idx: Field) -> Field { mem.read(base + idx * 4 + 3) }
// Write a TIR op to output
fn emit_op(sb: Field, opcode: Field, arg0: Field, arg1: Field, arg2: Field) {
let base: Field = s_out_base(sb)
let cnt: Field = s_out_count(sb)
mem.write(base + cnt * 4, opcode)
mem.write(base + cnt * 4 + 1, arg0)
mem.write(base + cnt * 4 + 2, arg1)
mem.write(base + cnt * 4 + 3, arg2)
s_set_out_count(sb, cnt + 1)
}
// Copy a TIR op from input to output
fn copy_op(sb: Field, idx: Field) {
let base: Field = s_in_base(sb)
emit_op(sb, op_at(base, idx), a0_at(base, idx), a1_at(base, idx), a2_at(base, idx))
}
// Swap input/output buffers (for next pass)
fn swap_buffers(sb: Field) {
let ib: Field = s_in_base(sb)
let ic: Field = s_in_count(sb)
let ob: Field = s_out_base(sb)
let oc: Field = s_out_count(sb)
mem.write(sb, ob)
s_set_in_count(sb, oc)
mem.write(sb + 2, ib)
s_set_out_count(sb, 0)
}
// =========================================================================
// Pass 1: Merge consecutive Hints
// Hint(a) + Hint(b) โ Hint(a+b), emitted in batches of 5
// =========================================================================
fn pass_merge_hints(sb: Field) {
let base: Field = s_in_base(sb)
let count: Field = s_in_count(sb)
s_set_out_count(sb, 0)
let mut i: Field = 0
let cu: U32 = convert.as_u32(count)
for step in 0..cu bounded 65536 {
if i == count {
// done
} else if op_at(base, i) == lower.OP_HINT() {
let mut total: Field = a0_at(base, i)
i = i + 1
// Merge consecutive hints
let inner: U32 = convert.as_u32(count)
for s2 in 0..inner bounded 65536 {
if i == count {
// done
} else if op_at(base, i) == lower.OP_HINT() {
total = total + a0_at(base, i)
i = i + 1
}
}
// Emit in batches of 5
let five: U32 = convert.as_u32(5)
let mut rem: U32 = convert.as_u32(total)
let bnd: U32 = convert.as_u32(20)
for s3 in 0..bnd bounded 20 {
if rem == convert.as_u32(0) {
// done
} else {
let mut batch: U32 = rem
if batch < convert.as_u32(6) {
// batch <= 5, use as is
} else {
batch = five
}
emit_op(sb, lower.OP_HINT(), convert.as_field(batch), 0, 0)
rem = convert.as_u32(convert.as_field(rem) + field.neg(convert.as_field(batch)))
}
}
} else {
copy_op(sb, i)
i = i + 1
}
}
swap_buffers(sb)
}
// =========================================================================
// Pass 2: Merge consecutive Pops
// Pop(a) + Pop(b) โ Pop(a+b), emitted in batches of 5
// =========================================================================
fn pass_merge_pops(sb: Field) {
let base: Field = s_in_base(sb)
let count: Field = s_in_count(sb)
s_set_out_count(sb, 0)
let mut i: Field = 0
let cu: U32 = convert.as_u32(count)
for step in 0..cu bounded 65536 {
if i == count {
// done
} else if op_at(base, i) == lower.OP_POP() {
let mut total: Field = a0_at(base, i)
i = i + 1
let inner: U32 = convert.as_u32(count)
for s2 in 0..inner bounded 65536 {
if i == count {
// done
} else if op_at(base, i) == lower.OP_POP() {
total = total + a0_at(base, i)
i = i + 1
}
}
let five: U32 = convert.as_u32(5)
let mut rem: U32 = convert.as_u32(total)
let bnd: U32 = convert.as_u32(20)
for s3 in 0..bnd bounded 20 {
if rem == convert.as_u32(0) {
// done
} else {
let mut batch: U32 = rem
if batch < convert.as_u32(6) {
// ok
} else {
batch = five
}
emit_op(sb, lower.OP_POP(), convert.as_field(batch), 0, 0)
rem = convert.as_u32(convert.as_field(rem) + field.neg(convert.as_field(batch)))
}
}
} else {
copy_op(sb, i)
i = i + 1
}
}
swap_buffers(sb)
}
// =========================================================================
// Pass 3: Eliminate no-ops โ Swap(0) and Pop(0)
// =========================================================================
fn pass_eliminate_nops(sb: Field) {
let base: Field = s_in_base(sb)
let count: Field = s_in_count(sb)
s_set_out_count(sb, 0)
let cu: U32 = convert.as_u32(count)
for i in 0..cu bounded 65536 {
let idx: Field = convert.as_field(i)
let op: Field = op_at(base, idx)
let a0: Field = a0_at(base, idx)
if op == lower.OP_SWAP() {
if a0 == 0 {
// skip Swap(0)
} else {
copy_op(sb, idx)
}
} else if op == lower.OP_POP() {
if a0 == 0 {
// skip Pop(0)
} else {
copy_op(sb, idx)
}
} else {
copy_op(sb, idx)
}
}
swap_buffers(sb)
}
// =========================================================================
// Pass 4: Dead spill elimination
// Spill write: Push(addr), Swap(1), WriteMem(1), Pop(1)
// Spill read: Push(addr), ReadMem(1), Pop(1)
// =========================================================================
fn pass_dead_spills(sb: Field) {
let base: Field = s_in_base(sb)
let count: Field = s_in_count(sb)
// Pass 1: count writes and reads per address
let ab: Field = s_addr_base(sb)
let mut ac: Field = 0
// Scan for write patterns (4-op window)
let cu: U32 = convert.as_u32(count)
for i in 0..cu bounded 65536 {
let idx: Field = convert.as_field(i)
if convert.as_u32(idx + 3) < cu {
if op_at(base, idx) == lower.OP_PUSH() {
if op_at(base, idx + 1) == lower.OP_SWAP() {
if a0_at(base, idx + 1) == 1 {
if op_at(base, idx + 2) == lower.OP_WRITE_MEM() {
if a0_at(base, idx + 2) == 1 {
if op_at(base, idx + 3) == lower.OP_POP() {
if a0_at(base, idx + 3) == 1 {
let addr: Field = a0_at(base, idx)
// Find or create addr stat
let mut found: Field = 0
let acu: U32 = convert.as_u32(ac)
for j in 0..acu bounded 4096 {
let ji: Field = convert.as_field(j)
if mem.read(ab + ji * 3) == addr {
mem.write(ab + ji * 3 + 1, mem.read(ab + ji * 3 + 1) + 1)
found = 1
}
}
if found == 0 {
mem.write(ab + ac * 3, addr)
mem.write(ab + ac * 3 + 1, 1) // writes
mem.write(ab + ac * 3 + 2, 0) // reads
ac = ac + 1
}
}
}
}
}
}
}
}
}
}
// Scan for read patterns (3-op window)
for i in 0..cu bounded 65536 {
let idx: Field = convert.as_field(i)
if convert.as_u32(idx + 2) < cu {
if op_at(base, idx) == lower.OP_PUSH() {
if op_at(base, idx + 1) == lower.OP_READ_MEM() {
if a0_at(base, idx + 1) == 1 {
if op_at(base, idx + 2) == lower.OP_POP() {
if a0_at(base, idx + 2) == 1 {
let addr: Field = a0_at(base, idx)
let acu: U32 = convert.as_u32(ac)
for j in 0..acu bounded 4096 {
let ji: Field = convert.as_field(j)
if mem.read(ab + ji * 3) == addr {
mem.write(ab + ji * 3 + 2, mem.read(ab + ji * 3 + 2) + 1)
}
}
}
}
}
}
}
}
}
// Classify: pair (1 write, 1 read) or dead (writes only, 0 reads)
// Check if anything needs optimization
let mut has_changes: Field = 0
let acu: U32 = convert.as_u32(ac)
for j in 0..acu bounded 4096 {
let ji: Field = convert.as_field(j)
let wc: Field = mem.read(ab + ji * 3 + 1)
let rc: Field = mem.read(ab + ji * 3 + 2)
if wc == 1 {
if rc == 1 {
has_changes = 1
} else if rc == 0 {
has_changes = 1
}
} else if rc == 0 {
has_changes = 1
}
}
if has_changes == 0 {
// No changes, skip rewrite
return
}
// Pass 2: rewrite
s_set_out_count(sb, 0)
let mut i: Field = 0
for step in 0..cu bounded 65536 {
if i == count {
// done
} else {
let mut skip: Field = 0
// Check write pattern
if convert.as_u32(i + 3) < cu {
if op_at(base, i) == lower.OP_PUSH() {
if op_at(base, i + 1) == lower.OP_SWAP() {
if a0_at(base, i + 1) == 1 {
if op_at(base, i + 2) == lower.OP_WRITE_MEM() {
if a0_at(base, i + 2) == 1 {
if op_at(base, i + 3) == lower.OP_POP() {
if a0_at(base, i + 3) == 1 {
let addr: Field = a0_at(base, i)
let cls: Field = classify_addr(ab, ac, addr)
if cls == 1 {
// pair: remove entirely
i = i + 4
skip = 1
} else if cls == 2 {
// dead: replace with Pop(1)
emit_op(sb, lower.OP_POP(), 1, 0, 0)
i = i + 4
skip = 1
}
}
}
}
}
}
}
}
}
// Check read pattern
if skip == 0 {
if convert.as_u32(i + 2) < cu {
if op_at(base, i) == lower.OP_PUSH() {
if op_at(base, i + 1) == lower.OP_READ_MEM() {
if a0_at(base, i + 1) == 1 {
if op_at(base, i + 2) == lower.OP_POP() {
if a0_at(base, i + 2) == 1 {
let addr: Field = a0_at(base, i)
let cls: Field = classify_addr(ab, ac, addr)
if cls == 1 {
// pair: remove entirely
i = i + 3
skip = 1
}
}
}
}
}
}
}
}
if skip == 0 {
copy_op(sb, i)
i = i + 1
}
}
}
swap_buffers(sb)
}
// Classify an address: 0=keep, 1=pair (remove both), 2=dead (replace write with pop)
fn classify_addr(ab: Field, ac: Field, addr: Field) -> Field {
let acu: U32 = convert.as_u32(ac)
let mut result: Field = 0
for j in 0..acu bounded 4096 {
let ji: Field = convert.as_field(j)
if mem.read(ab + ji * 3) == addr {
let wc: Field = mem.read(ab + ji * 3 + 1)
let rc: Field = mem.read(ab + ji * 3 + 2)
if wc == 1 {
if rc == 1 {
result = 1 // pair
} else if rc == 0 {
result = 2 // dead
}
} else if rc == 0 {
result = 2 // dead
}
}
}
return result
}
// =========================================================================
// Pass 5: Eliminate Dup(0)+Pop(1) and Dup(0)+Swap(1)+Pop(1) no-ops
// =========================================================================
fn pass_dup_pop_nops(sb: Field) {
let base: Field = s_in_base(sb)
let count: Field = s_in_count(sb)
s_set_out_count(sb, 0)
let mut i: Field = 0
let cu: U32 = convert.as_u32(count)
for step in 0..cu bounded 65536 {
if i == count {
// done
} else {
let mut skip: Field = 0
// 3-op pattern: Dup(0), Swap(1), Pop(1)
if convert.as_u32(i + 2) < cu {
if op_at(base, i) == lower.OP_DUP() {
if a0_at(base, i) == 0 {
if op_at(base, i + 1) == lower.OP_SWAP() {
if a0_at(base, i + 1) == 1 {
if op_at(base, i + 2) == lower.OP_POP() {
if a0_at(base, i + 2) == 1 {
i = i + 3
skip = 1
}
}
}
}
}
}
}
// 2-op pattern: Dup(0), Pop(1)
if skip == 0 {
if convert.as_u32(i + 1) < cu {
if op_at(base, i) == lower.OP_DUP() {
if a0_at(base, i) == 0 {
if op_at(base, i + 1) == lower.OP_POP() {
if a0_at(base, i + 1) == 1 {
i = i + 2
skip = 1
}
}
}
}
}
}
if skip == 0 {
copy_op(sb, i)
i = i + 1
}
}
}
swap_buffers(sb)
}
// =========================================================================
// Pass 6: Eliminate double swaps โ Swap(N)+Swap(N) โ nothing
// =========================================================================
fn pass_double_swaps(sb: Field) {
let base: Field = s_in_base(sb)
let count: Field = s_in_count(sb)
s_set_out_count(sb, 0)
let mut i: Field = 0
let cu: U32 = convert.as_u32(count)
for step in 0..cu bounded 65536 {
if i == count {
// done
} else {
let mut skip: Field = 0
if convert.as_u32(i + 1) < cu {
if op_at(base, i) == lower.OP_SWAP() {
if op_at(base, i + 1) == lower.OP_SWAP() {
if a0_at(base, i) == a0_at(base, i + 1) {
i = i + 2
skip = 1
}
}
}
}
if skip == 0 {
copy_op(sb, i)
i = i + 1
}
}
}
swap_buffers(sb)
}
// =========================================================================
// Pass 7: Collapse dup+swap+pop round-trips
// Pattern: NรDup(D), Swap(N), Pop(N) where D+1==N โ no-op
// =========================================================================
fn pass_swap_pop_chains(sb: Field) {
let base: Field = s_in_base(sb)
let count: Field = s_in_count(sb)
s_set_out_count(sb, 0)
let mut i: Field = 0
let cu: U32 = convert.as_u32(count)
for step in 0..cu bounded 65536 {
if i == count {
// done
} else {
let mut skip: Field = 0
if op_at(base, i) == lower.OP_DUP() {
let d_val: Field = a0_at(base, i)
// Count consecutive Dup(d_val)
let mut dup_count: Field = 0
let mut j: Field = i
let inner: U32 = convert.as_u32(count)
for s2 in 0..inner bounded 65536 {
if j == count {
// done
} else if op_at(base, j) == lower.OP_DUP() {
if a0_at(base, j) == d_val {
dup_count = dup_count + 1
j = j + 1
}
}
}
// Check for Swap(dup_count) + Pop totaling dup_count
if convert.as_u32(dup_count + field.neg(1)) < cu {
if convert.as_u32(j) < cu {
if op_at(base, j) == lower.OP_SWAP() {
if a0_at(base, j) == dup_count {
let after_swap: Field = j + 1
let mut total_popped: Field = 0
let mut k: Field = after_swap
let inner2: U32 = convert.as_u32(count)
for s3 in 0..inner2 bounded 65536 {
if k == count {
// done
} else if op_at(base, k) == lower.OP_POP() {
total_popped = total_popped + a0_at(base, k)
k = k + 1
if total_popped == dup_count {
// enough
}
}
}
if total_popped == dup_count {
if d_val + 1 == dup_count {
// No-op: skip everything
i = k
skip = 1
}
}
}
}
}
}
}
if skip == 0 {
copy_op(sb, i)
i = i + 1
}
}
}
swap_buffers(sb)
}
// =========================================================================
// Pass 8: Collapse epilogue swap(D)+pop(1) chains
// =========================================================================
fn pass_epilogue_cleanup(sb: Field) {
let base: Field = s_in_base(sb)
let count: Field = s_in_count(sb)
s_set_out_count(sb, 0)
let mut i: Field = 0
let cu: U32 = convert.as_u32(count)
for step in 0..cu bounded 65536 {
if i == count {
// done
} else {
let mut skip: Field = 0
// Need at least 4 ops (2 pairs minimum for optimization)
if convert.as_u32(i + 3) < cu {
if op_at(base, i) == lower.OP_SWAP() {
if op_at(base, i + 1) == lower.OP_POP() {
if a0_at(base, i + 1) == 1 {
let first_d: Field = a0_at(base, i)
// Count consecutive swap(D)+pop(1) pairs
let mut pair_count: Field = 1
let mut j: Field = i + 2
let inner: U32 = convert.as_u32(count)
for s2 in 0..inner bounded 65536 {
if convert.as_u32(j + 1) < cu {
if op_at(base, j) == lower.OP_SWAP() {
if op_at(base, j + 1) == lower.OP_POP() {
if a0_at(base, j + 1) == 1 {
pair_count = pair_count + 1
j = j + 2
}
}
}
}
}
// Optimize if 3+ pairs (pair_count >= 3)
if pair_count == 3 {
// Constant-depth D=1: swap(count)+pop(count)
if first_d == 1 {
emit_op(sb, lower.OP_SWAP(), pair_count, 0, 0)
emit_op(sb, lower.OP_POP(), pair_count, 0, 0)
} else {
emit_op(sb, lower.OP_SWAP(), first_d, 0, 0)
emit_op(sb, lower.OP_POP(), pair_count, 0, 0)
}
i = j
skip = 1
} else if convert.as_u32(pair_count) < convert.as_u32(3) {
// pair_count < 3, don't optimize
} else {
// pair_count > 3
if first_d == 1 {
// Emit in chunks of 5 for pop
emit_op(sb, lower.OP_SWAP(), pair_count, 0, 0)
// Batch pops
let mut rem: Field = pair_count
let bnd: U32 = convert.as_u32(20)
for s3 in 0..bnd bounded 20 {
if rem == 0 {
// done
} else {
let mut batch: Field = rem
if convert.as_u32(batch) < convert.as_u32(6) {
// ok
} else {
batch = 5
}
emit_op(sb, lower.OP_POP(), batch, 0, 0)
rem = rem + field.neg(batch)
}
}
} else {
emit_op(sb, lower.OP_SWAP(), first_d, 0, 0)
let mut rem: Field = pair_count
let bnd: U32 = convert.as_u32(20)
for s3 in 0..bnd bounded 20 {
if rem == 0 {
// done
} else {
let mut batch: Field = rem
if convert.as_u32(batch) < convert.as_u32(6) {
// ok
} else {
batch = 5
}
emit_op(sb, lower.OP_POP(), batch, 0, 0)
rem = rem + field.neg(batch)
}
}
}
i = j
skip = 1
}
}
}
}
}
if skip == 0 {
copy_op(sb, i)
i = i + 1
}
}
}
swap_buffers(sb)
}
// =========================================================================
// Nested optimization โ optimize bodies of structural ops in-place
//
// For our flat TIR format, structural ops reference body ranges within
// the same array. We optimize those sub-ranges by copying them out,
// running optimize on the sub-array, and writing them back.
// This is a simplified version โ we just run all non-nested passes on
// sub-ranges referenced by IfElse/IfOnly/Loop/ProofBlock.
// =========================================================================
// Note: Full nested optimization would require allocating temp buffers
// for each sub-range and running the full optimize loop. For the initial
// version, we skip nested optimization โ the flat passes already handle
// most patterns. The codegen module can apply optimization before
// inserting body ops into structural op ranges.
// =========================================================================
// Main entry point โ fixed-point loop over all passes
// =========================================================================
pub fn optimize(state_base: Field) {
let sb: Field = state_base
let max_passes: U32 = convert.as_u32(20)
for pass in 0..max_passes bounded 20 {
let before: Field = s_in_count(sb)
pass_merge_hints(sb)
pass_merge_pops(sb)
pass_eliminate_nops(sb)
pass_dead_spills(sb)
pass_dup_pop_nops(sb)
pass_double_swaps(sb)
pass_swap_pop_chains(sb)
pass_epilogue_cleanup(sb)
let after: Field = s_in_count(sb)
if before == after {
// Fixed point reached
return
}
}
}
// Get the optimized op count
pub fn result_count(state_base: Field) -> Field {
s_in_count(state_base)
}
// Get the base address of optimized ops
pub fn result_base(state_base: Field) -> Field {
s_in_base(state_base)
}
trident/std/compiler/optimize.tri
ฯ 0.0%