use burn::prelude::*;
use burn::tensor::IndexingUpdateOp;
pub fn scatter_add<B: Backend>(
src_features: Tensor<B, 2>,
dst_indices: Tensor<B, 1, Int>,
num_nodes: usize,
) -> Tensor<B, 2> {
let device = src_features.device();
let num_edges = dst_indices.dims()[0];
let d = src_features.dims()[1];
let indices_2d: Tensor<B, 2, Int> = dst_indices.unsqueeze_dim::<2>(1).expand([num_edges, d]);
let output = Tensor::<B, 2>::zeros([num_nodes, d], &device);
output.scatter(0, indices_2d, src_features, IndexingUpdateOp::Add)
}
pub fn neighborhood_softmax<B: Backend>(
edge_scores: Tensor<B, 2>,
dst_indices: Tensor<B, 1, Int>,
num_nodes: usize,
) -> Tensor<B, 2> {
let device = edge_scores.device();
let num_edges = edge_scores.dims()[0];
let indices_2d: Tensor<B, 2, Int> = dst_indices
.clone()
.unsqueeze_dim::<2>(1)
.expand([num_edges, 1]);
let clamped = edge_scores.clamp(-20.0, 20.0);
let exp_scores = clamped.exp();
let zeros = Tensor::<B, 2>::zeros([num_nodes, 1], &device);
let node_sum = zeros.scatter(0, indices_2d, exp_scores.clone(), IndexingUpdateOp::Add);
let edge_sum = node_sum.select(0, dst_indices);
exp_scores / (edge_sum + 1e-10)
}
pub struct BatchedGraph<B: Backend> {
pub node_features: Tensor<B, 2>,
pub src_indices: Tensor<B, 1, Int>,
pub dst_indices: Tensor<B, 1, Int>,
pub edge_types: Tensor<B, 1, Int>,
pub graph_ids: Tensor<B, 1, Int>,
pub num_nodes: usize,
pub num_graphs: usize,
}
pub fn batch_graphs<B: Backend>(
node_features_list: &[Tensor<B, 2>],
src_indices_list: &[Tensor<B, 1, Int>],
dst_indices_list: &[Tensor<B, 1, Int>],
edge_types_list: &[Tensor<B, 1, Int>],
device: &B::Device,
) -> BatchedGraph<B> {
let num_graphs = node_features_list.len();
let mut total_nodes = 0usize;
let mut all_features = Vec::new();
let mut all_src = Vec::new();
let mut all_dst = Vec::new();
let mut all_edge_types = Vec::new();
let mut all_graph_ids = Vec::new();
for i in 0..num_graphs {
let n = node_features_list[i].dims()[0];
let offset = total_nodes as i64;
all_features.push(node_features_list[i].clone());
let offset_tensor =
Tensor::<B, 1, Int>::full([src_indices_list[i].dims()[0]], offset, device);
all_src.push(src_indices_list[i].clone() + offset_tensor.clone());
all_dst.push(dst_indices_list[i].clone() + offset_tensor);
all_edge_types.push(edge_types_list[i].clone());
let graph_id = Tensor::<B, 1, Int>::full([n], i as i64, device);
all_graph_ids.push(graph_id);
total_nodes += n;
}
BatchedGraph {
node_features: Tensor::cat(all_features, 0),
src_indices: Tensor::cat(all_src, 0),
dst_indices: Tensor::cat(all_dst, 0),
edge_types: Tensor::cat(all_edge_types, 0),
graph_ids: Tensor::cat(all_graph_ids, 0),
num_nodes: total_nodes,
num_graphs,
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type B = NdArray;
#[test]
fn scatter_add_basic() {
let device = Default::default();
let src = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], &device);
let dst = Tensor::<B, 1, Int>::from_ints([0, 0, 1], &device);
let result = scatter_add(src, dst, 2);
let data = result.to_data();
assert_eq!(data.as_slice::<f32>().unwrap(), &[4.0, 6.0, 5.0, 6.0]);
}
#[test]
fn neighborhood_softmax_sums_to_one() {
let device = Default::default();
let scores = Tensor::<B, 2>::from_floats([[1.0], [2.0], [3.0], [4.0]], &device);
let dst = Tensor::<B, 1, Int>::from_ints([0, 0, 1, 1], &device);
let result = neighborhood_softmax(scores, dst, 2);
let data = result.to_data();
let vals = data.as_slice::<f32>().unwrap();
let sum_node0 = vals[0] + vals[1];
assert!((sum_node0 - 1.0).abs() < 1e-5, "node 0 sum: {}", sum_node0);
let sum_node1 = vals[2] + vals[3];
assert!((sum_node1 - 1.0).abs() < 1e-5, "node 1 sum: {}", sum_node1);
}
}