module std.compiler.optimize

// Self-hosted TIR peephole optimizer.
//
// Runs pattern-based rewrites on a flat TIR array (stride 4) to reduce
// instruction count. Applied between TIR building and lowering.
//
// Eight peephole passes run in a fixed-point loop:
//   1. merge_hints      โ€” consecutive Hint(a)+Hint(b) โ†’ Hint(a+b)
//   2. merge_pops       โ€” consecutive Pop(a)+Pop(b) โ†’ Pop(a+b)
//   3. eliminate_nops    โ€” remove Swap(0), Pop(0)
//   4. dead_spills       โ€” dead spill/reload elimination
//   5. dup_pop_nops      โ€” Dup(0)+Pop(1) โ†’ nothing
//   6. double_swaps      โ€” Swap(N)+Swap(N) โ†’ nothing
//   7. swap_pop_chains   โ€” dup+swap+pop round-trips
//   8. epilogue_cleanup  โ€” swap(D)+pop(1) chains โ†’ compact
//   + optimize_nested    โ€” recurse into structural op bodies
//
// Memory layout (state block):
//   +0   tir_base       Input TIR ops (stride 4)
//   +1   tir_count      Number of input ops
//   +2   out_base       Output buffer (stride 4)
//   +3   out_count      Output op count
//   +4   addr_stats_base  Address stats for dead spill (stride 3: addr, writes, reads)
//   +5   addr_stats_count

use vm.core.field

use vm.core.convert

use vm.io.mem

use std.compiler.lower

// =========================================================================
// State accessors
// =========================================================================

fn s_in_base(sb: Field) -> Field { mem.read(sb) }
fn s_in_count(sb: Field) -> Field { mem.read(sb + 1) }
fn s_set_in_count(sb: Field, v: Field) { mem.write(sb + 1, v) }
fn s_out_base(sb: Field) -> Field { mem.read(sb + 2) }
fn s_out_count(sb: Field) -> Field { mem.read(sb + 3) }
fn s_set_out_count(sb: Field, v: Field) { mem.write(sb + 3, v) }
fn s_addr_base(sb: Field) -> Field { mem.read(sb + 4) }
fn s_addr_count(sb: Field) -> Field { mem.read(sb + 5) }
fn s_set_addr_count(sb: Field, v: Field) { mem.write(sb + 5, v) }

// TIR op accessors (stride 4 from a given base)
fn op_at(base: Field, idx: Field) -> Field { mem.read(base + idx * 4) }
fn a0_at(base: Field, idx: Field) -> Field { mem.read(base + idx * 4 + 1) }
fn a1_at(base: Field, idx: Field) -> Field { mem.read(base + idx * 4 + 2) }
fn a2_at(base: Field, idx: Field) -> Field { mem.read(base + idx * 4 + 3) }

// Write a TIR op to output
fn emit_op(sb: Field, opcode: Field, arg0: Field, arg1: Field, arg2: Field) {
    let base: Field = s_out_base(sb)
    let cnt: Field = s_out_count(sb)
    mem.write(base + cnt * 4, opcode)
    mem.write(base + cnt * 4 + 1, arg0)
    mem.write(base + cnt * 4 + 2, arg1)
    mem.write(base + cnt * 4 + 3, arg2)
    s_set_out_count(sb, cnt + 1)
}

// Copy a TIR op from input to output
fn copy_op(sb: Field, idx: Field) {
    let base: Field = s_in_base(sb)
    emit_op(sb, op_at(base, idx), a0_at(base, idx), a1_at(base, idx), a2_at(base, idx))
}

// Swap input/output buffers (for next pass)
fn swap_buffers(sb: Field) {
    let ib: Field = s_in_base(sb)
    let ic: Field = s_in_count(sb)
    let ob: Field = s_out_base(sb)
    let oc: Field = s_out_count(sb)
    mem.write(sb, ob)
    s_set_in_count(sb, oc)
    mem.write(sb + 2, ib)
    s_set_out_count(sb, 0)
}

// =========================================================================
// Pass 1: Merge consecutive Hints
// Hint(a) + Hint(b) โ†’ Hint(a+b), emitted in batches of 5
// =========================================================================

fn pass_merge_hints(sb: Field) {
    let base: Field = s_in_base(sb)
    let count: Field = s_in_count(sb)
    s_set_out_count(sb, 0)
    let mut i: Field = 0
    let cu: U32 = convert.as_u32(count)
    for step in 0..cu bounded 65536 {
        if i == count {
            // done
        } else if op_at(base, i) == lower.OP_HINT() {
            let mut total: Field = a0_at(base, i)
            i = i + 1
            // Merge consecutive hints
            let inner: U32 = convert.as_u32(count)
            for s2 in 0..inner bounded 65536 {
                if i == count {
                    // done
                } else if op_at(base, i) == lower.OP_HINT() {
                    total = total + a0_at(base, i)
                    i = i + 1
                }
            }
            // Emit in batches of 5
            let five: U32 = convert.as_u32(5)
            let mut rem: U32 = convert.as_u32(total)
            let bnd: U32 = convert.as_u32(20)
            for s3 in 0..bnd bounded 20 {
                if rem == convert.as_u32(0) {
                    // done
                } else {
                    let mut batch: U32 = rem
                    if batch < convert.as_u32(6) {
                        // batch <= 5, use as is
                    } else {
                        batch = five
                    }
                    emit_op(sb, lower.OP_HINT(), convert.as_field(batch), 0, 0)
                    rem = convert.as_u32(convert.as_field(rem) + field.neg(convert.as_field(batch)))
                }
            }
        } else {
            copy_op(sb, i)
            i = i + 1
        }
    }
    swap_buffers(sb)
}

// =========================================================================
// Pass 2: Merge consecutive Pops
// Pop(a) + Pop(b) โ†’ Pop(a+b), emitted in batches of 5
// =========================================================================

fn pass_merge_pops(sb: Field) {
    let base: Field = s_in_base(sb)
    let count: Field = s_in_count(sb)
    s_set_out_count(sb, 0)
    let mut i: Field = 0
    let cu: U32 = convert.as_u32(count)
    for step in 0..cu bounded 65536 {
        if i == count {
            // done
        } else if op_at(base, i) == lower.OP_POP() {
            let mut total: Field = a0_at(base, i)
            i = i + 1
            let inner: U32 = convert.as_u32(count)
            for s2 in 0..inner bounded 65536 {
                if i == count {
                    // done
                } else if op_at(base, i) == lower.OP_POP() {
                    total = total + a0_at(base, i)
                    i = i + 1
                }
            }
            let five: U32 = convert.as_u32(5)
            let mut rem: U32 = convert.as_u32(total)
            let bnd: U32 = convert.as_u32(20)
            for s3 in 0..bnd bounded 20 {
                if rem == convert.as_u32(0) {
                    // done
                } else {
                    let mut batch: U32 = rem
                    if batch < convert.as_u32(6) {
                        // ok
                    } else {
                        batch = five
                    }
                    emit_op(sb, lower.OP_POP(), convert.as_field(batch), 0, 0)
                    rem = convert.as_u32(convert.as_field(rem) + field.neg(convert.as_field(batch)))
                }
            }
        } else {
            copy_op(sb, i)
            i = i + 1
        }
    }
    swap_buffers(sb)
}

// =========================================================================
// Pass 3: Eliminate no-ops โ€” Swap(0) and Pop(0)
// =========================================================================

fn pass_eliminate_nops(sb: Field) {
    let base: Field = s_in_base(sb)
    let count: Field = s_in_count(sb)
    s_set_out_count(sb, 0)
    let cu: U32 = convert.as_u32(count)
    for i in 0..cu bounded 65536 {
        let idx: Field = convert.as_field(i)
        let op: Field = op_at(base, idx)
        let a0: Field = a0_at(base, idx)
        if op == lower.OP_SWAP() {
            if a0 == 0 {
                // skip Swap(0)
            } else {
                copy_op(sb, idx)
            }
        } else if op == lower.OP_POP() {
            if a0 == 0 {
                // skip Pop(0)
            } else {
                copy_op(sb, idx)
            }
        } else {
            copy_op(sb, idx)
        }
    }
    swap_buffers(sb)
}

// =========================================================================
// Pass 4: Dead spill elimination
// Spill write: Push(addr), Swap(1), WriteMem(1), Pop(1)
// Spill read:  Push(addr), ReadMem(1), Pop(1)
// =========================================================================

fn pass_dead_spills(sb: Field) {
    let base: Field = s_in_base(sb)
    let count: Field = s_in_count(sb)

    // Pass 1: count writes and reads per address
    let ab: Field = s_addr_base(sb)
    let mut ac: Field = 0

    // Scan for write patterns (4-op window)
    let cu: U32 = convert.as_u32(count)
    for i in 0..cu bounded 65536 {
        let idx: Field = convert.as_field(i)
        if convert.as_u32(idx + 3) < cu {
            if op_at(base, idx) == lower.OP_PUSH() {
                if op_at(base, idx + 1) == lower.OP_SWAP() {
                    if a0_at(base, idx + 1) == 1 {
                        if op_at(base, idx + 2) == lower.OP_WRITE_MEM() {
                            if a0_at(base, idx + 2) == 1 {
                                if op_at(base, idx + 3) == lower.OP_POP() {
                                    if a0_at(base, idx + 3) == 1 {
                                        let addr: Field = a0_at(base, idx)
                                        // Find or create addr stat
                                        let mut found: Field = 0
                                        let acu: U32 = convert.as_u32(ac)
                                        for j in 0..acu bounded 4096 {
                                            let ji: Field = convert.as_field(j)
                                            if mem.read(ab + ji * 3) == addr {
                                                mem.write(ab + ji * 3 + 1, mem.read(ab + ji * 3 + 1) + 1)
                                                found = 1
                                            }
                                        }
                                        if found == 0 {
                                            mem.write(ab + ac * 3, addr)
                                            mem.write(ab + ac * 3 + 1, 1) // writes
                                            mem.write(ab + ac * 3 + 2, 0) // reads
                                            ac = ac + 1
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }

    // Scan for read patterns (3-op window)
    for i in 0..cu bounded 65536 {
        let idx: Field = convert.as_field(i)
        if convert.as_u32(idx + 2) < cu {
            if op_at(base, idx) == lower.OP_PUSH() {
                if op_at(base, idx + 1) == lower.OP_READ_MEM() {
                    if a0_at(base, idx + 1) == 1 {
                        if op_at(base, idx + 2) == lower.OP_POP() {
                            if a0_at(base, idx + 2) == 1 {
                                let addr: Field = a0_at(base, idx)
                                let acu: U32 = convert.as_u32(ac)
                                for j in 0..acu bounded 4096 {
                                    let ji: Field = convert.as_field(j)
                                    if mem.read(ab + ji * 3) == addr {
                                        mem.write(ab + ji * 3 + 2, mem.read(ab + ji * 3 + 2) + 1)
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }

    // Classify: pair (1 write, 1 read) or dead (writes only, 0 reads)
    // Check if anything needs optimization
    let mut has_changes: Field = 0
    let acu: U32 = convert.as_u32(ac)
    for j in 0..acu bounded 4096 {
        let ji: Field = convert.as_field(j)
        let wc: Field = mem.read(ab + ji * 3 + 1)
        let rc: Field = mem.read(ab + ji * 3 + 2)
        if wc == 1 {
            if rc == 1 {
                has_changes = 1
            } else if rc == 0 {
                has_changes = 1
            }
        } else if rc == 0 {
            has_changes = 1
        }
    }

    if has_changes == 0 {
        // No changes, skip rewrite
        return
    }

    // Pass 2: rewrite
    s_set_out_count(sb, 0)
    let mut i: Field = 0
    for step in 0..cu bounded 65536 {
        if i == count {
            // done
        } else {
            let mut skip: Field = 0
            // Check write pattern
            if convert.as_u32(i + 3) < cu {
                if op_at(base, i) == lower.OP_PUSH() {
                    if op_at(base, i + 1) == lower.OP_SWAP() {
                        if a0_at(base, i + 1) == 1 {
                            if op_at(base, i + 2) == lower.OP_WRITE_MEM() {
                                if a0_at(base, i + 2) == 1 {
                                    if op_at(base, i + 3) == lower.OP_POP() {
                                        if a0_at(base, i + 3) == 1 {
                                            let addr: Field = a0_at(base, i)
                                            let cls: Field = classify_addr(ab, ac, addr)
                                            if cls == 1 {
                                                // pair: remove entirely
                                                i = i + 4
                                                skip = 1
                                            } else if cls == 2 {
                                                // dead: replace with Pop(1)
                                                emit_op(sb, lower.OP_POP(), 1, 0, 0)
                                                i = i + 4
                                                skip = 1
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }
            // Check read pattern
            if skip == 0 {
                if convert.as_u32(i + 2) < cu {
                    if op_at(base, i) == lower.OP_PUSH() {
                        if op_at(base, i + 1) == lower.OP_READ_MEM() {
                            if a0_at(base, i + 1) == 1 {
                                if op_at(base, i + 2) == lower.OP_POP() {
                                    if a0_at(base, i + 2) == 1 {
                                        let addr: Field = a0_at(base, i)
                                        let cls: Field = classify_addr(ab, ac, addr)
                                        if cls == 1 {
                                            // pair: remove entirely
                                            i = i + 3
                                            skip = 1
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }
            if skip == 0 {
                copy_op(sb, i)
                i = i + 1
            }
        }
    }
    swap_buffers(sb)
}

// Classify an address: 0=keep, 1=pair (remove both), 2=dead (replace write with pop)
fn classify_addr(ab: Field, ac: Field, addr: Field) -> Field {
    let acu: U32 = convert.as_u32(ac)
    let mut result: Field = 0
    for j in 0..acu bounded 4096 {
        let ji: Field = convert.as_field(j)
        if mem.read(ab + ji * 3) == addr {
            let wc: Field = mem.read(ab + ji * 3 + 1)
            let rc: Field = mem.read(ab + ji * 3 + 2)
            if wc == 1 {
                if rc == 1 {
                    result = 1 // pair
                } else if rc == 0 {
                    result = 2 // dead
                }
            } else if rc == 0 {
                result = 2 // dead
            }
        }
    }
    return result
}

// =========================================================================
// Pass 5: Eliminate Dup(0)+Pop(1) and Dup(0)+Swap(1)+Pop(1) no-ops
// =========================================================================

fn pass_dup_pop_nops(sb: Field) {
    let base: Field = s_in_base(sb)
    let count: Field = s_in_count(sb)
    s_set_out_count(sb, 0)
    let mut i: Field = 0
    let cu: U32 = convert.as_u32(count)
    for step in 0..cu bounded 65536 {
        if i == count {
            // done
        } else {
            let mut skip: Field = 0
            // 3-op pattern: Dup(0), Swap(1), Pop(1)
            if convert.as_u32(i + 2) < cu {
                if op_at(base, i) == lower.OP_DUP() {
                    if a0_at(base, i) == 0 {
                        if op_at(base, i + 1) == lower.OP_SWAP() {
                            if a0_at(base, i + 1) == 1 {
                                if op_at(base, i + 2) == lower.OP_POP() {
                                    if a0_at(base, i + 2) == 1 {
                                        i = i + 3
                                        skip = 1
                                    }
                                }
                            }
                        }
                    }
                }
            }
            // 2-op pattern: Dup(0), Pop(1)
            if skip == 0 {
                if convert.as_u32(i + 1) < cu {
                    if op_at(base, i) == lower.OP_DUP() {
                        if a0_at(base, i) == 0 {
                            if op_at(base, i + 1) == lower.OP_POP() {
                                if a0_at(base, i + 1) == 1 {
                                    i = i + 2
                                    skip = 1
                                }
                            }
                        }
                    }
                }
            }
            if skip == 0 {
                copy_op(sb, i)
                i = i + 1
            }
        }
    }
    swap_buffers(sb)
}

// =========================================================================
// Pass 6: Eliminate double swaps โ€” Swap(N)+Swap(N) โ†’ nothing
// =========================================================================

fn pass_double_swaps(sb: Field) {
    let base: Field = s_in_base(sb)
    let count: Field = s_in_count(sb)
    s_set_out_count(sb, 0)
    let mut i: Field = 0
    let cu: U32 = convert.as_u32(count)
    for step in 0..cu bounded 65536 {
        if i == count {
            // done
        } else {
            let mut skip: Field = 0
            if convert.as_u32(i + 1) < cu {
                if op_at(base, i) == lower.OP_SWAP() {
                    if op_at(base, i + 1) == lower.OP_SWAP() {
                        if a0_at(base, i) == a0_at(base, i + 1) {
                            i = i + 2
                            skip = 1
                        }
                    }
                }
            }
            if skip == 0 {
                copy_op(sb, i)
                i = i + 1
            }
        }
    }
    swap_buffers(sb)
}

// =========================================================================
// Pass 7: Collapse dup+swap+pop round-trips
// Pattern: Nร—Dup(D), Swap(N), Pop(N) where D+1==N โ†’ no-op
// =========================================================================

fn pass_swap_pop_chains(sb: Field) {
    let base: Field = s_in_base(sb)
    let count: Field = s_in_count(sb)
    s_set_out_count(sb, 0)
    let mut i: Field = 0
    let cu: U32 = convert.as_u32(count)
    for step in 0..cu bounded 65536 {
        if i == count {
            // done
        } else {
            let mut skip: Field = 0
            if op_at(base, i) == lower.OP_DUP() {
                let d_val: Field = a0_at(base, i)
                // Count consecutive Dup(d_val)
                let mut dup_count: Field = 0
                let mut j: Field = i
                let inner: U32 = convert.as_u32(count)
                for s2 in 0..inner bounded 65536 {
                    if j == count {
                        // done
                    } else if op_at(base, j) == lower.OP_DUP() {
                        if a0_at(base, j) == d_val {
                            dup_count = dup_count + 1
                            j = j + 1
                        }
                    }
                }
                // Check for Swap(dup_count) + Pop totaling dup_count
                if convert.as_u32(dup_count + field.neg(1)) < cu {
                    if convert.as_u32(j) < cu {
                        if op_at(base, j) == lower.OP_SWAP() {
                            if a0_at(base, j) == dup_count {
                                let after_swap: Field = j + 1
                                let mut total_popped: Field = 0
                                let mut k: Field = after_swap
                                let inner2: U32 = convert.as_u32(count)
                                for s3 in 0..inner2 bounded 65536 {
                                    if k == count {
                                        // done
                                    } else if op_at(base, k) == lower.OP_POP() {
                                        total_popped = total_popped + a0_at(base, k)
                                        k = k + 1
                                        if total_popped == dup_count {
                                            // enough
                                        }
                                    }
                                }
                                if total_popped == dup_count {
                                    if d_val + 1 == dup_count {
                                        // No-op: skip everything
                                        i = k
                                        skip = 1
                                    }
                                }
                            }
                        }
                    }
                }
            }
            if skip == 0 {
                copy_op(sb, i)
                i = i + 1
            }
        }
    }
    swap_buffers(sb)
}

// =========================================================================
// Pass 8: Collapse epilogue swap(D)+pop(1) chains
// =========================================================================

fn pass_epilogue_cleanup(sb: Field) {
    let base: Field = s_in_base(sb)
    let count: Field = s_in_count(sb)
    s_set_out_count(sb, 0)
    let mut i: Field = 0
    let cu: U32 = convert.as_u32(count)
    for step in 0..cu bounded 65536 {
        if i == count {
            // done
        } else {
            let mut skip: Field = 0
            // Need at least 4 ops (2 pairs minimum for optimization)
            if convert.as_u32(i + 3) < cu {
                if op_at(base, i) == lower.OP_SWAP() {
                    if op_at(base, i + 1) == lower.OP_POP() {
                        if a0_at(base, i + 1) == 1 {
                            let first_d: Field = a0_at(base, i)
                            // Count consecutive swap(D)+pop(1) pairs
                            let mut pair_count: Field = 1
                            let mut j: Field = i + 2
                            let inner: U32 = convert.as_u32(count)
                            for s2 in 0..inner bounded 65536 {
                                if convert.as_u32(j + 1) < cu {
                                    if op_at(base, j) == lower.OP_SWAP() {
                                        if op_at(base, j + 1) == lower.OP_POP() {
                                            if a0_at(base, j + 1) == 1 {
                                                pair_count = pair_count + 1
                                                j = j + 2
                                            }
                                        }
                                    }
                                }
                            }
                            // Optimize if 3+ pairs (pair_count >= 3)
                            if pair_count == 3 {
                                // Constant-depth D=1: swap(count)+pop(count)
                                if first_d == 1 {
                                    emit_op(sb, lower.OP_SWAP(), pair_count, 0, 0)
                                    emit_op(sb, lower.OP_POP(), pair_count, 0, 0)
                                } else {
                                    emit_op(sb, lower.OP_SWAP(), first_d, 0, 0)
                                    emit_op(sb, lower.OP_POP(), pair_count, 0, 0)
                                }
                                i = j
                                skip = 1
                            } else if convert.as_u32(pair_count) < convert.as_u32(3) {
                                // pair_count < 3, don't optimize
                            } else {
                                // pair_count > 3
                                if first_d == 1 {
                                    // Emit in chunks of 5 for pop
                                    emit_op(sb, lower.OP_SWAP(), pair_count, 0, 0)
                                    // Batch pops
                                    let mut rem: Field = pair_count
                                    let bnd: U32 = convert.as_u32(20)
                                    for s3 in 0..bnd bounded 20 {
                                        if rem == 0 {
                                            // done
                                        } else {
                                            let mut batch: Field = rem
                                            if convert.as_u32(batch) < convert.as_u32(6) {
                                                // ok
                                            } else {
                                                batch = 5
                                            }
                                            emit_op(sb, lower.OP_POP(), batch, 0, 0)
                                            rem = rem + field.neg(batch)
                                        }
                                    }
                                } else {
                                    emit_op(sb, lower.OP_SWAP(), first_d, 0, 0)
                                    let mut rem: Field = pair_count
                                    let bnd: U32 = convert.as_u32(20)
                                    for s3 in 0..bnd bounded 20 {
                                        if rem == 0 {
                                            // done
                                        } else {
                                            let mut batch: Field = rem
                                            if convert.as_u32(batch) < convert.as_u32(6) {
                                                // ok
                                            } else {
                                                batch = 5
                                            }
                                            emit_op(sb, lower.OP_POP(), batch, 0, 0)
                                            rem = rem + field.neg(batch)
                                        }
                                    }
                                }
                                i = j
                                skip = 1
                            }
                        }
                    }
                }
            }
            if skip == 0 {
                copy_op(sb, i)
                i = i + 1
            }
        }
    }
    swap_buffers(sb)
}

// =========================================================================
// Nested optimization โ€” optimize bodies of structural ops in-place
//
// For our flat TIR format, structural ops reference body ranges within
// the same array. We optimize those sub-ranges by copying them out,
// running optimize on the sub-array, and writing them back.
// This is a simplified version โ€” we just run all non-nested passes on
// sub-ranges referenced by IfElse/IfOnly/Loop/ProofBlock.
// =========================================================================

// Note: Full nested optimization would require allocating temp buffers
// for each sub-range and running the full optimize loop. For the initial
// version, we skip nested optimization โ€” the flat passes already handle
// most patterns. The codegen module can apply optimization before
// inserting body ops into structural op ranges.

// =========================================================================
// Main entry point โ€” fixed-point loop over all passes
// =========================================================================

pub fn optimize(state_base: Field) {
    let sb: Field = state_base
    let max_passes: U32 = convert.as_u32(20)
    for pass in 0..max_passes bounded 20 {
        let before: Field = s_in_count(sb)
        pass_merge_hints(sb)
        pass_merge_pops(sb)
        pass_eliminate_nops(sb)
        pass_dead_spills(sb)
        pass_dup_pop_nops(sb)
        pass_double_swaps(sb)
        pass_swap_pop_chains(sb)
        pass_epilogue_cleanup(sb)
        let after: Field = s_in_count(sb)
        if before == after {
            // Fixed point reached
            return
        }
    }
}

// Get the optimized op count
pub fn result_count(state_base: Field) -> Field {
    s_in_count(state_base)
}

// Get the base address of optimized ops
pub fn result_base(state_base: Field) -> Field {
    s_in_base(state_base)
}

Dimensions

trident/benches/harnesses/std/compiler/optimize.tri

Local Graph