diff --git a/compiler/rustc_middle/src/mir/interpret/value.rs b/compiler/rustc_middle/src/mir/interpret/value.rs index 20861d5ffa405..47831d44de4e3 100644 --- a/compiler/rustc_middle/src/mir/interpret/value.rs +++ b/compiler/rustc_middle/src/mir/interpret/value.rs @@ -93,6 +93,10 @@ impl<'tcx> ConstValue<'tcx> { ConstValue::Scalar(Scalar::from_bool(b)) } + pub fn from_u32(i: u32) -> Self { + ConstValue::Scalar(Scalar::from_u32(i)) + } + pub fn from_u64(i: u64) -> Self { ConstValue::Scalar(Scalar::from_u64(i)) } diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs index 3e023ccdead67..0d327c332ac00 100644 --- a/compiler/rustc_middle/src/ty/sty.rs +++ b/compiler/rustc_middle/src/ty/sty.rs @@ -1872,6 +1872,10 @@ impl<'tcx> Region<'tcx> { /// Constructors for `Ty` impl<'tcx> Ty<'tcx> { + pub fn new_bool(tcx: TyCtxt<'tcx>) -> Ty<'tcx> { + Ty::new(tcx, TyKind::Bool) + } + // Avoid this in favour of more specific `new_*` methods, where possible. #[allow(rustc::usage_of_ty_tykind)] #[inline] diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 734321e97d892..a4c769be31966 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -95,6 +95,7 @@ mod required_consts; mod reveal_all; mod separate_const_switch; mod shim; +mod simplify_pow_of_two; mod ssa; // This pass is public to allow external drivers to perform MIR cleanup mod check_alignment; @@ -546,6 +547,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first &unreachable_prop::UnreachablePropagation, &uninhabited_enum_branching::UninhabitedEnumBranching, + &simplify_pow_of_two::SimplifyPowOfTwo, &o1(simplify::SimplifyCfg::AfterUninhabitedEnumBranching), &inline::Inline, &remove_storage_markers::RemoveStorageMarkers, diff --git a/compiler/rustc_mir_transform/src/simplify_pow_of_two.rs b/compiler/rustc_mir_transform/src/simplify_pow_of_two.rs new file mode 100644 index 0000000000000..a500b45651c88 --- /dev/null +++ b/compiler/rustc_mir_transform/src/simplify_pow_of_two.rs @@ -0,0 +1,212 @@ +//! A pass that checks for and simplifies calls to `pow` where the receiver is a power of +//! two. This can be done with `<<` instead. + +use crate::MirPass; +use rustc_const_eval::interpret::{ConstValue, Scalar}; +use rustc_middle::mir::patch::MirPatch; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt, UintTy}; +use rustc_span::sym; +use rustc_target::abi::FieldIdx; + +pub struct SimplifyPowOfTwo; + +impl<'tcx> MirPass<'tcx> for SimplifyPowOfTwo { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let mut patch = MirPatch::new(body); + + for (i, bb) in body.basic_blocks.iter_enumerated() { + let term = bb.terminator(); + let source_info = term.source_info; + let span = source_info.span; + + if let TerminatorKind::Call { + func, + args, + destination, + target: Some(target), + call_source: CallSource::Normal, + .. + } = &term.kind + && let Some(def_id) = func.const_fn_def().map(|def| def.0) + && let def_path = tcx.def_path(def_id) + && tcx.crate_name(def_path.krate) == sym::core + && let [recv, exp] = args.as_slice() + && let Some(recv_const) = recv.constant() + && let ConstantKind::Val( + ConstValue::Scalar(Scalar::Int(recv_int)), + recv_ty, + ) = recv_const.literal + && recv_ty.is_integral() + && tcx.item_name(def_id) == sym::pow + && let Ok(recv_val) = match recv_ty.kind() { + ty::Int(_) => { + let result = recv_int.try_to_int(recv_int.size()).unwrap_or(-1).max(0); + if result > 0 { + Ok(result as u128) + } else { + continue; + } + }, + ty::Uint(_) => recv_int.try_to_uint(recv_int.size()), + _ => continue, + } + && let power_used = f32::log2(recv_val as f32) + // Precision loss means it's not a power of two + && power_used == (power_used as u32) as f32 + // `0` would be `1.pow()`, which we shouldn't try to optimize as it's + // already entirely optimized away + && power_used != 0.0 + // `-inf` would be `0.pow()` + && power_used.is_finite() + { + let power_used = power_used as u32; + let loc = Location { block: i, statement_index: bb.statements.len() }; + let exp_ty = Ty::new(tcx, ty::Uint(UintTy::U32)); + let checked_mul = + patch.new_temp(Ty::new_tup(tcx, &[exp_ty, Ty::new_bool(tcx)]), span); + + // If this is not `2.pow(...)`, we need to multiply the number of times we + // shift the bits left by the receiver's power of two used, e.g.: + // + // > 2 -> 1 + // > 4 -> 2 + // > 16 -> 4 + // > 256 -> 8 + // + // If this is `1`, then we *could* remove this entirely but it'll be + // optimized out anyway by later passes (or perhaps LLVM) so it's entirely + // unnecessary to do so. + patch.add_assign( + loc, + checked_mul.into(), + Rvalue::CheckedBinaryOp( + BinOp::Mul, + Box::new(( + exp.clone(), + Operand::Constant(Box::new(Constant { + span, + user_ty: None, + literal: ConstantKind::Val( + ConstValue::from_u32(power_used), + exp_ty, + ), + })), + )), + ), + ); + + let num_shl = tcx.mk_place_field(checked_mul.into(), FieldIdx::from_u32(0), exp_ty); + let mul_result = tcx.mk_place_field( + checked_mul.into(), + FieldIdx::from_u32(1), + Ty::new_bool(tcx), + ); + let shl_result = patch.new_temp(Ty::new_bool(tcx), span); + + // Whether the shl will overflow, if so we return 0. We can do this rather + // than doing a shr because only one bit is set on any power of two + patch.add_assign( + loc, + shl_result.into(), + Rvalue::BinaryOp( + BinOp::Lt, + Box::new(( + Operand::Copy(num_shl), + Operand::Constant(Box::new(Constant { + span, + user_ty: None, + literal: ConstantKind::Val( + ConstValue::from_u32(recv_int.size().bits() as u32), + exp_ty, + ), + })), + )), + ), + ); + + let fine_bool = patch.new_temp(Ty::new_bool(tcx), span); + let fine = patch.new_temp(recv_ty, span); + + patch.add_assign( + loc, + fine_bool.into(), + Rvalue::BinaryOp( + BinOp::BitOr, + Box::new(( + Operand::Copy(mul_result.into()), + Operand::Copy(shl_result.into()), + )), + ), + ); + + patch.add_assign( + loc, + fine.into(), + Rvalue::Cast(CastKind::IntToInt, Operand::Copy(fine_bool.into()), recv_ty), + ); + + let shl = patch.new_temp(recv_ty, span); + + patch.add_assign( + loc, + shl.into(), + Rvalue::BinaryOp( + BinOp::Shl, + Box::new(( + Operand::Constant(Box::new(Constant { + span, + user_ty: None, + literal: ConstantKind::Val( + ConstValue::Scalar(Scalar::from_uint(1u128, recv_int.size())), + recv_ty, + ), + })), + Operand::Copy(num_shl.into()), + )), + ), + ); + + patch.add_assign( + loc, + *destination, + Rvalue::BinaryOp( + BinOp::MulUnchecked, + Box::new((Operand::Copy(shl.into()), Operand::Copy(fine.into()))), + ), + ); + + // FIXME(Centri3): Do we use `debug_assertions` or `overflow_checks` here? + if tcx.sess.opts.debug_assertions { + patch.patch_terminator( + i, + TerminatorKind::Assert { + cond: Operand::Copy(fine_bool.into()), + expected: true, + msg: Box::new(AssertMessage::Overflow( + // For consistency with the previous error message, though + // it's technically incorrect + BinOp::Mul, + Operand::Constant(Box::new(Constant { + span, + user_ty: None, + literal: ConstantKind::Val( + ConstValue::Scalar(Scalar::from_u32(1)), + exp_ty, + ), + })), + Operand::Copy(num_shl.into()), + )), + target: *target, + unwind: UnwindAction::Continue, + }, + ); + } else { + patch.patch_terminator(i, TerminatorKind::Goto { target: *target }); + } + } + } + + patch.apply(body); + } +} diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index d3739733c1d10..c6add72fde56c 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -1054,6 +1054,7 @@ symbols! { not, notable_trait, note, + num, object_safe_for_dispatch, of, offset, @@ -1121,6 +1122,7 @@ symbols! { poll, position, post_dash_lto: "post-lto", + pow, powerpc_target_feature, powf32, powf64, diff --git a/tests/codegen/simplify-pow-of-two-debug-assertions.rs b/tests/codegen/simplify-pow-of-two-debug-assertions.rs new file mode 100644 index 0000000000000..c1d2a0b2acd01 --- /dev/null +++ b/tests/codegen/simplify-pow-of-two-debug-assertions.rs @@ -0,0 +1,19 @@ +// compile-flags: -Copt-level=3 -Cdebug-assertions=true + +// CHECK-LABEL: @slow_2_u( +#[no_mangle] +fn slow_2_u(a: u32) -> u32 { + // CHECK: %_3 = icmp ult i32 %a, 32 + // CHECK-NEXT: br i1 %_3, label %bb1, label %panic, !prof !{{[0-9]+}} + // CHECK-EMPTY: + // CHECK-NEXT: bb1: + // CHECK-NEXT: %_01 = shl nuw i32 1, %a + // CHECK-NEXT: ret i32 %_0 + // CHECK-EMPTY: + // CHECK-NEXT: panic: + 2u32.pow(a) +} + +fn main() { + slow_2_u(2); +} diff --git a/tests/codegen/simplify-pow-of-two.rs b/tests/codegen/simplify-pow-of-two.rs new file mode 100644 index 0000000000000..657e4af8dbc10 --- /dev/null +++ b/tests/codegen/simplify-pow-of-two.rs @@ -0,0 +1,16 @@ +// compile-flags: -Copt-level=3 + +// CHECK-LABEL: @slow_2_u( +#[no_mangle] +fn slow_2_u(a: u32) -> u32 { + // CHECK: %_3 = icmp ult i32 %a, 32 + // CHECK-NEXT: %_5 = zext i1 %_3 to i32 + // CHECK-NEXT: %0 = and i32 %a, 31 + // CHECK-NEXT: %_01 = shl nuw i32 %_5, %0 + // CHECK-NEXT: ret i32 %_01 + 2u32.pow(a) +} + +fn main() { + slow_2_u(2); +} diff --git a/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.rs b/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.rs new file mode 100644 index 0000000000000..1315aa8a60107 --- /dev/null +++ b/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.rs @@ -0,0 +1,51 @@ +// unit-test: SimplifyPowOfTwo +// compile-flags: -Cdebug-assertions=false + +// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_2_u.SimplifyPowOfTwo.after.mir +fn slow_2_u(a: u32) -> u32 { + 2u32.pow(a) +} + +// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_2_i.SimplifyPowOfTwo.after.mir +fn slow_2_i(a: u32) -> i32 { + 2i32.pow(a) +} + +// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_4_u.SimplifyPowOfTwo.after.mir +fn slow_4_u(a: u32) -> u32 { + 4u32.pow(a) +} + +// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_4_i.SimplifyPowOfTwo.after.mir +fn slow_4_i(a: u32) -> i32 { + 4i32.pow(a) +} + +// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_256_u.SimplifyPowOfTwo.after.mir +fn slow_256_u(a: u32) -> u32 { + 256u32.pow(a) +} + +// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_256_i.SimplifyPowOfTwo.after.mir +fn slow_256_i(a: u32) -> i32 { + 256i32.pow(a) +} + +fn main() { + slow_2_u(0); + slow_2_i(0); + slow_2_u(1); + slow_2_i(1); + slow_2_u(2); + slow_2_i(2); + slow_4_u(4); + slow_4_i(4); + slow_4_u(15); + slow_4_i(15); + slow_4_u(16); + slow_4_i(16); + slow_4_u(17); + slow_4_i(17); + slow_256_u(2); + slow_256_i(2); +} diff --git a/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_256_i.SimplifyPowOfTwo.after.mir b/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_256_i.SimplifyPowOfTwo.after.mir new file mode 100644 index 0000000000000..23dc99e9d846a --- /dev/null +++ b/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_256_i.SimplifyPowOfTwo.after.mir @@ -0,0 +1,29 @@ +// MIR for `slow_256_i` after SimplifyPowOfTwo + +fn slow_256_i(_1: u32) -> i32 { + debug a => _1; + let mut _0: i32; + let mut _2: u32; + let mut _3: (u32, bool); + let mut _4: bool; + let mut _5: bool; + let mut _6: i32; + let mut _7: i32; + + bb0: { + StorageLive(_2); + _2 = _1; + _3 = CheckedMul(move _2, const 8_u32); + _4 = Lt((_3.0: u32), const 32_u32); + _5 = BitOr((_3.1: bool), _4); + _6 = _5 as i32 (IntToInt); + _7 = Shl(const 1_i32, (_3.0: u32)); + _0 = MulUnchecked(_7, _6); + goto -> bb1; + } + + bb1: { + StorageDead(_2); + return; + } +} diff --git a/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_256_u.SimplifyPowOfTwo.after.mir b/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_256_u.SimplifyPowOfTwo.after.mir new file mode 100644 index 0000000000000..49d3236f4ca49 --- /dev/null +++ b/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_256_u.SimplifyPowOfTwo.after.mir @@ -0,0 +1,29 @@ +// MIR for `slow_256_u` after SimplifyPowOfTwo + +fn slow_256_u(_1: u32) -> u32 { + debug a => _1; + let mut _0: u32; + let mut _2: u32; + let mut _3: (u32, bool); + let mut _4: bool; + let mut _5: bool; + let mut _6: u32; + let mut _7: u32; + + bb0: { + StorageLive(_2); + _2 = _1; + _3 = CheckedMul(move _2, const 8_u32); + _4 = Lt((_3.0: u32), const 32_u32); + _5 = BitOr((_3.1: bool), _4); + _6 = _5 as u32 (IntToInt); + _7 = Shl(const 1_u32, (_3.0: u32)); + _0 = MulUnchecked(_7, _6); + goto -> bb1; + } + + bb1: { + StorageDead(_2); + return; + } +} diff --git a/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_2_i.SimplifyPowOfTwo.after.mir b/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_2_i.SimplifyPowOfTwo.after.mir new file mode 100644 index 0000000000000..73c9e5607d22c --- /dev/null +++ b/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_2_i.SimplifyPowOfTwo.after.mir @@ -0,0 +1,29 @@ +// MIR for `slow_2_i` after SimplifyPowOfTwo + +fn slow_2_i(_1: u32) -> i32 { + debug a => _1; + let mut _0: i32; + let mut _2: u32; + let mut _3: (u32, bool); + let mut _4: bool; + let mut _5: bool; + let mut _6: i32; + let mut _7: i32; + + bb0: { + StorageLive(_2); + _2 = _1; + _3 = CheckedMul(move _2, const 1_u32); + _4 = Lt((_3.0: u32), const 32_u32); + _5 = BitOr((_3.1: bool), _4); + _6 = _5 as i32 (IntToInt); + _7 = Shl(const 1_i32, (_3.0: u32)); + _0 = MulUnchecked(_7, _6); + goto -> bb1; + } + + bb1: { + StorageDead(_2); + return; + } +} diff --git a/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_2_u.SimplifyPowOfTwo.after.mir b/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_2_u.SimplifyPowOfTwo.after.mir new file mode 100644 index 0000000000000..216e3909a0479 --- /dev/null +++ b/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_2_u.SimplifyPowOfTwo.after.mir @@ -0,0 +1,29 @@ +// MIR for `slow_2_u` after SimplifyPowOfTwo + +fn slow_2_u(_1: u32) -> u32 { + debug a => _1; + let mut _0: u32; + let mut _2: u32; + let mut _3: (u32, bool); + let mut _4: bool; + let mut _5: bool; + let mut _6: u32; + let mut _7: u32; + + bb0: { + StorageLive(_2); + _2 = _1; + _3 = CheckedMul(move _2, const 1_u32); + _4 = Lt((_3.0: u32), const 32_u32); + _5 = BitOr((_3.1: bool), _4); + _6 = _5 as u32 (IntToInt); + _7 = Shl(const 1_u32, (_3.0: u32)); + _0 = MulUnchecked(_7, _6); + goto -> bb1; + } + + bb1: { + StorageDead(_2); + return; + } +} diff --git a/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_4_i.SimplifyPowOfTwo.after.mir b/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_4_i.SimplifyPowOfTwo.after.mir new file mode 100644 index 0000000000000..bfe055d0be5c6 --- /dev/null +++ b/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_4_i.SimplifyPowOfTwo.after.mir @@ -0,0 +1,29 @@ +// MIR for `slow_4_i` after SimplifyPowOfTwo + +fn slow_4_i(_1: u32) -> i32 { + debug a => _1; + let mut _0: i32; + let mut _2: u32; + let mut _3: (u32, bool); + let mut _4: bool; + let mut _5: bool; + let mut _6: i32; + let mut _7: i32; + + bb0: { + StorageLive(_2); + _2 = _1; + _3 = CheckedMul(move _2, const 2_u32); + _4 = Lt((_3.0: u32), const 32_u32); + _5 = BitOr((_3.1: bool), _4); + _6 = _5 as i32 (IntToInt); + _7 = Shl(const 1_i32, (_3.0: u32)); + _0 = MulUnchecked(_7, _6); + goto -> bb1; + } + + bb1: { + StorageDead(_2); + return; + } +} diff --git a/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_4_u.SimplifyPowOfTwo.after.mir b/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_4_u.SimplifyPowOfTwo.after.mir new file mode 100644 index 0000000000000..238528a0b5275 --- /dev/null +++ b/tests/mir-opt/simplify_pow_of_two_no_overflow_checks.slow_4_u.SimplifyPowOfTwo.after.mir @@ -0,0 +1,29 @@ +// MIR for `slow_4_u` after SimplifyPowOfTwo + +fn slow_4_u(_1: u32) -> u32 { + debug a => _1; + let mut _0: u32; + let mut _2: u32; + let mut _3: (u32, bool); + let mut _4: bool; + let mut _5: bool; + let mut _6: u32; + let mut _7: u32; + + bb0: { + StorageLive(_2); + _2 = _1; + _3 = CheckedMul(move _2, const 2_u32); + _4 = Lt((_3.0: u32), const 32_u32); + _5 = BitOr((_3.1: bool), _4); + _6 = _5 as u32 (IntToInt); + _7 = Shl(const 1_u32, (_3.0: u32)); + _0 = MulUnchecked(_7, _6); + goto -> bb1; + } + + bb1: { + StorageDead(_2); + return; + } +} diff --git a/tests/mir-opt/simplify_pow_of_two_overflow_checks.rs b/tests/mir-opt/simplify_pow_of_two_overflow_checks.rs new file mode 100644 index 0000000000000..c8d219755c747 --- /dev/null +++ b/tests/mir-opt/simplify_pow_of_two_overflow_checks.rs @@ -0,0 +1,51 @@ +// unit-test: SimplifyPowOfTwo +// compile-flags: -Cdebug-assertions=true + +// EMIT_MIR simplify_pow_of_two_overflow_checks.slow_2_u.SimplifyPowOfTwo.after.mir +fn slow_2_u(a: u32) -> u32 { + 2u32.pow(a) +} + +// EMIT_MIR simplify_pow_of_two_overflow_checks.slow_2_i.SimplifyPowOfTwo.after.mir +fn slow_2_i(a: u32) -> i32 { + 2i32.pow(a) +} + +// EMIT_MIR simplify_pow_of_two_overflow_checks.slow_4_u.SimplifyPowOfTwo.after.mir +fn slow_4_u(a: u32) -> u32 { + 4u32.pow(a) +} + +// EMIT_MIR simplify_pow_of_two_overflow_checks.slow_4_i.SimplifyPowOfTwo.after.mir +fn slow_4_i(a: u32) -> i32 { + 4i32.pow(a) +} + +// EMIT_MIR simplify_pow_of_two_overflow_checks.slow_256_u.SimplifyPowOfTwo.after.mir +fn slow_256_u(a: u32) -> u32 { + 256u32.pow(a) +} + +// EMIT_MIR simplify_pow_of_two_overflow_checks.slow_256_i.SimplifyPowOfTwo.after.mir +fn slow_256_i(a: u32) -> i32 { + 256i32.pow(a) +} + +fn main() { + slow_2_u(0); + slow_2_i(0); + slow_2_u(1); + slow_2_i(1); + slow_2_u(2); + slow_2_i(2); + slow_4_u(4); + slow_4_i(4); + slow_4_u(15); + slow_4_i(15); + slow_4_u(16); + slow_4_i(16); + slow_4_u(17); + slow_4_i(17); + slow_256_u(2); + slow_256_i(2); +} diff --git a/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_256_i.SimplifyPowOfTwo.after.mir b/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_256_i.SimplifyPowOfTwo.after.mir new file mode 100644 index 0000000000000..e6b7a74bfbab5 --- /dev/null +++ b/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_256_i.SimplifyPowOfTwo.after.mir @@ -0,0 +1,29 @@ +// MIR for `slow_256_i` after SimplifyPowOfTwo + +fn slow_256_i(_1: u32) -> i32 { + debug a => _1; + let mut _0: i32; + let mut _2: u32; + let mut _3: (u32, bool); + let mut _4: bool; + let mut _5: bool; + let mut _6: i32; + let mut _7: i32; + + bb0: { + StorageLive(_2); + _2 = _1; + _3 = CheckedMul(move _2, const 8_u32); + _4 = Lt((_3.0: u32), const 32_u32); + _5 = BitOr((_3.1: bool), _4); + _6 = _5 as i32 (IntToInt); + _7 = Shl(const 1_i32, (_3.0: u32)); + _0 = MulUnchecked(_7, _6); + assert(_5, "attempt to compute `{} * {}`, which would overflow", const 1_u32, (_3.0: u32)) -> [success: bb1, unwind continue]; + } + + bb1: { + StorageDead(_2); + return; + } +} diff --git a/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_256_u.SimplifyPowOfTwo.after.mir b/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_256_u.SimplifyPowOfTwo.after.mir new file mode 100644 index 0000000000000..2c8921d0d2add --- /dev/null +++ b/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_256_u.SimplifyPowOfTwo.after.mir @@ -0,0 +1,29 @@ +// MIR for `slow_256_u` after SimplifyPowOfTwo + +fn slow_256_u(_1: u32) -> u32 { + debug a => _1; + let mut _0: u32; + let mut _2: u32; + let mut _3: (u32, bool); + let mut _4: bool; + let mut _5: bool; + let mut _6: u32; + let mut _7: u32; + + bb0: { + StorageLive(_2); + _2 = _1; + _3 = CheckedMul(move _2, const 8_u32); + _4 = Lt((_3.0: u32), const 32_u32); + _5 = BitOr((_3.1: bool), _4); + _6 = _5 as u32 (IntToInt); + _7 = Shl(const 1_u32, (_3.0: u32)); + _0 = MulUnchecked(_7, _6); + assert(_5, "attempt to compute `{} * {}`, which would overflow", const 1_u32, (_3.0: u32)) -> [success: bb1, unwind continue]; + } + + bb1: { + StorageDead(_2); + return; + } +} diff --git a/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_2_i.SimplifyPowOfTwo.after.mir b/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_2_i.SimplifyPowOfTwo.after.mir new file mode 100644 index 0000000000000..f525bcb8309ba --- /dev/null +++ b/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_2_i.SimplifyPowOfTwo.after.mir @@ -0,0 +1,29 @@ +// MIR for `slow_2_i` after SimplifyPowOfTwo + +fn slow_2_i(_1: u32) -> i32 { + debug a => _1; + let mut _0: i32; + let mut _2: u32; + let mut _3: (u32, bool); + let mut _4: bool; + let mut _5: bool; + let mut _6: i32; + let mut _7: i32; + + bb0: { + StorageLive(_2); + _2 = _1; + _3 = CheckedMul(move _2, const 1_u32); + _4 = Lt((_3.0: u32), const 32_u32); + _5 = BitOr((_3.1: bool), _4); + _6 = _5 as i32 (IntToInt); + _7 = Shl(const 1_i32, (_3.0: u32)); + _0 = MulUnchecked(_7, _6); + assert(_5, "attempt to compute `{} * {}`, which would overflow", const 1_u32, (_3.0: u32)) -> [success: bb1, unwind continue]; + } + + bb1: { + StorageDead(_2); + return; + } +} diff --git a/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_2_u.SimplifyPowOfTwo.after.mir b/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_2_u.SimplifyPowOfTwo.after.mir new file mode 100644 index 0000000000000..905b2797b3ff2 --- /dev/null +++ b/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_2_u.SimplifyPowOfTwo.after.mir @@ -0,0 +1,29 @@ +// MIR for `slow_2_u` after SimplifyPowOfTwo + +fn slow_2_u(_1: u32) -> u32 { + debug a => _1; + let mut _0: u32; + let mut _2: u32; + let mut _3: (u32, bool); + let mut _4: bool; + let mut _5: bool; + let mut _6: u32; + let mut _7: u32; + + bb0: { + StorageLive(_2); + _2 = _1; + _3 = CheckedMul(move _2, const 1_u32); + _4 = Lt((_3.0: u32), const 32_u32); + _5 = BitOr((_3.1: bool), _4); + _6 = _5 as u32 (IntToInt); + _7 = Shl(const 1_u32, (_3.0: u32)); + _0 = MulUnchecked(_7, _6); + assert(_5, "attempt to compute `{} * {}`, which would overflow", const 1_u32, (_3.0: u32)) -> [success: bb1, unwind continue]; + } + + bb1: { + StorageDead(_2); + return; + } +} diff --git a/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_4_i.SimplifyPowOfTwo.after.mir b/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_4_i.SimplifyPowOfTwo.after.mir new file mode 100644 index 0000000000000..388d52e3abfc9 --- /dev/null +++ b/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_4_i.SimplifyPowOfTwo.after.mir @@ -0,0 +1,29 @@ +// MIR for `slow_4_i` after SimplifyPowOfTwo + +fn slow_4_i(_1: u32) -> i32 { + debug a => _1; + let mut _0: i32; + let mut _2: u32; + let mut _3: (u32, bool); + let mut _4: bool; + let mut _5: bool; + let mut _6: i32; + let mut _7: i32; + + bb0: { + StorageLive(_2); + _2 = _1; + _3 = CheckedMul(move _2, const 2_u32); + _4 = Lt((_3.0: u32), const 32_u32); + _5 = BitOr((_3.1: bool), _4); + _6 = _5 as i32 (IntToInt); + _7 = Shl(const 1_i32, (_3.0: u32)); + _0 = MulUnchecked(_7, _6); + assert(_5, "attempt to compute `{} * {}`, which would overflow", const 1_u32, (_3.0: u32)) -> [success: bb1, unwind continue]; + } + + bb1: { + StorageDead(_2); + return; + } +} diff --git a/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_4_u.SimplifyPowOfTwo.after.mir b/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_4_u.SimplifyPowOfTwo.after.mir new file mode 100644 index 0000000000000..a4536affbc5c4 --- /dev/null +++ b/tests/mir-opt/simplify_pow_of_two_overflow_checks.slow_4_u.SimplifyPowOfTwo.after.mir @@ -0,0 +1,29 @@ +// MIR for `slow_4_u` after SimplifyPowOfTwo + +fn slow_4_u(_1: u32) -> u32 { + debug a => _1; + let mut _0: u32; + let mut _2: u32; + let mut _3: (u32, bool); + let mut _4: bool; + let mut _5: bool; + let mut _6: u32; + let mut _7: u32; + + bb0: { + StorageLive(_2); + _2 = _1; + _3 = CheckedMul(move _2, const 2_u32); + _4 = Lt((_3.0: u32), const 32_u32); + _5 = BitOr((_3.1: bool), _4); + _6 = _5 as u32 (IntToInt); + _7 = Shl(const 1_u32, (_3.0: u32)); + _0 = MulUnchecked(_7, _6); + assert(_5, "attempt to compute `{} * {}`, which would overflow", const 1_u32, (_3.0: u32)) -> [success: bb1, unwind continue]; + } + + bb1: { + StorageDead(_2); + return; + } +} diff --git a/tests/ui/mir/simplify-pow-of-two.rs b/tests/ui/mir/simplify-pow-of-two.rs new file mode 100644 index 0000000000000..70df62717ea11 --- /dev/null +++ b/tests/ui/mir/simplify-pow-of-two.rs @@ -0,0 +1,78 @@ +// run-pass + +#![feature(core_intrinsics)] +#![feature(custom_mir)] + +use std::intrinsics::mir::*; + +fn slow_2_u(a: u32) -> u32 { + 2u32.pow(a) +} + +fn slow_2_i(a: u32) -> i32 { + 2i32.pow(a) +} + +fn slow_4_u(a: u32) -> u32 { + 4u32.pow(a) +} + +fn slow_4_i(a: u32) -> i32 { + 4i32.pow(a) +} + +fn slow_256_u(a: u32) -> u32 { + 256u32.pow(a) +} + +fn slow_256_i(a: u32) -> i32 { + 256i32.pow(a) +} + +fn main() { + assert_eq!(slow_2_u(0), 1); + assert_eq!(slow_2_i(0), 1); + assert_eq!(slow_2_u(1), 2); + assert_eq!(slow_2_i(1), 2); + assert_eq!(slow_2_u(2), 4); + assert_eq!(slow_2_i(2), 4); + assert_eq!(slow_4_u(4), 256); + assert_eq!(slow_4_i(4), 256); + assert_eq!(slow_4_u(15), 1073741824); + assert_eq!(slow_4_i(15), 1073741824); + assert_eq!(slow_4_u(16), 0); + assert_eq!(slow_4_i(16), 0); + assert_eq!(slow_4_u(17), 0); + assert_eq!(slow_4_i(17), 0); + assert_eq!(slow_256_u(2), 65536); + assert_eq!(slow_256_i(2), 65536); + + for i in 0..300 { + for j in 0..3000 { + let ix = 2u128.pow(i); + assert_eq!(ix.pow(j), test_mir(i, j), "{ix}, {j}"); + } + } +} + +/// num is the power used to get recv, will be calculated while building this +/// MIR but it's necessary here for testing +/// +/// You can test this out in the playground here: +/// https://play.rust-lang.org/?version=nightly&mode=release&edition=2021&gist=de34e2a6a8f9114ce01bfb62f9379413 +/// +/// This equals `ix.pow(j)` both with and without this optimization. +#[custom_mir(dialect = "built")] +pub fn test_mir(num: u32, exp: u32) -> u128 { + mir! { + { + let num_shl = Checked(exp * num); + let shl_result = num_shl.0 < 128; + let shl = 1u128 << num_shl.0; + let fine_bool = shl_result | num_shl.1; + let fine = fine_bool as u128; + RET = shl * fine; + Return() + } + } +}