blob: 33a87cb987306c5aad9be516d2f37e017839880a [file] [log] [blame]
use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange};
use rustc_hir::LangItem;
use rustc_index::IndexVec;
use rustc_middle::bug;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::*;
use rustc_middle::ty::layout::PrimitiveExt;
use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv};
use rustc_session::Session;
use tracing::debug;
/// This pass inserts checks for a valid enum discriminant where they are most
/// likely to find UB, because checking everywhere like Miri would generate too
/// much MIR.
pub(super) struct CheckEnums;
impl<'tcx> crate::MirPass<'tcx> for CheckEnums {
fn is_enabled(&self, sess: &Session) -> bool {
sess.ub_checks()
}
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// This pass emits new panics. If for whatever reason we do not have a panic
// implementation, running this pass may cause otherwise-valid code to not compile.
if tcx.lang_items().get(LangItem::PanicImpl).is_none() {
return;
}
let typing_env = body.typing_env(tcx);
let basic_blocks = body.basic_blocks.as_mut();
let local_decls = &mut body.local_decls;
// This operation inserts new blocks. Each insertion changes the Location for all
// statements/blocks after. Iterating or visiting the MIR in order would require updating
// our current location after every insertion. By iterating backwards, we dodge this issue:
// The only Locations that an insertion changes have already been handled.
for block in basic_blocks.indices().rev() {
for statement_index in (0..basic_blocks[block].statements.len()).rev() {
let location = Location { block, statement_index };
let statement = &basic_blocks[block].statements[statement_index];
let source_info = statement.source_info;
let mut finder = EnumFinder::new(tcx, local_decls, typing_env);
finder.visit_statement(statement, location);
for check in finder.into_found_enums() {
debug!("Inserting enum check");
let new_block = split_block(basic_blocks, location);
match check {
EnumCheckType::Direct { source_op, discr, op_size, valid_discrs } => {
insert_direct_enum_check(
tcx,
local_decls,
basic_blocks,
block,
source_op,
discr,
op_size,
valid_discrs,
source_info,
new_block,
)
}
EnumCheckType::Uninhabited => insert_uninhabited_enum_check(
tcx,
local_decls,
&mut basic_blocks[block],
source_info,
new_block,
),
EnumCheckType::WithNiche {
source_op,
discr,
op_size,
offset,
valid_range,
} => insert_niche_check(
tcx,
local_decls,
&mut basic_blocks[block],
source_op,
valid_range,
discr,
op_size,
offset,
source_info,
new_block,
),
}
}
}
}
}
fn is_required(&self) -> bool {
true
}
}
/// Represent the different kind of enum checks we can insert.
enum EnumCheckType<'tcx> {
/// We know we try to create an uninhabited enum from an inhabited variant.
Uninhabited,
/// We know the enum does no niche optimizations and can thus easily compute
/// the valid discriminants.
Direct {
source_op: Operand<'tcx>,
discr: TyAndSize<'tcx>,
op_size: Size,
valid_discrs: Vec<u128>,
},
/// We try to construct an enum that has a niche.
WithNiche {
source_op: Operand<'tcx>,
discr: TyAndSize<'tcx>,
op_size: Size,
offset: Size,
valid_range: WrappingRange,
},
}
#[derive(Debug, Copy, Clone)]
struct TyAndSize<'tcx> {
pub ty: Ty<'tcx>,
pub size: Size,
}
/// A [Visitor] that finds the construction of enums and evaluates which checks
/// we should apply.
struct EnumFinder<'a, 'tcx> {
tcx: TyCtxt<'tcx>,
local_decls: &'a mut LocalDecls<'tcx>,
typing_env: TypingEnv<'tcx>,
enums: Vec<EnumCheckType<'tcx>>,
}
impl<'a, 'tcx> EnumFinder<'a, 'tcx> {
fn new(
tcx: TyCtxt<'tcx>,
local_decls: &'a mut LocalDecls<'tcx>,
typing_env: TypingEnv<'tcx>,
) -> Self {
EnumFinder { tcx, local_decls, typing_env, enums: Vec::new() }
}
/// Returns the found enum creations and which checks should be inserted.
fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> {
self.enums
}
}
impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> {
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
let ty::Adt(adt_def, _) = ty.kind() else {
return;
};
if !adt_def.is_enum() {
return;
}
let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else {
return;
};
let Ok(op_layout) = self
.tcx
.layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx)))
else {
return;
};
match enum_layout.variants {
Variants::Empty if op_layout.is_uninhabited() => return,
// An empty enum that tries to be constructed from an inhabited value, this
// is never correct.
Variants::Empty => {
// The enum layout is uninhabited but we construct it from sth inhabited.
// This is always UB.
self.enums.push(EnumCheckType::Uninhabited);
}
// Construction of Single value enums is always fine.
Variants::Single { .. } => {}
// Construction of an enum with multiple variants but no niche optimizations.
Variants::Multiple {
tag_encoding: TagEncoding::Direct,
tag: Scalar::Initialized { value, .. },
..
} => {
let valid_discrs =
adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
let discr =
TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
self.enums.push(EnumCheckType::Direct {
source_op: op.to_copy(),
discr,
op_size: op_layout.size,
valid_discrs,
});
}
// Construction of an enum with multiple variants and niche optimizations.
Variants::Multiple {
tag_encoding: TagEncoding::Niche { .. },
tag: Scalar::Initialized { value, valid_range, .. },
tag_field,
..
} => {
let discr =
TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
self.enums.push(EnumCheckType::WithNiche {
source_op: op.to_copy(),
discr,
op_size: op_layout.size,
offset: enum_layout.fields.offset(tag_field.as_usize()),
valid_range,
});
}
_ => return,
}
self.super_rvalue(rvalue, location);
}
}
}
fn split_block(
basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
location: Location,
) -> BasicBlock {
let block_data = &mut basic_blocks[location.block];
// Drain every statement after this one and move the current terminator to a new basic block.
let new_block = BasicBlockData::new_stmts(
block_data.statements.split_off(location.statement_index),
block_data.terminator.take(),
block_data.is_cleanup,
);
basic_blocks.push(new_block)
}
/// Inserts the cast of an operand (any type) to a u128 value that holds the discriminant value.
fn insert_discr_cast_to_u128<'tcx>(
tcx: TyCtxt<'tcx>,
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
block_data: &mut BasicBlockData<'tcx>,
source_op: Operand<'tcx>,
discr: TyAndSize<'tcx>,
op_size: Size,
offset: Option<Size>,
source_info: SourceInfo,
) -> Place<'tcx> {
let get_ty_for_size = |tcx: TyCtxt<'tcx>, size: Size| -> Ty<'tcx> {
match size.bytes() {
1 => tcx.types.u8,
2 => tcx.types.u16,
4 => tcx.types.u32,
8 => tcx.types.u64,
16 => tcx.types.u128,
invalid => bug!("Found discriminant with invalid size, has {} bytes", invalid),
}
};
let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() {
// The discriminant is less wide than the operand, cast the operand into
// [MaybeUninit; N] and then index into it.
let mu = Ty::new_maybe_uninit(tcx, tcx.types.u8);
let array_len = op_size.bytes();
let mu_array_ty = Ty::new_array(tcx, mu, array_len);
let mu_array =
local_decls.push(LocalDecl::with_source_info(mu_array_ty, source_info)).into();
let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_array_ty);
block_data
.statements
.push(Statement::new(source_info, StatementKind::Assign(Box::new((mu_array, rvalue)))));
// Index into the array of MaybeUninit to get something that is actually
// as wide as the discriminant.
let offset = offset.unwrap_or(Size::ZERO);
let smaller_mu_array = mu_array.project_deeper(
&[ProjectionElem::Subslice {
from: offset.bytes(),
to: offset.bytes() + discr.size.bytes(),
from_end: false,
}],
tcx,
);
(CastKind::Transmute, Operand::Copy(smaller_mu_array))
} else {
let operand_int_ty = get_ty_for_size(tcx, op_size);
let op_as_int =
local_decls.push(LocalDecl::with_source_info(operand_int_ty, source_info)).into();
let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, operand_int_ty);
block_data.statements.push(Statement::new(
source_info,
StatementKind::Assign(Box::new((op_as_int, rvalue))),
));
(CastKind::IntToInt, Operand::Copy(op_as_int))
};
// Cast the resulting value to the actual discriminant integer type.
let rvalue = Rvalue::Cast(cast_kind, discr_ty_bits, discr.ty);
let discr_in_discr_ty =
local_decls.push(LocalDecl::with_source_info(discr.ty, source_info)).into();
block_data.statements.push(Statement::new(
source_info,
StatementKind::Assign(Box::new((discr_in_discr_ty, rvalue))),
));
// Cast the discriminant to a u128 (base for comparisons of enum discriminants).
let const_u128 = Ty::new_uint(tcx, ty::UintTy::U128);
let rvalue = Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_in_discr_ty), const_u128);
let discr = local_decls.push(LocalDecl::with_source_info(const_u128, source_info)).into();
block_data
.statements
.push(Statement::new(source_info, StatementKind::Assign(Box::new((discr, rvalue)))));
discr
}
fn insert_direct_enum_check<'tcx>(
tcx: TyCtxt<'tcx>,
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
current_block: BasicBlock,
source_op: Operand<'tcx>,
discr: TyAndSize<'tcx>,
op_size: Size,
discriminants: Vec<u128>,
source_info: SourceInfo,
new_block: BasicBlock,
) {
// Insert a new target block that is branched to in case of an invalid discriminant.
let invalid_discr_block_data = BasicBlockData::new(None, false);
let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
let block_data = &mut basic_blocks[current_block];
let discr_place = insert_discr_cast_to_u128(
tcx,
local_decls,
block_data,
source_op,
discr,
op_size,
None,
source_info,
);
// Mask out the bits of the discriminant type.
let mask = discr.size.unsigned_int_max();
let discr_masked =
local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
let rvalue = Rvalue::BinaryOp(
BinOp::BitAnd,
Box::new((
Operand::Copy(discr_place),
Operand::Constant(Box::new(ConstOperand {
span: source_info.span,
user_ty: None,
const_: Const::Val(ConstValue::from_u128(mask), tcx.types.u128),
})),
)),
);
block_data
.statements
.push(Statement::new(source_info, StatementKind::Assign(Box::new((discr_masked, rvalue)))));
// Branch based on the discriminant value.
block_data.terminator = Some(Terminator {
source_info,
kind: TerminatorKind::SwitchInt {
discr: Operand::Copy(discr_masked),
targets: SwitchTargets::new(
discriminants
.into_iter()
.map(|discr_val| (discr.size.truncate(discr_val), new_block)),
invalid_discr_block,
),
},
});
// Abort in case of an invalid enum discriminant.
basic_blocks[invalid_discr_block].terminator = Some(Terminator {
source_info,
kind: TerminatorKind::Assert {
cond: Operand::Constant(Box::new(ConstOperand {
span: source_info.span,
user_ty: None,
const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
})),
expected: true,
target: new_block,
msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr_masked))),
// This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
// We never want to insert an unwind into unsafe code, because unwinding could
// make a failing UB check turn into much worse UB when we start unwinding.
unwind: UnwindAction::Unreachable,
},
});
}
fn insert_uninhabited_enum_check<'tcx>(
tcx: TyCtxt<'tcx>,
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
block_data: &mut BasicBlockData<'tcx>,
source_info: SourceInfo,
new_block: BasicBlock,
) {
let is_ok: Place<'_> =
local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
block_data.statements.push(Statement::new(
source_info,
StatementKind::Assign(Box::new((
is_ok,
Rvalue::Use(Operand::Constant(Box::new(ConstOperand {
span: source_info.span,
user_ty: None,
const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
}))),
))),
));
block_data.terminator = Some(Terminator {
source_info,
kind: TerminatorKind::Assert {
cond: Operand::Copy(is_ok),
expected: true,
target: new_block,
msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Constant(Box::new(
ConstOperand {
span: source_info.span,
user_ty: None,
const_: Const::Val(ConstValue::from_u128(0), tcx.types.u128),
},
)))),
// This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
// We never want to insert an unwind into unsafe code, because unwinding could
// make a failing UB check turn into much worse UB when we start unwinding.
unwind: UnwindAction::Unreachable,
},
});
}
fn insert_niche_check<'tcx>(
tcx: TyCtxt<'tcx>,
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
block_data: &mut BasicBlockData<'tcx>,
source_op: Operand<'tcx>,
valid_range: WrappingRange,
discr: TyAndSize<'tcx>,
op_size: Size,
offset: Size,
source_info: SourceInfo,
new_block: BasicBlock,
) {
let discr = insert_discr_cast_to_u128(
tcx,
local_decls,
block_data,
source_op,
discr,
op_size,
Some(offset),
source_info,
);
// Compare the discriminant against the valid_range.
let start_const = Operand::Constant(Box::new(ConstOperand {
span: source_info.span,
user_ty: None,
const_: Const::Val(ConstValue::from_u128(valid_range.start), tcx.types.u128),
}));
let end_start_diff_const = Operand::Constant(Box::new(ConstOperand {
span: source_info.span,
user_ty: None,
const_: Const::Val(
ConstValue::from_u128(u128::wrapping_sub(valid_range.end, valid_range.start)),
tcx.types.u128,
),
}));
let discr_diff: Place<'_> =
local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
block_data.statements.push(Statement::new(
source_info,
StatementKind::Assign(Box::new((
discr_diff,
Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(discr), start_const))),
))),
));
let is_ok: Place<'_> =
local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
block_data.statements.push(Statement::new(
source_info,
StatementKind::Assign(Box::new((
is_ok,
Rvalue::BinaryOp(
// This is a `WrappingRange`, so make sure to get the wrapping right.
BinOp::Le,
Box::new((Operand::Copy(discr_diff), end_start_diff_const)),
),
))),
));
block_data.terminator = Some(Terminator {
source_info,
kind: TerminatorKind::Assert {
cond: Operand::Copy(is_ok),
expected: true,
target: new_block,
msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
// This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
// We never want to insert an unwind into unsafe code, because unwinding could
// make a failing UB check turn into much worse UB when we start unwinding.
unwind: UnwindAction::Unreachable,
},
});
}