diff --git a/macros/node-poll-variants/src/lib.rs b/macros/node-poll-variants/src/lib.rs index 03675f3..490a7af 100644 --- a/macros/node-poll-variants/src/lib.rs +++ b/macros/node-poll-variants/src/lib.rs @@ -1,11 +1,13 @@ use proc_macro::TokenStream; use quote::quote; use std::ops::Deref; -use std::path::Path; use syn::__private::Span; use syn::punctuated::Punctuated; use syn::token::{Colon, PathSep, Plus}; -use syn::{parse_macro_input, Data, DeriveInput, GenericParam, Ident, PathSegment, Token, TraitBound, TraitBoundModifier, TypeParam, TypeParamBound, Meta, Expr, Lit}; +use syn::{ + parse_macro_input, Data, DeriveInput, Expr, GenericParam, Ident, Lit, Meta, Path, PathSegment, + Token, TraitBound, TraitBoundModifier, TypeParam, TypeParamBound, +}; // Struct name: Ads1256PollStateful, Ads1256PollPublish, Ads1256PollStatePub #[proc_macro_derive(PollVariants, attributes(value_type))] @@ -31,10 +33,10 @@ pub fn poll_variant_macro(input: TokenStream) -> TokenStream { let mut value_type_ident: Option = None; for attribute in attrs.iter() { // if the attribute is a named value - if let Meta::NameValue(meta) = &attribute.meta { - // if the name of the attribute is value_t - if meta.path.segments[0].ident == VALUE_T_NAME { - // if the value of the attribute is a literal + if let Meta::NameValue(meta) = &attribute.meta { + // if the name of the attribute is value_t + if meta.path.segments[0].ident == VALUE_T_NAME { + // if the value of the attribute is a literal if let Expr::Lit(lit) = &meta.value { // if the literal is a string if let Lit::Str(lit) = &lit.lit { @@ -45,13 +47,13 @@ pub fn poll_variant_macro(input: TokenStream) -> TokenStream { } else { panic!("{VALUE_T_NAME} must be set with a literal.") } - } else { - continue; - } - } else { - continue; - } - }; + } else { + continue; + } + } else { + continue; + } + } let value_type_ident = if let Some(val) = value_type_ident { val } else { @@ -59,56 +61,95 @@ pub fn poll_variant_macro(input: TokenStream) -> TokenStream { }; // ----- Build publisher generics ---------------------------------- - //TODO: Get rid of all this adding T stuff + // Check if generics contains T + const MUTEX_T_IDENT_STR: &str = "MutexT"; + let mut num_lifetimes: usize = 0; + let mut num_types: usize = 0; + let mut num_const: usize = 0; + let mut has_mutex_t = false; + for param in og_generics.params.iter() { + match param { + GenericParam::Lifetime(_) => num_lifetimes += 1, + GenericParam::Type(param) => { + num_types += 1; - // // Check if generics contains T - // const T_IDENT_STR: &str = "T"; - // const COPY_IDENT_STR: &str = "Copy"; - // let mut first_type_index: Option = None; - // let mut has_t = false; - // for (index, param) in og_generics.params.iter().enumerate() { - // // If the generic parameter is a type - // if let GenericParam::Type(param) = param { - // if first_type_index.is_none() { - // first_type_index = Some(index); - // } - // // If the generic parameter is T - // if param.ident == T_IDENT_STR { - // has_t = true; - // } - // } - // } + // If the generic parameter is T + if param.ident == MUTEX_T_IDENT_STR { + has_mutex_t = true; + } + }, + GenericParam::Const(_) => num_const += 1, + } + } - // let mut generics_with_t = og_generics.clone(); - // // If T is not a generic parameter, add it - // if !has_t { - // let first_type_index = first_type_index.unwrap_or(0); - // let t_ident = Ident::new(T_IDENT_STR, Span::call_site()); - // let copy_ident = Ident::new(COPY_IDENT_STR, Span::call_site()); - // let mut t_bounds: Punctuated = Punctuated::new(); - // - // let t_bound = TraitBound { - // paren_token: None, - // modifier: TraitBoundModifier::None, - // lifetimes: None, - // path: copy_ident.into(), - // }; - // t_bounds.push(t_bound.into()); - // - // let t_param: TypeParam = TypeParam { - // attrs: Vec::new(), - // ident: t_ident, - // colon_token: Some(Colon::default()), - // bounds: t_bounds, - // eq_token: None, - // default: None, - // }; - // generics_with_t - // .params - // .insert(first_type_index, t_param.into()); - // } + let mut publish_generics = og_generics.clone(); + // If MutexT is not a generic parameter, add it + if !has_mutex_t { + let mutex_t_ident = Ident::new(MUTEX_T_IDENT_STR, Span::call_site()); + let mut mutex_t_bounds: Punctuated = Punctuated::new(); + + let mut path: Punctuated = Punctuated::new(); + path.push(Ident::new("embassy_sync", Span::call_site()).into()); + path.push(Ident::new("blocking_mutex", Span::call_site()).into()); + path.push(Ident::new("raw", Span::call_site()).into()); + path.push(Ident::new("RawMutex", Span::call_site()).into()); + let raw_mutex_path = Path { + leading_colon: None, + segments: path, + }; + + let mutex_t_bound = TraitBound { + paren_token: None, + modifier: TraitBoundModifier::None, + lifetimes: None, + path: raw_mutex_path, + }; + mutex_t_bounds.push(mutex_t_bound.into()); + + let mutex_t_param: TypeParam = TypeParam { + attrs: Vec::new(), + ident: mutex_t_ident, + colon_token: Some(Colon::default()), + bounds: mutex_t_bounds, + eq_token: None, + default: None, + }; + + let num_generics = num_lifetimes + num_types + num_const; + // If there are generics + if num_generics > 0 { + // If all generics are lifetimes + if num_lifetimes == num_generics { + // Add MutexT after the lifetimes + publish_generics.params.push(mutex_t_param.into()); + // If no generics are lifetimes + } else if num_lifetimes == 0 { + // Insert MutexT at the front + publish_generics + .params + .insert(0, mutex_t_param.into()); + // If there are lifetimes and other generics + } else { + // Insert MutexT after the lifetimes + publish_generics + .params + .insert(num_lifetimes, mutex_t_param.into()); + } + // If there are no generics + } else { + // Add MutexT + publish_generics.params.push(mutex_t_param.into()); + } + } + // const generics + + + let (publ_impl_generics, publ_type_generics, publ_where_clause) = + &publish_generics.split_for_impl(); let stateful_ident = Ident::new(format!("{ident}Stateful").deref(), ident.span()); + let publish_ident = Ident::new(format!("{ident}Publish").deref(), ident.span()); + let state_pub_ident = Ident::new(format!("{ident}StatePub").deref(), ident.span()); let expanded = quote! { // ----- Stateful struct ---------------------------------- @@ -144,6 +185,11 @@ pub fn poll_variant_macro(input: TokenStream) -> TokenStream { } // ----- Publish struct ---------------------------------- + #[cfg(feature = "embassy-sync")] + #vis struct #publish_ident #publish_generics #publ_where_clause { + pub poll: #ident #og_type_generics, + pub publisher: physical_node::transducer::Publisher<#value_type_ident, MutexT, CAPACITY, NUM_SUBS>, + } // ----- Publish impl ----------------------------------