blob: 04a7cd2c75f28cb50153c844b96f324f82947c72 [file] [log] [blame]
use bitflags::bitflags;
use crate::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
pub struct OffloadMetadata {
pub payload_size: u64,
pub mode: MappingFlags,
}
bitflags! {
/// Mirrors `OpenMPOffloadMappingFlags` from Clang/OpenMP.
#[derive(Debug, Copy, Clone)]
#[repr(transparent)]
pub struct MappingFlags: u64 {
/// No flags.
const NONE = 0x0;
/// Allocate memory on the device and move data from host to device.
const TO = 0x01;
/// Allocate memory on the device and move data from device to host.
const FROM = 0x02;
/// Always perform the requested mapping action, even if already mapped.
const ALWAYS = 0x04;
/// Delete the element from the device environment, ignoring ref count.
const DELETE = 0x08;
/// The element being mapped is a pointer-pointee pair.
const PTR_AND_OBJ = 0x10;
/// The base address should be passed to the target kernel as argument.
const TARGET_PARAM = 0x20;
/// The runtime must return the device pointer.
const RETURN_PARAM = 0x40;
/// The reference being passed is a pointer to private data.
const PRIVATE = 0x80;
/// Pass the element by value.
const LITERAL = 0x100;
/// Implicit map (generated by compiler, not explicit in code).
const IMPLICIT = 0x200;
/// Hint to allocate memory close to the target device.
const CLOSE = 0x400;
/// Reserved (0x800 in OpenMP for XLC compatibility).
const RESERVED = 0x800;
/// Require that the data is already allocated on the device.
const PRESENT = 0x1000;
/// Increment/decrement a separate ref counter (OpenACC compatibility).
const OMPX_HOLD = 0x2000;
/// Used for non-contiguous list items in target update.
const NON_CONTIG = 0x100000000000;
/// 16 MSBs indicate membership in a struct.
const MEMBER_OF = 0xffff000000000000;
}
}
impl OffloadMetadata {
pub fn from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self {
OffloadMetadata {
payload_size: get_payload_size(tcx, ty),
mode: MappingFlags::from_ty(tcx, ty),
}
}
}
// FIXME(Sa4dUs): implement a solid logic to determine the payload size
fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 {
match ty.kind() {
ty::RawPtr(inner, _) | ty::Ref(_, inner, _) => get_payload_size(tcx, *inner),
_ => tcx
.layout_of(PseudoCanonicalInput {
typing_env: TypingEnv::fully_monomorphized(),
value: ty,
})
.unwrap()
.size
.bytes(),
}
}
impl MappingFlags {
fn from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self {
use rustc_ast::Mutability::*;
match ty.kind() {
ty::Bool
| ty::Char
| ty::Int(_)
| ty::Uint(_)
| ty::Float(_)
| ty::Adt(_, _)
| ty::Tuple(_)
| ty::Array(_, _)
| ty::Alias(_, _)
| ty::Param(_) => MappingFlags::TO,
ty::RawPtr(_, Not) | ty::Ref(_, _, Not) => MappingFlags::TO,
ty::RawPtr(_, Mut) | ty::Ref(_, _, Mut) => MappingFlags::TO | MappingFlags::FROM,
ty::Slice(_) | ty::Str | ty::Dynamic(_, _) => MappingFlags::TO | MappingFlags::FROM,
ty::Foreign(_) | ty::Pat(_, _) | ty::UnsafeBinder(_) => {
MappingFlags::TO | MappingFlags::FROM
}
ty::FnDef(_, _)
| ty::FnPtr(_, _)
| ty::Closure(_, _)
| ty::CoroutineClosure(_, _)
| ty::Coroutine(_, _)
| ty::CoroutineWitness(_, _)
| ty::Never
| ty::Bound(_, _)
| ty::Placeholder(_)
| ty::Infer(_)
| ty::Error(_) => {
tcx.dcx()
.span_err(rustc_span::DUMMY_SP, format!("type `{ty:?}` cannot be offloaded"));
MappingFlags::empty()
}
}
}
}