| mod enums; |
| mod parse; |
| mod shared; |
| |
| use parse::{Invocation, StructuredInput}; |
| use proc_macro as pm; |
| use proc_macro2::{self as pm2, Span}; |
| use quote::{ToTokens, quote}; |
| pub(crate) use shared::{ALL_OPERATIONS, FloatTy, MathOpInfo, Ty}; |
| use syn::spanned::Spanned; |
| use syn::visit_mut::VisitMut; |
| use syn::{Ident, ItemEnum}; |
| |
| const KNOWN_TYPES: &[&str] = &[ |
| "FTy", "CFn", "CArgs", "CRet", "RustFn", "RustArgs", "RustRet", "public", |
| ]; |
| |
| /// Populate an enum with a variant representing function. Names are in upper camel case. |
| /// |
| /// Applied to an empty enum. Expects one attribute `#[function_enum(BaseName)]` that provides |
| /// the name of the `BaseName` enum. |
| #[proc_macro_attribute] |
| pub fn function_enum(attributes: pm::TokenStream, tokens: pm::TokenStream) -> pm::TokenStream { |
| let item = syn::parse_macro_input!(tokens as ItemEnum); |
| let res = enums::function_enum(item, attributes.into()); |
| |
| match res { |
| Ok(ts) => ts, |
| Err(e) => e.into_compile_error(), |
| } |
| .into() |
| } |
| |
| /// Create an enum representing all possible base names, with names in upper camel case. |
| /// |
| /// Applied to an empty enum. |
| #[proc_macro_attribute] |
| pub fn base_name_enum(attributes: pm::TokenStream, tokens: pm::TokenStream) -> pm::TokenStream { |
| let item = syn::parse_macro_input!(tokens as ItemEnum); |
| let res = enums::base_name_enum(item, attributes.into()); |
| |
| match res { |
| Ok(ts) => ts, |
| Err(e) => e.into_compile_error(), |
| } |
| .into() |
| } |
| |
| /// Do something for each function present in this crate. |
| /// |
| /// Takes a callback macro and invokes it multiple times, once for each function that |
| /// this crate exports. This makes it easy to create generic tests, benchmarks, or other checks |
| /// and apply it to each symbol. |
| /// |
| /// Additionally, the `extra` and `fn_extra` patterns can make use of magic identifiers: |
| /// |
| /// - `MACRO_FN_NAME`: gets replaced with the name of the function on that invocation. |
| /// - `MACRO_FN_NAME_NORMALIZED`: similar to the above, but removes sufixes so e.g. `sinf` becomes |
| /// `sin`, `cosf128` becomes `cos`, etc. |
| /// |
| /// Invoke as: |
| /// |
| /// ``` |
| /// // Macro that is invoked once per function |
| /// macro_rules! callback_macro { |
| /// ( |
| /// // Name of that function |
| /// fn_name: $fn_name:ident, |
| /// // The basic float type for this function (e.g. `f32`, `f64`) |
| /// FTy: $FTy:ty, |
| /// // Function signature of the C version (e.g. `fn(f32, &mut f32) -> f32`) |
| /// CFn: $CFn:ty, |
| /// // A tuple representing the C version's arguments (e.g. `(f32, &mut f32)`) |
| /// CArgs: $CArgs:ty, |
| /// // The C version's return type (e.g. `f32`) |
| /// CRet: $CRet:ty, |
| /// // Function signature of the Rust version (e.g. `fn(f32) -> (f32, f32)`) |
| /// RustFn: $RustFn:ty, |
| /// // A tuple representing the Rust version's arguments (e.g. `(f32,)`) |
| /// RustArgs: $RustArgs:ty, |
| /// // The Rust version's return type (e.g. `(f32, f32)`) |
| /// RustRet: $RustRet:ty, |
| /// // True if this is part of `libm`'s public API |
| /// public: $public:expr, |
| /// // Attributes for the current function, if any |
| /// attrs: [$($attr:meta),*], |
| /// // Extra tokens passed directly (if any) |
| /// extra: [$extra:ident], |
| /// // Extra function-tokens passed directly (if any) |
| /// fn_extra: $fn_extra:expr, |
| /// ) => { }; |
| /// } |
| /// |
| /// // All fields except for `callback` are optional. |
| /// libm_macros::for_each_function! { |
| /// // The macro to invoke as a callback |
| /// callback: callback_macro, |
| /// // Which types to include either as a list (`[CFn, RustFn, RustArgs]`) or "all" |
| /// emit_types: all, |
| /// // Functions to skip, i.e. `callback` shouldn't be called at all for these. |
| /// skip: [sin, cos], |
| /// // Attributes passed as `attrs` for specific functions. For example, here the invocation |
| /// // with `sinf` and that with `cosf` will both get `meta1` and `meta2`, but no others will. |
| /// // |
| /// // Note that `f16_enabled` and `f128_enabled` will always get emitted regardless of whether |
| /// // or not this is specified. |
| /// attributes: [ |
| /// #[meta1] |
| /// #[meta2] |
| /// [sinf, cosf], |
| /// ], |
| /// // Any tokens that should be passed directly to all invocations of the callback. This can |
| /// // be used to pass local variables or other things the macro needs access to. |
| /// extra: [foo], |
| /// // Similar to `extra`, but allow providing a pattern for only specific functions. Uses |
| /// // a simplified match-like syntax. |
| /// fn_extra: match MACRO_FN_NAME { |
| /// hypot | hypotf => |x| x.hypot(), |
| /// // `ALL_*` magic matchers also work to extract specific types |
| /// ALL_F64 => |x| x, |
| /// // The default pattern gets applied to everything that did not match |
| /// _ => |x| x, |
| /// }, |
| /// } |
| /// ``` |
| #[proc_macro] |
| pub fn for_each_function(tokens: pm::TokenStream) -> pm::TokenStream { |
| let input = syn::parse_macro_input!(tokens as Invocation); |
| |
| let res = StructuredInput::from_fields(input) |
| .and_then(|mut s_in| validate(&mut s_in).map(|fn_list| (s_in, fn_list))) |
| .and_then(|(s_in, fn_list)| expand(s_in, &fn_list)); |
| |
| match res { |
| Ok(ts) => ts.into(), |
| Err(e) => e.into_compile_error().into(), |
| } |
| } |
| |
| /// Check for any input that is structurally correct but has other problems. |
| /// |
| /// Returns the list of function names that we should expand for. |
| fn validate(input: &mut StructuredInput) -> syn::Result<Vec<&'static MathOpInfo>> { |
| // Replace magic mappers with a list of relevant functions. |
| if let Some(map) = &mut input.fn_extra { |
| for (name, ty) in [ |
| ("ALL_F16", FloatTy::F16), |
| ("ALL_F32", FloatTy::F32), |
| ("ALL_F64", FloatTy::F64), |
| ("ALL_F128", FloatTy::F128), |
| ] { |
| let Some(k) = map.keys().find(|key| *key == name) else { |
| continue; |
| }; |
| |
| let key = k.clone(); |
| let val = map.remove(&key).unwrap(); |
| |
| for op in ALL_OPERATIONS.iter().filter(|op| op.float_ty == ty) { |
| map.insert(Ident::new(op.name, key.span()), val.clone()); |
| } |
| } |
| } |
| |
| // Collect lists of all functions that are provied as macro inputs in various fields (only, |
| // skip, attributes). |
| let attr_mentions = input |
| .attributes |
| .iter() |
| .flat_map(|map_list| map_list.iter()) |
| .flat_map(|attr_map| attr_map.names.iter()); |
| let only_mentions = input.only.iter().flat_map(|only_list| only_list.iter()); |
| let fn_extra_mentions = input |
| .fn_extra |
| .iter() |
| .flat_map(|v| v.keys()) |
| .filter(|name| *name != "_"); |
| let all_mentioned_fns = input |
| .skip |
| .iter() |
| .chain(only_mentions) |
| .chain(attr_mentions) |
| .chain(fn_extra_mentions); |
| |
| // Make sure that every function mentioned is a real function |
| for mentioned in all_mentioned_fns { |
| if !ALL_OPERATIONS.iter().any(|func| mentioned == func.name) { |
| let e = syn::Error::new( |
| mentioned.span(), |
| format!("unrecognized function name `{mentioned}`"), |
| ); |
| return Err(e); |
| } |
| } |
| |
| if !input.skip.is_empty() && input.only.is_some() { |
| let e = syn::Error::new( |
| input.only_span.unwrap(), |
| "only one of `skip` or `only` may be specified", |
| ); |
| return Err(e); |
| } |
| |
| // Construct a list of what we intend to expand |
| let mut fn_list = Vec::new(); |
| for func in ALL_OPERATIONS.iter() { |
| let fn_name = func.name; |
| // If we have an `only` list and it does _not_ contain this function name, skip it |
| if input |
| .only |
| .as_ref() |
| .is_some_and(|only| !only.iter().any(|o| o == fn_name)) |
| { |
| continue; |
| } |
| |
| // If there is a `skip` list that contains this function name, skip it |
| if input.skip.iter().any(|s| s == fn_name) { |
| continue; |
| } |
| |
| // Omit f16 and f128 functions if requested |
| if input.skip_f16_f128 && (func.float_ty == FloatTy::F16 || func.float_ty == FloatTy::F128) |
| { |
| continue; |
| } |
| |
| // Run everything else |
| fn_list.push(func); |
| } |
| |
| // Types that the user would like us to provide in the macro |
| let mut add_all_types = false; |
| for ty in &input.emit_types { |
| let ty_name = ty.to_string(); |
| if ty_name == "all" { |
| add_all_types = true; |
| continue; |
| } |
| |
| // Check that all requested types are valid |
| if !KNOWN_TYPES.contains(&ty_name.as_str()) { |
| let e = syn::Error::new( |
| ty_name.span(), |
| format!("unrecognized type identifier `{ty_name}`"), |
| ); |
| return Err(e); |
| } |
| } |
| |
| if add_all_types { |
| // Ensure that if `all` was specified that nothing else was |
| if input.emit_types.len() > 1 { |
| let e = syn::Error::new( |
| input.emit_types_span.unwrap(), |
| "if `all` is specified, no other type identifiers may be given", |
| ); |
| return Err(e); |
| } |
| |
| // ...and then add all types |
| input.emit_types.clear(); |
| for ty in KNOWN_TYPES { |
| let ident = Ident::new(ty, Span::call_site()); |
| input.emit_types.push(ident); |
| } |
| } |
| |
| if let Some(map) = &input.fn_extra |
| && !map.keys().any(|key| key == "_") |
| { |
| // No default provided; make sure every expected function is covered |
| let mut fns_not_covered = Vec::new(); |
| for func in &fn_list { |
| if !map.keys().any(|key| key == func.name) { |
| // `name` was not mentioned in the `match` statement |
| fns_not_covered.push(func); |
| } |
| } |
| |
| if !fns_not_covered.is_empty() { |
| let e = syn::Error::new( |
| input.fn_extra_span.unwrap(), |
| format!( |
| "`fn_extra`: no default `_` pattern specified and the following \ |
| patterns are not covered: {fns_not_covered:#?}" |
| ), |
| ); |
| return Err(e); |
| } |
| }; |
| |
| Ok(fn_list) |
| } |
| |
| /// Expand our structured macro input into invocations of the callback macro. |
| fn expand(input: StructuredInput, fn_list: &[&MathOpInfo]) -> syn::Result<pm2::TokenStream> { |
| let mut out = pm2::TokenStream::new(); |
| let default_ident = Ident::new("_", Span::call_site()); |
| let callback = input.callback; |
| |
| for func in fn_list { |
| let fn_name = Ident::new(func.name, Span::call_site()); |
| |
| // Prepare attributes in an `attrs: ...` field |
| let mut meta_fields = Vec::new(); |
| if let Some(attrs) = &input.attributes { |
| let meta_iter = attrs |
| .iter() |
| .filter(|map| map.names.contains(&fn_name)) |
| .flat_map(|map| &map.meta) |
| .map(|v| v.into_token_stream()); |
| |
| meta_fields.extend(meta_iter); |
| } |
| |
| // Always emit f16 and f128 meta so this doesn't need to be repeated everywhere |
| if func.rust_sig.args.contains(&Ty::F16) || func.rust_sig.returns.contains(&Ty::F16) { |
| let ts = quote! { cfg(f16_enabled) }; |
| meta_fields.push(ts); |
| } |
| if func.rust_sig.args.contains(&Ty::F128) || func.rust_sig.returns.contains(&Ty::F128) { |
| let ts = quote! { cfg(f128_enabled) }; |
| meta_fields.push(ts); |
| } |
| |
| let meta_field = quote! { attrs: [ #( #meta_fields ),* ], }; |
| |
| // Prepare extra in an `extra: ...` field, running the replacer |
| let extra_field = match input.extra.clone() { |
| Some(mut extra) => { |
| let mut v = MacroReplace::new(func.name); |
| v.visit_expr_mut(&mut extra); |
| v.finish()?; |
| |
| quote! { extra: #extra, } |
| } |
| None => pm2::TokenStream::new(), |
| }; |
| |
| // Prepare function-specific extra in a `fn_extra: ...` field, running the replacer |
| let fn_extra_field = match input.fn_extra { |
| Some(ref map) => { |
| let mut fn_extra = map |
| .get(&fn_name) |
| .or_else(|| map.get(&default_ident)) |
| .unwrap() |
| .clone(); |
| |
| let mut v = MacroReplace::new(func.name); |
| v.visit_expr_mut(&mut fn_extra); |
| v.finish()?; |
| |
| quote! { fn_extra: #fn_extra, } |
| } |
| None => pm2::TokenStream::new(), |
| }; |
| |
| let base_fty = func.float_ty; |
| let c_args = &func.c_sig.args; |
| let c_ret = &func.c_sig.returns; |
| let rust_args = &func.rust_sig.args; |
| let rust_ret = &func.rust_sig.returns; |
| let public = func.public; |
| |
| let mut ty_fields = Vec::new(); |
| for ty in &input.emit_types { |
| let field = match ty.to_string().as_str() { |
| "FTy" => quote! { FTy: #base_fty, }, |
| "CFn" => quote! { CFn: fn( #(#c_args),* ,) -> ( #(#c_ret),* ), }, |
| "CArgs" => quote! { CArgs: ( #(#c_args),* ,), }, |
| "CRet" => quote! { CRet: ( #(#c_ret),* ), }, |
| "RustFn" => quote! { RustFn: fn( #(#rust_args),* ,) -> ( #(#rust_ret),* ), }, |
| "RustArgs" => quote! { RustArgs: ( #(#rust_args),* ,), }, |
| "RustRet" => quote! { RustRet: ( #(#rust_ret),* ), }, |
| "public" => quote! { public: #public, }, |
| _ => unreachable!("checked in validation"), |
| }; |
| ty_fields.push(field); |
| } |
| |
| let new = quote! { |
| #callback! { |
| fn_name: #fn_name, |
| #( #ty_fields )* |
| #meta_field |
| #extra_field |
| #fn_extra_field |
| } |
| }; |
| |
| out.extend(new); |
| } |
| |
| Ok(out) |
| } |
| |
| /// Visitor to replace "magic" identifiers that we allow: `MACRO_FN_NAME` and |
| /// `MACRO_FN_NAME_NORMALIZED`. |
| struct MacroReplace { |
| fn_name: &'static str, |
| /// Remove the trailing `f` or `f128` to make |
| norm_name: String, |
| error: Option<syn::Error>, |
| } |
| |
| impl MacroReplace { |
| fn new(name: &'static str) -> Self { |
| let norm_name = base_name(name); |
| Self { |
| fn_name: name, |
| norm_name: norm_name.to_owned(), |
| error: None, |
| } |
| } |
| |
| fn finish(self) -> syn::Result<()> { |
| match self.error { |
| Some(e) => Err(e), |
| None => Ok(()), |
| } |
| } |
| |
| fn visit_ident_inner(&mut self, i: &mut Ident) { |
| let s = i.to_string(); |
| if !s.starts_with("MACRO") || self.error.is_some() { |
| return; |
| } |
| |
| match s.as_str() { |
| "MACRO_FN_NAME" => *i = Ident::new(self.fn_name, i.span()), |
| "MACRO_FN_NAME_NORMALIZED" => *i = Ident::new(&self.norm_name, i.span()), |
| _ => { |
| self.error = Some(syn::Error::new( |
| i.span(), |
| format!("unrecognized meta expression `{s}`"), |
| )); |
| } |
| } |
| } |
| } |
| |
| impl VisitMut for MacroReplace { |
| fn visit_ident_mut(&mut self, i: &mut Ident) { |
| self.visit_ident_inner(i); |
| syn::visit_mut::visit_ident_mut(self, i); |
| } |
| } |
| |
| /// Return the unsuffixed version of a function name; e.g. `abs` and `absf` both return `abs`, |
| /// `lgamma_r` and `lgammaf_r` both return `lgamma_r`. |
| fn base_name(name: &str) -> &str { |
| let known_mappings = &[ |
| ("erff", "erf"), |
| ("erf", "erf"), |
| ("lgammaf_r", "lgamma_r"), |
| ("modff", "modf"), |
| ("modf", "modf"), |
| ]; |
| |
| match known_mappings.iter().find(|known| known.0 == name) { |
| Some(found) => found.1, |
| None => name |
| .strip_suffix("f") |
| .or_else(|| name.strip_suffix("f16")) |
| .or_else(|| name.strip_suffix("f128")) |
| .unwrap_or(name), |
| } |
| } |
| |
| impl ToTokens for Ty { |
| fn to_tokens(&self, tokens: &mut pm2::TokenStream) { |
| let ts = match self { |
| Ty::F16 => quote! { f16 }, |
| Ty::F32 => quote! { f32 }, |
| Ty::F64 => quote! { f64 }, |
| Ty::F128 => quote! { f128 }, |
| Ty::I32 => quote! { i32 }, |
| Ty::CInt => quote! { ::core::ffi::c_int }, |
| Ty::MutF16 => quote! { &'a mut f16 }, |
| Ty::MutF32 => quote! { &'a mut f32 }, |
| Ty::MutF64 => quote! { &'a mut f64 }, |
| Ty::MutF128 => quote! { &'a mut f128 }, |
| Ty::MutI32 => quote! { &'a mut i32 }, |
| Ty::MutCInt => quote! { &'a mut core::ffi::c_int }, |
| }; |
| |
| tokens.extend(ts); |
| } |
| } |
| impl ToTokens for FloatTy { |
| fn to_tokens(&self, tokens: &mut pm2::TokenStream) { |
| let ts = match self { |
| FloatTy::F16 => quote! { f16 }, |
| FloatTy::F32 => quote! { f32 }, |
| FloatTy::F64 => quote! { f64 }, |
| FloatTy::F128 => quote! { f128 }, |
| }; |
| |
| tokens.extend(ts); |
| } |
| } |