use super::goldilocks::*;
#[inline(never)]
pub fn poseidon2_permute(state: &mut [u64; 16], rc: &[u64; 144], diag: &[u64; 16]) {
let mut ci = 0usize;
mds_external(state);
for _ in 0..4 {
add_rc_full(state, rc, ci);
ci += 16;
sbox7_full(state);
mds_external(state);
}
for r in 0..16 {
state[0] = gl_add(state[0], rc[128 + r]);
state[0] = gl_inv(state[0]);
mds_internal(state, diag);
}
for _ in 0..4 {
add_rc_full(state, rc, ci);
ci += 16;
sbox7_full(state);
mds_external(state);
}
}
#[inline(always)]
fn add_rc_full(state: &mut [u64; 16], rc: &[u64; 144], offset: usize) {
for i in 0..16 {
state[i] = gl_add(state[i], rc[offset + i]);
}
}
#[inline(always)]
fn sbox7_full(state: &mut [u64; 16]) {
let mut x2 = [0u64; 16];
for i in 0..16 {
x2[i] = gl_mul(state[i], state[i]);
}
let mut x3 = [0u64; 16];
for i in 0..16 {
x3[i] = gl_mul(x2[i], state[i]);
}
let mut x4 = [0u64; 16];
for i in 0..16 {
x4[i] = gl_mul(x2[i], x2[i]);
}
for i in 0..16 {
state[i] = gl_mul(x3[i], x4[i]);
}
}
#[inline(always)]
fn mds_external(state: &mut [u64; 16]) {
apply_m4(&mut state[0..4]);
apply_m4(&mut state[4..8]);
apply_m4(&mut state[8..12]);
apply_m4(&mut state[12..16]);
for k in 0..4 {
let s = gl_add(
gl_add(state[k], state[k + 4]),
gl_add(state[k + 8], state[k + 12]),
);
state[k] = gl_add(state[k], s);
state[k + 4] = gl_add(state[k + 4], s);
state[k + 8] = gl_add(state[k + 8], s);
state[k + 12] = gl_add(state[k + 12], s);
}
}
#[inline(always)]
fn apply_m4(x: &mut [u64]) {
let t01 = gl_add(x[0], x[1]);
let t23 = gl_add(x[2], x[3]);
let t0123 = gl_add(t01, t23);
let t01123 = gl_add(t0123, x[1]);
let t01233 = gl_add(t0123, x[3]);
let x0d = gl_add(x[0], x[0]); let x2d = gl_add(x[2], x[2]);
x[3] = gl_add(t01233, x0d);
x[1] = gl_add(t01123, x2d);
x[0] = gl_add(t01123, t01);
x[2] = gl_add(t01233, t23);
}
#[inline(always)]
fn mds_internal(state: &mut [u64; 16], diag: &[u64; 16]) {
let mut sum = state[0];
for s in &state[1..] {
sum = gl_add(sum, *s);
}
for i in 0..16 {
state[i] = gl_add(gl_mul(diag[i], state[i]), sum);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_rc() -> [u64; 144] {
let mut rc = [0u64; 144];
for i in 0..144 {
rc[i] = (i as u64 + 1).wrapping_mul(0x9E3779B97F4A7C15);
}
rc
}
fn test_diag() -> [u64; 16] {
let mut d = [0u64; 16];
for i in 0..16 {
d[i] = (i as u64 + 2).wrapping_mul(0x517CC1B727220A95);
}
d
}
#[test]
fn permute_changes_state() {
let mut state: [u64; 16] = core::array::from_fn(|i| i as u64 + 1);
let original = state;
poseidon2_permute(&mut state, &test_rc(), &test_diag());
assert_ne!(state, original);
}
#[test]
fn permute_deterministic() {
let rc = test_rc();
let diag = test_diag();
let mut s1: [u64; 16] = core::array::from_fn(|i| i as u64 + 1);
let mut s2 = s1;
poseidon2_permute(&mut s1, &rc, &diag);
poseidon2_permute(&mut s2, &rc, &diag);
assert_eq!(s1, s2);
}
#[test]
fn permute_diffusion() {
let rc = test_rc();
let diag = test_diag();
let mut s1: [u64; 16] = core::array::from_fn(|i| i as u64 + 100);
let mut s2 = s1;
s2[0] += 1;
poseidon2_permute(&mut s1, &rc, &diag);
poseidon2_permute(&mut s2, &rc, &diag);
for i in 0..16 {
assert_ne!(s1[i], s2[i], "element {} unchanged after input tweak", i);
}
}
}