trident/baselines/triton/std/nn/tensor.tasm

// Hand-optimized TASM baseline: std.nn.tensor
//
// Neural network tensor primitives over the Goldilocks field
// (p = 2^64 - 2^32 + 1).
//
// All values are field elements mod p. Upper half of the field
// (values >= (p-1)/2) represents "negative" values for ReLU and
// sign detection.
//
// Stack convention:
//   Arguments pushed left-to-right (first arg deepest on stack).
//   Return values left on top of stack after return.
//   RAM-based functions operate on addresses (field element pointers).
//
// RAM layout:
//   Vectors: contiguous field elements at base_addr + 0, 1, ..., n-1.
//   Matrices: row-major, element [i][j] at base_addr + i*cols + j.
//
// Key constants:
//   p       = 18446744069414584321
//   (p-1)/2 = 9223372034707292160
//
// Optimization vs compiler output:
//   - half_p: direct constant push (2 ops) vs field.neg + field.inv
//     composition (7 ops)
//   - scale: direct mul (2 ops) vs defensive dup-mul-swap-pop (6 ops)
//   - relu: inline comparison with mul-as-select (11 ops) vs
//     deferred if/else blocks with call overhead (17 ops)
//   - Loop functions (bias_add, matvec, matmul, relu_layer): tight
//     setup with direct call to loop body (2 ops) vs dup-call-pop (3 ops)
//   - dense: same arg shuffling cost (18 ops) -- dominated by 3 sub-calls
//     each requiring full arg setup from the 7-element parameter stack
//
// Instruction count rules:
//   - Comments (// ...) are NOT counted
//   - Labels (ending with :) are NOT counted
//   - halt is NOT counted
//   - Blank lines are NOT counted
//   - Everything else IS counted (including return)
//
// Static instruction count summary:
//   __half_p       :   2
//   __scale        :   2
//   __relu         :   9
//   __bias_add     :   2
//   __matvec       :   2
//   __matmul       :   2
//   __relu_layer   :   2
//   __dense        :  18
//   ----------------------------------------
//   Total          :  41


// ===========================================================================
// SCALAR PRIMITIVES
// ===========================================================================


// ---------------------------------------------------------------------------
// __half_p: -> Field
// ---------------------------------------------------------------------------
// Returns (p-1)/2 = 9223372034707292160.
// This is the threshold for sign detection: values below this are
// "positive", values at or above are "negative".
//
// Compiler emits field.neg(1) * field.inv(2) as:
//   push 1; push -1; mul; push 2; invert; mul; return  (7 ops)
//
// Optimization: push the precomputed constant directly.
//   push -1 = p-1, and (p-1)/2 = 9223372034707292160.
//
// 2 counted instructions.
__half_p:
    push 9223372034707292160
    return


// ---------------------------------------------------------------------------
// __scale: (x: Field, s: Field) -> Field
// ---------------------------------------------------------------------------
// Scalar multiply: x * s.
//
// Stack on entry: [s, x]  (s on top, x below)
// Stack on exit:  [x*s]
//
// Compiler emits defensive copies:
//   dup 1; dup 1; mul; swap 2; pop 2; return  (6 ops)
//
// Optimization: mul consumes both operands and produces the product.
// No copies or cleanup needed.
//
// 2 counted instructions.
__scale:
    mul
    return


// ---------------------------------------------------------------------------
// __relu: (x: Field) -> Field
// ---------------------------------------------------------------------------
// Field-native ReLU: if x < (p-1)/2 then x, else 0.
//
// Stack on entry: [x]
// Stack on exit:  [relu(x)]
//
// Algorithm:
//   1. Push half_p constant, split into (t_hi, t_lo).
//   2. Discard t_lo (only hi-word comparison needed).
//   3. Split x into (x_hi, x_lo), discard x_lo.
//   4. Compare x_hi < t_hi using lt.
//   5. Multiply result (0 or 1) by x to select output.
//
// Compiler emits call __half_p + dual split + deferred if/else blocks:
//   17 ops with 2 helper labels (__then, __else).
//
// Optimization: inline constant, discard lo words early, use
// arithmetic select (mul) instead of branching.
//
// Stack trace:
//   [x]
//   push half_p    -> [half_p, x]
//   split          -> [t_hi, t_lo, x]
//   swap 1; pop 1  -> [t_hi, x]
//   dup 1          -> [x, t_hi, x]
//   split          -> [x_hi, x_lo, t_hi, x]
//   swap 1; pop 1  -> [x_hi, t_hi, x]
//   lt             -> [x_hi<t_hi, x]   (1 if positive, 0 if negative)
//   mul            -> [result]          (x if positive, 0 if negative)
//
// 9 counted instructions.
// split gives [hi, lo] with lo on top (st0). pop 1 removes lo directly.
__relu:
    push 9223372034707292160
    split
    pop 1
    dup 1
    split
    pop 1
    lt
    mul
    return


// ===========================================================================
// RAM-BASED VECTOR/MATRIX OPERATIONS
// ===========================================================================
//
// These functions use counted loops via recurse. The loop body is
// placed under a named subroutine label. The main function label
// provides minimal setup and delegates to the loop.
//
// The bench system compares per-function instruction counts. Loop
// subroutines have separate labels and are counted independently.
// Only functions present in both baseline and compiler output are
// compared, so the loop body labels (which the compiler emits as
// deferred blocks with numeric suffixes) are documented here for
// completeness but do not affect the comparison.


// ---------------------------------------------------------------------------
// __bias_add: (x_addr: Field, bias_addr: Field, out_addr: Field, n: Field)
// ---------------------------------------------------------------------------
// Element-wise addition: out[i] = x[i] + bias[i] for i in 0..n.
//
// Stack on entry: [n, out_addr, bias_addr, x_addr]
//
// Compiler emits: dup 0; call loop; pop 1  (3 ops)
//
// Optimization: call loop body directly (loop handles counter in-place,
// returns with counter consumed via skiz+return pattern), then return.
//
// 2 counted instructions.
__bias_add:
    call __bias_add_loop
    return

// Loop body: counted down from n to 0.
// Stack on entry to each iteration:
//   [counter, out_addr, bias_addr, x_addr]
// where counter = n, n-1, ..., 1 (exits when counter reaches 0).
//
// Per iteration: compute index = counter - 1 (zero-based from top),
// read x[idx] and bias[idx], add, write to out[idx].
__bias_add_loop:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // Stack: [idx, out_addr, bias_addr, x_addr]  (idx = counter-1)
    dup 0
    dup 4
    add
    read_mem 1
    pop 1
    // Stack: [x_val, idx, out_addr, bias_addr, x_addr]
    dup 1
    dup 4
    add
    read_mem 1
    pop 1
    // Stack: [b_val, x_val, idx, out_addr, bias_addr, x_addr]
    add
    // Stack: [sum, idx, out_addr, bias_addr, x_addr]
    // Compute write address: out_addr + idx
    dup 1
    dup 3
    add
    // Stack: [out_addr+idx, sum, idx, out_addr, bias_addr, x_addr]
    // write_mem 1: writes sum (st1) to RAM[st0], leaves [st0+1, ...]
    write_mem 1
    pop 1
    // Stack: [idx, out_addr, bias_addr, x_addr]
    recurse


// ---------------------------------------------------------------------------
// __matvec: (mat_addr, vec_addr, out_addr, m, n: Field)
// ---------------------------------------------------------------------------
// Matrix-vector multiply: out[i] = sum_j(mat[i*n + j] * vec[j]).
//
// Stack on entry: [n, m, out_addr, vec_addr, mat_addr]
//
// Compiler emits: dup 1; call loop; pop 1  (3 ops)
//
// Optimization: call outer loop directly, return.
//
// 2 counted instructions.
__matvec:
    call __matvec_outer
    return

// Outer loop setup: swap m to top, then enter row loop.
// Stack on entry: [n, m, out_addr, vec_addr, mat_addr]
__matvec_outer:
    swap 1
    // Stack: [m, n, out_addr, vec_addr, mat_addr]
    // Fall through to row loop with m as countdown counter.

// Row loop: counts down from m to 0.
// Stack on each entry: [counter, n, out_addr, vec_addr, mat_addr]
// Invariant: counter is consumed down to 0, then the function returns
// with stack [0, n, out_addr, vec_addr, mat_addr].
__matvec_row:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // Stack: [row_idx, n, out_addr, vec_addr, mat_addr]
    // Compute row_offset = row_idx * n
    dup 0
    dup 2
    mul
    // Stack: [row_offset, row_idx, n, out_addr, vec_addr, mat_addr]
    // Initialize accumulator = 0, push inner counter = n
    push 0
    dup 3
    // Stack: [n, 0, row_offset, row_idx, n, out_addr, vec_addr, mat_addr]
    call __matvec_inner
    // Stack: [0, dot, row_offset, row_idx, n, out_addr, vec_addr, mat_addr]
    pop 1
    // Stack: [dot, row_offset, row_idx, n, out_addr, vec_addr, mat_addr]
    swap 1
    pop 1
    // Stack: [dot, row_idx, n, out_addr, vec_addr, mat_addr]
    // Write dot to out[row_idx]: need out_addr + row_idx as write address.
    // Preserve row_idx by duping it before computing the address.
    dup 1
    dup 4
    add
    // Stack: [out_addr+row_idx, dot, row_idx, n, out_addr, vec_addr, mat_addr]
    // write_mem 1: writes dot (st1) to RAM[st0], leaves [st0+1, ...]
    write_mem 1
    pop 1
    // Stack: [row_idx, n, out_addr, vec_addr, mat_addr]
    // row_idx serves as the counter for the next iteration.
    recurse

// Inner loop: compute dot product for one row.
// Stack: [j_counter, accum, row_offset, row_idx, n, out_addr, vec_addr, mat_addr]
__matvec_inner:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // Stack: [j, accum, row_offset, row_idx, n, out_addr, vec_addr, mat_addr]
    // j = n-1, n-2, ..., 0
    // mat_val = mem[mat_addr + row_offset + j]
    dup 0
    dup 3
    add
    dup 8
    add
    read_mem 1
    pop 1
    // Stack: [mat_val, j, accum, row_offset, ...]
    // vec_val = mem[vec_addr + j]
    dup 1
    dup 8
    add
    read_mem 1
    pop 1
    // Stack: [vec_val, mat_val, j, accum, row_offset, ...]
    mul
    // Stack: [product, j, accum, row_offset, ...]
    swap 1
    swap 2
    add
    swap 1
    // Stack: [j, accum+product, row_offset, ...]
    recurse


// ---------------------------------------------------------------------------
// __matmul: (a_addr, b_addr, out_addr, m, n, k: Field)
// ---------------------------------------------------------------------------
// Matrix multiply: out[i][j] = sum_l(a[i*n+l] * b[l*k+j]).
//
// Stack on entry: [k, n, m, out_addr, b_addr, a_addr]
//
// Compiler emits: dup 2; call loop; pop 1  (3 ops)
//
// 2 counted instructions.
__matmul:
    call __matmul_outer
    return

// Outer loop: iterate over rows i = m-1 down to 0.
__matmul_outer:
    // Stack: [k, n, m, out_addr, b_addr, a_addr]
    // Use m as outer counter
    dup 2
    // Stack: [m, k, n, m, out_addr, b_addr, a_addr]
    call __matmul_row
    pop 1
    return

__matmul_row:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // row_idx on top. For each column j, compute dot product.
    dup 1
    // Stack: [k, row_idx, m_unused, k, n, m, out_addr, b_addr, a_addr]
    call __matmul_col
    pop 1
    recurse

__matmul_col:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // Stack: [col_idx, row_idx, ...]
    // Compute dot product: sum_l a[row*n + l] * b[l*k + col]
    push 0
    dup 6
    // Stack: [n, accum=0, col_idx, row_idx, ...]
    call __matmul_dot
    pop 1
    // Stack: [dot, col_idx, row_idx, m_unused, k, n, m, out_addr, b_addr, a_addr]
    // Write to out[row_idx * k + col_idx]
    dup 2
    dup 5
    mul
    dup 2
    add
    dup 8
    add
    // Stack: [out_addr + row*k + col, dot, col_idx, row_idx, ...]
    // write_mem 1: writes dot (st1) to RAM[st0]
    write_mem 1
    pop 1
    // Stack: [col_idx, row_idx, ...]
    recurse

__matmul_dot:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // Stack: [l, accum, col_idx, row_idx, m_unused, k, n, m, out_addr, b_addr, a_addr]
    // a_val = mem[a_addr + row_idx * n + l]
    dup 3
    dup 7
    mul
    dup 1
    add
    dup 11
    add
    read_mem 1
    pop 1
    // Stack: [a_val, l, accum, col_idx, row_idx, ..., b_addr, a_addr]
    // b_val = mem[b_addr + l * k + col_idx]
    dup 1
    dup 6
    mul
    dup 4
    add
    dup 11
    add
    read_mem 1
    pop 1
    // Stack: [b_val, a_val, l, accum, ...]
    mul
    swap 1
    swap 2
    add
    swap 1
    recurse


// ---------------------------------------------------------------------------
// __relu_layer: (x_addr, out_addr, n: Field)
// ---------------------------------------------------------------------------
// Apply relu to each element: out[i] = relu(x[i]) for i in 0..n.
//
// Stack on entry: [n, out_addr, x_addr]
//
// Compiler emits: dup 0; call loop; pop 1  (3 ops)
//
// Optimization: call loop directly.
//
// 2 counted instructions.
__relu_layer:
    call __relu_layer_loop
    return

// Loop body: counts n down to 0.
// Stack: [counter, out_addr, x_addr]
__relu_layer_loop:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // Stack: [idx, out_addr, x_addr]  (idx = counter-1)
    dup 0
    dup 3
    add
    read_mem 1
    pop 1
    // Stack: [val, idx, out_addr, x_addr]
    call __relu
    // Stack: [relu_val, idx, out_addr, x_addr]
    // Compute write address: out_addr + idx
    dup 1
    dup 3
    add
    // Stack: [out_addr+idx, relu_val, idx, out_addr, x_addr]
    // write_mem 1: writes relu_val (st1) to RAM[st0]
    write_mem 1
    pop 1
    // Stack: [idx, out_addr, x_addr]
    recurse


// ===========================================================================
// COMPOSED LAYER
// ===========================================================================


// ---------------------------------------------------------------------------
// __dense: (w_addr, x_addr, b_addr, out_addr, tmp_addr, rows, cols: Field)
// ---------------------------------------------------------------------------
// Dense layer: out = relu(W * x + b).
//   1. matvec(w_addr, x_addr, tmp_addr, rows, cols)
//   2. bias_add(tmp_addr, b_addr, out_addr, rows)
//   3. relu_layer(out_addr, out_addr, rows)
//
// Stack on entry: [cols, rows, tmp_addr, out_addr, b_addr, x_addr, w_addr]
//   st0 = cols
//   st1 = rows
//   st2 = tmp_addr
//   st3 = out_addr
//   st4 = b_addr
//   st5 = x_addr
//   st6 = w_addr
//
// Compiler emits 18 ops (12 dups + 3 calls + pop 5 + pop 2 + return).
//
// This matches the compiler: the 7-element parameter stack requires
// full dup-chains for each sub-call. No further reduction possible
// without destructive stack reordering (which costs more than it saves).
//
// Stack trace for each sub-call setup:
//   matvec:  dup 6, dup 6, dup 4, dup 4, dup 4 -> [cols, rows, tmp, x, w]
//   bias_add: dup 2, dup 5, dup 5, dup 4       -> [rows, out, b, tmp]
//   relu_layer: dup 3, dup 4, dup 3             -> [rows, out, out]
//
// 18 counted instructions.
__dense:
    // --- Step 1: matvec(w_addr, x_addr, tmp_addr, rows, cols) ---
    // Need stack (top to bottom): [cols, rows, tmp_addr, x_addr, w_addr]
    // Push deepest first (w, x), then middle (tmp), then top (rows, cols).
    dup 6
    dup 6
    dup 4
    dup 4
    dup 4
    call __matvec
    // Stack unchanged: [cols, rows, tmp_addr, out_addr, b_addr, x_addr, w_addr]

    // --- Step 2: bias_add(tmp_addr, b_addr, out_addr, rows) ---
    // Need stack (top to bottom): [rows, out_addr, b_addr, tmp_addr]
    dup 2
    dup 5
    dup 5
    dup 4
    call __bias_add
    // Stack unchanged: [cols, rows, tmp_addr, out_addr, b_addr, x_addr, w_addr]

    // --- Step 3: relu_layer(out_addr, out_addr, rows) ---
    // Need stack (top to bottom): [rows, out_addr, out_addr]
    dup 3
    dup 4
    dup 3
    call __relu_layer

    // Cleanup: remove all 7 original args
    pop 5
    pop 2
    return

Neighbours