fn main() {
println!("=== acpu_probe ===\n");
println!("[1] Capabilities");
let caps = acpu::probe::scan();
println!(" chip: {}", caps.chip);
println!(" AMX ver: {}", caps.amx_ver);
println!(" P-cores: {}", caps.p_cores);
println!(" E-cores: {}", caps.e_cores);
println!(" L1 line: {} B", caps.l1_line);
println!(" L2 size: {} KiB", caps.l2_size / 1024);
println!(" FP16: {}", caps.has_fp16);
println!(" BF16: {}", caps.has_bf16);
println!(" DotProd: {}", caps.has_dotprod);
println!(" I8MM: {}", caps.has_i8mm);
println!(" FCMA: {}", caps.has_fcma);
println!(" RDM: {}", caps.has_rdm);
println!(" LSE: {}", caps.has_lse);
println!(" LRCPC: {}", caps.has_lrcpc);
println!();
println!("[2] AMX set/clr");
if caps.amx_ver > 0 {
match amx_set_clr_test() {
Ok(()) => println!(" PASS: AMX set/clr cycle succeeded"),
Err(e) => println!(" FAIL: {e}"),
}
} else {
println!(" SKIP: no AMX detected");
}
println!();
println!("[3] AMX load/store/fma32");
if caps.amx_ver > 0 {
match amx_fma_test() {
Ok(()) => println!(" PASS: AMX fma32 2x2 matmul correct"),
Err(e) => println!(" FAIL: {e}"),
}
} else {
println!(" SKIP: no AMX detected");
}
println!();
println!("[4] NEON matmul_f32");
match neon_matmul_f32_test() {
Ok(()) => println!(" PASS: 4x4 matmul_f32 correct"),
Err(e) => println!(" FAIL: {e}"),
}
println!();
println!("[5] PMU counters");
match pmu_test() {
Ok(()) => println!(" PASS: PMU counters read successfully"),
Err(e) => println!(" SKIP/FAIL: {e}"),
}
println!();
println!("=== done ===");
}
fn amx_set_clr_test() -> Result<(), String> {
unsafe {
acpu::matrix::asm::amx_set();
acpu::matrix::asm::amx_clr();
}
Ok(())
}
fn amx_fma_test() -> Result<(), String> {
#[repr(align(128))]
struct Aligned([u8; 64]);
let mut x_buf = Aligned([0u8; 64]);
let mut y_buf = Aligned([0u8; 64]);
let x_f32 = unsafe { std::slice::from_raw_parts_mut(x_buf.0.as_mut_ptr() as *mut f32, 16) };
for v in x_f32.iter_mut() {
*v = 1.0;
}
let y_f32 = unsafe { std::slice::from_raw_parts_mut(y_buf.0.as_mut_ptr() as *mut f32, 16) };
for v in y_f32.iter_mut() {
*v = 1.0;
}
#[repr(align(128))]
struct ZBuf([u8; 512]); let mut z_all = ZBuf([0u8; 512]);
let ctx = acpu::Matrix::new().map_err(|e| format!("{e}"))?;
unsafe {
use acpu::matrix::regs::*;
ctx.ldx(x_buf.0.as_ptr(), XRow::new_unchecked(0));
ctx.ldy(y_buf.0.as_ptr(), YRow::new_unchecked(0));
ctx.fma32(0u64);
for row in 0..8u8 {
ctx.stz(
z_all.0.as_mut_ptr().add(row as usize * 64),
ZRow::new_unchecked(row),
);
}
}
drop(ctx);
let z_f32 = unsafe { std::slice::from_raw_parts(z_all.0.as_ptr() as *const f32, 8 * 16) };
let nonzero_count = z_f32.iter().filter(|&&v| v != 0.0).count();
if nonzero_count == 0 {
return Err("AMX fma32 produced all zeros".into());
}
Ok(())
}
fn neon_matmul_f32_test() -> Result<(), String> {
const N: usize = 4;
let a: Vec<f32> = (1..=16).map(|i| i as f32).collect();
let mut b = vec![0.0f32; N * N];
for i in 0..N {
b[i * N + i] = 1.0;
}
let mut c = vec![0.0f32; N * N];
acpu::matmul_f32(&a, &b, &mut c, N, N, N);
for i in 0..N * N {
if (c[i] - a[i]).abs() > 1e-4 {
return Err(format!(
"mismatch at index {i}: expected {}, got {}",
a[i], c[i]
));
}
}
Ok(())
}
fn pmu_test() -> Result<(), String> {
use acpu::pulse::{Counter, Counters};
let mut ctx =
Counters::new(&[Counter::Cycles, Counter::Instructions]).map_err(|e| format!("{e}"))?;
ctx.start();
let a = ctx.read();
let mut sum = 0u64;
for i in 0..10_000 {
sum = sum.wrapping_add(i);
}
std::hint::black_box(sum);
let b = ctx.read();
ctx.stop();
let counts = ctx.elapsed(&a, &b);
println!(" cycles: {}", counts.cycles);
println!(" instructions: {}", counts.instructions);
if counts.cycles == 0 && counts.instructions == 0 {
return Err("all counters zero โ kpc access likely denied".into());
}
Ok(())
}