// ============================================================================
// tropical.wgsl โ tropical semiring matrix multiplication for GPU compute
// ============================================================================
//
// The tropical semiring (min, +) over u32:
// - Tropical addition: min(a, b)
// - Tropical multiplication: saturating_add(a, b)
// - Additive identity (ZERO): 0xFFFFFFFF (infinity)
// - Multiplicative identity (ONE): 0
//
// Matrix multiplication:
// C[i][j] = min_k (A[i][k] + B[k][j])
//
// where min is tropical addition and + is tropical multiplication
// (which is ordinary addition, saturating at u32::MAX).
//
// The shader implements tiled matrix multiplication with 16x16 workgroups.
// Each thread computes one element of the output matrix by iterating over
// tiles of the shared dimension.
//
// WGSL constraints:
// - No u64 types; we use u32 (u32::MAX = infinity)
// - Workgroup shared memory for tiled access
// - Workgroup size 16x16 = 256 threads
// ============================================================================
// ---------------------------------------------------------------------------
// Constants
// ---------------------------------------------------------------------------
// Tropical infinity: the additive identity (neutral element for min).
const TROP_INF: u32 = 0xFFFFFFFFu;
// Tile dimension for shared memory blocking.
const TILE_DIM: u32 = 16u;
// ---------------------------------------------------------------------------
// Tropical element operations
// ---------------------------------------------------------------------------
/// Tropical addition: min(a, b).
/// If either operand is INF, returns the other.
fn trop_add(a: u32, b: u32) -> u32 {
return min(a, b);
}
/// Tropical multiplication: a + b (saturating).
/// If either operand is INF, or if the sum overflows, returns INF.
fn trop_mul(a: u32, b: u32) -> u32 {
if a == TROP_INF || b == TROP_INF {
return TROP_INF;
}
let sum = a + b;
// Check for overflow: if sum < a, it wrapped around.
// Also check if sum equals TROP_INF (reserved for infinity).
if sum < a || sum == TROP_INF {
return TROP_INF;
}
return sum;
}
// ---------------------------------------------------------------------------
// Buffer layout
// ---------------------------------------------------------------------------
//
// Matrices are stored in row-major order as flat arrays of u32.
// A: n x n matrix (input)
// B: n x n matrix (input)
// C: n x n matrix (output)
//
// params.n is the matrix dimension.
struct Params {
n: u32,
}
@group(0) @binding(0) var<storage, read> matrix_a: array<u32>;
@group(0) @binding(1) var<storage, read> matrix_b: array<u32>;
@group(0) @binding(2) var<storage, read_write> matrix_c: array<u32>;
@group(0) @binding(3) var<uniform> params: Params;
// ---------------------------------------------------------------------------
// Shared memory for tiling
// ---------------------------------------------------------------------------
var<workgroup> tile_a: array<u32, 256>; // 16 x 16
var<workgroup> tile_b: array<u32, 256>; // 16 x 16
// ---------------------------------------------------------------------------
// Tiled tropical matrix multiplication
// ---------------------------------------------------------------------------
//
// Each workgroup computes a 16x16 tile of the output matrix C.
// The algorithm iterates over tiles along the shared dimension k:
//
// for each tile t in 0..ceil(n/16):
// 1. Load tile_a[local_row][local_col] = A[global_row][t*16 + local_col]
// 2. Load tile_b[local_row][local_col] = B[t*16 + local_row][global_col]
// 3. Barrier
// 4. For each k in 0..16:
// acc = trop_add(acc, trop_mul(tile_a[local_row][k], tile_b[k][local_col]))
// 5. Barrier
//
// C[global_row][global_col] = acc
//
// Out-of-bounds accesses load TROP_INF, which is the neutral element for
// tropical addition (min), so they do not affect the result.
@compute @workgroup_size(16, 16)
fn tropical_matmul(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let n = params.n;
let row = gid.x;
let col = gid.y;
let local_row = lid.x;
let local_col = lid.y;
var acc: u32 = TROP_INF;
let num_tiles = (n + TILE_DIM - 1u) / TILE_DIM;
for (var t: u32 = 0u; t < num_tiles; t = t + 1u) {
// Load A[row][t*TILE_DIM + local_col] into shared memory
let a_col = t * TILE_DIM + local_col;
if row < n && a_col < n {
tile_a[local_row * TILE_DIM + local_col] = matrix_a[row * n + a_col];
} else {
tile_a[local_row * TILE_DIM + local_col] = TROP_INF;
}
// Load B[t*TILE_DIM + local_row][col] into shared memory
let b_row = t * TILE_DIM + local_row;
if b_row < n && col < n {
tile_b[local_row * TILE_DIM + local_col] = matrix_b[b_row * n + col];
} else {
tile_b[local_row * TILE_DIM + local_col] = TROP_INF;
}
workgroupBarrier();
// Accumulate: acc = min over k of (acc, A[row][k] + B[k][col])
for (var k: u32 = 0u; k < TILE_DIM; k = k + 1u) {
let a_val = tile_a[local_row * TILE_DIM + k];
let b_val = tile_b[k * TILE_DIM + local_col];
let product = trop_mul(a_val, b_val);
acc = trop_add(acc, product);
}
workgroupBarrier();
}
// Write result
if row < n && col < n {
matrix_c[row * n + col] = acc;
}
}
trop/wgsl/src/shaders/tropical.wgsl
ฯ 0.0%