use jali::encoding;
use jali::ntt;
use jali::ring::RingElement;
use jali::sample;
use jali_wgsl::GpuContext;
use nebu::Goldilocks;
use std::hint::black_box;
use std::time::Instant;
#[derive(Clone, Copy, PartialEq)]
enum Backend {
Cpu,
Gpu,
}
impl std::fmt::Display for Backend {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Backend::Cpu => write!(f, "cpu"),
Backend::Gpu => write!(f, "gpu"),
}
}
}
struct Ctx {
gpu: Option<GpuContext>,
forced: Option<Backend>,
}
impl Ctx {
fn new(forced: Option<Backend>) -> Self {
let gpu = if forced == Some(Backend::Cpu) {
None
} else {
GpuContext::new()
};
if forced == Some(Backend::Gpu) && gpu.is_none() {
eprintln!("error: --gpu requested but no GPU adapter available");
std::process::exit(1);
}
Self { gpu, forced }
}
fn backend(&self) -> Backend {
if self.forced == Some(Backend::Cpu) {
return Backend::Cpu;
}
if self.gpu.is_some() {
Backend::Gpu
} else {
Backend::Cpu
}
}
fn gpu(&self) -> &GpuContext {
self.gpu.as_ref().unwrap()
}
}
fn parse_backend_flag(args: &[String]) -> (Option<Backend>, Vec<String>) {
let mut forced = None;
let mut rest = Vec::new();
for a in args {
match a.as_str() {
"--gpu" => forced = Some(Backend::Gpu),
"--cpu" => forced = Some(Backend::Cpu),
_ => rest.push(a.clone()),
}
}
(forced, rest)
}
fn to_lohi(g: Goldilocks) -> (u32, u32) {
let v = g.as_u64();
(v as u32, (v >> 32) as u32)
}
fn from_lohi(lo: u32, hi: u32) -> Goldilocks {
Goldilocks::new((hi as u64) << 32 | lo as u64)
}
fn ring_to_gpu(elem: &RingElement) -> Vec<(u32, u32)> {
elem.active_coeffs().iter().map(|g| to_lohi(*g)).collect()
}
fn gpu_to_ring(data: &[(u32, u32)], n: usize) -> RingElement {
let coeffs: Vec<Goldilocks> = data.iter().map(|&(lo, hi)| from_lohi(lo, hi)).collect();
RingElement::from_coeffs(&coeffs, n)
}
fn print_usage() {
eprintln!("jali โ polynomial ring arithmetic R_q = F_p[x]/(x^n+1)");
eprintln!();
eprintln!("usage:");
eprintln!(" jali calc add <n> <a0,a1,...> <b0,b1,...> โ add two polynomials");
eprintln!(" jali calc mul <n> <a0,a1,...> <b0,b1,...> โ multiply two polynomials");
eprintln!(" jali ntt forward <n> <a0,a1,...> โ forward NTT");
eprintln!(" jali ntt inverse <n> <a0,a1,...> โ inverse NTT");
eprintln!(" jali sample uniform <seed> <n> โ sample uniform polynomial");
eprintln!(" jali sample ternary <seed> <n> โ sample ternary polynomial");
eprintln!(" jali sample cbd <seed> <n> <eta> โ sample CBD polynomial");
eprintln!(" jali bench [n] [iters] โ benchmark ring ops");
eprintln!(" jali help โ show this message");
eprintln!();
eprintln!("flags:");
eprintln!(" --gpu force GPU backend");
eprintln!(" --cpu force CPU backend");
eprintln!(" (default: GPU if available, else CPU)");
}
fn parse_poly(s: &str, n: usize) -> RingElement {
let mut elem = RingElement::new(n);
for (i, tok) in s.split(',').enumerate() {
if i >= n {
break;
}
let v: u64 = tok.trim().parse().unwrap_or(0);
elem.coeffs[i] = Goldilocks::new(v);
}
elem
}
fn print_poly(elem: &RingElement) {
let n = elem.n;
let mut first = true;
for i in 0..n {
let v = elem.coeffs[i].as_u64();
if !first {
print!(",");
}
print!("{}", v);
first = false;
}
println!();
}
fn print_timed(backend: Backend, elapsed: std::time::Duration) {
let us = elapsed.as_nanos() as f64 / 1000.0;
if us < 1000.0 {
eprintln!("\x1b[90m[{backend} {us:.0}us]\x1b[0m");
} else {
eprintln!("\x1b[90m[{backend} {:.2}ms]\x1b[0m", us / 1000.0);
}
}
fn cmd_calc(forced: Option<Backend>, args: &[String]) {
if args.len() < 4 {
eprintln!("error: calc requires: <add|mul> <n> <poly_a> <poly_b>");
return;
}
let op = args[0].as_str();
let n: usize = args[1].parse().unwrap_or(0);
if !n.is_power_of_two() || n > 4096 {
eprintln!("error: n must be a power of 2 <= 4096");
return;
}
let a = parse_poly(&args[2], n);
let b = parse_poly(&args[3], n);
let ctx = Ctx::new(forced);
let backend = ctx.backend();
let t = Instant::now();
let result = match op {
"add" => {
if backend == Backend::Gpu {
let ga = ring_to_gpu(&a);
let gb = ring_to_gpu(&b);
let gr = ctx.gpu().run_ring_add(&ga, &gb, n);
gpu_to_ring(&gr, n)
} else {
a.add(&b)
}
}
"mul" => {
if backend == Backend::Gpu {
let ga = ring_to_gpu(&a);
let gb = ring_to_gpu(&b);
let gr = ctx.gpu().run_ring_mul(&ga, &gb, n);
gpu_to_ring(&gr, n)
} else {
a.mul(&b)
}
}
_ => {
eprintln!("error: unknown calc op '{}'", op);
return;
}
};
print_timed(backend, t.elapsed());
print_poly(&result);
}
fn cmd_ntt(forced: Option<Backend>, args: &[String]) {
if args.len() < 3 {
eprintln!("error: ntt requires: <forward|inverse> <n> <coeffs>");
return;
}
let direction = args[0].as_str();
let n: usize = args[1].parse().unwrap_or(0);
if !n.is_power_of_two() || n > 4096 {
eprintln!("error: n must be a power of 2 <= 4096");
return;
}
let mut elem = parse_poly(&args[2], n);
let ctx = Ctx::new(forced);
let backend = ctx.backend();
let t = Instant::now();
match direction {
"forward" => {
if backend == Backend::Gpu {
eprintln!("\x1b[90m[ntt: using CPU โ standalone NTT not yet on GPU]\x1b[0m");
}
ntt::to_ntt(&mut elem);
print_timed(Backend::Cpu, t.elapsed());
print_poly(&elem);
}
"inverse" => {
if backend == Backend::Gpu {
eprintln!("\x1b[90m[intt: using CPU โ standalone INTT not yet on GPU]\x1b[0m");
}
elem.is_ntt = true;
ntt::from_ntt(&mut elem);
print_timed(Backend::Cpu, t.elapsed());
print_poly(&elem);
}
_ => {
eprintln!("error: unknown ntt direction '{}'", direction);
}
}
}
fn cmd_sample(args: &[String]) {
if args.len() < 3 {
eprintln!("error: sample requires: <uniform|ternary|cbd> <seed> <n> [eta]");
return;
}
let kind = args[0].as_str();
let seed: u64 = args[1].parse().unwrap_or(0);
let n: usize = args[2].parse().unwrap_or(0);
if !n.is_power_of_two() || n > 4096 {
eprintln!("error: n must be a power of 2 <= 4096");
return;
}
let elem = match kind {
"uniform" => sample::sample_uniform(seed, n),
"ternary" => sample::sample_ternary(seed, n),
"cbd" => {
if args.len() < 4 {
eprintln!("error: cbd requires eta parameter");
return;
}
let eta: usize = args[3].parse().unwrap_or(2);
sample::sample_cbd(seed, n, eta)
}
_ => {
eprintln!("error: unknown sample kind '{}'", kind);
return;
}
};
print_poly(&elem);
}
fn cmd_bench(forced: Option<Backend>, args: &[String]) {
let n: usize = if args.is_empty() {
1024
} else {
args[0].parse().unwrap_or(1024)
};
let iters: u64 = if args.len() < 2 {
1000
} else {
args[1].parse().unwrap_or(1000)
};
if !n.is_power_of_two() || n > 4096 {
eprintln!("error: n must be a power of 2 <= 4096");
return;
}
let ctx = Ctx::new(forced);
let backend = ctx.backend();
eprintln!("\x1b[90m[backend: {backend}, n={n}, iters={iters}]\x1b[0m");
let a = sample::sample_uniform(1, n);
let b = sample::sample_uniform(2, n);
eprintln!();
eprintln!("\x1b[36m--- CPU ---\x1b[0m");
let start = Instant::now();
for _ in 0..iters {
black_box(black_box(&a).add(black_box(&b)));
}
let elapsed = start.elapsed();
eprintln!(
"ring_add n={}: {:.1} us/op ({} iters)",
n,
elapsed.as_micros() as f64 / iters as f64,
iters
);
let mul_iters = iters.min(100);
let start = Instant::now();
for _ in 0..mul_iters {
black_box(black_box(&a).mul(black_box(&b)));
}
let elapsed = start.elapsed();
eprintln!(
"ring_mul n={}: {:.1} us/op ({} iters)",
n,
elapsed.as_micros() as f64 / mul_iters as f64,
mul_iters
);
let start = Instant::now();
for _ in 0..iters {
let mut c = a.clone();
ntt::to_ntt(black_box(&mut c));
}
let elapsed = start.elapsed();
eprintln!(
"ntt_fwd n={}: {:.1} us/op ({} iters)",
n,
elapsed.as_micros() as f64 / iters as f64,
iters
);
let mut buf = vec![0u8; n * 8];
let start = Instant::now();
for _ in 0..iters {
encoding::encode_ring(black_box(&a), black_box(&mut buf));
black_box(encoding::decode_ring(black_box(&buf), n));
}
let elapsed = start.elapsed();
eprintln!(
"enc_rt n={}: {:.1} us/op ({} iters)",
n,
elapsed.as_micros() as f64 / iters as f64,
iters
);
if backend == Backend::Gpu {
eprintln!();
eprintln!("\x1b[33m--- GPU ---\x1b[0m");
let ga = ring_to_gpu(&a);
let gb = ring_to_gpu(&b);
let start = Instant::now();
for _ in 0..iters {
black_box(ctx.gpu().run_ring_add(black_box(&ga), black_box(&gb), n));
}
let elapsed = start.elapsed();
eprintln!(
"ring_add n={}: {:.1} us/op ({} iters)",
n,
elapsed.as_micros() as f64 / iters as f64,
iters
);
let gpu_mul_iters = iters.min(100);
let start = Instant::now();
for _ in 0..gpu_mul_iters {
black_box(ctx.gpu().run_ring_mul(black_box(&ga), black_box(&gb), n));
}
let elapsed = start.elapsed();
eprintln!(
"ring_mul n={}: {:.1} us/op ({} iters)",
n,
elapsed.as_micros() as f64 / gpu_mul_iters as f64,
gpu_mul_iters
);
let start = Instant::now();
for _ in 0..iters {
black_box(
ctx.gpu()
.run_ring_pointwise_mul(black_box(&ga), black_box(&gb), n),
);
}
let elapsed = start.elapsed();
eprintln!(
"pointwise n={}: {:.1} us/op ({} iters)",
n,
elapsed.as_micros() as f64 / iters as f64,
iters
);
}
}
fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
print_usage();
return;
}
let cmd = args[1].clone();
let (forced, rest) = parse_backend_flag(&args[2..]);
match cmd.as_str() {
"calc" => cmd_calc(forced, &rest),
"ntt" => cmd_ntt(forced, &rest),
"sample" => cmd_sample(&rest),
"bench" => cmd_bench(forced, &rest),
"help" | "--help" | "-h" => print_usage(),
_ => {
eprintln!("error: unknown command '{}'\n", cmd);
print_usage();
}
}
}