blob: 9dbd2d45a93af91e087e22f1dc677d3fea67891c [file] [edit]
use rustc_ast::token::{Delimiter, Token, TokenKind};
use rustc_ast::tokenstream::{DelimSpan, Spacing, TokenStream, TokenTree};
use rustc_ast::{AttrItem, ast};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_session::config::Offload;
use rustc_span::{Ident, Span, sym};
use thin_vec::thin_vec;
use crate::errors;
fn compile_for_device(ecx: &mut ExtCtxt<'_>) -> bool {
ecx.sess.opts.unstable_opts.offload.contains(&Offload::Device)
}
fn outer_normal_attr(
kind: &Box<rustc_ast::NormalAttr>,
id: rustc_ast::AttrId,
span: Span,
) -> rustc_ast::Attribute {
let style = rustc_ast::AttrStyle::Outer;
let kind = rustc_ast::AttrKind::Normal(kind.clone());
rustc_ast::Attribute { kind, id, style, span }
}
fn extract_fn(
item: &Annotatable,
) -> Option<(ast::Visibility, ast::FnSig, Ident, ast::Generics, Option<Box<ast::Block>>)> {
match item {
Annotatable::Item(iitem) => match &iitem.kind {
ast::ItemKind::Fn(ast::Fn { sig, ident, generics, body, .. }) => {
Some((iitem.vis.clone(), sig.clone(), *ident, generics.clone(), body.clone()))
}
_ => None,
},
_ => None,
}
}
/// The `offload_kernel` macro expands the function into two separate definitions:
/// one on the host to handle the call, and one on the device for executing the kernel.
///
/// ```
/// #[offload_kernel]
/// fn foo(a: &[f32], b: &[f32], c: *mut f32) {
/// *c = a[0] + b[0];
/// }
/// ```
///
/// This expands to the host-side function:
///
/// ```
/// #[unsafe(no_mangle)]
/// #[inline(never)]
/// fn foo(_: &[f32], _: &[f32], _: *mut f32) {
/// ::core::panicking::panic("not implemented")
/// }
/// ```
///
/// And the device-side kernel:
///
/// ```
/// #[rustc_offload_kernel]
/// #[unsafe(no_mangle)]
/// unsafe extern "gpu-kernel" fn foo(a: &[f32], b: &[f32], c: *mut f32) {
/// *c = a[0] + b[0];
/// }
/// ```
pub(crate) fn expand_kernel(
ecx: &mut ExtCtxt<'_>,
expand_span: Span,
_meta_item: &ast::MetaItem,
item: Annotatable,
) -> Vec<Annotatable> {
let dcx = ecx.sess.dcx();
let Some((vis, sig, ident, generics, body)) = extract_fn(&item) else {
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
return vec![item];
};
let span = ecx.with_def_site_ctxt(expand_span);
// device function
let mut device_fn = Box::new(ast::Fn {
defaultness: ast::Defaultness::Implicit,
sig: sig.clone(),
ident,
generics: generics.clone(),
contract: None,
body,
define_opaque: None,
eii_impls: Default::default(),
});
let extern_gpu_kernel = ast::Extern::from_abi(
Some(ast::StrLit {
symbol: sym::gpu_kernel,
suffix: None,
symbol_unescaped: sym::gpu_kernel,
style: ast::StrStyle::Cooked,
span,
}),
span,
);
device_fn.sig.header.ext = extern_gpu_kernel;
device_fn.sig.header.safety = ast::Safety::Unsafe(span);
// rustc_offload_kernel attr
let rustc_offload_kernel_attr =
Box::new(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_offload_kernel)));
let rustc_offload_kernel = outer_normal_attr(
&rustc_offload_kernel_attr,
ecx.sess.psess.attr_id_generator.mk_attr_id(),
span,
);
// unsafe(no_mangle) attr
let unsafe_item = AttrItem {
unsafety: ast::Safety::Unsafe(span),
path: ast::Path::from_ident(Ident::new(sym::no_mangle, span)),
args: ast::AttrItemKind::Unparsed(ast::AttrArgs::Empty),
tokens: None,
};
let no_mangle_attr = Box::new(ast::NormalAttr { item: unsafe_item, tokens: None });
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let unsafe_no_mangle = outer_normal_attr(&no_mangle_attr, new_id, span);
let device_item = {
let mut item = ecx.item(
span,
thin_vec![rustc_offload_kernel, unsafe_no_mangle],
ast::ItemKind::Fn(device_fn),
);
item.vis = vis.clone();
Annotatable::Item(item)
};
// unimplemented! body
let macro_expr = ecx.expr_macro_call(
span,
ecx.macro_call(
span,
ecx.path_global(
span,
[sym::std, sym::unimplemented].map(|s| Ident::new(s, span)).to_vec(),
),
Delimiter::Parenthesis,
TokenStream::default(),
),
);
let stmt = ecx.stmt_expr(macro_expr);
let body = ecx.block(span, thin_vec![stmt]);
// host function
let mut host_fn = Box::new(ast::Fn {
defaultness: ast::Defaultness::Implicit,
sig: sig.clone(),
ident,
generics: generics.clone(),
contract: None,
body: Some(body),
define_opaque: None,
eii_impls: Default::default(),
});
for param in host_fn.sig.decl.inputs.iter_mut() {
param.pat = Box::new(ecx.pat_wild(param.pat.span));
}
// inline(never) attr
let ts: Vec<TokenTree> = vec![TokenTree::Token(
Token::new(TokenKind::Ident(sym::never, false.into()), span),
Spacing::Joint,
)];
let never_arg = ast::DelimArgs {
dspan: DelimSpan::from_single(span),
delim: Delimiter::Parenthesis,
tokens: TokenStream::from_iter(ts),
};
let inline_item = ast::AttrItem {
unsafety: ast::Safety::Default,
path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
args: rustc_ast::ast::AttrItemKind::Unparsed(ast::AttrArgs::Delimited(never_arg)),
tokens: None,
};
let inline_never_attr = Box::new(ast::NormalAttr { item: inline_item, tokens: None });
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let unsafe_no_mangle = outer_normal_attr(&no_mangle_attr, new_id, span);
let host_item = {
let mut item =
ecx.item(span, thin_vec![unsafe_no_mangle, inline_never], ast::ItemKind::Fn(host_fn));
item.vis = vis.clone();
Annotatable::Item(item)
};
if compile_for_device(ecx) { vec![device_item] } else { vec![host_item] }
}