//! Quantization operations: f32 <-> int8 with scale/zero-point.
//!
//! Uses NEON vectorized paths on aarch64 for bulk conversions.
/// Quantize f32 to int8: `dst[i] = clamp(round(src[i] / scale), -128, 127)`.
///
/// `scale` is the quantization step size. Typically computed as
/// `(max - min) / 255` for asymmetric or `max(abs) / 127` for symmetric.
pub fn cast_f32_i8(dst: &mut [i8], src: &[f32], scale: f32) {
let n = dst.len().min(src.len());
let inv_scale = 1.0 / scale;
#[cfg(target_arch = "aarch64")]
{
let mut i = 0;
// Process 16 f32 -> 16 i8 per iteration
while i + 16 <= n {
unsafe {
let s = src.as_ptr().add(i);
let d = dst.as_mut_ptr().add(i);
std::arch::asm!(
// Broadcast inv_scale into v31
"dup v31.4s, {inv:w}",
// Load 16 x f32
"ldp q0, q1, [{s}]",
"ldp q2, q3, [{s}, #32]",
// Multiply by inv_scale
"fmul v0.4s, v0.4s, v31.4s",
"fmul v1.4s, v1.4s, v31.4s",
"fmul v2.4s, v2.4s, v31.4s",
"fmul v3.4s, v3.4s, v31.4s",
// Round to nearest integer (f32)
"frintn v0.4s, v0.4s",
"frintn v1.4s, v1.4s",
"frintn v2.4s, v2.4s",
"frintn v3.4s, v3.4s",
// Convert to i32
"fcvtzs v0.4s, v0.4s",
"fcvtzs v1.4s, v1.4s",
"fcvtzs v2.4s, v2.4s",
"fcvtzs v3.4s, v3.4s",
// Narrow i32 -> i16 (saturating)
"sqxtn v0.4h, v0.4s",
"sqxtn2 v0.8h, v1.4s",
"sqxtn v2.4h, v2.4s",
"sqxtn2 v2.8h, v3.4s",
// Narrow i16 -> i8 (saturating)
"sqxtn v0.8b, v0.8h",
"sqxtn2 v0.16b, v2.8h",
// Store 16 x i8
"str q0, [{d}]",
inv = in(reg) inv_scale.to_bits(),
s = in(reg) s,
d = in(reg) d,
out("v0") _, out("v1") _, out("v2") _, out("v3") _,
out("v31") _,
);
}
i += 16;
}
// Scalar remainder
for j in i..n {
let q = (src[j] * inv_scale).round();
dst[j] = q.clamp(-128.0, 127.0) as i8;
}
}
#[cfg(not(target_arch = "aarch64"))]
{
for i in 0..n {
let q = (src[i] * inv_scale).round();
dst[i] = q.clamp(-128.0, 127.0) as i8;
}
}
}
/// Dequantize int8 to f32: `dst[i] = (src[i] - zero) * scale`.
///
/// `zero` is the zero-point offset. For symmetric quantization, `zero = 0`.
pub fn cast_i8_f32(dst: &mut [f32], src: &[i8], scale: f32, zero: i8) {
let n = dst.len().min(src.len());
#[cfg(target_arch = "aarch64")]
{
let mut i = 0;
// Process 16 i8 -> 16 f32 per iteration
while i + 16 <= n {
unsafe {
let s = src.as_ptr().add(i);
let d = dst.as_mut_ptr().add(i);
std::arch::asm!(
// Broadcast scale and zero
"dup v30.4s, {sc:w}",
"dup v31.16b, {z:w}",
// Load 16 x i8
"ldr q0, [{s}]",
// Subtract zero point (i8 saturating)
"sqsub v0.16b, v0.16b, v31.16b",
// Widen i8 -> i16
"sxtl v1.8h, v0.8b",
"sxtl2 v2.8h, v0.16b",
// Widen i16 -> i32
"sxtl v3.4s, v1.4h",
"sxtl2 v4.4s, v1.8h",
"sxtl v5.4s, v2.4h",
"sxtl2 v6.4s, v2.8h",
// Convert i32 -> f32
"scvtf v3.4s, v3.4s",
"scvtf v4.4s, v4.4s",
"scvtf v5.4s, v5.4s",
"scvtf v6.4s, v6.4s",
// Multiply by scale
"fmul v3.4s, v3.4s, v30.4s",
"fmul v4.4s, v4.4s, v30.4s",
"fmul v5.4s, v5.4s, v30.4s",
"fmul v6.4s, v6.4s, v30.4s",
// Store 16 x f32
"stp q3, q4, [{d}]",
"stp q5, q6, [{d}, #32]",
sc = in(reg) scale.to_bits(),
z = in(reg) zero as i32,
s = in(reg) s,
d = in(reg) d,
out("v0") _, out("v1") _, out("v2") _,
out("v3") _, out("v4") _, out("v5") _, out("v6") _,
out("v30") _, out("v31") _,
);
}
i += 16;
}
// Scalar remainder
for j in i..n {
dst[j] = ((src[j] as i32 - zero as i32) as f32) * scale;
}
}
#[cfg(not(target_arch = "aarch64"))]
{
for i in 0..n {
dst[i] = ((src[i] as i32 - zero as i32) as f32) * scale;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn quantize_dequantize_symmetric() {
let src: Vec<f32> = (0..32).map(|i| (i as f32 - 16.0) * 0.1).collect();
let scale = 0.1_f32 / 127.0 * 16.0; // rough scale
let mut qi8 = vec![0i8; 32];
let mut dst = vec![0.0f32; 32];
cast_f32_i8(&mut qi8, &src, scale);
cast_i8_f32(&mut dst, &qi8, scale, 0);
for i in 0..32 {
let err = (dst[i] - src[i]).abs();
assert!(
err < scale * 1.5, // within ~1 quantization step
"mismatch at {i}: src={}, dst={}, err={}",
src[i],
dst[i],
err,
);
}
}
#[test]
fn clamp_overflow() {
let src = [1000.0f32, -1000.0];
let mut dst = [0i8; 2];
cast_f32_i8(&mut dst, &src, 1.0);
assert_eq!(dst[0], 127);
assert_eq!(dst[1], -128);
}
}
//! Quantization operations: f32 <-> int8 with scale/zero-point.
//!
//! Uses NEON vectorized paths on aarch64 for bulk conversions.
/// Quantize f32 to int8: `dst[i] = clamp(round(src[i] / scale), -128, 127)`.
///
/// `scale` is the quantization step size. Typically computed as
/// `(max - min) / 255` for asymmetric or `max(abs) / 127` for symmetric.
/// Dequantize int8 to f32: `dst[i] = (src[i] - zero) * scale`.
///
/// `zero` is the zero-point offset. For symmetric quantization, `zero = 0`.