module std.nn.tensor
// Neural network tensor primitives over the Goldilocks field.
//
// All values are field elements (mod p = 2^64 - 2^32 + 1).
// "Negative" values live in the upper half of the field (>= p/2).
// Matrices stored in RAM as row-major flattened arrays.
//
// These primitives are the foundation for provable AI inference:
// every matmul, every activation, every layer โ inside a STARK proof.
use vm.core.field
use vm.core.convert
use vm.io.mem
// ---------------------------------------------------------------------------
// Threshold for sign detection in field arithmetic.
// Values >= HALF_P are treated as "negative" for ReLU and comparisons.
// HALF_P = (p - 1) / 2 where p = 2^64 - 2^32 + 1 (Goldilocks)
// ---------------------------------------------------------------------------
fn half_p() -> Field {
// In Goldilocks: 0 - 1 = p - 1, so (p-1) * inv(2) = (p-1)/2
field.neg(1) * field.inv(2)
}
// ---------------------------------------------------------------------------
// Dot product: sum(a[i] * b[i]) for i in 0..N
// ---------------------------------------------------------------------------
pub fn dot<N>(a: [Field; N], b: [Field; N]) -> Field {
let mut sum: Field = 0
for i in 0..N bounded 4096 {
sum = sum + a[i] * b[i]
}
sum
}
// ---------------------------------------------------------------------------
// Scalar multiply: x * s
// Trivial but explicit for cost tracking and readability.
// ---------------------------------------------------------------------------
pub fn scale(x: Field, s: Field) -> Field {
x * s
}
// ---------------------------------------------------------------------------
// Field-native ReLU: if x < p/2 then x, else 0.
// Upper half of field represents negative values.
// ---------------------------------------------------------------------------
// Compare hi 32-bit word of two field elements.
// Separate function keeps stack shallow during split.
fn field_hi_lt(a: Field, b: Field) -> Bool {
let (a_hi, a_lo) = convert.split(a)
let (b_hi, b_lo) = convert.split(b)
a_hi < b_hi
}
pub fn relu(x: Field) -> Field {
let threshold: Field = half_p()
if field_hi_lt(x, threshold) {
x
} else {
0
}
}
// ---------------------------------------------------------------------------
// Bias add: x + b for each element.
// RAM-based: reads N values from x_addr, adds bias, writes to out_addr.
// ---------------------------------------------------------------------------
pub fn bias_add(x_addr: Field, bias_addr: Field, out_addr: Field, n: Field) {
for i in 0..n bounded 4096 {
let idx: Field = convert.as_field(i)
let x: Field = mem.read(x_addr + idx)
let b: Field = mem.read(bias_addr + idx)
mem.write(out_addr + idx, x + b)
}
}
// ---------------------------------------------------------------------------
// Matrix-vector multiply (RAM-based).
// mat: M x N matrix at mat_addr (row-major)
// vec: N-element vector at vec_addr
// out: M-element result at out_addr
// out[i] = sum_j(mat[i*n + j] * vec[j])
// ---------------------------------------------------------------------------
pub fn matvec(
mat_addr: Field,
vec_addr: Field,
out_addr: Field,
m: Field,
n: Field
) {
for i in 0..m bounded 4096 {
let mut sum: Field = 0
let row_offset: Field = convert.as_field(i) * n
for j in 0..n bounded 4096 {
let col: Field = convert.as_field(j)
let mat_val: Field = mem.read(mat_addr + row_offset + col)
let vec_val: Field = mem.read(vec_addr + col)
sum = sum + mat_val * vec_val
}
mem.write(out_addr + convert.as_field(i), sum)
}
}
// ---------------------------------------------------------------------------
// Matrix multiply (RAM-based).
// a: M x N matrix at a_addr (row-major)
// b: N x K matrix at b_addr (row-major)
// out: M x K result at out_addr (row-major)
// out[i][j] = sum_l(a[i*n + l] * b[l*k + j])
// ---------------------------------------------------------------------------
pub fn matmul(
a_addr: Field,
b_addr: Field,
out_addr: Field,
m: Field,
n: Field,
k: Field
) {
for i in 0..m bounded 4096 {
let row_i: Field = convert.as_field(i)
for j in 0..k bounded 4096 {
let mut sum: Field = 0
let col_j: Field = convert.as_field(j)
for l in 0..n bounded 4096 {
let col_l: Field = convert.as_field(l)
let a_val: Field = mem.read(a_addr + row_i * n + col_l)
let b_val: Field = mem.read(b_addr + col_l * k + col_j)
sum = sum + a_val * b_val
}
mem.write(out_addr + row_i * k + col_j, sum)
}
}
}
// ---------------------------------------------------------------------------
// ReLU layer (RAM-based): apply relu to each of N elements.
// ---------------------------------------------------------------------------
pub fn relu_layer(x_addr: Field, out_addr: Field, n: Field) {
for i in 0..n bounded 4096 {
let idx: Field = convert.as_field(i)
let val: Field = mem.read(x_addr + idx)
mem.write(out_addr + idx, relu(val))
}
}
// ---------------------------------------------------------------------------
// Dense layer: out = relu(W * x + b)
// Combines matmul, bias add, and activation in one call.
// w_addr: rows x cols weight matrix
// x_addr: cols-element input vector
// b_addr: rows-element bias vector
// out_addr: rows-element output vector
// tmp_addr: rows-element scratch space for pre-activation
// ---------------------------------------------------------------------------
pub fn dense(
w_addr: Field,
x_addr: Field,
b_addr: Field,
out_addr: Field,
tmp_addr: Field,
rows: Field,
cols: Field
) {
matvec(w_addr, x_addr, tmp_addr, rows, cols)
bias_add(tmp_addr, b_addr, out_addr, rows)
relu_layer(out_addr, out_addr, rows)
}
// ---------------------------------------------------------------------------
// Argmax: index of the largest positive value in RAM.
// "Positive" = below HALF_P. If all values are negative (>= HALF_P),
// returns 0. Ties go to the first occurrence.
// ---------------------------------------------------------------------------
// Check if a field value is "positive" (below half_p).
fn is_positive(x: Field) -> Bool {
field_hi_lt(x, half_p())
}
// Check if a > b as unsigned field values (hi-word comparison).
fn field_hi_gt(a: Field, b: Field) -> Bool {
field_hi_lt(b, a)
}
pub fn argmax(addr: Field, n: Field) -> Field {
// Store all loop state in RAM scratch (1073741872..1073741875).
// Avoids compiler dup-offset bug with mutable vars across loops.
mem.write(1073741872, addr)
mem.write(1073741873, 0) // best_idx
mem.write(1073741874, mem.read(addr)) // best_val
for i in 1..n bounded 4096 {
let idx: Field = convert.as_field(i)
let r_addr: Field = mem.read(1073741872)
let val: Field = mem.read(r_addr + idx)
let val_pos: Bool = is_positive(val)
if val_pos {
let r_best_val: Field = mem.read(1073741874)
let best_pos: Bool = is_positive(r_best_val)
if best_pos {
// Both positive โ bigger u64 wins
let r_bv2: Field = mem.read(1073741874)
if field_hi_gt(val, r_bv2) {
mem.write(1073741873, idx)
mem.write(1073741874, val)
}
} else {
// val positive, best negative โ val wins
mem.write(1073741873, idx)
mem.write(1073741874, val)
}
}
}
mem.read(1073741873)
}
trident/std/nn/tensor.tri
ฯ 0.0%