blob: edd9d9c196cf369d2fecef160889307cf527c989 [file]
use std::str::FromStr;
use rustc_ast::LitKind;
use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode};
use rustc_feature::{AttributeTemplate, template};
use rustc_hir::attrs::{AttributeKind, RustcAutodiff};
use rustc_hir::{MethodKind, Target};
use rustc_span::{Symbol, sym};
use thin_vec::ThinVec;
use crate::attributes::prelude::Allow;
use crate::attributes::{OnDuplicate, SingleAttributeParser};
use crate::context::{AcceptContext, Stage};
use crate::parser::{ArgParser, MetaItemOrLitParser};
use crate::target_checking::AllowedTargets;
pub(crate) struct RustcAutodiffParser;
impl<S: Stage> SingleAttributeParser<S> for RustcAutodiffParser {
const PATH: &[Symbol] = &[sym::rustc_autodiff];
const ON_DUPLICATE: OnDuplicate<S> = OnDuplicate::Error;
const ALLOWED_TARGETS: AllowedTargets = AllowedTargets::AllowList(&[
Allow(Target::Fn),
Allow(Target::Method(MethodKind::Inherent)),
Allow(Target::Method(MethodKind::Trait { body: true })),
Allow(Target::Method(MethodKind::Trait { body: false })),
Allow(Target::Method(MethodKind::TraitImpl)),
]);
const TEMPLATE: AttributeTemplate = template!(
List: &["MODE", "WIDTH", "INPUT_ACTIVITIES", "OUTPUT_ACTIVITY"],
"https://doc.rust-lang.org/std/autodiff/index.html"
);
fn convert(cx: &mut AcceptContext<'_, '_, S>, args: &ArgParser) -> Option<AttributeKind> {
let list = match args {
ArgParser::NoArgs => return Some(AttributeKind::RustcAutodiff(None)),
ArgParser::List(list) => list,
ArgParser::NameValue(_) => {
cx.expected_list_or_no_args(cx.attr_span);
return None;
}
};
let mut items = list.mixed().peekable();
// Parse name
let Some(mode) = items.next() else {
cx.expected_at_least_one_argument(list.span);
return None;
};
let Some(mode) = mode.meta_item() else {
cx.expected_identifier(mode.span());
return None;
};
let Ok(()) = mode.args().no_args() else {
cx.expected_identifier(mode.span());
return None;
};
let Some(mode) = mode.path().word() else {
cx.expected_identifier(mode.span());
return None;
};
let Ok(mode) = DiffMode::from_str(mode.as_str()) else {
cx.expected_specific_argument(mode.span, DiffMode::all_modes());
return None;
};
// Parse width
let width = if let Some(width) = items.peek()
&& let MetaItemOrLitParser::Lit(width) = width
&& let LitKind::Int(width, _) = width.kind
&& let Ok(width) = width.0.try_into()
{
_ = items.next();
width
} else {
1
};
// Parse activities
let mut activities = ThinVec::new();
for activity in items {
let MetaItemOrLitParser::MetaItemParser(activity) = activity else {
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
return None;
};
let Ok(()) = activity.args().no_args() else {
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
return None;
};
let Some(activity) = activity.path().word() else {
cx.expected_specific_argument(activity.span(), DiffActivity::all_activities());
return None;
};
let Ok(activity) = DiffActivity::from_str(activity.as_str()) else {
cx.expected_specific_argument(activity.span, DiffActivity::all_activities());
return None;
};
activities.push(activity);
}
let Some(ret_activity) = activities.pop() else {
cx.expected_specific_argument(
list.span.with_lo(list.span.hi()),
DiffActivity::all_activities(),
);
return None;
};
Some(AttributeKind::RustcAutodiff(Some(Box::new(RustcAutodiff {
mode,
width,
input_activity: activities,
ret_activity,
}))))
}
}