#[inline(always)]
pub fn bf16_to_f32(v: u16) -> f32 {
f32::from_bits((v as u32) << 16)
}
#[inline(always)]
pub fn f32_to_bf16(v: f32) -> u16 {
f32_to_bf16_rne(v)
}
pub fn cast_bf16_f32(dst: &mut [f32], src: &[u16]) {
let n = dst.len().min(src.len());
#[cfg(target_arch = "aarch64")]
{
let mut i = 0;
while i + 16 <= n {
unsafe {
let s = src.as_ptr().add(i);
let d = dst.as_mut_ptr().add(i);
std::arch::asm!(
"ldp q0, q1, [{s}]", "shll v2.4s, v0.4h, #16", "shll2 v3.4s, v0.8h, #16", "shll v4.4s, v1.4h, #16", "shll2 v5.4s, v1.8h, #16", "stp q2, q3, [{d}]",
"stp q4, q5, [{d}, #32]",
s = in(reg) s,
d = in(reg) d,
out("v0") _, out("v1") _,
out("v2") _, out("v3") _, out("v4") _, out("v5") _,
);
}
i += 16;
}
while i + 4 <= n {
unsafe {
std::arch::asm!(
"ldr d0, [{s}]", "shll v1.4s, v0.4h, #16",
"str q1, [{d}]",
s = in(reg) src.as_ptr().add(i),
d = in(reg) dst.as_mut_ptr().add(i),
out("v0") _, out("v1") _,
);
}
i += 4;
}
for j in i..n {
dst[j] = bf16_to_f32(src[j]);
}
}
#[cfg(not(target_arch = "aarch64"))]
{
for i in 0..n {
dst[i] = bf16_to_f32(src[i]);
}
}
}
pub fn cast_f32_bf16(dst: &mut [u16], src: &[f32]) {
let n = dst.len().min(src.len());
#[cfg(all(target_arch = "aarch64", target_feature = "bf16"))]
{
let mut i = 0;
while i + 16 <= n {
unsafe {
let s = src.as_ptr().add(i);
let d = dst.as_mut_ptr().add(i);
std::arch::asm!(
"ldp q0, q1, [{s}]", "ldp q2, q3, [{s}, #32]", ".inst 0x0ea16800", ".inst 0x4ea16820", ".inst 0x0ea16842", ".inst 0x4ea16862", "stp q0, q2, [{d}]",
s = in(reg) s,
d = in(reg) d,
out("v0") _, out("v1") _, out("v2") _, out("v3") _,
);
}
i += 16;
}
for j in i..n {
dst[j] = f32_to_bf16_rne(src[j]);
}
}
#[cfg(not(all(target_arch = "aarch64", target_feature = "bf16")))]
{
for i in 0..n {
dst[i] = f32_to_bf16_rne(src[i]);
}
}
}
fn f32_to_bf16_rne(v: f32) -> u16 {
let bits = v.to_bits();
if (bits & 0x7FFF_FFFF) > 0x7F80_0000 {
return ((bits >> 16) | 0x0040) as u16; }
let rounding_bias = 0x7FFF + ((bits >> 16) & 1);
let rounded = bits.wrapping_add(rounding_bias);
(rounded >> 16) as u16
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_bf16() {
let values = [
0.0f32,
1.0,
-1.0,
3.14,
-100.0,
f32::INFINITY,
f32::NEG_INFINITY,
];
for &v in &values {
let h = f32_to_bf16(v);
let back = bf16_to_f32(h);
if v.is_finite() {
assert!(
(back - v).abs() / v.abs().max(1.0) < 0.01,
"roundtrip failed for {v}: got {back}"
);
} else {
assert_eq!(back, v);
}
}
}
#[test]
fn bulk_bf16() {
let src: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.5).collect();
let mut bf = vec![0u16; 64];
let mut dst = vec![0.0f32; 64];
cast_f32_bf16(&mut bf, &src);
cast_bf16_f32(&mut dst, &bf);
for i in 0..64 {
assert!(
(dst[i] - src[i]).abs() < 0.5,
"mismatch at {i}: expected {}, got {}",
src[i],
dst[i]
);
}
}
#[test]
fn rne_rounding() {
assert_eq!(f32_to_bf16(1.0), 0x3F80);
assert_eq!(f32_to_bf16(0.0), 0x0000);
}
}