use crate::ring::RingElement;
use nebu::Goldilocks;
use nebu::field::P;
struct Xorshift64 {
state: u64,
}
impl Xorshift64 {
fn new(seed: u64) -> Self {
let state = if seed == 0 {
0xDEAD_BEEF_CAFE_BABE
} else {
seed
};
Xorshift64 { state }
}
fn next(&mut self) -> u64 {
let mut x = self.state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.state = x;
x
}
}
pub fn sample_uniform(seed: u64, n: usize) -> RingElement {
let mut rng = Xorshift64::new(seed);
let mut elem = RingElement::new(n);
for i in 0..n {
let mut val = rng.next();
while val >= P {
val = rng.next();
}
elem.coeffs[i] = Goldilocks::new(val);
}
elem
}
pub fn sample_ternary(seed: u64, n: usize) -> RingElement {
let mut rng = Xorshift64::new(seed);
let mut elem = RingElement::new(n);
for i in 0..n {
let val = rng.next() % 3;
elem.coeffs[i] = match val {
0 => Goldilocks::ZERO,
1 => Goldilocks::ONE,
2 => Goldilocks::NEG_ONE, _ => unreachable!(),
};
}
elem
}
pub fn sample_cbd(seed: u64, n: usize, eta: usize) -> RingElement {
assert!(eta > 0 && eta <= 16, "eta must be in [1, 16]");
let mut rng = Xorshift64::new(seed);
let mut elem = RingElement::new(n);
for i in 0..n {
let bits = rng.next();
let mut pos_sum: i64 = 0;
let mut neg_sum: i64 = 0;
for j in 0..eta {
pos_sum += ((bits >> j) & 1) as i64;
neg_sum += ((bits >> (j + eta)) & 1) as i64;
}
let val = pos_sum - neg_sum;
if val >= 0 {
elem.coeffs[i] = Goldilocks::new(val as u64);
} else {
elem.coeffs[i] = Goldilocks::new(P - ((-val) as u64));
}
}
elem
}