use proc_macro::TokenStream; use quote::quote; use std::ops::Deref; use syn::__private::Span; use syn::punctuated::Punctuated; use syn::token::{Colon, Comma, PathSep, Plus}; use syn::{ parse_macro_input, parse_quote, parse_quote_spanned, Data, DeriveInput, Expr, GenericParam, Generics, Ident, Lit, Meta, Path, PathSegment, Token, TraitBound, TraitBoundModifier, TypeParam, TypeParamBound, }; // Struct name: Ads1256PollStateful, Ads1256PollPublish, Ads1256PollStatePub #[proc_macro_derive(PollVariants, attributes(value_type))] pub fn poll_variant_macro(input: TokenStream) -> TokenStream { // ----- Parse input ---------------------------------- let input = parse_macro_input!(input as DeriveInput); let attrs = &input.attrs; let vis = &input.vis; let ident = &input.ident; let data = &input.data; let og_generics = &input.generics; let (og_impl_generics, og_type_generics, og_where_clause) = &og_generics.split_for_impl(); // ----- Check that item the macro was used on is a struct ---------------------------------- match data { Data::Struct(struct_data) => struct_data, Data::Enum(_) => panic!("Stateful struct cannot be derived from an enum."), Data::Union(_) => panic!("Stateful struct cannot be derived from a union."), }; // ----- Extract value type attribute ---------------------------------- const VALUE_T_NAME: &str = "value_type"; let mut value_type_ident: Option = None; for attribute in attrs.iter() { // if the attribute is a named value if let Meta::NameValue(meta) = &attribute.meta { // if the name of the attribute is value_t if meta.path.segments[0].ident == VALUE_T_NAME { // if the value of the attribute is a literal if let Expr::Lit(lit) = &meta.value { // if the literal is a string if let Lit::Str(lit) = &lit.lit { value_type_ident = Some(Ident::new(lit.value().deref(), Span::call_site())); } else { panic!("{VALUE_T_NAME} must be set with a string literal.") } } else { panic!("{VALUE_T_NAME} must be set with a literal.") } } else { continue; } } else { continue; } } let value_type_ident = if let Some(val) = value_type_ident { val } else { panic!("Need attribute {VALUE_T_NAME}: #[{VALUE_T_NAME} = \"type\"]") }; // ----- Add publisher generics ---------------------------------- // MutexT const MUTEX_T_NAME: &str = "PublishMutexT"; let mutex_t_ident = Ident::new(MUTEX_T_NAME, Span::call_site()); const CAPACITY_NAME: &str = "CAPACITY"; let capacity_ident = Ident::new(CAPACITY_NAME, Span::call_site()); const NUM_SUBS_NAME: &str = "NUM_SUBS"; let num_subs_ident = Ident::new(NUM_SUBS_NAME, Span::call_site()); let mut num_lifetimes: usize = 0; let mut num_types: usize = 0; let mut num_const: usize = 0; let mut has_mutex_t = false; for param in og_generics.params.iter() { match param { GenericParam::Lifetime(_) => num_lifetimes += 1, GenericParam::Type(param) => { num_types += 1; // If the generic parameter is MutexT if param.ident == MUTEX_T_NAME { has_mutex_t = true; } }, GenericParam::Const(_) => num_const += 1, } } let mut publish_generics = og_generics.clone(); // If MutexT is not a generic parameter, add it if !has_mutex_t { let mutex_t_param: GenericParam = parse_quote!(#mutex_t_ident: embassy_sync::blocking_mutex::raw::RawMutex); let num_generics = num_lifetimes + num_types + num_const; // If there are generics if num_generics > 0 { // If all generics are lifetimes if num_lifetimes == num_generics { // Add MutexT after the lifetimes publish_generics.params.push(mutex_t_param); // If no generics are lifetimes } else if num_lifetimes == 0 { // Insert MutexT at the front publish_generics.params.insert(0, mutex_t_param); // If there are lifetimes and other generics } else { // Insert MutexT after the lifetimes publish_generics.params.insert(num_lifetimes, mutex_t_param); } // If there are no generics } else { // Add MutexT publish_generics.params.push(mutex_t_param); } } // const generics let capacity_param: GenericParam = parse_quote!(const #capacity_ident: usize); let num_subs_param: GenericParam = parse_quote!(const #num_subs_ident: usize); publish_generics.params.push(capacity_param); publish_generics.params.push(num_subs_param); let (publ_impl_generics, publ_type_generics, publ_where_clause) = &publish_generics.split_for_impl(); let pubsub_error_path: Path = parse_quote!(embassy_sync::pubsub::Error); let pubsub_sub_path: Path = parse_quote!(embassy_sync::pubsub::Subscriber); let stateful_variant_ident = Ident::new(format!("{ident}Stateful").deref(), ident.span()); let publish_variant_ident = Ident::new(format!("{ident}Publish").deref(), ident.span()); let state_pub_variant_ident = Ident::new(format!("{ident}StatePub").deref(), ident.span()); let poll_path: Path = parse_quote!(physical_node::transducer::input::Poll); let stateful_path: Path = parse_quote!(physical_node::transducer::Stateful); let publish_path: Path = parse_quote!(physical_node::transducer::Publish); let state_path: Path = parse_quote!(physical_node::transducer::State); let publisher_path: Path = parse_quote!(physical_node::transducer::Publisher); let cellview_path: Path = parse_quote!(physical_node::cell::CellView); let expanded = quote! { // ----- Stateful struct ---------------------------------- #vis struct #stateful_variant_ident #og_generics #og_where_clause { pub poll: #ident #og_type_generics, pub state: #state_path<#value_type_ident>, } // ----- Stateful impls ---------------------------------- impl #og_impl_generics #poll_path for #stateful_variant_ident #og_type_generics #og_where_clause { type Value = #value_type_ident; #[inline(always)] async fn poll(&self) -> Self::Value { let value = self.poll.poll().await; self.state.update(value); value } } impl #og_impl_generics #stateful_path for #stateful_variant_ident #og_type_generics #og_where_clause { type Value = #value_type_ident; #[inline(always)] fn state_cell(&self) -> #cellview_path { self.state.state_cell() } #[inline(always)] fn state(&self) -> Self::Value { self.state.state() } } // ----- Publish struct ---------------------------------- #[cfg(feature = "embassy-sync")] #vis struct #publish_variant_ident #publish_generics #publ_where_clause { pub poll: #ident #og_type_generics, pub publisher: #publisher_path<#value_type_ident, #mutex_t_ident, #capacity_ident, #num_subs_ident>, } // ----- Publish impl ---------------------------------- #[cfg(feature = "embassy-sync")] impl #publ_impl_generics #poll_path for #publish_variant_ident #publ_type_generics #publ_where_clause { type Value = #value_type_ident; #[inline(always)] async fn poll(&self) -> Self::Value { let value = self.poll.poll().await; self.publisher.update(value); value } } #[cfg(feature = "embassy-sync")] impl #publ_impl_generics #publish_path<#capacity_ident, #num_subs_ident> for #publish_variant_ident #publ_type_generics #publ_where_clause { type Value = #value_type_ident; type Mutex = #mutex_t_ident; #[inline(always)] fn subscribe( &self, ) -> Result<#pubsub_sub_path, #pubsub_error_path> { self.publisher.subscribe() } } // ----- StatePub struct ---------------------------------- #[cfg(feature = "embassy-sync")] #vis struct #state_pub_variant_ident #publish_generics #publ_where_clause { pub poll: #ident #og_type_generics, pub state: #state_path<#value_type_ident>, pub publisher: #publisher_path<#value_type_ident, #mutex_t_ident, #capacity_ident, #num_subs_ident>, } #[cfg(feature = "embassy-sync")] impl #publ_impl_generics #poll_path for #state_pub_variant_ident #publ_type_generics #publ_where_clause { type Value = #value_type_ident; #[inline] async fn poll(&self) -> Self::Value { let value = self.poll.poll().await; self.state.update(value); self.publisher.update(value); value } } #[cfg(feature = "embassy-sync")] impl #publ_impl_generics #stateful_path for #state_pub_variant_ident #publ_type_generics #publ_where_clause { type Value = #value_type_ident; #[inline(always)] fn state_cell(&self) -> #cellview_path { self.state.state_cell() } #[inline(always)] fn state(&self) -> Self::Value { self.state.state() } } #[cfg(feature = "embassy-sync")] impl #publ_impl_generics #publish_path<#capacity_ident, #num_subs_ident> for #state_pub_variant_ident #publ_type_generics #publ_where_clause { type Value = #value_type_ident; type Mutex = #mutex_t_ident; #[inline(always)] fn subscribe( &self, ) -> Result<#pubsub_sub_path, #pubsub_error_path> { self.publisher.subscribe() } } }; TokenStream::from(expanded) }