Initial node implementation (#4)
Reviewed-on: #4 Co-authored-by: Zack <zack@bfpower.io> Co-committed-by: Zack <zack@bfpower.io>
This commit is contained in:
28
macros/node-poll-variants/Cargo.toml
Normal file
28
macros/node-poll-variants/Cargo.toml
Normal file
@ -0,0 +1,28 @@
|
||||
[package]
|
||||
name = "node-poll-variants"
|
||||
description = "Macros for physical nodes."
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
repository.workspace = true
|
||||
readme.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[[test]]
|
||||
name = "tests"
|
||||
path = "tests/test_build.rs"
|
||||
|
||||
[dependencies.syn]
|
||||
workspace = true
|
||||
[dependencies.quote]
|
||||
workspace = true
|
||||
|
||||
[dev-dependencies.trybuild]
|
||||
workspace = true
|
||||
[dev-dependencies.physical-node]
|
||||
path = "../../node"
|
||||
features = ["embassy-sync"]
|
||||
[dev-dependencies.embassy-sync]
|
||||
workspace = true
|
294
macros/node-poll-variants/src/lib.rs
Normal file
294
macros/node-poll-variants/src/lib.rs
Normal file
@ -0,0 +1,294 @@
|
||||
use proc_macro::TokenStream;
|
||||
use quote::quote;
|
||||
use std::ops::Deref;
|
||||
use quote::__private::parse_spanned;
|
||||
use syn::__private::{str, Span};
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::token::{Colon, Comma, PathSep, Plus};
|
||||
use syn::{parse_macro_input, parse_quote, parse_quote_spanned, parse_str, Data, DeriveInput, Expr, GenericParam, Generics, Ident, Lit, LitStr, Meta, Path, PathSegment, Token, TraitBound, TraitBoundModifier, TypeParam, TypeParamBound, Type};
|
||||
|
||||
#[proc_macro_derive(PollVariants, attributes(value_type, error_type))]
|
||||
pub fn poll_variant_macro(input: TokenStream) -> TokenStream {
|
||||
// ----- Parse input ----------------------------------
|
||||
let input = parse_macro_input!(input as DeriveInput);
|
||||
let attrs = &input.attrs;
|
||||
let vis = &input.vis;
|
||||
let ident = &input.ident;
|
||||
let data = &input.data;
|
||||
let og_generics = &input.generics;
|
||||
let (og_impl_generics, og_type_generics, og_where_clause) = &og_generics.split_for_impl();
|
||||
|
||||
// ----- Check that item the macro was used on is a struct ----------------------------------
|
||||
match data {
|
||||
Data::Struct(struct_data) => struct_data,
|
||||
Data::Enum(_) => panic!("Stateful struct cannot be derived from an enum."),
|
||||
Data::Union(_) => panic!("Stateful struct cannot be derived from a union."),
|
||||
};
|
||||
|
||||
// ----- Extract attribute information ----------------------------------
|
||||
const VALUE_T_NAME: &str = "value_type";
|
||||
const ERROR_T_NAME: &str = "error_type";
|
||||
let mut value_type: Option<Type> = None;
|
||||
let mut error_type: Option<Type> = None;
|
||||
for attribute in attrs.iter() {
|
||||
// if the attribute is a named value
|
||||
if let Meta::NameValue(meta) = &attribute.meta {
|
||||
// if the name of the attribute is value_type
|
||||
if meta.path.segments[0].ident == VALUE_T_NAME {
|
||||
// if the value of the attribute is a literal
|
||||
if let Expr::Lit(lit) = &meta.value {
|
||||
// if the literal is a string
|
||||
if let Lit::Str(lit) = &lit.lit {
|
||||
let span = lit.span();
|
||||
let string = lit.token().to_string();
|
||||
let string = string.trim_matches('"').to_string();
|
||||
let _value_type: Type = parse_str(string.deref()).unwrap();
|
||||
let _value_type: Type = parse_quote_spanned!(span=> #_value_type);
|
||||
|
||||
value_type = Some(_value_type);
|
||||
|
||||
} else {
|
||||
panic!("{VALUE_T_NAME} must be set with a string literal.")
|
||||
}
|
||||
} else {
|
||||
panic!("{VALUE_T_NAME} must be set with a literal.")
|
||||
}
|
||||
}
|
||||
|
||||
// if the name of the attribute is error_type
|
||||
if meta.path.segments[0].ident == ERROR_T_NAME {
|
||||
// if the value of the attribute is a literal
|
||||
if let Expr::Lit(lit) = &meta.value {
|
||||
// if the literal is a string
|
||||
if let Lit::Str(lit) = &lit.lit {
|
||||
let span = lit.span();
|
||||
let string = lit.token().to_string();
|
||||
let string = string.trim_matches('"').to_string();
|
||||
let _error_type: Type = parse_str(string.deref()).unwrap();
|
||||
let _error_type: Type = parse_quote_spanned!(span=> #_error_type);
|
||||
|
||||
error_type = Some(_error_type);
|
||||
} else {
|
||||
panic!("{ERROR_T_NAME} must be set with a string literal.")
|
||||
}
|
||||
} else {
|
||||
panic!("{ERROR_T_NAME} must be set with a literal.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let value_type = value_type
|
||||
.expect(format!("Need attribute {VALUE_T_NAME}: #[{VALUE_T_NAME} = \"type\"]").deref());
|
||||
let error_type = error_type
|
||||
.expect(format!("Need attribute {ERROR_T_NAME}: #[{ERROR_T_NAME} = \"type\"]").deref());
|
||||
|
||||
// ----- Add publisher generics ----------------------------------
|
||||
// MutexT
|
||||
const MUTEX_T_NAME: &str = "PublishMutexT";
|
||||
let mutex_t_ident = Ident::new(MUTEX_T_NAME, Span::call_site());
|
||||
const CAPACITY_NAME: &str = "CAPACITY";
|
||||
let capacity_ident = Ident::new(CAPACITY_NAME, Span::call_site());
|
||||
const NUM_SUBS_NAME: &str = "NUM_SUBS";
|
||||
let num_subs_ident = Ident::new(NUM_SUBS_NAME, Span::call_site());
|
||||
let mut num_lifetimes: usize = 0;
|
||||
let mut num_types: usize = 0;
|
||||
let mut num_const: usize = 0;
|
||||
let mut has_mutex_t = false;
|
||||
for param in og_generics.params.iter() {
|
||||
match param {
|
||||
GenericParam::Lifetime(_) => num_lifetimes += 1,
|
||||
GenericParam::Type(param) => {
|
||||
num_types += 1;
|
||||
|
||||
// If the generic parameter is MutexT
|
||||
if param.ident == MUTEX_T_NAME {
|
||||
has_mutex_t = true;
|
||||
}
|
||||
},
|
||||
GenericParam::Const(_) => num_const += 1,
|
||||
}
|
||||
}
|
||||
|
||||
let mut publish_generics = og_generics.clone();
|
||||
// If MutexT is not a generic parameter, add it
|
||||
if !has_mutex_t {
|
||||
let mutex_t_param: GenericParam =
|
||||
parse_quote!(#mutex_t_ident: embassy_sync::blocking_mutex::raw::RawMutex);
|
||||
|
||||
let num_generics = num_lifetimes + num_types + num_const;
|
||||
// If there are generics
|
||||
if num_generics > 0 {
|
||||
// If all generics are lifetimes
|
||||
if num_lifetimes == num_generics {
|
||||
// Add MutexT after the lifetimes
|
||||
publish_generics.params.push(mutex_t_param);
|
||||
// If no generics are lifetimes
|
||||
} else if num_lifetimes == 0 {
|
||||
// Insert MutexT at the front
|
||||
publish_generics.params.insert(0, mutex_t_param);
|
||||
// If there are lifetimes and other generics
|
||||
} else {
|
||||
// Insert MutexT after the lifetimes
|
||||
publish_generics.params.insert(num_lifetimes, mutex_t_param);
|
||||
}
|
||||
// If there are no generics
|
||||
} else {
|
||||
// Add MutexT
|
||||
publish_generics.params.push(mutex_t_param);
|
||||
}
|
||||
}
|
||||
// const generics
|
||||
let capacity_param: GenericParam = parse_quote!(const #capacity_ident: usize);
|
||||
let num_subs_param: GenericParam = parse_quote!(const #num_subs_ident: usize);
|
||||
publish_generics.params.push(capacity_param);
|
||||
publish_generics.params.push(num_subs_param);
|
||||
|
||||
let (publ_impl_generics, publ_type_generics, publ_where_clause) =
|
||||
&publish_generics.split_for_impl();
|
||||
|
||||
let pubsub_error_path: Path = parse_quote!(embassy_sync::pubsub::Error);
|
||||
let pubsub_sub_path: Path = parse_quote!(embassy_sync::pubsub::Subscriber);
|
||||
|
||||
let stateful_variant_ident = Ident::new(format!("{ident}Stateful").deref(), ident.span());
|
||||
let publish_variant_ident = Ident::new(format!("{ident}Publish").deref(), ident.span());
|
||||
let state_pub_variant_ident = Ident::new(format!("{ident}StatePub").deref(), ident.span());
|
||||
|
||||
let poll_path: Path = parse_quote!(physical_node::transducer::input::Poll);
|
||||
let stateful_path: Path = parse_quote!(physical_node::transducer::Stateful);
|
||||
let publish_path: Path = parse_quote!(physical_node::transducer::Publish);
|
||||
|
||||
let state_path: Path = parse_quote!(physical_node::transducer::State);
|
||||
let publisher_path: Path = parse_quote!(physical_node::transducer::Publisher);
|
||||
let cellview_path: Path = parse_quote!(physical_node::cell::CellView);
|
||||
|
||||
let error_path: Path = parse_quote!(physical_node::CriticalError);
|
||||
|
||||
let expanded = quote! {
|
||||
// ----- Stateful struct ----------------------------------
|
||||
#vis struct #stateful_variant_ident #og_generics #og_where_clause {
|
||||
pub poll: #ident #og_type_generics,
|
||||
pub state: #state_path<#value_type>,
|
||||
}
|
||||
|
||||
// ----- Stateful impls ----------------------------------
|
||||
impl #og_impl_generics #poll_path for #stateful_variant_ident #og_type_generics #og_where_clause {
|
||||
type Value = #value_type;
|
||||
type Error = #error_type;
|
||||
|
||||
#[inline]
|
||||
async fn poll(&self) -> Result<Self::Value, Self::Error> {
|
||||
let result = self.poll.poll().await;
|
||||
if let Ok(value) = result {
|
||||
self.state.update(value);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl #og_impl_generics #stateful_path for #stateful_variant_ident #og_type_generics #og_where_clause {
|
||||
type Value = #value_type;
|
||||
|
||||
#[inline(always)]
|
||||
fn state_cell(&self) -> #cellview_path<Self::Value> {
|
||||
self.state.state_cell()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn state(&self) -> Self::Value {
|
||||
self.state.state()
|
||||
}
|
||||
}
|
||||
|
||||
// ----- Publish struct ----------------------------------
|
||||
#[cfg(feature = "embassy-sync")]
|
||||
#vis struct #publish_variant_ident #publish_generics #publ_where_clause {
|
||||
pub poll: #ident #og_type_generics,
|
||||
pub publisher: #publisher_path<#value_type, #mutex_t_ident, #capacity_ident, #num_subs_ident>,
|
||||
}
|
||||
|
||||
// ----- Publish impl ----------------------------------
|
||||
#[cfg(feature = "embassy-sync")]
|
||||
impl #publ_impl_generics #poll_path for #publish_variant_ident #publ_type_generics #publ_where_clause {
|
||||
type Value = #value_type;
|
||||
type Error = #error_type;
|
||||
|
||||
#[inline]
|
||||
async fn poll(&self) -> Result<Self::Value, Self::Error> {
|
||||
let result = self.poll.poll().await;
|
||||
if let Ok(value) = result {
|
||||
self.publisher.update(value);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "embassy-sync")]
|
||||
impl #publ_impl_generics #publish_path<#capacity_ident, #num_subs_ident> for #publish_variant_ident #publ_type_generics #publ_where_clause {
|
||||
type Value = #value_type;
|
||||
type Mutex = #mutex_t_ident;
|
||||
|
||||
#[inline(always)]
|
||||
fn subscribe(
|
||||
&self,
|
||||
) -> Result<#pubsub_sub_path<Self::Mutex, Self::Value, #capacity_ident, #num_subs_ident, 0>, #pubsub_error_path> {
|
||||
self.publisher.subscribe()
|
||||
}
|
||||
}
|
||||
|
||||
// ----- StatePub struct ----------------------------------
|
||||
#[cfg(feature = "embassy-sync")]
|
||||
#vis struct #state_pub_variant_ident #publish_generics #publ_where_clause {
|
||||
pub poll: #ident #og_type_generics,
|
||||
pub state: #state_path<#value_type>,
|
||||
pub publisher: #publisher_path<#value_type, #mutex_t_ident, #capacity_ident, #num_subs_ident>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "embassy-sync")]
|
||||
impl #publ_impl_generics #poll_path for #state_pub_variant_ident #publ_type_generics #publ_where_clause {
|
||||
type Value = #value_type;
|
||||
type Error = #error_type;
|
||||
|
||||
#[inline]
|
||||
async fn poll(&self) -> Result<Self::Value, Self::Error> {
|
||||
let result = self.poll.poll().await;
|
||||
if let Ok(value) = result {
|
||||
self.state.update(value);
|
||||
self.publisher.update(value);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "embassy-sync")]
|
||||
impl #publ_impl_generics #stateful_path for #state_pub_variant_ident #publ_type_generics #publ_where_clause {
|
||||
type Value = #value_type;
|
||||
|
||||
#[inline(always)]
|
||||
fn state_cell(&self) -> #cellview_path<Self::Value> {
|
||||
self.state.state_cell()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn state(&self) -> Self::Value {
|
||||
self.state.state()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "embassy-sync")]
|
||||
impl #publ_impl_generics #publish_path<#capacity_ident, #num_subs_ident> for #state_pub_variant_ident #publ_type_generics #publ_where_clause {
|
||||
type Value = #value_type;
|
||||
type Mutex = #mutex_t_ident;
|
||||
|
||||
#[inline(always)]
|
||||
fn subscribe(
|
||||
&self,
|
||||
) -> Result<#pubsub_sub_path<Self::Mutex, Self::Value, #capacity_ident, #num_subs_ident, 0>, #pubsub_error_path> {
|
||||
self.publisher.subscribe()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
TokenStream::from(expanded)
|
||||
}
|
31
macros/node-poll-variants/tests/generate.rs
Normal file
31
macros/node-poll-variants/tests/generate.rs
Normal file
@ -0,0 +1,31 @@
|
||||
#![feature(async_fn_in_trait, impl_trait_projections, never_type)]
|
||||
|
||||
use node_poll_variants::PollVariants;
|
||||
use physical_node::transducer::input::Poll;
|
||||
|
||||
#[derive(PollVariants)]
|
||||
#[value_type = "SecondT"]
|
||||
#[error_type = "!"]
|
||||
struct ExamplePoll<'a, FirstT, SecondT>
|
||||
where
|
||||
SecondT: Copy,
|
||||
{
|
||||
a: &'a i32,
|
||||
b: i32,
|
||||
first: FirstT,
|
||||
second: SecondT,
|
||||
}
|
||||
|
||||
impl<'a, FirstT, SecondT> Poll for ExamplePoll<'a, FirstT, SecondT>
|
||||
where
|
||||
SecondT: Copy,
|
||||
{
|
||||
type Value = SecondT;
|
||||
type Error = !;
|
||||
|
||||
async fn poll(&self) -> Result<Self::Value, Self::Error> {
|
||||
Ok(self.second)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {}
|
5
macros/node-poll-variants/tests/test_build.rs
Normal file
5
macros/node-poll-variants/tests/test_build.rs
Normal file
@ -0,0 +1,5 @@
|
||||
#[test]
|
||||
fn tests() {
|
||||
let t = trybuild::TestCases::new();
|
||||
t.pass("tests/generate.rs");
|
||||
}
|
Reference in New Issue
Block a user