// Hand-optimized TASM baseline: std.private.poly
//
// Written from first principles: algorithm + Triton VM stack machine.
// NOT derived from compiler output.
//
// Triton VM memory semantics:
//   read_mem 1:  [addr, rest...] -> [addr-1, mem[addr], rest...]
//   write_mem 1: [val, addr, rest...] -> [addr+1, rest...]
//   pop N: removes top N elements
//   lt: [b, a, ...] -> [(a < b), ...]  (u32 comparison)
//
// Stack convention: first arg deepest, last arg on top.
//   eval(coeff_addr, x, n):           st0=n, st1=x, st2=coeff_addr
//   add(a_addr, b_addr, out_addr, n): st0=n, st1=out_addr, st2=b_addr, st3=a_addr
//   ntt(a_addr, omega, log_n):        st0=log_n, st1=omega, st2=a_addr
//
// Design principles:
//   - Entry functions match compiler's calling convention exactly.
//     No stack rearrangement before calling loops.
//   - Loops work with native arg order to minimize entry overhead.
//   - Single shared __pow2 subroutine for all 2^n computations.
//   - NTT butterfly uses direct lt guard (no push 0; eq negation).


// ===========================================================================
// EVAL -- Horner evaluation: P(x) = a[n-1]*x^(n-1) + ... + a[0]
// ===========================================================================
// Entry: [n, x, coeff_addr]
// Exit:  [result]
//
// Horner loop stack: [counter, accum, coeff_addr, x]
// Counter counts down; after decrement it equals the coefficient index.
__eval:
    swap 1
    swap 2
    swap 1
    push 0
    swap 1
    // [counter=n, accum=0, coeff_addr, x]
    call __eval_loop
    pop 1
    swap 2
    pop 2
    return

__eval_loop:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // [idx, accum, coeff_addr, x]
    swap 1
    dup 3
    mul
    // [accum*x, idx, coeff_addr, x]
    dup 1
    dup 3
    add
    read_mem 1
    pop 1
    // [coeff, accum*x, idx, coeff_addr, x]
    add
    swap 1
    recurse


// ===========================================================================
// ADD -- out[i] = a[i] + b[i], in-place loop
// ===========================================================================
// Entry: [n, out_addr, b_addr, a_addr]
// Exit:  [] (stack cleaned)
//
// Loop works with native arg order: [idx, out, b, a]
// No rearrangement needed.
__add:
    call __add_loop
    pop 4
    return

// Loop stack: [idx, out_addr, b_addr, a_addr]
// Counts down from n to 0.
__add_loop:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // [i, out, b, a]
    // read a[i]
    dup 0
    dup 4
    add
    read_mem 1
    pop 1
    // [a_val, i, out, b, a]
    // read b[i]
    dup 1
    dup 4
    add
    read_mem 1
    pop 1
    // [b_val, a_val, i, out, b, a]
    add
    // [sum, i, out, b, a]
    // write to out[i]
    dup 1
    dup 3
    add
    swap 1
    write_mem 1
    pop 1
    recurse


// ===========================================================================
// SUB -- out[i] = a[i] - b[i]
// ===========================================================================
// Entry: [n, out_addr, b_addr, a_addr]
// Exit:  []
__sub:
    call __sub_loop
    pop 4
    return

__sub_loop:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // read a[i]
    dup 0
    dup 4
    add
    read_mem 1
    pop 1
    // [a_val, i, out, b, a]
    // read b[i]
    dup 1
    dup 4
    add
    read_mem 1
    pop 1
    // [b_val, a_val, i, out, b, a]
    push -1
    mul
    add
    // [a-b, i, out, b, a]
    dup 1
    dup 3
    add
    swap 1
    write_mem 1
    pop 1
    recurse


// ===========================================================================
// POINTWISE_MUL -- out[i] = a[i] * b[i]
// ===========================================================================
// Entry: [n, out_addr, b_addr, a_addr]
// Exit:  []
__pointwise_mul:
    call __pw_mul_loop
    pop 4
    return

__pw_mul_loop:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    dup 0
    dup 4
    add
    read_mem 1
    pop 1
    dup 1
    dup 4
    add
    read_mem 1
    pop 1
    mul
    dup 1
    dup 3
    add
    swap 1
    write_mem 1
    pop 1
    recurse


// ===========================================================================
// SCALE -- out[i] = a[i] * s
// ===========================================================================
// Entry: [n, out_addr, s, a_addr]
// Exit:  []
__scale:
    call __scale_loop
    pop 4
    return

__scale_loop:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // [i, out, s, a]
    dup 0
    dup 4
    add
    read_mem 1
    pop 1
    // [val, i, out, s, a]
    dup 3
    mul
    // [val*s, i, out, s, a]
    dup 1
    dup 3
    add
    swap 1
    write_mem 1
    pop 1
    recurse


// ===========================================================================
// POW2 -- Shared 2^n computation: [counter, accum] -> [0, 2^counter * accum]
// ===========================================================================
// Used by ntt, intt, poly_mul for computing n = 2^log_n.
__pow2:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    swap 1
    dup 0
    add
    swap 1
    recurse


// ===========================================================================
// NTT -- Number Theoretic Transform (iterative Cooley-Tukey, in-place)
// ===========================================================================
// Entry: [log_n, omega, a_addr]
// Exit:  [] (stack cleaned)
//
// Algorithm: compute n = 2^log_n, then run stage/j/butterfly loops.
__ntt:
    // Compute n = 2^log_n
    dup 0
    push 1
    swap 1
    call __pow2
    pop 1
    // [n, log_n, omega, a_addr]
    // Rearrange to stage stack: [stage_count=log_n, len=1, w_step=omega, a_addr, n]
    swap 3
    swap 2
    swap 1
    push 1
    swap 1
    // [log_n, 1, omega, a_addr, n]
    call __ntt_stage
    pop 5
    return

// Stage loop: [sc, len, ws, a, n]
__ntt_stage:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // Push j-loop vars: [j_ctr=len, j=0, w=1, sc, len, ws, a, n]
    push 1
    push 0
    dup 4
    // [len, 0, 1, sc-1, len, ws, a, n]
    call __ntt_j
    pop 3
    // [sc-1, len, ws, a, n]
    // len *= 2
    swap 1
    dup 0
    add
    swap 1
    // ws *= ws
    swap 2
    dup 0
    mul
    swap 2
    recurse

// J-loop: [j_ctr, j, w, sc, len, ws, a, n]
__ntt_j:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // Push k=j for butterfly
    dup 1
    call __ntt_bfly
    pop 1
    // j += 1
    swap 1
    push 1
    add
    swap 1
    // w *= ws
    swap 2
    dup 5
    mul
    swap 2
    recurse

// Butterfly: [k, j_ctr, j, w, sc, len, ws, a, n]
// positions:  0    1     2  3   4   5    6   7  8
//
// Guard: k < n. Uses direct lt (no push 0; eq negation).
// Triton lt: [b, a] -> (a < b). We want: if k >= n, exit.
// Push n, then k: [k, n] -> lt -> (n < k)? No: lt gives (a < b) where a=n, b=k.
// We want "k < n" = true to continue. So: dup 8 (=n), dup 1 (=k on top) won't work
// because lt([k, n]) = (n < k). We want (k < n).
// Solution: push [n, k] -> lt gives (k < n). So: dup 0 (k), dup 9 (n after dup).
// But skiz skips if nonzero. lt returns 1 if true. So if k < n, lt=1, skiz skips
// the return, loop continues. If k >= n, lt=0, skiz doesn't skip, return executes.
// Wait: skiz skips NEXT instruction if st0 != 0. So we need:
//   if k >= n: execute return
//   if k < n: skip return
// With lt giving (k < n) = 1 when true: skiz would SKIP return when k < n. Good.
// But we need to negate: when k >= n, we get 0 from lt, skiz doesn't skip, falls to return... no.
// skiz: if st0 != 0, skip next. So if lt=1 (k<n), skip return -> continue. If lt=0 (k>=n), don't skip -> return.
// Wait, that's WRONG. We want to RETURN when k>=n. If lt=0, skiz doesn't skip, next instruction runs.
// But next instruction after skiz should be return. So: lt; skiz; return; <loop body>.
// When lt=0 (k>=n): skiz pops 0, doesn't skip, executes return. Correct.
// When lt=1 (k<n): skiz pops 1, skips return, enters loop body. Correct!
//
// So the guard is: dup 0; dup 9; lt; skiz; return  (5 ops, saves 2 vs old 7)
// Wait: dup 0 pushes k. dup 9 pushes n (shifted by 1). lt pops both, pushes (k < n).
// But lt is defined as [b, a] -> (a < b). st0=n (b), st1=k (a) after dup 0; dup 9.
// Wait no: after dup 0, stack is [k, k, j_ctr, ...]. After dup 9: n is at position 9
// in original (8+1 for dup). So [n, k, k, j_ctr, j, w, sc, len, ws, a, n].
// lt: st0=n, st1=k. lt gives (k < n).
// But we want to CONTINUE when k < n, RETURN when k >= n.
// lt returns 1 when k < n. skiz skips next when nonzero.
// So: lt(1) -> skiz skips return -> continues. lt(0) -> skiz doesn't skip -> returns.
// That's exactly backwards from what we want!
// When k < n (continue): lt=1, skiz SKIPS return -> good, continues.
// When k >= n (stop): lt=0, skiz DOESN'T skip -> executes return -> good.
// Actually this IS correct! Let me re-check:
// skiz: if top of stack is nonzero, skip next instruction.
// k < n -> lt = 1 (true) -> skiz skips return -> falls through to body. CORRECT.
// k >= n -> lt = 0 (false) -> skiz doesn't skip -> executes return. CORRECT.
// So: dup 0; dup 9; lt; skiz; return = 5 guard ops. But wait, we need `return` to
// not be skipped when we want to exit. Let me re-read skiz:
// "Skip next instruction if st0 is non-zero."
// After lt: st0 = (k < n) which is 1 or 0.
// If k < n: st0=1, skiz skips next (return), loop continues. GOOD.
// If k >= n: st0=0, skiz does NOT skip, return executes. GOOD.
//
// BUT WAIT. There's a subtlety with `lt` in Triton VM. From the spec:
// lt: top of stack is treated as unsigned 32-bit integers.
// Our k and n are field elements used as indices. If they're within u32 range, lt works.
// Since n = 2^log_n and k < n in all valid iterations, this is fine.
//
// HOWEVER: I realize the issue. `dup 0; dup 9` gives us [n, k, ...] where n is at
// st0 and k is at st1. Then lt gives us (st1 < st0) = (k < n). The result replaces
// both and goes to st0.
// Actually, re-reading Triton spec more carefully:
// lt: _ b a -> _ (a < b)
// Notation: rightmost = top. So st0=a, st1=b. Result: (a < b).
// After dup 0; dup 9: st0=n, st1=k. So a=n, b=k. lt gives (n < k).
// That's the OPPOSITE of what we want! (n < k) is true when k > n.
// So: when k < n, lt = 0, skiz doesn't skip, return executes -> WRONG!
//
// Fix: swap the order. Use dup 8; dup 1 instead:
// After dup 8: st0=n (n was at pos 8, no shift since we haven't pushed yet...
// wait, the original stack is [k, j_ctr, j, w, sc, len, ws, a, n].
// n is at position 8 (0-indexed from top). dup 8 copies n to top.
// Stack: [n, k, j_ctr, j, w, sc, len, ws, a, n].
// Then dup 1 copies k: [k, n, k, j_ctr, ...].
// lt: st0=k (a), st1=n (b). Result: (k < n). YES!
// When k < n: result=1, skiz skips return. Correct.
// When k >= n: result=0, skiz doesn't skip. Correct.
//
// Guard: dup 8; dup 1; lt; skiz; return  (5 ops)
__ntt_bfly:
    dup 8
    dup 1
    lt
    skiz
    return
    // lo_addr = a + k
    dup 0
    dup 8
    add
    // [lo, k, j_ctr, j, w, sc, len, ws, a, n]
    // hi_addr = lo + len
    dup 0
    dup 7
    add
    // [hi, lo, k, j_ctr, j, w, sc, len, ws, a, n]
    // Read hi_val from mem[hi]
    dup 0
    read_mem 1
    pop 1
    // [hi_val, hi, lo, k, j_ctr, j, w, sc, len, ws, a, n]
    // t = hi_val * w (w at position 6 after 3 pushes: 3+3=6)
    dup 6
    mul
    // [t, hi, lo, k, j_ctr, j, w, sc, len, ws, a, n]
    // Read lo_val from mem[lo]
    dup 2
    read_mem 1
    pop 1
    // [lo_val, t, hi, lo, k, j_ctr, j, w, sc, len, ws, a, n]
    // Write (lo+t) to lo_addr
    dup 0
    dup 2
    add
    dup 4
    swap 1
    write_mem 1
    pop 1
    // [lo_val, t, hi, lo, k, ...]
    // Compute lo - t, write to hi_addr
    swap 1
    push -1
    mul
    add
    // [lo-t, hi, lo, k, ...]
    swap 1
    swap 2
    pop 1
    swap 1
    write_mem 1
    pop 1
    // [k, j_ctr, j, w, sc, len, ws, a, n]
    // k += 2*len
    dup 5
    dup 0
    add
    add
    recurse


// ===========================================================================
// INTT -- Inverse NTT: forward NTT with omega_inv, then scale by n_inv
// ===========================================================================
// Entry: [log_n, n_inv, omega_inv, a_addr]
// Exit:  []
__intt:
    // Forward NTT(a_addr, omega_inv, log_n)
    dup 3
    dup 3
    dup 2
    call __ntt
    // [log_n, n_inv, omega_inv, a_addr]
    // Compute n = 2^log_n
    dup 0
    push 1
    swap 1
    call __pow2
    pop 1
    // [n, log_n, n_inv, omega_inv, a_addr]
    // Scale(a_addr, n_inv, a_addr, n): need [n, out=a_addr, s=n_inv, src=a_addr]
    dup 4
    dup 3
    dup 6
    dup 3
    call __scale
    pop 5
    return


// ===========================================================================
// POLY_MUL -- Polynomial multiplication via NTT pipeline
// ===========================================================================
// Entry: [log_n, n_inv, omega_inv, omega, tmp, out, b, a]
// Exit:  []
//
// Steps: size = 2^log_n, copy a->out & b->tmp, NTT(out), NTT(tmp),
//        pointwise_mul(out, tmp, out), INTT(out)
__poly_mul:
    // Compute size = 2^log_n
    dup 0
    push 1
    swap 1
    call __pow2
    pop 1
    // [size, log_n, n_inv, omega_inv, omega, tmp, out, b, a]
    // pos: 0=size 1=log_n 2=n_inv 3=omega_inv 4=omega 5=tmp 6=out 7=b 8=a

    // Copy a->out, b->tmp
    dup 6
    dup 6
    dup 10
    dup 10
    dup 4
    // [size, a, b, out, tmp, size, ...]
    call __pm_copy
    pop 5

    // NTT(out, omega, log_n)
    dup 6
    dup 5
    dup 3
    call __ntt

    // NTT(tmp, omega, log_n)
    dup 5
    dup 5
    dup 3
    call __ntt

    // pointwise_mul(out, tmp, out, size)
    dup 6
    dup 6
    dup 7
    dup 3
    call __pointwise_mul

    // INTT(out, omega_inv, n_inv, log_n)
    dup 6
    dup 4
    dup 4
    dup 4
    call __intt

    pop 5
    pop 4
    return

// Copy loop: a[i]->out[i], b[i]->tmp[i] simultaneously
// Stack: [counter, a_addr, b_addr, out_addr, tmp_addr]
__pm_copy:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // [i, a, b, out, tmp]
    // a[i] -> out[i]
    dup 0
    dup 2
    add
    read_mem 1
    pop 1
    dup 1
    dup 5
    add
    swap 1
    write_mem 1
    pop 1
    // b[i] -> tmp[i]
    dup 0
    dup 3
    add
    read_mem 1
    pop 1
    dup 1
    dup 6
    add
    swap 1
    write_mem 1
    pop 1
    recurse

Neighbours