// ---------------------------------------------------------------------------
// Rotary Positional Embedding (RoPE) -- NEON fast-path + scalar fallback
// ---------------------------------------------------------------------------
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
/// Apply rotary positional embedding.
///
/// `x` contains pairs of values (x0, x1) for each dimension pair.
/// `freqs` contains the base frequencies for each pair (length = x.len() / 2).
/// `pos` is the token position.
///
/// For each pair (x0, x1) with frequency f:
/// theta = pos * f
/// out[2*i] = x0 * cos(theta) - x1 * sin(theta)
/// out[2*i+1] = x0 * sin(theta) + x1 * cos(theta)
pub fn rotate(out: &mut [f32], x: &[f32], freqs: &[f32], pos: usize) {
let dim = x.len();
assert!(dim.is_multiple_of(2), "rotate: dimension must be even");
let n_pairs = dim / 2;
assert!(freqs.len() >= n_pairs, "rotate: not enough frequencies");
assert!(out.len() >= dim, "rotate: output buffer too small");
let pos_f = pos as f32;
let mut i = 0; // pair index
#[cfg(target_arch = "aarch64")]
{
// Process 2 pairs at a time (= 4 f32 values from x)
unsafe {
while i + 2 <= n_pairs {
// Load two frequencies
let f0 = freqs[i];
let f1 = freqs[i + 1];
let theta0 = pos_f * f0;
let theta1 = pos_f * f1;
let cos0 = theta0.cos();
let sin0 = theta0.sin();
let cos1 = theta1.cos();
let sin1 = theta1.sin();
// Load 4 values: [x0_a, x1_a, x0_b, x1_b]
let xv = vld1q_f32(x.as_ptr().add(i * 2));
// Build cos/sin vectors matching the interleaved layout
let cos_v = {
let arr: [f32; 4] = [cos0, cos0, cos1, cos1];
vld1q_f32(arr.as_ptr())
};
let sin_v = {
let arr: [f32; 4] = [sin0, sin0, sin1, sin1];
vld1q_f32(arr.as_ptr())
};
// Build the "swapped and negated" version:
// [-x1_a, x0_a, -x1_b, x0_b]
// We negate odd-indexed originals and swap pairs.
let x_arr: [f32; 4] = [
vgetq_lane_f32::<0>(xv),
vgetq_lane_f32::<1>(xv),
vgetq_lane_f32::<2>(xv),
vgetq_lane_f32::<3>(xv),
];
let swapped_arr: [f32; 4] = [-x_arr[1], x_arr[0], -x_arr[3], x_arr[2]];
let swapped = vld1q_f32(swapped_arr.as_ptr());
// result = x * cos + swapped * sin
let result = vfmaq_f32(vmulq_f32(xv, cos_v), swapped, sin_v);
vst1q_f32(out.as_mut_ptr().add(i * 2), result);
i += 2;
}
}
}
// Scalar tail / fallback
while i < n_pairs {
let theta = pos_f * freqs[i];
let (sin_t, cos_t) = theta.sin_cos();
let x0 = x[2 * i];
let x1 = x[2 * i + 1];
out[2 * i] = x0 * cos_t - x1 * sin_t;
out[2 * i + 1] = x0 * sin_t + x1 * cos_t;
i += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rope_identity_at_zero() {
// At pos=0, theta=0, cos=1, sin=0, so output == input.
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let freqs = vec![0.1, 0.01, 0.001, 0.0001];
let mut out = vec![0.0; 8];
rotate(&mut out, &x, &freqs, 0);
for i in 0..x.len() {
assert!(
(out[i] - x[i]).abs() < 1e-5,
"mismatch at {}: {} vs {}",
i,
out[i],
x[i]
);
}
}
#[test]
fn test_rope_preserves_norm() {
// RoPE is a rotation, so the norm of each pair should be preserved.
let x = vec![3.0, 4.0, 1.0, 0.0];
let freqs = vec![1.0, 2.0];
let mut out = vec![0.0; 4];
rotate(&mut out, &x, &freqs, 7);
let norm_in_0 = (x[0] * x[0] + x[1] * x[1]).sqrt();
let norm_out_0 = (out[0] * out[0] + out[1] * out[1]).sqrt();
assert!(
(norm_in_0 - norm_out_0).abs() < 1e-4,
"norm not preserved: {} vs {}",
norm_in_0,
norm_out_0
);
let norm_in_1 = (x[2] * x[2] + x[3] * x[3]).sqrt();
let norm_out_1 = (out[2] * out[2] + out[3] * out[3]).sqrt();
assert!(
(norm_in_1 - norm_out_1).abs() < 1e-4,
"norm not preserved: {} vs {}",
norm_in_1,
norm_out_1
);
}
#[test]
fn test_rope_rotation() {
// Verify explicit rotation for a single pair.
let x = vec![1.0, 0.0];
let freq = 1.0;
let pos = 1;
let theta = pos as f32 * freq;
let mut out = vec![0.0; 2];
rotate(&mut out, &x, &[freq], pos);
assert!((out[0] - theta.cos()).abs() < 1e-5);
assert!((out[1] - theta.sin()).abs() < 1e-5);
}
#[test]
fn test_rope_odd_pairs() {
// 3 pairs = 6 elements, tests scalar tail when NEON handles 2 pairs.
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let freqs = vec![0.5, 1.0, 1.5];
let mut out = vec![0.0; 6];
rotate(&mut out, &x, &freqs, 3);
// Just verify norms are preserved
for p in 0..3 {
let ni = (x[2 * p] * x[2 * p] + x[2 * p + 1] * x[2 * p + 1]).sqrt();
let no = (out[2 * p] * out[2 * p] + out[2 * p + 1] * out[2 * p + 1]).sqrt();
assert!((ni - no).abs() < 1e-4, "pair {}: {} vs {}", p, ni, no);
}
}
}
// ---------------------------------------------------------------------------
// Rotary Positional Embedding (RoPE) -- NEON fast-path + scalar fallback
// ---------------------------------------------------------------------------
use *;
/// Apply rotary positional embedding.
///
/// `x` contains pairs of values (x0, x1) for each dimension pair.
/// `freqs` contains the base frequencies for each pair (length = x.len() / 2).
/// `pos` is the token position.
///
/// For each pair (x0, x1) with frequency f:
/// theta = pos * f
/// out[2*i] = x0 * cos(theta) - x1 * sin(theta)
/// out[2*i+1] = x0 * sin(theta) + x1 * cos(theta)