use std::sync::{Arc, Barrier, Mutex};
type Job = Box<dyn FnOnce() + Send>;
pub(super) struct Pool {
senders: Vec<std::sync::mpsc::Sender<Job>>,
barrier: Arc<Barrier>,
}
impl Pool {
fn new(n: usize) -> Self {
let barrier = Arc::new(Barrier::new(n + 1));
let mut senders = Vec::with_capacity(n);
for _ in 0..n {
let (tx, rx) = std::sync::mpsc::channel::<Job>();
let bar = barrier.clone();
std::thread::spawn(move || {
let _ = crate::sync::affinity::pin_p_core();
super::ensure_amx();
while let Ok(job) = rx.recv() {
job();
bar.wait();
}
});
senders.push(tx);
}
Self { senders, barrier }
}
pub(super) fn len(&self) -> usize {
self.senders.len()
}
pub(super) fn run(&self, jobs: Vec<Job>) {
assert_eq!(jobs.len(), self.senders.len());
for (tx, job) in self.senders.iter().zip(jobs) {
tx.send(job).expect("worker thread panicked");
}
self.barrier.wait();
}
}
static POOL: Mutex<Option<Arc<Pool>>> = Mutex::new(None);
pub(super) fn get_pool(n_threads: usize) -> Arc<Pool> {
let mut guard = POOL.lock().unwrap();
if let Some(ref pool) = *guard {
if pool.len() == n_threads {
return pool.clone();
}
}
let pool = Arc::new(Pool::new(n_threads));
*guard = Some(pool.clone());
pool
}