use crate::constants::ROUND_CONSTANTS;
use crate::field::{Goldilocks, matmul_internal, mds_light_permutation};
use crate::trace::{FullRoundWitnesses, RoundVisitor};
const NUM_EXTERNAL: usize = 128;
pub fn permute(state: &mut [Goldilocks; 16]) {
permute_with_constants(state, &ROUND_CONSTANTS);
}
pub fn permute_with_constants(state: &mut [Goldilocks; 16], constants: &[Goldilocks]) {
let (external, internal) = constants.split_at(NUM_EXTERNAL);
let (initial_rc, terminal_rc) = external.split_at(NUM_EXTERNAL / 2);
mds_light_permutation(state);
for round in 0..4 {
let rc = &initial_rc[round * 16..(round + 1) * 16];
for i in 0..16 {
state[i] += rc[i];
state[i] = state[i].pow7();
}
mds_light_permutation(state);
}
for round in 0..16 {
state[0] += internal[round];
state[0] = state[0].inv();
matmul_internal(state);
}
for round in 0..4 {
let rc = &terminal_rc[round * 16..(round + 1) * 16];
for i in 0..16 {
state[i] += rc[i];
state[i] = state[i].pow7();
}
mds_light_permutation(state);
}
}
pub(crate) fn permute_one_round(state: &mut [Goldilocks; 16], round: usize) {
let (external, internal) = ROUND_CONSTANTS.split_at(NUM_EXTERNAL);
let (initial_rc, terminal_rc) = external.split_at(NUM_EXTERNAL / 2);
if round < 4 {
let rc = &initial_rc[round * 16..(round + 1) * 16];
for i in 0..16 {
state[i] += rc[i];
state[i] = state[i].pow7();
}
mds_light_permutation(state);
} else if round < 20 {
let pr = round - 4;
state[0] += internal[pr];
state[0] = state[0].inv();
matmul_internal(state);
} else {
let tr = round - 20;
let rc = &terminal_rc[tr * 16..(tr + 1) * 16];
for i in 0..16 {
state[i] += rc[i];
state[i] = state[i].pow7();
}
mds_light_permutation(state);
}
}
pub fn permute_traced<V: RoundVisitor>(state: &mut [Goldilocks; 16], visitor: &mut V) {
let (external, internal) = ROUND_CONSTANTS.split_at(NUM_EXTERNAL);
let (initial_rc, terminal_rc) = external.split_at(NUM_EXTERNAL / 2);
mds_light_permutation(state);
for round in 0..4u8 {
let rc = &initial_rc[round as usize * 16..(round as usize + 1) * 16];
let witnesses = full_round_step(state, rc);
visitor.full_round(round, state, &witnesses);
}
for round in 0..16u8 {
state[0] += internal[round as usize];
state[0] = state[0].inv();
let sbox_out = state[0];
matmul_internal(state);
visitor.partial_round(round, state, sbox_out);
}
for round in 0..4u8 {
let rc = &terminal_rc[round as usize * 16..(round as usize + 1) * 16];
let witnesses = full_round_step(state, rc);
visitor.full_round(round + 4, state, &witnesses);
}
}
#[inline]
fn full_round_step(state: &mut [Goldilocks; 16], rc: &[Goldilocks]) -> FullRoundWitnesses {
let mut witnesses = [[Goldilocks::ZERO; 2]; 16];
for i in 0..16 {
state[i] += rc[i];
let x2 = state[i] * state[i];
let x3 = x2 * state[i];
let x4 = x2 * x2;
witnesses[i] = [x2, x3];
state[i] = x3 * x4; }
mds_light_permutation(state);
witnesses
}
#[cfg(test)]
mod tests {
use super::*;
use crate::trace::{FullRoundWitnesses, RoundVisitor};
struct RoundCounter {
full: u8,
partial: u8,
}
impl RoundVisitor for RoundCounter {
fn full_round(&mut self, _: u8, _: &[Goldilocks; 16], _: &FullRoundWitnesses) {
self.full += 1;
}
fn partial_round(&mut self, _: u8, _: &[Goldilocks; 16], _: Goldilocks) {
self.partial += 1;
}
}
#[test]
fn traced_matches_plain() {
let mut state_plain = [Goldilocks::new(42); 16];
let mut state_traced = state_plain;
let mut counter = RoundCounter { full: 0, partial: 0 };
permute(&mut state_plain);
permute_traced(&mut state_traced, &mut counter);
assert_eq!(state_plain, state_traced);
assert_eq!(counter.full, 8);
assert_eq!(counter.partial, 16);
}
#[test]
fn trace_round_indices_are_ordered() {
struct IndexRecorder {
full_indices: [u8; 8],
full_count: usize,
partial_indices: [u8; 16],
partial_count: usize,
}
impl RoundVisitor for IndexRecorder {
fn full_round(&mut self, index: u8, _: &[Goldilocks; 16], _: &FullRoundWitnesses) {
self.full_indices[self.full_count] = index;
self.full_count += 1;
}
fn partial_round(&mut self, index: u8, _: &[Goldilocks; 16], _: Goldilocks) {
self.partial_indices[self.partial_count] = index;
self.partial_count += 1;
}
}
let mut state = [Goldilocks::ZERO; 16];
let mut rec = IndexRecorder {
full_indices: [0; 8],
full_count: 0,
partial_indices: [0; 16],
partial_count: 0,
};
permute_traced(&mut state, &mut rec);
for i in 0..8u8 {
assert_eq!(rec.full_indices[i as usize], i, "full round index {i}");
}
for i in 0..16u8 {
assert_eq!(rec.partial_indices[i as usize], i, "partial round index {i}");
}
}
#[test]
fn full_round_witnesses_correct() {
struct WitnessChecker;
impl RoundVisitor for WitnessChecker {
fn full_round(&mut self, _: u8, _: &[Goldilocks; 16], _: &FullRoundWitnesses) {}
fn partial_round(&mut self, _: u8, _: &[Goldilocks; 16], _: Goldilocks) {}
}
let mut state = [Goldilocks::new(7); 16];
permute_traced(&mut state, &mut WitnessChecker);
}
#[test]
fn permutation_is_deterministic() {
let mut s1 = [Goldilocks::ZERO; 16];
let mut s2 = [Goldilocks::ZERO; 16];
permute(&mut s1);
permute(&mut s2);
assert_eq!(s1, s2);
}
#[test]
fn permutation_changes_state() {
let mut state = [Goldilocks::ZERO; 16];
let original = state;
permute(&mut state);
assert_ne!(state, original);
}
#[test]
fn different_inputs_different_outputs() {
let mut s1 = [Goldilocks::ZERO; 16];
let mut s2 = [Goldilocks::ZERO; 16];
s2[0] = Goldilocks::new(1);
permute(&mut s1);
permute(&mut s2);
assert_ne!(s1, s2);
}
}