use super::TIROp;
use std::collections::{BTreeMap, BTreeSet};
pub(crate) fn eliminate_dead_spills(ops: Vec<TIROp>) -> Vec<TIROp> {
let mut write_addrs: BTreeMap<u64, usize> = BTreeMap::new();
let mut read_addrs: BTreeMap<u64, usize> = BTreeMap::new();
for window in ops.windows(4) {
if let (TIROp::Push(addr), TIROp::Swap(1), TIROp::WriteMem(1), TIROp::Pop(1)) =
(&window[0], &window[1], &window[2], &window[3])
{
*write_addrs.entry(*addr).or_insert(0) += 1;
}
}
for window in ops.windows(3) {
if let (TIROp::Push(addr), TIROp::ReadMem(1), TIROp::Pop(1)) =
(&window[0], &window[1], &window[2])
{
*read_addrs.entry(*addr).or_insert(0) += 1;
}
}
let mut pair_addrs: BTreeSet<u64> = BTreeSet::new();
let mut dead_addrs: BTreeSet<u64> = BTreeSet::new();
for (addr, wc) in &write_addrs {
let rc = read_addrs.get(addr).copied().unwrap_or(0);
if *wc == 1 && rc == 1 {
pair_addrs.insert(*addr);
} else if rc == 0 {
dead_addrs.insert(*addr);
}
}
if pair_addrs.is_empty() && dead_addrs.is_empty() {
return ops;
}
let mut out: Vec<TIROp> = Vec::with_capacity(ops.len());
let mut i = 0;
while i < ops.len() {
if i + 3 < ops.len() {
if let (TIROp::Push(addr), TIROp::Swap(1), TIROp::WriteMem(1), TIROp::Pop(1)) =
(&ops[i], &ops[i + 1], &ops[i + 2], &ops[i + 3])
{
if pair_addrs.contains(addr) {
i += 4; continue;
}
if dead_addrs.contains(addr) {
out.push(TIROp::Pop(1));
i += 4;
continue;
}
}
}
if i + 2 < ops.len() {
if let (TIROp::Push(addr), TIROp::ReadMem(1), TIROp::Pop(1)) =
(&ops[i], &ops[i + 1], &ops[i + 2])
{
if pair_addrs.contains(addr) {
i += 3; continue;
}
}
}
out.push(ops[i].clone());
i += 1;
}
out
}