use rane::mil;
use rane::{f32_to_fp16, fp16_to_f32, Buffer, Program};
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("ANE Matmul โ Pure Rust\n");
let ic = 64;
let oc = 64;
let seq = 64;
let program = mil::matmul(ic, oc, seq);
let (in_ch, in_sp) = program.input_shape();
let (out_ch, out_sp) = program.output_shape();
println!(" MIL: matmul({ic}x{oc}, seq={seq})");
println!(
" Input: [1, {in_ch}, 1, {in_sp}] fp16 ({} KB)",
program.input_size() / 1024
);
println!(
" Output: [1, {out_ch}, 1, {out_sp}] fp16 ({} KB)\n",
program.output_size() / 1024
);
print!(" Compiling...");
let mut model = Program::compile(&program, &[])?;
println!(" OK");
print!(" Loading...");
model.load()?;
println!(" OK");
let input = Buffer::new(program.input_size())?;
let output = Buffer::new(program.output_size())?;
input.write(|data| {
for ch in 0..ic {
for s in 0..seq {
data[ch * in_sp + s] = f32_to_fp16(1.0);
}
for o in 0..oc {
data[ch * in_sp + seq + o] = if ch == o { f32_to_fp16(1.0) } else { 0 };
}
}
});
println!(" Input: all 1.0 activations, identity weight matrix");
print!(" Evaluating on ANE...");
model.run(&input, &output)?;
println!(" OK\n");
output.read(|data| {
print!(" Output[0..8] = [");
for i in 0..8 {
if i > 0 {
print!(", ");
}
print!("{:.1}", fp16_to_f32(data[i]));
}
println!("]");
let all_ones = data[..out_ch * out_sp]
.iter()
.all(|&v| fp16_to_f32(v) == 1.0);
if all_ones {
println!("\n VERIFIED: all {} output values = 1.0", out_ch * out_sp);
println!(" Pure Rust โ MIL โ ANE bytecode โ ANE hardware โ correct result");
} else {
let nonzero: Vec<_> = data[..8].iter().map(|&v| fp16_to_f32(v)).collect();
println!("\n Output values: {:?}", nonzero);
}
});
Ok(())
}