// Hand-optimized TASM baseline: std.private.poly
//
// Written from first principles: algorithm + Triton VM stack machine.
// NOT derived from compiler output.
//
// Triton VM memory semantics:
// read_mem 1: [addr, rest...] -> [addr-1, mem[addr], rest...]
// write_mem 1: [val, addr, rest...] -> [addr+1, rest...]
// pop N: removes top N elements
// lt: [b, a, ...] -> [(a < b), ...] (u32 comparison)
//
// Stack convention: first arg deepest, last arg on top.
// eval(coeff_addr, x, n): st0=n, st1=x, st2=coeff_addr
// add(a_addr, b_addr, out_addr, n): st0=n, st1=out_addr, st2=b_addr, st3=a_addr
// ntt(a_addr, omega, log_n): st0=log_n, st1=omega, st2=a_addr
//
// Design principles:
// - Entry functions match compiler's calling convention exactly.
// No stack rearrangement before calling loops.
// - Loops work with native arg order to minimize entry overhead.
// - Single shared __pow2 subroutine for all 2^n computations.
// - NTT butterfly uses direct lt guard (no push 0; eq negation).
// ===========================================================================
// EVAL -- Horner evaluation: P(x) = a[n-1]*x^(n-1) + ... + a[0]
// ===========================================================================
// Entry: [n, x, coeff_addr]
// Exit: [result]
//
// Horner loop stack: [counter, accum, coeff_addr, x]
// Counter counts down; after decrement it equals the coefficient index.
__eval:
swap 1
swap 2
swap 1
push 0
swap 1
// [counter=n, accum=0, coeff_addr, x]
call __eval_loop
pop 1
swap 2
pop 2
return
__eval_loop:
dup 0
push 0
eq
skiz
return
push -1
add
// [idx, accum, coeff_addr, x]
swap 1
dup 3
mul
// [accum*x, idx, coeff_addr, x]
dup 1
dup 3
add
read_mem 1
pop 1
// [coeff, accum*x, idx, coeff_addr, x]
add
swap 1
recurse
// ===========================================================================
// ADD -- out[i] = a[i] + b[i], in-place loop
// ===========================================================================
// Entry: [n, out_addr, b_addr, a_addr]
// Exit: [] (stack cleaned)
//
// Loop works with native arg order: [idx, out, b, a]
// No rearrangement needed.
__add:
call __add_loop
pop 4
return
// Loop stack: [idx, out_addr, b_addr, a_addr]
// Counts down from n to 0.
__add_loop:
dup 0
push 0
eq
skiz
return
push -1
add
// [i, out, b, a]
// read a[i]
dup 0
dup 4
add
read_mem 1
pop 1
// [a_val, i, out, b, a]
// read b[i]
dup 1
dup 4
add
read_mem 1
pop 1
// [b_val, a_val, i, out, b, a]
add
// [sum, i, out, b, a]
// write to out[i]
dup 1
dup 3
add
swap 1
write_mem 1
pop 1
recurse
// ===========================================================================
// SUB -- out[i] = a[i] - b[i]
// ===========================================================================
// Entry: [n, out_addr, b_addr, a_addr]
// Exit: []
__sub:
call __sub_loop
pop 4
return
__sub_loop:
dup 0
push 0
eq
skiz
return
push -1
add
// read a[i]
dup 0
dup 4
add
read_mem 1
pop 1
// [a_val, i, out, b, a]
// read b[i]
dup 1
dup 4
add
read_mem 1
pop 1
// [b_val, a_val, i, out, b, a]
push -1
mul
add
// [a-b, i, out, b, a]
dup 1
dup 3
add
swap 1
write_mem 1
pop 1
recurse
// ===========================================================================
// POINTWISE_MUL -- out[i] = a[i] * b[i]
// ===========================================================================
// Entry: [n, out_addr, b_addr, a_addr]
// Exit: []
__pointwise_mul:
call __pw_mul_loop
pop 4
return
__pw_mul_loop:
dup 0
push 0
eq
skiz
return
push -1
add
dup 0
dup 4
add
read_mem 1
pop 1
dup 1
dup 4
add
read_mem 1
pop 1
mul
dup 1
dup 3
add
swap 1
write_mem 1
pop 1
recurse
// ===========================================================================
// SCALE -- out[i] = a[i] * s
// ===========================================================================
// Entry: [n, out_addr, s, a_addr]
// Exit: []
__scale:
call __scale_loop
pop 4
return
__scale_loop:
dup 0
push 0
eq
skiz
return
push -1
add
// [i, out, s, a]
dup 0
dup 4
add
read_mem 1
pop 1
// [val, i, out, s, a]
dup 3
mul
// [val*s, i, out, s, a]
dup 1
dup 3
add
swap 1
write_mem 1
pop 1
recurse
// ===========================================================================
// POW2 -- Shared 2^n computation: [counter, accum] -> [0, 2^counter * accum]
// ===========================================================================
// Used by ntt, intt, poly_mul for computing n = 2^log_n.
__pow2:
dup 0
push 0
eq
skiz
return
push -1
add
swap 1
dup 0
add
swap 1
recurse
// ===========================================================================
// NTT -- Number Theoretic Transform (iterative Cooley-Tukey, in-place)
// ===========================================================================
// Entry: [log_n, omega, a_addr]
// Exit: [] (stack cleaned)
//
// Algorithm: compute n = 2^log_n, then run stage/j/butterfly loops.
__ntt:
// Compute n = 2^log_n
dup 0
push 1
swap 1
call __pow2
pop 1
// [n, log_n, omega, a_addr]
// Rearrange to stage stack: [stage_count=log_n, len=1, w_step=omega, a_addr, n]
swap 3
swap 2
swap 1
push 1
swap 1
// [log_n, 1, omega, a_addr, n]
call __ntt_stage
pop 5
return
// Stage loop: [sc, len, ws, a, n]
__ntt_stage:
dup 0
push 0
eq
skiz
return
push -1
add
// Push j-loop vars: [j_ctr=len, j=0, w=1, sc, len, ws, a, n]
push 1
push 0
dup 4
// [len, 0, 1, sc-1, len, ws, a, n]
call __ntt_j
pop 3
// [sc-1, len, ws, a, n]
// len *= 2
swap 1
dup 0
add
swap 1
// ws *= ws
swap 2
dup 0
mul
swap 2
recurse
// J-loop: [j_ctr, j, w, sc, len, ws, a, n]
__ntt_j:
dup 0
push 0
eq
skiz
return
push -1
add
// Push k=j for butterfly
dup 1
call __ntt_bfly
pop 1
// j += 1
swap 1
push 1
add
swap 1
// w *= ws
swap 2
dup 5
mul
swap 2
recurse
// Butterfly: [k, j_ctr, j, w, sc, len, ws, a, n]
// positions: 0 1 2 3 4 5 6 7 8
//
// Guard: k < n. Uses direct lt (no push 0; eq negation).
// Triton lt: [b, a] -> (a < b). We want: if k >= n, exit.
// Push n, then k: [k, n] -> lt -> (n < k)? No: lt gives (a < b) where a=n, b=k.
// We want "k < n" = true to continue. So: dup 8 (=n), dup 1 (=k on top) won't work
// because lt([k, n]) = (n < k). We want (k < n).
// Solution: push [n, k] -> lt gives (k < n). So: dup 0 (k), dup 9 (n after dup).
// But skiz skips if nonzero. lt returns 1 if true. So if k < n, lt=1, skiz skips
// the return, loop continues. If k >= n, lt=0, skiz doesn't skip, return executes.
// Wait: skiz skips NEXT instruction if st0 != 0. So we need:
// if k >= n: execute return
// if k < n: skip return
// With lt giving (k < n) = 1 when true: skiz would SKIP return when k < n. Good.
// But we need to negate: when k >= n, we get 0 from lt, skiz doesn't skip, falls to return... no.
// skiz: if st0 != 0, skip next. So if lt=1 (k<n), skip return -> continue. If lt=0 (k>=n), don't skip -> return.
// Wait, that's WRONG. We want to RETURN when k>=n. If lt=0, skiz doesn't skip, next instruction runs.
// But next instruction after skiz should be return. So: lt; skiz; return; <loop body>.
// When lt=0 (k>=n): skiz pops 0, doesn't skip, executes return. Correct.
// When lt=1 (k<n): skiz pops 1, skips return, enters loop body. Correct!
//
// So the guard is: dup 0; dup 9; lt; skiz; return (5 ops, saves 2 vs old 7)
// Wait: dup 0 pushes k. dup 9 pushes n (shifted by 1). lt pops both, pushes (k < n).
// But lt is defined as [b, a] -> (a < b). st0=n (b), st1=k (a) after dup 0; dup 9.
// Wait no: after dup 0, stack is [k, k, j_ctr, ...]. After dup 9: n is at position 9
// in original (8+1 for dup). So [n, k, k, j_ctr, j, w, sc, len, ws, a, n].
// lt: st0=n, st1=k. lt gives (k < n).
// But we want to CONTINUE when k < n, RETURN when k >= n.
// lt returns 1 when k < n. skiz skips next when nonzero.
// So: lt(1) -> skiz skips return -> continues. lt(0) -> skiz doesn't skip -> returns.
// That's exactly backwards from what we want!
// When k < n (continue): lt=1, skiz SKIPS return -> good, continues.
// When k >= n (stop): lt=0, skiz DOESN'T skip -> executes return -> good.
// Actually this IS correct! Let me re-check:
// skiz: if top of stack is nonzero, skip next instruction.
// k < n -> lt = 1 (true) -> skiz skips return -> falls through to body. CORRECT.
// k >= n -> lt = 0 (false) -> skiz doesn't skip -> executes return. CORRECT.
// So: dup 0; dup 9; lt; skiz; return = 5 guard ops. But wait, we need `return` to
// not be skipped when we want to exit. Let me re-read skiz:
// "Skip next instruction if st0 is non-zero."
// After lt: st0 = (k < n) which is 1 or 0.
// If k < n: st0=1, skiz skips next (return), loop continues. GOOD.
// If k >= n: st0=0, skiz does NOT skip, return executes. GOOD.
//
// BUT WAIT. There's a subtlety with `lt` in Triton VM. From the spec:
// lt: top of stack is treated as unsigned 32-bit integers.
// Our k and n are field elements used as indices. If they're within u32 range, lt works.
// Since n = 2^log_n and k < n in all valid iterations, this is fine.
//
// HOWEVER: I realize the issue. `dup 0; dup 9` gives us [n, k, ...] where n is at
// st0 and k is at st1. Then lt gives us (st1 < st0) = (k < n). The result replaces
// both and goes to st0.
// Actually, re-reading Triton spec more carefully:
// lt: _ b a -> _ (a < b)
// Notation: rightmost = top. So st0=a, st1=b. Result: (a < b).
// After dup 0; dup 9: st0=n, st1=k. So a=n, b=k. lt gives (n < k).
// That's the OPPOSITE of what we want! (n < k) is true when k > n.
// So: when k < n, lt = 0, skiz doesn't skip, return executes -> WRONG!
//
// Fix: swap the order. Use dup 8; dup 1 instead:
// After dup 8: st0=n (n was at pos 8, no shift since we haven't pushed yet...
// wait, the original stack is [k, j_ctr, j, w, sc, len, ws, a, n].
// n is at position 8 (0-indexed from top). dup 8 copies n to top.
// Stack: [n, k, j_ctr, j, w, sc, len, ws, a, n].
// Then dup 1 copies k: [k, n, k, j_ctr, ...].
// lt: st0=k (a), st1=n (b). Result: (k < n). YES!
// When k < n: result=1, skiz skips return. Correct.
// When k >= n: result=0, skiz doesn't skip. Correct.
//
// Guard: dup 8; dup 1; lt; skiz; return (5 ops)
__ntt_bfly:
dup 8
dup 1
lt
skiz
return
// lo_addr = a + k
dup 0
dup 8
add
// [lo, k, j_ctr, j, w, sc, len, ws, a, n]
// hi_addr = lo + len
dup 0
dup 7
add
// [hi, lo, k, j_ctr, j, w, sc, len, ws, a, n]
// Read hi_val from mem[hi]
dup 0
read_mem 1
pop 1
// [hi_val, hi, lo, k, j_ctr, j, w, sc, len, ws, a, n]
// t = hi_val * w (w at position 6 after 3 pushes: 3+3=6)
dup 6
mul
// [t, hi, lo, k, j_ctr, j, w, sc, len, ws, a, n]
// Read lo_val from mem[lo]
dup 2
read_mem 1
pop 1
// [lo_val, t, hi, lo, k, j_ctr, j, w, sc, len, ws, a, n]
// Write (lo+t) to lo_addr
dup 0
dup 2
add
dup 4
swap 1
write_mem 1
pop 1
// [lo_val, t, hi, lo, k, ...]
// Compute lo - t, write to hi_addr
swap 1
push -1
mul
add
// [lo-t, hi, lo, k, ...]
swap 1
swap 2
pop 1
swap 1
write_mem 1
pop 1
// [k, j_ctr, j, w, sc, len, ws, a, n]
// k += 2*len
dup 5
dup 0
add
add
recurse
// ===========================================================================
// INTT -- Inverse NTT: forward NTT with omega_inv, then scale by n_inv
// ===========================================================================
// Entry: [log_n, n_inv, omega_inv, a_addr]
// Exit: []
__intt:
// Forward NTT(a_addr, omega_inv, log_n)
dup 3
dup 3
dup 2
call __ntt
// [log_n, n_inv, omega_inv, a_addr]
// Compute n = 2^log_n
dup 0
push 1
swap 1
call __pow2
pop 1
// [n, log_n, n_inv, omega_inv, a_addr]
// Scale(a_addr, n_inv, a_addr, n): need [n, out=a_addr, s=n_inv, src=a_addr]
dup 4
dup 3
dup 6
dup 3
call __scale
pop 5
return
// ===========================================================================
// POLY_MUL -- Polynomial multiplication via NTT pipeline
// ===========================================================================
// Entry: [log_n, n_inv, omega_inv, omega, tmp, out, b, a]
// Exit: []
//
// Steps: size = 2^log_n, copy a->out & b->tmp, NTT(out), NTT(tmp),
// pointwise_mul(out, tmp, out), INTT(out)
__poly_mul:
// Compute size = 2^log_n
dup 0
push 1
swap 1
call __pow2
pop 1
// [size, log_n, n_inv, omega_inv, omega, tmp, out, b, a]
// pos: 0=size 1=log_n 2=n_inv 3=omega_inv 4=omega 5=tmp 6=out 7=b 8=a
// Copy a->out, b->tmp
dup 6
dup 6
dup 10
dup 10
dup 4
// [size, a, b, out, tmp, size, ...]
call __pm_copy
pop 5
// NTT(out, omega, log_n)
dup 6
dup 5
dup 3
call __ntt
// NTT(tmp, omega, log_n)
dup 5
dup 5
dup 3
call __ntt
// pointwise_mul(out, tmp, out, size)
dup 6
dup 6
dup 7
dup 3
call __pointwise_mul
// INTT(out, omega_inv, n_inv, log_n)
dup 6
dup 4
dup 4
dup 4
call __intt
pop 5
pop 4
return
// Copy loop: a[i]->out[i], b[i]->tmp[i] simultaneously
// Stack: [counter, a_addr, b_addr, out_addr, tmp_addr]
__pm_copy:
dup 0
push 0
eq
skiz
return
push -1
add
// [i, a, b, out, tmp]
// a[i] -> out[i]
dup 0
dup 2
add
read_mem 1
pop 1
dup 1
dup 5
add
swap 1
write_mem 1
pop 1
// b[i] -> tmp[i]
dup 0
dup 3
add
read_mem 1
pop 1
dup 1
dup 6
add
swap 1
write_mem 1
pop 1
recurse