use aruminium::{Gpu, GpuError};
fn main() -> Result<(), GpuError> {
let device = Gpu::open()?;
println!("Device: {}", device.name());
let queue = device.new_command_queue()?;
let source = r#"
#include <metal_stdlib>
using namespace metal;
struct MatmulParams {
uint M;
uint N;
uint K;
};
kernel void matmul(device const float *A buffer(0),
device const float *B buffer(1),
device float *C buffer(2),
constant MatmulParams ¶ms buffer(3),
uint2 gid thread_position_in_grid) {
uint row = gid.y;
uint col = gid.x;
if (row >= params.M || col >= params.N) return;
float sum = 0.0;
for (uint k = 0; k < params.K; k++) {
sum += A[row * params.K + k] * B[k * params.N + col];
}
C[row * params.N + col] = sum;
}
"#;
let lib = device.compile(source)?;
let func = lib.function("matmul")?;
let pipeline = device.pipeline(&func)?;
let m = 64usize;
let n = 64usize;
let k = 64usize;
let buf_a = device.buffer(m * k * 4)?;
let buf_b = device.buffer(k * n * 4)?;
let buf_c = device.buffer(m * n * 4)?;
buf_a.write_f32(|d| {
for i in 0..m {
for j in 0..k {
d[i * k + j] = if i == j { 1.0 } else { 0.0 };
}
}
});
buf_b.write_f32(|d| {
for i in 0..d.len() {
d[i] = 1.0;
}
});
#[repr(C)]
struct MatmulParams {
m: u32,
n: u32,
k: u32,
}
let params = MatmulParams {
m: m as u32,
n: n as u32,
k: k as u32,
};
let params_bytes = unsafe {
std::slice::from_raw_parts(
¶ms as *const MatmulParams as *const u8,
std::mem::size_of::<MatmulParams>(),
)
};
let cmd = queue.commands()?;
let enc = cmd.encoder()?;
enc.bind(&pipeline);
enc.bind_buffer(&buf_a, 0, 0);
enc.bind_buffer(&buf_b, 0, 1);
enc.bind_buffer(&buf_c, 0, 2);
enc.push(params_bytes, 3);
enc.launch((n, m, 1), (16, 16, 1));
enc.finish();
cmd.submit();
cmd.wait();
buf_c.read_f32(|d| {
let mut max_err: f32 = 0.0;
for i in 0..m * n {
max_err = max_err.max((d[i] - 1.0).abs());
}
if max_err < 1e-5 {
println!(
"PASS: {}x{}x{} matmul verified (max_err={:.2e})",
m, n, k, max_err
);
} else {
println!("FAIL: max_err = {:.6}", max_err);
}
});
Ok(())
}