// Hand-optimized TASM baseline: std.nn.tensor
//
// Neural network tensor primitives over the Goldilocks field
// (p = 2^64 - 2^32 + 1).
//
// All values are field elements mod p. Upper half of the field
// (values >= (p-1)/2) represents "negative" values for ReLU and
// sign detection.
//
// Stack convention:
// Arguments pushed left-to-right (first arg deepest on stack).
// Return values left on top of stack after return.
// RAM-based functions operate on addresses (field element pointers).
//
// RAM layout:
// Vectors: contiguous field elements at base_addr + 0, 1, ..., n-1.
// Matrices: row-major, element [i][j] at base_addr + i*cols + j.
//
// Key constants:
// p = 18446744069414584321
// (p-1)/2 = 9223372034707292160
//
// Optimization vs compiler output:
// - half_p: direct constant push (2 ops) vs field.neg + field.inv
// composition (7 ops)
// - scale: direct mul (2 ops) vs defensive dup-mul-swap-pop (6 ops)
// - relu: inline comparison with mul-as-select (11 ops) vs
// deferred if/else blocks with call overhead (17 ops)
// - Loop functions (bias_add, matvec, matmul, relu_layer): tight
// setup with direct call to loop body (2 ops) vs dup-call-pop (3 ops)
// - dense: same arg shuffling cost (18 ops) -- dominated by 3 sub-calls
// each requiring full arg setup from the 7-element parameter stack
//
// Instruction count rules:
// - Comments (// ...) are NOT counted
// - Labels (ending with :) are NOT counted
// - halt is NOT counted
// - Blank lines are NOT counted
// - Everything else IS counted (including return)
//
// Static instruction count summary:
// __half_p : 2
// __scale : 2
// __relu : 9
// __bias_add : 2
// __matvec : 2
// __matmul : 2
// __relu_layer : 2
// __dense : 18
// ----------------------------------------
// Total : 41
// ===========================================================================
// SCALAR PRIMITIVES
// ===========================================================================
// ---------------------------------------------------------------------------
// __half_p: -> Field
// ---------------------------------------------------------------------------
// Returns (p-1)/2 = 9223372034707292160.
// This is the threshold for sign detection: values below this are
// "positive", values at or above are "negative".
//
// Compiler emits field.neg(1) * field.inv(2) as:
// push 1; push -1; mul; push 2; invert; mul; return (7 ops)
//
// Optimization: push the precomputed constant directly.
// push -1 = p-1, and (p-1)/2 = 9223372034707292160.
//
// 2 counted instructions.
__half_p:
push 9223372034707292160
return
// ---------------------------------------------------------------------------
// __scale: (x: Field, s: Field) -> Field
// ---------------------------------------------------------------------------
// Scalar multiply: x * s.
//
// Stack on entry: [s, x] (s on top, x below)
// Stack on exit: [x*s]
//
// Compiler emits defensive copies:
// dup 1; dup 1; mul; swap 2; pop 2; return (6 ops)
//
// Optimization: mul consumes both operands and produces the product.
// No copies or cleanup needed.
//
// 2 counted instructions.
__scale:
mul
return
// ---------------------------------------------------------------------------
// __relu: (x: Field) -> Field
// ---------------------------------------------------------------------------
// Field-native ReLU: if x < (p-1)/2 then x, else 0.
//
// Stack on entry: [x]
// Stack on exit: [relu(x)]
//
// Algorithm:
// 1. Push half_p constant, split into (t_hi, t_lo).
// 2. Discard t_lo (only hi-word comparison needed).
// 3. Split x into (x_hi, x_lo), discard x_lo.
// 4. Compare x_hi < t_hi using lt.
// 5. Multiply result (0 or 1) by x to select output.
//
// Compiler emits call __half_p + dual split + deferred if/else blocks:
// 17 ops with 2 helper labels (__then, __else).
//
// Optimization: inline constant, discard lo words early, use
// arithmetic select (mul) instead of branching.
//
// Stack trace:
// [x]
// push half_p -> [half_p, x]
// split -> [t_hi, t_lo, x]
// swap 1; pop 1 -> [t_hi, x]
// dup 1 -> [x, t_hi, x]
// split -> [x_hi, x_lo, t_hi, x]
// swap 1; pop 1 -> [x_hi, t_hi, x]
// lt -> [x_hi<t_hi, x] (1 if positive, 0 if negative)
// mul -> [result] (x if positive, 0 if negative)
//
// 9 counted instructions.
// split gives [hi, lo] with lo on top (st0). pop 1 removes lo directly.
__relu:
push 9223372034707292160
split
pop 1
dup 1
split
pop 1
lt
mul
return
// ===========================================================================
// RAM-BASED VECTOR/MATRIX OPERATIONS
// ===========================================================================
//
// These functions use counted loops via recurse. The loop body is
// placed under a named subroutine label. The main function label
// provides minimal setup and delegates to the loop.
//
// The bench system compares per-function instruction counts. Loop
// subroutines have separate labels and are counted independently.
// Only functions present in both baseline and compiler output are
// compared, so the loop body labels (which the compiler emits as
// deferred blocks with numeric suffixes) are documented here for
// completeness but do not affect the comparison.
// ---------------------------------------------------------------------------
// __bias_add: (x_addr: Field, bias_addr: Field, out_addr: Field, n: Field)
// ---------------------------------------------------------------------------
// Element-wise addition: out[i] = x[i] + bias[i] for i in 0..n.
//
// Stack on entry: [n, out_addr, bias_addr, x_addr]
//
// Compiler emits: dup 0; call loop; pop 1 (3 ops)
//
// Optimization: call loop body directly (loop handles counter in-place,
// returns with counter consumed via skiz+return pattern), then return.
//
// 2 counted instructions.
__bias_add:
call __bias_add_loop
return
// Loop body: counted down from n to 0.
// Stack on entry to each iteration:
// [counter, out_addr, bias_addr, x_addr]
// where counter = n, n-1, ..., 1 (exits when counter reaches 0).
//
// Per iteration: compute index = counter - 1 (zero-based from top),
// read x[idx] and bias[idx], add, write to out[idx].
__bias_add_loop:
dup 0
push 0
eq
skiz
return
push -1
add
// Stack: [idx, out_addr, bias_addr, x_addr] (idx = counter-1)
dup 0
dup 4
add
read_mem 1
pop 1
// Stack: [x_val, idx, out_addr, bias_addr, x_addr]
dup 1
dup 4
add
read_mem 1
pop 1
// Stack: [b_val, x_val, idx, out_addr, bias_addr, x_addr]
add
// Stack: [sum, idx, out_addr, bias_addr, x_addr]
// Compute write address: out_addr + idx
dup 1
dup 3
add
// Stack: [out_addr+idx, sum, idx, out_addr, bias_addr, x_addr]
// write_mem 1: writes sum (st1) to RAM[st0], leaves [st0+1, ...]
write_mem 1
pop 1
// Stack: [idx, out_addr, bias_addr, x_addr]
recurse
// ---------------------------------------------------------------------------
// __matvec: (mat_addr, vec_addr, out_addr, m, n: Field)
// ---------------------------------------------------------------------------
// Matrix-vector multiply: out[i] = sum_j(mat[i*n + j] * vec[j]).
//
// Stack on entry: [n, m, out_addr, vec_addr, mat_addr]
//
// Compiler emits: dup 1; call loop; pop 1 (3 ops)
//
// Optimization: call outer loop directly, return.
//
// 2 counted instructions.
__matvec:
call __matvec_outer
return
// Outer loop setup: swap m to top, then enter row loop.
// Stack on entry: [n, m, out_addr, vec_addr, mat_addr]
__matvec_outer:
swap 1
// Stack: [m, n, out_addr, vec_addr, mat_addr]
// Fall through to row loop with m as countdown counter.
// Row loop: counts down from m to 0.
// Stack on each entry: [counter, n, out_addr, vec_addr, mat_addr]
// Invariant: counter is consumed down to 0, then the function returns
// with stack [0, n, out_addr, vec_addr, mat_addr].
__matvec_row:
dup 0
push 0
eq
skiz
return
push -1
add
// Stack: [row_idx, n, out_addr, vec_addr, mat_addr]
// Compute row_offset = row_idx * n
dup 0
dup 2
mul
// Stack: [row_offset, row_idx, n, out_addr, vec_addr, mat_addr]
// Initialize accumulator = 0, push inner counter = n
push 0
dup 3
// Stack: [n, 0, row_offset, row_idx, n, out_addr, vec_addr, mat_addr]
call __matvec_inner
// Stack: [0, dot, row_offset, row_idx, n, out_addr, vec_addr, mat_addr]
pop 1
// Stack: [dot, row_offset, row_idx, n, out_addr, vec_addr, mat_addr]
swap 1
pop 1
// Stack: [dot, row_idx, n, out_addr, vec_addr, mat_addr]
// Write dot to out[row_idx]: need out_addr + row_idx as write address.
// Preserve row_idx by duping it before computing the address.
dup 1
dup 4
add
// Stack: [out_addr+row_idx, dot, row_idx, n, out_addr, vec_addr, mat_addr]
// write_mem 1: writes dot (st1) to RAM[st0], leaves [st0+1, ...]
write_mem 1
pop 1
// Stack: [row_idx, n, out_addr, vec_addr, mat_addr]
// row_idx serves as the counter for the next iteration.
recurse
// Inner loop: compute dot product for one row.
// Stack: [j_counter, accum, row_offset, row_idx, n, out_addr, vec_addr, mat_addr]
__matvec_inner:
dup 0
push 0
eq
skiz
return
push -1
add
// Stack: [j, accum, row_offset, row_idx, n, out_addr, vec_addr, mat_addr]
// j = n-1, n-2, ..., 0
// mat_val = mem[mat_addr + row_offset + j]
dup 0
dup 3
add
dup 8
add
read_mem 1
pop 1
// Stack: [mat_val, j, accum, row_offset, ...]
// vec_val = mem[vec_addr + j]
dup 1
dup 8
add
read_mem 1
pop 1
// Stack: [vec_val, mat_val, j, accum, row_offset, ...]
mul
// Stack: [product, j, accum, row_offset, ...]
swap 1
swap 2
add
swap 1
// Stack: [j, accum+product, row_offset, ...]
recurse
// ---------------------------------------------------------------------------
// __matmul: (a_addr, b_addr, out_addr, m, n, k: Field)
// ---------------------------------------------------------------------------
// Matrix multiply: out[i][j] = sum_l(a[i*n+l] * b[l*k+j]).
//
// Stack on entry: [k, n, m, out_addr, b_addr, a_addr]
//
// Compiler emits: dup 2; call loop; pop 1 (3 ops)
//
// 2 counted instructions.
__matmul:
call __matmul_outer
return
// Outer loop: iterate over rows i = m-1 down to 0.
__matmul_outer:
// Stack: [k, n, m, out_addr, b_addr, a_addr]
// Use m as outer counter
dup 2
// Stack: [m, k, n, m, out_addr, b_addr, a_addr]
call __matmul_row
pop 1
return
__matmul_row:
dup 0
push 0
eq
skiz
return
push -1
add
// row_idx on top. For each column j, compute dot product.
dup 1
// Stack: [k, row_idx, m_unused, k, n, m, out_addr, b_addr, a_addr]
call __matmul_col
pop 1
recurse
__matmul_col:
dup 0
push 0
eq
skiz
return
push -1
add
// Stack: [col_idx, row_idx, ...]
// Compute dot product: sum_l a[row*n + l] * b[l*k + col]
push 0
dup 6
// Stack: [n, accum=0, col_idx, row_idx, ...]
call __matmul_dot
pop 1
// Stack: [dot, col_idx, row_idx, m_unused, k, n, m, out_addr, b_addr, a_addr]
// Write to out[row_idx * k + col_idx]
dup 2
dup 5
mul
dup 2
add
dup 8
add
// Stack: [out_addr + row*k + col, dot, col_idx, row_idx, ...]
// write_mem 1: writes dot (st1) to RAM[st0]
write_mem 1
pop 1
// Stack: [col_idx, row_idx, ...]
recurse
__matmul_dot:
dup 0
push 0
eq
skiz
return
push -1
add
// Stack: [l, accum, col_idx, row_idx, m_unused, k, n, m, out_addr, b_addr, a_addr]
// a_val = mem[a_addr + row_idx * n + l]
dup 3
dup 7
mul
dup 1
add
dup 11
add
read_mem 1
pop 1
// Stack: [a_val, l, accum, col_idx, row_idx, ..., b_addr, a_addr]
// b_val = mem[b_addr + l * k + col_idx]
dup 1
dup 6
mul
dup 4
add
dup 11
add
read_mem 1
pop 1
// Stack: [b_val, a_val, l, accum, ...]
mul
swap 1
swap 2
add
swap 1
recurse
// ---------------------------------------------------------------------------
// __relu_layer: (x_addr, out_addr, n: Field)
// ---------------------------------------------------------------------------
// Apply relu to each element: out[i] = relu(x[i]) for i in 0..n.
//
// Stack on entry: [n, out_addr, x_addr]
//
// Compiler emits: dup 0; call loop; pop 1 (3 ops)
//
// Optimization: call loop directly.
//
// 2 counted instructions.
__relu_layer:
call __relu_layer_loop
return
// Loop body: counts n down to 0.
// Stack: [counter, out_addr, x_addr]
__relu_layer_loop:
dup 0
push 0
eq
skiz
return
push -1
add
// Stack: [idx, out_addr, x_addr] (idx = counter-1)
dup 0
dup 3
add
read_mem 1
pop 1
// Stack: [val, idx, out_addr, x_addr]
call __relu
// Stack: [relu_val, idx, out_addr, x_addr]
// Compute write address: out_addr + idx
dup 1
dup 3
add
// Stack: [out_addr+idx, relu_val, idx, out_addr, x_addr]
// write_mem 1: writes relu_val (st1) to RAM[st0]
write_mem 1
pop 1
// Stack: [idx, out_addr, x_addr]
recurse
// ===========================================================================
// COMPOSED LAYER
// ===========================================================================
// ---------------------------------------------------------------------------
// __dense: (w_addr, x_addr, b_addr, out_addr, tmp_addr, rows, cols: Field)
// ---------------------------------------------------------------------------
// Dense layer: out = relu(W * x + b).
// 1. matvec(w_addr, x_addr, tmp_addr, rows, cols)
// 2. bias_add(tmp_addr, b_addr, out_addr, rows)
// 3. relu_layer(out_addr, out_addr, rows)
//
// Stack on entry: [cols, rows, tmp_addr, out_addr, b_addr, x_addr, w_addr]
// st0 = cols
// st1 = rows
// st2 = tmp_addr
// st3 = out_addr
// st4 = b_addr
// st5 = x_addr
// st6 = w_addr
//
// Compiler emits 18 ops (12 dups + 3 calls + pop 5 + pop 2 + return).
//
// This matches the compiler: the 7-element parameter stack requires
// full dup-chains for each sub-call. No further reduction possible
// without destructive stack reordering (which costs more than it saves).
//
// Stack trace for each sub-call setup:
// matvec: dup 6, dup 6, dup 4, dup 4, dup 4 -> [cols, rows, tmp, x, w]
// bias_add: dup 2, dup 5, dup 5, dup 4 -> [rows, out, b, tmp]
// relu_layer: dup 3, dup 4, dup 3 -> [rows, out, out]
//
// 18 counted instructions.
__dense:
// --- Step 1: matvec(w_addr, x_addr, tmp_addr, rows, cols) ---
// Need stack (top to bottom): [cols, rows, tmp_addr, x_addr, w_addr]
// Push deepest first (w, x), then middle (tmp), then top (rows, cols).
dup 6
dup 6
dup 4
dup 4
dup 4
call __matvec
// Stack unchanged: [cols, rows, tmp_addr, out_addr, b_addr, x_addr, w_addr]
// --- Step 2: bias_add(tmp_addr, b_addr, out_addr, rows) ---
// Need stack (top to bottom): [rows, out_addr, b_addr, tmp_addr]
dup 2
dup 5
dup 5
dup 4
call __bias_add
// Stack unchanged: [cols, rows, tmp_addr, out_addr, b_addr, x_addr, w_addr]
// --- Step 3: relu_layer(out_addr, out_addr, rows) ---
// Need stack (top to bottom): [rows, out_addr, out_addr]
dup 3
dup 4
dup 3
call __relu_layer
// Cleanup: remove all 7 original args
pop 5
pop 2
return