use std::time::Instant;
const ITERS: usize = 10_000;
const MR: usize = 16;
fn median_ns<F: FnMut()>(mut f: F) -> u64 {
for _ in 0..100 {
f();
}
let mut times = Vec::with_capacity(ITERS);
for _ in 0..ITERS {
let t = Instant::now();
f();
times.push(t.elapsed().as_nanos() as u64);
}
times.sort();
times[ITERS / 2]
}
#[link(name = "Accelerate", kind = "framework")]
extern "C" {
fn cblas_sgemm(
order: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: f32,
a: *const f32,
lda: i32,
b: *const f32,
ldb: i32,
beta: f32,
c: *mut f32,
ldc: i32,
);
}
fn main() {
for &n in &[64usize, 128, 256] {
decompose(n);
println!();
}
}
fn decompose(n: usize) {
let a: Vec<f32> = (0..n * n).map(|i| (i % 7) as f32 * 0.1).collect();
let b: Vec<f32> = (0..n * n).map(|i| (i % 11) as f32 * 0.1).collect();
let n_mr = n.div_ceil(MR);
let n_tiles = n_mr * n_mr;
println!("=== n={n} ({n_tiles} tiles, k={n}) ===");
println!("{:>32} {:>7} {:>5}", "component", "ns", "%sgemm");
let mut c = vec![0.0f32; n * n];
let sgemm_ns = median_ns(|| {
c.fill(0.0);
acpu::matmul_f32(&a, &b, &mut c, n, n, n);
std::hint::black_box(&c);
});
let accel_ns = median_ns(|| {
c.fill(0.0);
unsafe {
cblas_sgemm(
101,
111,
111,
n as i32,
n as i32,
n as i32,
1.0,
a.as_ptr(),
n as i32,
b.as_ptr(),
n as i32,
0.0,
c.as_mut_ptr(),
n as i32,
);
}
std::hint::black_box(&c);
});
let setclr_ns = median_ns(|| {
let ctx = acpu::Matrix::new().unwrap();
std::hint::black_box(&ctx);
drop(ctx);
});
let pack_scalar_ns = {
use core::mem::MaybeUninit;
#[repr(align(64))]
struct APack([MaybeUninit<f32>; 256 * 256]);
let t = median_ns(|| {
let mut buf: APack = unsafe { MaybeUninit::uninit().assume_init() };
let ap = &mut buf.0[..];
for s in 0..n_mr {
let rs = s * MR;
let rows = MR.min(n - rs);
let base = s * n * MR;
for i in 0..rows {
let a_row = (rs + i) * n;
for p in 0..n {
ap[base + p * MR + i] = MaybeUninit::new(a[a_row + p]);
}
}
for i in rows..MR {
for p in 0..n {
ap[base + p * MR + i] = MaybeUninit::new(0.0);
}
}
}
std::hint::black_box(&buf);
});
t
};
let pack_neon_ns = {
#[repr(align(64))]
struct APack([f32; 256 * 256]);
let t = median_ns(|| {
let mut buf = APack([0.0f32; 256 * 256]);
let dst = &mut buf.0[..];
for s in 0..n_mr {
let row_start = s * MR;
let base = s * n * MR;
unsafe {
use core::arch::aarch64::*;
for ig in 0..(MR / 4) {
let ii = ig * 4;
let a0 = (row_start + ii) * n;
let a1 = (row_start + ii + 1) * n;
let a2 = (row_start + ii + 2) * n;
let a3 = (row_start + ii + 3) * n;
let mut p = 0;
while p + 4 <= n {
let r0 = vld1q_f32(a.as_ptr().add(a0 + p));
let r1 = vld1q_f32(a.as_ptr().add(a1 + p));
let r2 = vld1q_f32(a.as_ptr().add(a2 + p));
let r3 = vld1q_f32(a.as_ptr().add(a3 + p));
let lo01 = vzip1q_f32(r0, r1);
let hi01 = vzip2q_f32(r0, r1);
let lo23 = vzip1q_f32(r2, r3);
let hi23 = vzip2q_f32(r2, r3);
let c0 = vreinterpretq_f32_f64(vzip1q_f64(
vreinterpretq_f64_f32(lo01),
vreinterpretq_f64_f32(lo23),
));
let c1 = vreinterpretq_f32_f64(vzip2q_f64(
vreinterpretq_f64_f32(lo01),
vreinterpretq_f64_f32(lo23),
));
let c2 = vreinterpretq_f32_f64(vzip1q_f64(
vreinterpretq_f64_f32(hi01),
vreinterpretq_f64_f32(hi23),
));
let c3 = vreinterpretq_f32_f64(vzip2q_f64(
vreinterpretq_f64_f32(hi01),
vreinterpretq_f64_f32(hi23),
));
let d = dst.as_mut_ptr().add(base);
vst1q_f32(d.add(p * MR + ii), c0);
vst1q_f32(d.add((p + 1) * MR + ii), c1);
vst1q_f32(d.add((p + 2) * MR + ii), c2);
vst1q_f32(d.add((p + 3) * MR + ii), c3);
p += 4;
}
}
}
}
std::hint::black_box(&buf);
});
t
};
let compute_perop_ns = {
#[repr(align(64))]
struct Buf([f32; 256 * 16]);
let a_pack = Buf([1.0f32; 256 * 16]);
let b_pack = Buf([1.0f32; 256 * 16]);
median_ns(|| unsafe {
use acpu::matrix::asm::*;
use acpu::matrix::fma::*;
use acpu::matrix::regs::*;
amx_set();
for _tile in 0..n_tiles {
let ap = a_pack.0.as_ptr() as *const u8;
let bp = b_pack.0.as_ptr() as *const u8;
amx_op::<OP_LDX>(bp as u64);
amx_op::<OP_LDY>(ap as u64);
amx_op::<OP_FMA32>(fma_first(XRow::new_unchecked(0), YRow::new_unchecked(0), 0));
for p in 1..n {
amx_op::<OP_LDX>(bp.add(p * 64) as u64);
amx_op::<OP_LDY>(ap.add(p * 64) as u64);
amx_op::<OP_FMA32>(fma_acc(XRow::new_unchecked(0), YRow::new_unchecked(0), 0));
}
}
amx_clr();
})
};
let compute_uk16_ns = {
#[repr(align(64))]
struct Buf([f32; 256 * 16]);
let a_pack = Buf([1.0f32; 256 * 16]);
let b_pack = Buf([1.0f32; 256 * 16]);
median_ns(|| unsafe {
use acpu::matrix::asm::*;
amx_set();
for _tile in 0..n_tiles {
acpu::matrix::tile::microkernel_16x16(
a_pack.0.as_ptr() as *const u8,
b_pack.0.as_ptr() as *const u8,
n,
);
}
amx_clr();
})
};
let compute_uk64_ns = {
#[repr(align(64))]
struct Buf([f32; 256 * 16]);
let ap = Buf([1.0f32; 256 * 16]);
let b0 = Buf([1.0f32; 256 * 16]);
let b1 = Buf([1.0f32; 256 * 16]);
let b2 = Buf([1.0f32; 256 * 16]);
let b3 = Buf([1.0f32; 256 * 16]);
let batches = n_tiles / 4;
let leftover = n_tiles % 4;
median_ns(|| unsafe {
use acpu::matrix::asm::*;
amx_set();
for _ in 0..batches {
acpu::matrix::tile::microkernel_16x64(
ap.0.as_ptr() as *const u8,
b0.0.as_ptr() as *const u8,
b1.0.as_ptr() as *const u8,
b2.0.as_ptr() as *const u8,
b3.0.as_ptr() as *const u8,
n,
);
}
for _ in 0..leftover {
acpu::matrix::tile::microkernel_16x16(
ap.0.as_ptr() as *const u8,
b0.0.as_ptr() as *const u8,
n,
);
}
amx_clr();
})
};
let pack_b_ns = {
use core::mem::MaybeUninit;
#[repr(align(64))]
struct BPack([MaybeUninit<f32>; 256 * 256]);
let t = median_ns(|| {
let mut buf: BPack = unsafe { MaybeUninit::uninit().assume_init() };
let bp = &mut buf.0[..];
for s in 0..n_mr {
let cs = s * MR;
let cols = MR.min(n - cs);
let base = s * n * MR;
for p in 0..n {
let src = p * n + cs;
let dst = base + p * MR;
for j in 0..cols {
bp[dst + j] = MaybeUninit::new(b[src + j]);
}
}
}
std::hint::black_box(&buf);
});
t
};
let preload_path_ns = {
#[repr(align(64))]
struct Buf([f32; 256 * 16]);
let ap = Buf([1.0f32; 256 * 16]);
let b0 = Buf([1.0f32; 256 * 16]);
let b1 = Buf([1.0f32; 256 * 16]);
let b2 = Buf([1.0f32; 256 * 16]);
let b3 = Buf([1.0f32; 256 * 16]);
let mut c_buf = vec![0.0f32; n * n];
let batches = n_tiles / 4;
median_ns(|| unsafe {
use acpu::matrix::asm::*;
amx_set();
for batch in 0..batches {
let ir = batch / (n_mr / 4);
let jr4 = batch % (n_mr / 4);
let cp = c_buf.as_mut_ptr().add(ir * MR * n + jr4 * 4 * MR);
for t in 0u8..4 {
acpu::matrix::tile::preload_c(cp.add(t as usize * MR), n, t);
}
acpu::matrix::tile::microkernel_16x64_acc(
ap.0.as_ptr() as *const u8,
b0.0.as_ptr() as *const u8,
b1.0.as_ptr() as *const u8,
b2.0.as_ptr() as *const u8,
b3.0.as_ptr() as *const u8,
n,
64,
);
for t in 0u8..4 {
acpu::matrix::tile::store_c(cp.add(t as usize * MR), n, t);
}
}
amx_clr();
std::hint::black_box(&c_buf);
})
};
let accum_old_ns = {
let mut c_buf = vec![0.0f32; n * n];
median_ns(|| unsafe {
use acpu::matrix::asm::*;
amx_set();
for t in 0..n_tiles {
let ir = t / n_mr;
let jr = t % n_mr;
let cp = c_buf.as_mut_ptr().add(ir * MR * n + jr * MR);
acpu::matrix::tile::accumulate_tile(cp, n, 0);
}
amx_clr();
std::hint::black_box(&c_buf);
})
};
let row = |name: &str, ns: u64| {
let pct = ns as f64 / sgemm_ns as f64 * 100.0;
println!("{:>32} {:>7} {:>4.0}%", name, ns, pct);
};
row("AMX set+clr", setclr_ns);
row("A pack scalar", pack_scalar_ns);
row("A pack NEON", pack_neon_ns);
row("B pack scalar", pack_b_ns);
row("compute per-op (LDX+LDY+FMA)", compute_perop_ns);
row("compute µkernel_16x16", compute_uk16_ns);
row("compute µkernel_16x64 (fused)", compute_uk64_ns);
row("preload+µk64_acc+store (new)", preload_path_ns);
row("accum old (STZ→tmp→NEON add)", accum_old_ns);
println!("{:>32} {:>7} {:>4.0}%", "----------", "---", "");
row("sgemm (full)", sgemm_ns);
row("Accelerate", accel_ns);
let gap = sgemm_ns as f64 / accel_ns as f64;
println!("{:>32} {:>7.1}x", "gap (acpu/accel)", gap);
}