pub fn complex_mul_acc(acc: &mut [f32], a: &[f32], b: &[f32]) {
assert_eq!(a.len(), b.len(), "a and b must have the same length");
assert_eq!(a.len() % 2, 0, "complex slices must have even length");
let n = acc.len().min(a.len());
assert_eq!(n % 2, 0, "acc length must be even");
let mut i = 0;
#[cfg(target_arch = "aarch64")]
{
unsafe {
let pa = a.as_ptr();
let pb = b.as_ptr();
let pc = acc.as_mut_ptr();
while i + 16 <= n {
let mut c0 = core::arch::aarch64::vld1q_f32(pc.add(i));
let mut c1 = core::arch::aarch64::vld1q_f32(pc.add(i + 4));
let mut c2 = core::arch::aarch64::vld1q_f32(pc.add(i + 8));
let mut c3 = core::arch::aarch64::vld1q_f32(pc.add(i + 12));
let a0 = core::arch::aarch64::vld1q_f32(pa.add(i));
let a1 = core::arch::aarch64::vld1q_f32(pa.add(i + 4));
let a2 = core::arch::aarch64::vld1q_f32(pa.add(i + 8));
let a3 = core::arch::aarch64::vld1q_f32(pa.add(i + 12));
let b0 = core::arch::aarch64::vld1q_f32(pb.add(i));
let b1 = core::arch::aarch64::vld1q_f32(pb.add(i + 4));
let b2 = core::arch::aarch64::vld1q_f32(pb.add(i + 8));
let b3 = core::arch::aarch64::vld1q_f32(pb.add(i + 12));
core::arch::asm!(
"fcmla {c0:v}.4s, {a0:v}.4s, {b0:v}.4s, #0",
"fcmla {c1:v}.4s, {a1:v}.4s, {b1:v}.4s, #0",
"fcmla {c2:v}.4s, {a2:v}.4s, {b2:v}.4s, #0",
"fcmla {c3:v}.4s, {a3:v}.4s, {b3:v}.4s, #0",
"fcmla {c0:v}.4s, {a0:v}.4s, {b0:v}.4s, #90",
"fcmla {c1:v}.4s, {a1:v}.4s, {b1:v}.4s, #90",
"fcmla {c2:v}.4s, {a2:v}.4s, {b2:v}.4s, #90",
"fcmla {c3:v}.4s, {a3:v}.4s, {b3:v}.4s, #90",
c0 = inout(vreg) c0, c1 = inout(vreg) c1,
c2 = inout(vreg) c2, c3 = inout(vreg) c3,
a0 = in(vreg) a0, a1 = in(vreg) a1,
a2 = in(vreg) a2, a3 = in(vreg) a3,
b0 = in(vreg) b0, b1 = in(vreg) b1,
b2 = in(vreg) b2, b3 = in(vreg) b3,
);
core::arch::aarch64::vst1q_f32(pc.add(i), c0);
core::arch::aarch64::vst1q_f32(pc.add(i + 4), c1);
core::arch::aarch64::vst1q_f32(pc.add(i + 8), c2);
core::arch::aarch64::vst1q_f32(pc.add(i + 12), c3);
i += 16;
}
while i + 4 <= n {
let mut c0 = core::arch::aarch64::vld1q_f32(pc.add(i));
let a0 = core::arch::aarch64::vld1q_f32(pa.add(i));
let b0 = core::arch::aarch64::vld1q_f32(pb.add(i));
core::arch::asm!(
"fcmla {c:v}.4s, {a:v}.4s, {b:v}.4s, #0",
"fcmla {c:v}.4s, {a:v}.4s, {b:v}.4s, #90",
c = inout(vreg) c0, a = in(vreg) a0, b = in(vreg) b0,
);
core::arch::aarch64::vst1q_f32(pc.add(i), c0);
i += 4;
}
}
}
while i + 1 < n {
let ar = a[i];
let ai = a[i + 1];
let br = b[i];
let bi = b[i + 1];
acc[i] += ar * br - ai * bi;
acc[i + 1] += ar * bi + ai * br;
i += 2;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn complex_mul_acc_basic() {
let a = [1.0f32, 2.0];
let b = [3.0f32, 4.0];
let mut acc = [0.0f32, 0.0];
complex_mul_acc(&mut acc, &a, &b);
assert!((acc[0] - (-5.0)).abs() < 1e-6);
assert!((acc[1] - 10.0).abs() < 1e-6);
}
#[test]
fn complex_mul_acc_accumulates() {
let a = [1.0f32, 0.0, 0.0, 1.0];
let b = [1.0f32, 0.0, 0.0, 1.0];
let mut acc = [10.0f32, 20.0, 30.0, 40.0];
complex_mul_acc(&mut acc, &a, &b);
assert!((acc[0] - 11.0).abs() < 1e-6);
assert!((acc[1] - 20.0).abs() < 1e-6);
assert!((acc[2] - 29.0).abs() < 1e-6);
assert!((acc[3] - 40.0).abs() < 1e-6);
}
#[test]
fn complex_mul_acc_large() {
let n = 4096;
let a: Vec<f32> = (0..n).map(|i| (i % 7) as f32 * 0.1).collect();
let b: Vec<f32> = (0..n).map(|i| (i % 11) as f32 * 0.1).collect();
let mut acc_neon = vec![0.0f32; n];
let mut acc_scalar = vec![0.0f32; n];
complex_mul_acc(&mut acc_neon, &a, &b);
for i in (0..n).step_by(2) {
acc_scalar[i] += a[i] * b[i] - a[i + 1] * b[i + 1];
acc_scalar[i + 1] += a[i] * b[i + 1] + a[i + 1] * b[i];
}
for i in 0..n {
assert!(
(acc_neon[i] - acc_scalar[i]).abs() < 1e-4,
"mismatch at {}: {} vs {}",
i,
acc_neon[i],
acc_scalar[i]
);
}
}
}