module std.nn.tensor

// Neural network tensor primitives over the Goldilocks field.
//
// All values are field elements (mod p = 2^64 - 2^32 + 1).
// "Negative" values live in the upper half of the field (>= p/2).
// Matrices stored in RAM as row-major flattened arrays.
//
// These primitives are the foundation for provable AI inference:
// every matmul, every activation, every layer โ€” inside a STARK proof.
use vm.core.field

use vm.core.convert

use vm.io.mem

// ---------------------------------------------------------------------------
// Threshold for sign detection in field arithmetic.
// Values >= HALF_P are treated as "negative" for ReLU and comparisons.
// HALF_P = (p - 1) / 2 where p = 2^64 - 2^32 + 1 (Goldilocks)
// ---------------------------------------------------------------------------
fn half_p() -> Field {
    // In Goldilocks: 0 - 1 = p - 1, so (p-1) * inv(2) = (p-1)/2
    field.neg(1) * field.inv(2)
}

// ---------------------------------------------------------------------------
// Dot product: sum(a[i] * b[i]) for i in 0..N
// ---------------------------------------------------------------------------
pub fn dot<N>(a: [Field; N], b: [Field; N]) -> Field {
    let mut sum: Field = 0
    for i in 0..N bounded 4096 {
        sum = sum + a[i] * b[i]
    }
    sum
}

// ---------------------------------------------------------------------------
// Scalar multiply: x * s
// Trivial but explicit for cost tracking and readability.
// ---------------------------------------------------------------------------
pub fn scale(x: Field, s: Field) -> Field {
    x * s
}

// ---------------------------------------------------------------------------
// Field-native ReLU: if x < p/2 then x, else 0.
// Upper half of field represents negative values.
// ---------------------------------------------------------------------------
// Compare hi 32-bit word of two field elements.
// Separate function keeps stack shallow during split.
fn field_hi_lt(a: Field, b: Field) -> Bool {
    let (a_hi, a_lo) = convert.split(a)
    let (b_hi, b_lo) = convert.split(b)
    a_hi < b_hi
}

pub fn relu(x: Field) -> Field {
    let threshold: Field = half_p()
    if field_hi_lt(x, threshold) {
        x
    } else {
        0
    }
}

// ---------------------------------------------------------------------------
// Bias add: x + b for each element.
// RAM-based: reads N values from x_addr, adds bias, writes to out_addr.
// ---------------------------------------------------------------------------
pub fn bias_add(x_addr: Field, bias_addr: Field, out_addr: Field, n: Field) {
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        let x: Field = mem.read(x_addr + idx)
        let b: Field = mem.read(bias_addr + idx)
        mem.write(out_addr + idx, x + b)
    }
}

// ---------------------------------------------------------------------------
// Matrix-vector multiply (RAM-based).
// mat: M x N matrix at mat_addr (row-major)
// vec: N-element vector at vec_addr
// out: M-element result at out_addr
// out[i] = sum_j(mat[i*n + j] * vec[j])
// ---------------------------------------------------------------------------
pub fn matvec(
    mat_addr: Field,
    vec_addr: Field,
    out_addr: Field,
    m: Field,
    n: Field
) {
    for i in 0..m bounded 4096 {
        let mut sum: Field = 0
        let row_offset: Field = convert.as_field(i) * n
        for j in 0..n bounded 4096 {
            let col: Field = convert.as_field(j)
            let mat_val: Field = mem.read(mat_addr + row_offset + col)
            let vec_val: Field = mem.read(vec_addr + col)
            sum = sum + mat_val * vec_val
        }
        mem.write(out_addr + convert.as_field(i), sum)
    }
}

// ---------------------------------------------------------------------------
// Matrix multiply (RAM-based).
// a: M x N matrix at a_addr (row-major)
// b: N x K matrix at b_addr (row-major)
// out: M x K result at out_addr (row-major)
// out[i][j] = sum_l(a[i*n + l] * b[l*k + j])
// ---------------------------------------------------------------------------
pub fn matmul(
    a_addr: Field,
    b_addr: Field,
    out_addr: Field,
    m: Field,
    n: Field,
    k: Field
) {
    for i in 0..m bounded 4096 {
        let row_i: Field = convert.as_field(i)
        for j in 0..k bounded 4096 {
            let mut sum: Field = 0
            let col_j: Field = convert.as_field(j)
            for l in 0..n bounded 4096 {
                let col_l: Field = convert.as_field(l)
                let a_val: Field = mem.read(a_addr + row_i * n + col_l)
                let b_val: Field = mem.read(b_addr + col_l * k + col_j)
                sum = sum + a_val * b_val
            }
            mem.write(out_addr + row_i * k + col_j, sum)
        }
    }
}

// ---------------------------------------------------------------------------
// ReLU layer (RAM-based): apply relu to each of N elements.
// ---------------------------------------------------------------------------
pub fn relu_layer(x_addr: Field, out_addr: Field, n: Field) {
    for i in 0..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        let val: Field = mem.read(x_addr + idx)
        mem.write(out_addr + idx, relu(val))
    }
}

// ---------------------------------------------------------------------------
// Dense layer: out = relu(W * x + b)
// Combines matmul, bias add, and activation in one call.
// w_addr: rows x cols weight matrix
// x_addr: cols-element input vector
// b_addr: rows-element bias vector
// out_addr: rows-element output vector
// tmp_addr: rows-element scratch space for pre-activation
// ---------------------------------------------------------------------------
pub fn dense(
    w_addr: Field,
    x_addr: Field,
    b_addr: Field,
    out_addr: Field,
    tmp_addr: Field,
    rows: Field,
    cols: Field
) {
    matvec(w_addr, x_addr, tmp_addr, rows, cols)
    bias_add(tmp_addr, b_addr, out_addr, rows)
    relu_layer(out_addr, out_addr, rows)
}

// ---------------------------------------------------------------------------
// Argmax: index of the largest positive value in RAM.
// "Positive" = below HALF_P. If all values are negative (>= HALF_P),
// returns 0. Ties go to the first occurrence.
// ---------------------------------------------------------------------------
// Check if a field value is "positive" (below half_p).
fn is_positive(x: Field) -> Bool {
    field_hi_lt(x, half_p())
}

// Check if a > b as unsigned field values (hi-word comparison).
fn field_hi_gt(a: Field, b: Field) -> Bool {
    field_hi_lt(b, a)
}

pub fn argmax(addr: Field, n: Field) -> Field {
    // Store all loop state in RAM scratch (1073741872..1073741875).
    // Avoids compiler dup-offset bug with mutable vars across loops.
    mem.write(1073741872, addr)
    mem.write(1073741873, 0)              // best_idx
    mem.write(1073741874, mem.read(addr)) // best_val
    for i in 1..n bounded 4096 {
        let idx: Field = convert.as_field(i)
        let r_addr: Field = mem.read(1073741872)
        let val: Field = mem.read(r_addr + idx)
        let val_pos: Bool = is_positive(val)
        if val_pos {
            let r_best_val: Field = mem.read(1073741874)
            let best_pos: Bool = is_positive(r_best_val)
            if best_pos {
                // Both positive โ€” bigger u64 wins
                let r_bv2: Field = mem.read(1073741874)
                if field_hi_gt(val, r_bv2) {
                    mem.write(1073741873, idx)
                    mem.write(1073741874, val)
                }
            } else {
                // val positive, best negative โ€” val wins
                mem.write(1073741873, idx)
                mem.write(1073741874, val)
            }
        }
    }
    mem.read(1073741873)
}

Local Graph