//! Cell code generation.
//!
//! Takes a parsed `CellDef` and generates all output code:
//! state structs, wrapper, Cell trait, migration, error enum,
//! metadata, methods, and channel type aliases.

use proc_macro2::TokenStream;
use quote::{format_ident, quote, ToTokens};
use syn::{Ident, Result, ReturnType, Type};

use super::parse::{CellDef, CellField, CellMethod, MethodVis, MigrateSource, SelfArg};

pub fn generate(cell: &CellDef) -> Result<TokenStream> {
    let state_struct = gen_state_struct(cell);
    let step_state_struct = gen_step_state_struct(cell);
    let wrapper_struct = gen_wrapper_struct(cell);
    let cell_trait_impl = gen_cell_trait_impl(cell);
    let migrate_impl = gen_migrate_impl(cell);
    let error_enum = gen_error_enum(cell);
    let error_aliases = gen_error_aliases(cell);
    let metadata_impl = gen_metadata_impl(cell);
    let methods_impl = gen_methods_impl(cell);
    let channel_types = gen_channel_types(cell);

    Ok(quote! {
        #state_struct
        #step_state_struct
        #wrapper_struct
        #cell_trait_impl
        #migrate_impl
        #error_enum
        #error_aliases
        #metadata_impl
        #methods_impl
        #channel_types
    })
}

fn gen_state_struct(cell: &CellDef) -> TokenStream {
    let state_name = format_ident!("{}State", cell.name);
    let fields: Vec<TokenStream> = cell
        .state_fields
        .iter()
        .map(|f| {
            let name = &f.name;
            let ty = &f.ty;
            quote! { pub #name: #ty, }
        })
        .collect();
    quote! {
        pub struct #state_name {
            #(#fields)*
        }
    }
}

fn gen_step_state_struct(cell: &CellDef) -> TokenStream {
    let step_name = format_ident!("{}StepState", cell.name);
    if cell.step_state_fields.is_empty() {
        return quote! {
            pub struct #step_name;
            impl rs_lang::StepReset for #step_name {
                fn reset(&mut self) {}
            }
        };
    }
    let fields: Vec<TokenStream> = cell
        .step_state_fields
        .iter()
        .map(|f| {
            let name = &f.name;
            let ty = &f.ty;
            quote! { pub #name: #ty, }
        })
        .collect();
    let reset_stmts: Vec<TokenStream> = cell
        .step_state_fields
        .iter()
        .map(|f| step_reset_stmt(f))
        .collect();
    quote! {
        pub struct #step_name {
            #(#fields)*
        }
        impl rs_lang::StepReset for #step_name {
            fn reset(&mut self) {
                #(#reset_stmts)*
            }
        }
    }
}

fn step_reset_stmt(field: &CellField) -> TokenStream {
    let name = &field.name;
    let type_name = extract_type_name(&field.ty);
    match type_name.as_deref() {
        Some("u8") | Some("u16") | Some("u32") | Some("u64") | Some("u128")
        | Some("i8") | Some("i16") | Some("i32") | Some("i64") | Some("i128") => {
            quote! { self.#name = 0; }
        }
        Some("bool") => quote! { self.#name = false; },
        Some("Option") => quote! { self.#name = None; },
        Some("BoundedVec") | Some("BoundedMap") => quote! { self.#name.clear(); },
        Some("AtomicU32") | Some("AtomicU64") => {
            quote! { self.#name.store(0, ::core::sync::atomic::Ordering::SeqCst); }
        }
        _ => quote! { rs_lang::StepReset::reset(&mut self.#name); },
    }
}

fn gen_wrapper_struct(cell: &CellDef) -> TokenStream {
    let name = &cell.name;
    let state_name = format_ident!("{}State", cell.name);
    let step_name = format_ident!("{}StepState", cell.name);
    let state_field_inits: Vec<TokenStream> = cell
        .state_fields
        .iter()
        .map(|f| {
            let fname = &f.name;
            quote! { #fname: Default::default(), }
        })
        .collect();
    let step_field_inits: Vec<TokenStream> = cell
        .step_state_fields
        .iter()
        .map(|f| {
            let fname = &f.name;
            quote! { #fname: Default::default(), }
        })
        .collect();
    let step_init = if cell.step_state_fields.is_empty() {
        quote! { #step_name }
    } else {
        quote! { #step_name { #(#step_field_inits)* } }
    };
    quote! {
        pub struct #name {
            state: #state_name,
            step_state: #step_name,
            __step: u64,
        }

        impl #name {
            /// Create a new cell instance with default state.
            pub fn new() -> Self {
                Self {
                    state: #state_name { #(#state_field_inits)* },
                    step_state: #step_init,
                    __step: 0,
                }
            }
        }
    }
}

fn gen_cell_trait_impl(cell: &CellDef) -> TokenStream {
    let name = &cell.name;
    let cell_name_str = cell.name.to_string();
    let version = cell.version;
    let budget = &cell.budget;
    let heartbeat = &cell.heartbeat;
    quote! {
        impl rs_lang::Cell for #name {
            const NAME: &'static str = #cell_name_str;
            const VERSION: u32 = #version;
            const BUDGET: ::core::time::Duration = #budget;
            const HEARTBEAT: ::core::time::Duration = #heartbeat;
            fn current_step(&self) -> u64 { self.__step }
            fn health_check(&self) -> rs_lang::HealthStatus {
                rs_lang::HealthStatus::Healthy
            }
            fn reset_step_state(&mut self) {
                rs_lang::StepReset::reset(&mut self.step_state);
            }
        }
    }
}

fn gen_migrate_impl(cell: &CellDef) -> TokenStream {
    let migrate = match &cell.migrate {
        Some(m) => m,
        None => return TokenStream::new(),
    };
    let state_name = format_ident!("{}State", cell.name);
    let old_type = match &migrate.from_version {
        MigrateSource::Version(n) => {
            let old_ident = format_ident!("{}StateV{}", cell.name, n);
            quote! { #old_ident }
        }
        MigrateSource::Path(path) => quote! { #path },
    };
    let field_inits: Vec<TokenStream> = migrate
        .field_mappings
        .iter()
        .map(|fm| {
            let name = &fm.name;
            let expr = &fm.expr;
            quote! { #name: #expr, }
        })
        .collect();
    quote! {
        impl rs_lang::MigrateFrom<#old_type> for #state_name {
            fn migrate(old: #old_type) -> Self {
                Self { #(#field_inits)* }
            }
        }
    }
}

fn gen_error_enum(cell: &CellDef) -> TokenStream {
    let error_name = format_ident!("{}Error", cell.name);
    let mut variants = collect_error_variants(cell);
    let has_async_deadline = cell.methods.iter().any(|m| m.deadline.is_some());
    if has_async_deadline && !variants.contains(&"Timeout".to_string()) {
        variants.push("Timeout".to_string());
    }
    if variants.is_empty() {
        return quote! {
            #[derive(Debug)]
            pub enum #error_name {}
        };
    }
    let variant_idents: Vec<Ident> = variants.iter().map(|v| format_ident!("{}", v)).collect();
    let timeout_impl = if has_async_deadline {
        quote! {
            impl From<rs_lang::Timeout> for #error_name {
                fn from(_: rs_lang::Timeout) -> Self { #error_name::Timeout }
            }
        }
    } else {
        TokenStream::new()
    };
    quote! {
        #[derive(Debug)]
        pub enum #error_name {
            #(#variant_idents,)*
        }
        #timeout_impl
    }
}

fn gen_error_aliases(_cell: &CellDef) -> TokenStream {
    // The spec says `Error::Variant` resolves to `{CellName}Error::Variant`
    // inside cell methods. This is achieved by rewriting `Error` references
    // in the method token stream to `{CellName}Error` during codegen.
    //
    // Token-stream rewriting happens in gen_single_method — the error enum
    // name is substituted wherever `Error` appears in a path position.
    TokenStream::new()
}

fn collect_error_variants(cell: &CellDef) -> Vec<String> {
    let mut variants = Vec::new();
    for method in &cell.methods {
        scan_for_error_variants(&method.body, &mut variants);
    }
    let mut seen = std::collections::HashSet::new();
    variants.retain(|v| seen.insert(v.clone()));
    variants
}

fn scan_for_error_variants(stream: &TokenStream, variants: &mut Vec<String>) {
    use proc_macro2::TokenTree;
    let tokens: Vec<TokenTree> = stream.clone().into_iter().collect();
    let len = tokens.len();
    for i in 0..len {
        if let TokenTree::Ident(ident) = &tokens[i] {
            if ident == "Error" && i + 3 < len {
                if let (
                    TokenTree::Punct(p1),
                    TokenTree::Punct(p2),
                    TokenTree::Ident(variant),
                ) = (&tokens[i + 1], &tokens[i + 2], &tokens[i + 3])
                {
                    if p1.as_char() == ':' && p2.as_char() == ':' {
                        variants.push(variant.to_string());
                    }
                }
            }
        }
        if let TokenTree::Group(group) = &tokens[i] {
            scan_for_error_variants(&group.stream(), variants);
        }
    }
}

fn gen_metadata_impl(cell: &CellDef) -> TokenStream {
    let name = &cell.name;
    let pub_methods: Vec<&CellMethod> = cell
        .methods
        .iter()
        .filter(|m| m.vis == MethodVis::Public)
        .collect();
    let sig_exprs: Vec<TokenStream> = pub_methods
        .iter()
        .map(|m| {
            let method_name = m.name.to_string();
            let arg_strs: Vec<String> = m
                .args
                .iter()
                .map(|a| {
                    let ty = &a.ty;
                    quote! { #ty }.to_string()
                })
                .collect();
            let ret_str = match &m.ret {
                ReturnType::Default => "()".to_string(),
                ReturnType::Type(_, ty) => quote! { #ty }.to_string(),
            };
            let deadline_expr = match &m.deadline {
                Some(d) => quote! { Some(#d) },
                None => quote! { None },
            };
            let args_array = if arg_strs.is_empty() {
                quote! { &[] }
            } else {
                let strs: Vec<TokenStream> = arg_strs.iter().map(|s| quote! { #s }).collect();
                quote! { &[#(#strs),*] }
            };
            quote! {
                rs_lang::FunctionSignature {
                    name: #method_name,
                    args: #args_array,
                    ret: #ret_str,
                    deadline: #deadline_expr,
                }
            }
        })
        .collect();
    quote! {
        impl rs_lang::CellMetadata for #name {
            fn interface() -> &'static [rs_lang::FunctionSignature] {
                &[#(#sig_exprs),*]
            }
        }
    }
}

fn gen_methods_impl(cell: &CellDef) -> TokenStream {
    let name = &cell.name;
    let error_name = format_ident!("{}Error", cell.name);
    let method_fns: Vec<TokenStream> = cell
        .methods
        .iter()
        .map(|m| gen_single_method(m, &error_name))
        .collect();
    quote! {
        impl #name {
            #(#method_fns)*
        }
    }
}

fn gen_single_method(method: &CellMethod, error_name: &Ident) -> TokenStream {
    let vis = match method.vis {
        MethodVis::Public => quote! { pub },
        MethodVis::Private => quote! {},
    };
    let fn_name = &method.name;
    let attrs = &method.attrs;
    let self_param = match method.self_arg {
        SelfArg::Ref => quote! { &self, },
        SelfArg::RefMut => quote! { &mut self, },
        SelfArg::None => quote! {},
    };
    let params: Vec<TokenStream> = method
        .args
        .iter()
        .map(|a| {
            let name = &a.name;
            let ty = &a.ty;
            quote! { #name: #ty }
        })
        .collect();
    let ret = rewrite_error_ident(&method.ret.to_token_stream(), error_name);
    let body = rewrite_error_ident(&method.body, error_name);
    if method.is_async {
        if let Some(deadline) = &method.deadline {
            quote! {
                #(#attrs)*
                #vis fn #fn_name(#self_param #(#params),*) #ret {
                    rs_lang::runtime::with_deadline(#deadline, async move #body)
                }
            }
        } else {
            quote! {
                #(#attrs)*
                #vis async fn #fn_name(#self_param #(#params),*) #ret #body
            }
        }
    } else {
        quote! {
            #(#attrs)*
            #vis fn #fn_name(#self_param #(#params),*) #ret #body
        }
    }
}

fn gen_channel_types(cell: &CellDef) -> TokenStream {
    let name = &cell.name;
    let mut output = TokenStream::new();
    if let Some(input_ch) = &cell.input_channel {
        let alias = format_ident!("{}Input", name);
        let ty = &input_ch.ty;
        output.extend(quote! { pub type #alias = #ty; });
    }
    if let Some(output_ch) = &cell.output_channel {
        let alias = format_ident!("{}Output", name);
        let ty = &output_ch.ty;
        output.extend(quote! { pub type #alias = #ty; });
    }
    output
}

fn extract_type_name(ty: &Type) -> Option<String> {
    match ty {
        Type::Path(type_path) => {
            let last_seg = type_path.path.segments.last()?;
            Some(last_seg.ident.to_string())
        }
        _ => None,
    }
}

/// Rewrite cell-specific names in a token stream.
///
/// - `Error` → `{CellName}Error` (in any position)
/// - `Result<T>` → `core::result::Result<T, {CellName}Error>` (1-arg Result)
fn rewrite_error_ident(stream: &TokenStream, error_name: &Ident) -> TokenStream {
    use proc_macro2::TokenTree;
    let mut out = TokenStream::new();
    let tokens: Vec<TokenTree> = stream.clone().into_iter().collect();
    let len = tokens.len();
    let mut i = 0;
    while i < len {
        match &tokens[i] {
            TokenTree::Ident(ident) if ident == "Error" => {
                out.extend(quote! { #error_name });
            }
            TokenTree::Ident(ident) if ident == "Result" => {
                // Check if followed by `<T>` (angle-bracket group).
                // proc_macro2 doesn't parse `<T>` as a group, so Result<T>
                // appears as: Ident("Result") Punct('<') ... Punct('>').
                // We rewrite the entire sequence.
                if i + 1 < len {
                    if let TokenTree::Punct(p) = &tokens[i + 1] {
                        if p.as_char() == '<' {
                            // Collect tokens until matching '>'
                            let mut depth = 1;
                            let mut inner_tokens = TokenStream::new();
                            let mut j = i + 2;
                            while j < len && depth > 0 {
                                match &tokens[j] {
                                    TokenTree::Punct(p) if p.as_char() == '<' => {
                                        depth += 1;
                                        inner_tokens.extend(core::iter::once(tokens[j].clone()));
                                    }
                                    TokenTree::Punct(p) if p.as_char() == '>' => {
                                        depth -= 1;
                                        if depth > 0 {
                                            inner_tokens.extend(core::iter::once(tokens[j].clone()));
                                        }
                                    }
                                    other => {
                                        inner_tokens.extend(core::iter::once(other.clone()));
                                    }
                                }
                                j += 1;
                            }
                            let rewritten_inner = rewrite_error_ident(&inner_tokens, error_name);
                            out.extend(quote! {
                                core::result::Result<#rewritten_inner, #error_name>
                            });
                            i = j;
                            continue;
                        }
                    }
                }
                // Bare `Result` without angle brackets — pass through
                out.extend(core::iter::once(tokens[i].clone()));
            }
            TokenTree::Group(group) => {
                let rewritten = rewrite_error_ident(&group.stream(), error_name);
                let mut new_group =
                    proc_macro2::Group::new(group.delimiter(), rewritten);
                new_group.set_span(group.span());
                out.extend(core::iter::once(TokenTree::Group(new_group)));
            }
            other => {
                out.extend(core::iter::once(other.clone()));
            }
        }
        i += 1;
    }
    out
}

Dimensions

rs/macros/src/registers/codegen.rs
trident/benches/references/std/compiler/codegen.rs

Local Graph