#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
use super::math;
use super::reduce;
pub fn softmax(x: &mut [f32]) {
if x.is_empty() {
return;
}
let len = x.len();
let m = reduce::max(x);
let mut i = 0;
#[cfg(target_arch = "aarch64")]
let sum = {
unsafe {
let mv = vdupq_n_f32(m);
let mut s0 = vdupq_n_f32(0.0);
let mut s1 = vdupq_n_f32(0.0);
let mut s2 = vdupq_n_f32(0.0);
let mut s3 = vdupq_n_f32(0.0);
let ptr = x.as_mut_ptr();
while i + 16 <= len {
let e0 = math::exp_neon(vsubq_f32(vld1q_f32(ptr.add(i)), mv));
let e1 = math::exp_neon(vsubq_f32(vld1q_f32(ptr.add(i + 4)), mv));
let e2 = math::exp_neon(vsubq_f32(vld1q_f32(ptr.add(i + 8)), mv));
let e3 = math::exp_neon(vsubq_f32(vld1q_f32(ptr.add(i + 12)), mv));
vst1q_f32(ptr.add(i), e0);
vst1q_f32(ptr.add(i + 4), e1);
vst1q_f32(ptr.add(i + 8), e2);
vst1q_f32(ptr.add(i + 12), e3);
s0 = vaddq_f32(s0, e0);
s1 = vaddq_f32(s1, e1);
s2 = vaddq_f32(s2, e2);
s3 = vaddq_f32(s3, e3);
i += 16;
}
while i + 4 <= len {
let e = math::exp_neon(vsubq_f32(vld1q_f32(ptr.add(i)), mv));
vst1q_f32(ptr.add(i), e);
s0 = vaddq_f32(s0, e);
i += 4;
}
let mut s = vaddvq_f32(vaddq_f32(vaddq_f32(s0, s1), vaddq_f32(s2, s3)));
while i < len {
let e = math::exp_scalar(x[i] - m);
x[i] = e;
s += e;
i += 1;
}
s
}
};
#[cfg(not(target_arch = "aarch64"))]
let sum = {
let mut s = 0.0f32;
for v in x.iter_mut() {
let e = math::exp_scalar(*v - m);
*v = e;
s += e;
}
s
};
if sum == 0.0 {
return;
}
let inv_s = 1.0 / sum;
i = 0;
#[cfg(target_arch = "aarch64")]
{
unsafe {
let inv_v = vdupq_n_f32(inv_s);
let ptr = x.as_mut_ptr();
while i + 16 <= len {
vst1q_f32(ptr.add(i), vmulq_f32(vld1q_f32(ptr.add(i)), inv_v));
vst1q_f32(ptr.add(i + 4), vmulq_f32(vld1q_f32(ptr.add(i + 4)), inv_v));
vst1q_f32(ptr.add(i + 8), vmulq_f32(vld1q_f32(ptr.add(i + 8)), inv_v));
vst1q_f32(
ptr.add(i + 12),
vmulq_f32(vld1q_f32(ptr.add(i + 12)), inv_v),
);
i += 16;
}
while i + 4 <= len {
vst1q_f32(ptr.add(i), vmulq_f32(vld1q_f32(ptr.add(i)), inv_v));
i += 4;
}
}
}
while i < len {
x[i] *= inv_s;
i += 1;
}
}
pub fn normalize(out: &mut [f32], x: &[f32], weight: &[f32], eps: f32) {
let len = x.len();
assert_eq!(len, weight.len(), "normalize: x and weight length mismatch");
assert!(out.len() >= len, "normalize: output buffer too small");
if len == 0 {
return;
}
let mut i = 0;
#[cfg(target_arch = "aarch64")]
let ss = {
unsafe {
let mut a0 = vdupq_n_f32(0.0);
let mut a1 = vdupq_n_f32(0.0);
let mut a2 = vdupq_n_f32(0.0);
let mut a3 = vdupq_n_f32(0.0);
while i + 16 <= len {
let v0 = vld1q_f32(x.as_ptr().add(i));
let v1 = vld1q_f32(x.as_ptr().add(i + 4));
let v2 = vld1q_f32(x.as_ptr().add(i + 8));
let v3 = vld1q_f32(x.as_ptr().add(i + 12));
a0 = vfmaq_f32(a0, v0, v0);
a1 = vfmaq_f32(a1, v1, v1);
a2 = vfmaq_f32(a2, v2, v2);
a3 = vfmaq_f32(a3, v3, v3);
i += 16;
}
while i + 4 <= len {
let v = vld1q_f32(x.as_ptr().add(i));
a0 = vfmaq_f32(a0, v, v);
i += 4;
}
let sum4 = vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3));
let mut s = vaddvq_f32(sum4);
while i < len {
s += x[i] * x[i];
i += 1;
}
s
}
};
#[cfg(not(target_arch = "aarch64"))]
let ss = {
let mut s = 0.0f32;
for &v in x {
s += v * v;
}
s
};
let rms_inv = 1.0 / (ss / len as f32 + eps).sqrt();
i = 0;
#[cfg(target_arch = "aarch64")]
{
unsafe {
let scale = vdupq_n_f32(rms_inv);
while i + 16 <= len {
let x0 = vld1q_f32(x.as_ptr().add(i));
let x1 = vld1q_f32(x.as_ptr().add(i + 4));
let x2 = vld1q_f32(x.as_ptr().add(i + 8));
let x3 = vld1q_f32(x.as_ptr().add(i + 12));
let w0 = vld1q_f32(weight.as_ptr().add(i));
let w1 = vld1q_f32(weight.as_ptr().add(i + 4));
let w2 = vld1q_f32(weight.as_ptr().add(i + 8));
let w3 = vld1q_f32(weight.as_ptr().add(i + 12));
vst1q_f32(out.as_mut_ptr().add(i), vmulq_f32(vmulq_f32(x0, scale), w0));
vst1q_f32(
out.as_mut_ptr().add(i + 4),
vmulq_f32(vmulq_f32(x1, scale), w1),
);
vst1q_f32(
out.as_mut_ptr().add(i + 8),
vmulq_f32(vmulq_f32(x2, scale), w2),
);
vst1q_f32(
out.as_mut_ptr().add(i + 12),
vmulq_f32(vmulq_f32(x3, scale), w3),
);
i += 16;
}
while i + 4 <= len {
let vx = vld1q_f32(x.as_ptr().add(i));
let vw = vld1q_f32(weight.as_ptr().add(i));
vst1q_f32(out.as_mut_ptr().add(i), vmulq_f32(vmulq_f32(vx, scale), vw));
i += 4;
}
}
}
while i < len {
out[i] = x[i] * rms_inv * weight[i];
i += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_softmax_sums_to_one() {
let mut v = vec![1.0, 2.0, 3.0, 4.0, 5.0];
softmax(&mut v);
let s: f32 = v.iter().sum();
assert!((s - 1.0).abs() < 1e-5);
for i in 1..v.len() {
assert!(v[i] >= v[i - 1]);
}
}
#[test]
fn test_softmax_single() {
let mut v = vec![42.0];
softmax(&mut v);
assert!((v[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_softmax_equal() {
let mut v = vec![1.0; 4];
softmax(&mut v);
for &val in &v {
assert!((val - 0.25).abs() < 1e-5);
}
}
#[test]
fn test_normalize_unit_weight() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let w = vec![1.0; 4];
let mut out = vec![0.0; 4];
normalize(&mut out, &x, &w, 1e-5);
let ss: f32 = out.iter().map(|v| v * v).sum::<f32>() / out.len() as f32;
assert!((ss - 1.0).abs() < 1e-3);
}
#[test]
fn test_normalize_scaling() {
let x = vec![2.0; 8];
let w = vec![3.0; 8];
let mut out = vec![0.0; 8];
normalize(&mut out, &x, &w, 1e-5);
for &v in &out {
assert!((v - 3.0).abs() < 1e-3);
}
}
}