const P: u64 = crate::field::goldilocks::MODULUS;
const T: usize = 8;
const RATE: usize = 4;
const R_F: usize = 8;
const R_P: usize = 22;
#[cfg(test)]
const ALPHA: u64 = 7;
const DIAG: [u64; T] = [2, 3, 5, 9, 17, 33, 65, 129];
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct GoldilocksField(pub u64);
impl GoldilocksField {
pub const ZERO: Self = Self(0);
pub const ONE: Self = Self(1);
#[inline]
pub fn new(v: u64) -> Self {
Self(v % P)
}
#[inline]
fn reduce128(x: u128) -> Self {
let lo = x as u64;
let hi = (x >> 64) as u64;
let hi_shifted = (hi as u128) * ((1u128 << 32) - 1);
let sum = lo as u128 + hi_shifted;
let lo2 = sum as u64;
let hi2 = (sum >> 64) as u64;
if hi2 == 0 {
Self(if lo2 >= P { lo2 - P } else { lo2 })
} else {
let r = lo2 as u128 + (hi2 as u128) * ((1u128 << 32) - 1);
let lo3 = r as u64;
let hi3 = (r >> 64) as u64;
if hi3 == 0 {
Self(if lo3 >= P { lo3 - P } else { lo3 })
} else {
let v = lo3.wrapping_add(hi3.wrapping_mul(u32::MAX as u64));
Self(if v >= P { v - P } else { v })
}
}
}
#[inline]
pub fn add(self, rhs: Self) -> Self {
let (sum, carry) = self.0.overflowing_add(rhs.0);
if carry {
let r = sum + (u32::MAX as u64);
Self(if r >= P { r - P } else { r })
} else {
Self(if sum >= P { sum - P } else { sum })
}
}
#[inline]
pub fn sub(self, rhs: Self) -> Self {
if self.0 >= rhs.0 {
Self(self.0 - rhs.0)
} else {
Self(P - rhs.0 + self.0)
}
}
#[inline]
pub fn mul(self, rhs: Self) -> Self {
Self::reduce128((self.0 as u128) * (rhs.0 as u128))
}
pub fn pow(self, mut exp: u64) -> Self {
let mut base = self;
let mut acc = Self::ONE;
while exp > 0 {
if exp & 1 == 1 {
acc = acc.mul(base);
}
base = base.mul(base);
exp >>= 1;
}
acc
}
#[inline]
pub fn sbox(self) -> Self {
let x2 = self.mul(self);
let x3 = x2.mul(self);
let x6 = x3.mul(x3);
x6.mul(self)
}
}
const TOTAL_ROUNDS: usize = R_F + R_P;
fn round_constant(round: usize, element: usize) -> GoldilocksField {
let tag = format!("Poseidon2-Goldilocks-t8-RF8-RP22-{round}-{element}");
let digest = blake3::hash(tag.as_bytes());
let bytes: [u8; 8] = digest.as_bytes()[..8].try_into().unwrap_or([0u8; 8]);
GoldilocksField::new(u64::from_le_bytes(bytes))
}
fn generate_all_constants() -> Vec<GoldilocksField> {
let mut constants = Vec::new();
for r in 0..TOTAL_ROUNDS {
let is_full = r < R_F / 2 || r >= R_F / 2 + R_P;
if is_full {
for e in 0..T {
constants.push(round_constant(r, e));
}
} else {
constants.push(round_constant(r, 0));
}
}
constants
}
fn cached_round_constants() -> &'static [GoldilocksField] {
static CONSTANTS: std::sync::OnceLock<Vec<GoldilocksField>> = std::sync::OnceLock::new();
CONSTANTS.get_or_init(generate_all_constants)
}
pub struct Poseidon2Sponge {
pub state: [GoldilocksField; T],
}
impl Poseidon2Sponge {
pub fn new() -> Self {
Self {
state: [GoldilocksField::ZERO; T],
}
}
#[inline]
fn full_sbox(&mut self) {
for s in self.state.iter_mut() {
*s = s.sbox();
}
}
#[inline]
fn partial_sbox(&mut self) {
self.state[0] = self.state[0].sbox();
}
fn external_linear(&mut self) {
let sum = self
.state
.iter()
.fold(GoldilocksField::ZERO, |a, &b| a.add(b));
for s in self.state.iter_mut() {
*s = s.add(sum); }
}
fn internal_linear(&mut self) {
let sum = self
.state
.iter()
.fold(GoldilocksField::ZERO, |a, &b| a.add(b));
for (i, s) in self.state.iter_mut().enumerate() {
*s = GoldilocksField(DIAG[i]).mul(*s).add(sum);
}
}
pub fn permutation(&mut self) {
let constants = cached_round_constants();
let mut ci = 0;
for _ in 0..R_F / 2 {
for s in self.state.iter_mut() {
*s = s.add(constants[ci]);
ci += 1;
}
self.full_sbox();
self.external_linear();
}
for _ in 0..R_P {
self.state[0] = self.state[0].add(constants[ci]);
ci += 1;
self.partial_sbox();
self.internal_linear();
}
for _ in 0..R_F / 2 {
for s in self.state.iter_mut() {
*s = s.add(constants[ci]);
ci += 1;
}
self.full_sbox();
self.external_linear();
}
debug_assert_eq!(ci, constants.len());
}
}
pub struct Poseidon2Hasher {
state: Poseidon2Sponge,
absorbed: usize,
}
impl Poseidon2Hasher {
pub fn new() -> Self {
Self {
state: Poseidon2Sponge::new(),
absorbed: 0,
}
}
pub fn absorb(&mut self, elements: &[GoldilocksField]) {
for &elem in elements {
if self.absorbed == RATE {
self.state.permutation();
self.absorbed = 0;
}
self.state.state[self.absorbed] = self.state.state[self.absorbed].add(elem);
self.absorbed += 1;
}
}
pub fn absorb_bytes(&mut self, data: &[u8]) {
const BYTES_PER_ELEM: usize = 7;
let mut elements = Vec::with_capacity(data.len() / BYTES_PER_ELEM + 2);
for chunk in data.chunks(BYTES_PER_ELEM) {
let mut buf = [0u8; 8];
buf[..chunk.len()].copy_from_slice(chunk);
elements.push(GoldilocksField::new(u64::from_le_bytes(buf)));
}
elements.push(GoldilocksField::new(data.len() as u64));
self.absorb(&elements);
}
pub fn squeeze(&mut self, count: usize) -> Vec<GoldilocksField> {
let mut out = Vec::with_capacity(count);
self.state.permutation();
self.absorbed = 0;
let mut squeezed = 0;
loop {
for &elem in self.state.state[..RATE].iter() {
out.push(elem);
squeezed += 1;
if squeezed == count {
return out;
}
}
self.state.permutation();
}
}
pub fn finalize(mut self) -> GoldilocksField {
self.squeeze(1)[0]
}
pub fn finalize_4(mut self) -> [GoldilocksField; 4] {
let v = self.squeeze(4);
[v[0], v[1], v[2], v[3]]
}
}
pub fn hash_bytes(data: &[u8]) -> [u8; 32] {
let mut hasher = Poseidon2Hasher::new();
hasher.absorb_bytes(data);
let result = hasher.finalize_4();
let mut out = [0u8; 32];
for (i, elem) in result.iter().enumerate() {
out[i * 8..i * 8 + 8].copy_from_slice(&elem.0.to_le_bytes());
}
out
}
pub fn hash_fields(elements: &[GoldilocksField]) -> [GoldilocksField; 4] {
let mut hasher = Poseidon2Hasher::new();
hasher.absorb(elements);
hasher.finalize_4()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_goldilocks_arithmetic() {
let a = GoldilocksField::new(P - 1);
let b = GoldilocksField::ONE;
assert_eq!(a.add(b), GoldilocksField::ZERO);
assert_eq!(GoldilocksField::ZERO.sub(b), a);
let x = GoldilocksField::new(123456789);
assert_eq!(x.mul(GoldilocksField::ONE), x);
assert_eq!(x.mul(GoldilocksField::ZERO), GoldilocksField::ZERO);
let y = GoldilocksField::new(987654321);
assert_eq!(x.mul(y), y.mul(x));
assert_eq!(x.pow(0), GoldilocksField::ONE);
assert_eq!(x.pow(1), x);
assert_eq!(x.pow(3), x.mul(x).mul(x));
assert_eq!(a.mul(a), GoldilocksField::ONE);
}
#[test]
fn test_sbox() {
let x = GoldilocksField::new(42);
assert_eq!(x.sbox(), x.pow(ALPHA));
assert_eq!(GoldilocksField::ZERO.sbox(), GoldilocksField::ZERO);
assert_eq!(GoldilocksField::ONE.sbox(), GoldilocksField::ONE);
let z = GoldilocksField::new(1000);
assert_ne!(z.sbox(), z);
assert_eq!(z.sbox(), z.pow(7));
}
#[test]
fn test_permutation_deterministic() {
let input: [GoldilocksField; T] =
core::array::from_fn(|i| GoldilocksField::new(i as u64 + 1));
let mut s1 = Poseidon2Sponge { state: input };
let mut s2 = Poseidon2Sponge { state: input };
s1.permutation();
s2.permutation();
assert_eq!(s1.state, s2.state);
}
#[test]
fn test_permutation_diffusion() {
let base: [GoldilocksField; T] =
core::array::from_fn(|i| GoldilocksField::new(i as u64 + 100));
let mut s_base = Poseidon2Sponge { state: base };
s_base.permutation();
let mut tweaked = base;
tweaked[0] = tweaked[0].add(GoldilocksField::ONE);
let mut s_tweak = Poseidon2Sponge { state: tweaked };
s_tweak.permutation();
for i in 0..T {
assert_ne!(
s_base.state[i], s_tweak.state[i],
"Element {i} unchanged after input tweak"
);
}
}
#[test]
fn test_hash_bytes_deterministic() {
assert_eq!(hash_bytes(b"hello world"), hash_bytes(b"hello world"));
}
#[test]
fn test_hash_bytes_different_inputs() {
assert_ne!(hash_bytes(b"hello"), hash_bytes(b"world"));
}
#[test]
fn test_absorb_squeeze() {
let elems: Vec<GoldilocksField> =
(0..10).map(|i| GoldilocksField::new(i * 7 + 3)).collect();
let mut h1 = Poseidon2Hasher::new();
h1.absorb(&elems);
let out1 = h1.squeeze(4);
let mut h2 = Poseidon2Hasher::new();
h2.absorb(&elems);
let out2 = h2.squeeze(4);
assert_eq!(out1, out2);
assert!(out1.iter().any(|e| *e != GoldilocksField::ZERO));
}
#[test]
fn test_hash_fields() {
let elems: Vec<GoldilocksField> = (1..=5).map(GoldilocksField::new).collect();
assert_eq!(hash_fields(&elems), hash_fields(&elems));
}
#[test]
fn test_empty_hash() {
let h = hash_bytes(b"");
assert_eq!(h, hash_bytes(b""));
assert_ne!(h, [0u8; 32]);
}
#[test]
fn test_collision_resistance() {
let hashes: Vec<[u8; 32]> = (0u64..20).map(|i| hash_bytes(&i.to_le_bytes())).collect();
for i in 0..hashes.len() {
for j in i + 1..hashes.len() {
assert_ne!(hashes[i], hashes[j], "Collision between inputs {i} and {j}");
}
}
}
}