// ---
// tags: trop, trident
// crystal-type: circuit
// crystal-domain: comp
// ---

//  Tropical matrix over fixed-size storage.
//
//  Square matrices up to MAX_N x MAX_N stored in a flat U32 array.
//  All operations use (min, +) semiring arithmetic.

module trop.matrix

use trop.element.{INF, Tropical, add, mul, is_inf, zero, one, from_u32}

/// Maximum supported matrix dimension.
pub const MAX_N: U32 = 64

/// Total storage slots: 64 * 64 = 4096.
const MAX_ENTRIES: U32 = 4096

/// A square tropical matrix with runtime dimension n.
/// Row-major flat storage; only the first n*n entries are meaningful.
pub struct TropMatrix {
    n: U32,
    data: [U32; 4096]
}

/// Create an n x n matrix filled with INF (tropical zero matrix).
pub fn new(n: U32) -> TropMatrix {
    let mut data: [U32; 4096] = [INF; 4096]
    TropMatrix { n: n, data: data }
}

/// Create an n x n tropical identity matrix.
/// Diagonal entries are 0 (tropical one), off-diagonal entries are INF.
pub fn identity(n: U32) -> TropMatrix {
    let mut m: TropMatrix = new(n)
    for i in 0..64 {
        if i < n {
            m.data[i * MAX_N + i] = 0
        }
    }
    m
}

/// Get the element at (i, j) as a Tropical.
pub fn get(m: TropMatrix, i: U32, j: U32) -> Tropical {
    from_u32(m.data[i * MAX_N + j])
}

/// Set the element at (i, j) from a Tropical value.
pub fn set(m: &mut TropMatrix, i: U32, j: U32, val: Tropical) {
    m.data[i * MAX_N + j] = val.val
}

/// Set the element at (i, j) from a raw U32.
pub fn set_raw(m: &mut TropMatrix, i: U32, j: U32, val: U32) {
    m.data[i * MAX_N + j] = val
}

/// Tropical matrix addition: elementwise min.
pub fn matadd(a: TropMatrix, b: TropMatrix) -> TropMatrix {
    let n: U32 = a.n
    let mut result: TropMatrix = new(n)
    for i in 0..64 {
        if i < n {
            for j in 0..64 {
                if j < n {
                    let idx: U32 = i * MAX_N + j
                    let va: Tropical = from_u32(a.data[idx])
                    let vb: Tropical = from_u32(b.data[idx])
                    result.data[idx] = add(va, vb).val
                }
            }
        }
    }
    result
}

/// Tropical matrix multiplication.
/// C[i][j] = min_k (A[i][k] + B[k][j])
pub fn matmul(a: TropMatrix, b: TropMatrix) -> TropMatrix {
    let n: U32 = a.n
    let mut result: TropMatrix = new(n)
    for i in 0..64 {
        if i < n {
            for k in 0..64 {
                if k < n {
                    let a_ik: Tropical = from_u32(a.data[i * MAX_N + k])
                    if is_inf(a_ik) == false {
                        for j in 0..64 {
                            if j < n {
                                let b_kj: Tropical = from_u32(b.data[k * MAX_N + j])
                                let product: Tropical = mul(a_ik, b_kj)
                                let idx: U32 = i * MAX_N + j
                                let current: Tropical = from_u32(result.data[idx])
                                result.data[idx] = add(current, product).val
                            }
                        }
                    }
                }
            }
        }
    }
    result
}

/// Tropical matrix exponentiation by repeated squaring.
/// Computes A^exp under tropical multiplication.
/// A^0 = identity matrix.
/// exp_bits is the number of bits in exp (max 32).
pub fn power(a: TropMatrix, exp: U32, exp_bits: U32) -> TropMatrix {
    let n: U32 = a.n
    let mut result: TropMatrix = identity(n)
    let mut base: TropMatrix = a
    let mut e: U32 = exp
    for bit in 0..32 {
        if bit < exp_bits {
            let low: U32 = e & 1
            if low == 1 {
                result = matmul(result, base)
            }
            base = matmul(base, base)
            e = e >> 1
        }
    }
    result
}

Local Graph