diff --git a/crates/interpreter/src/instructions/bitwise.rs b/crates/interpreter/src/instructions/bitwise.rs index 586af76ce3..62edf11c5a 100644 --- a/crates/interpreter/src/instructions/bitwise.rs +++ b/crates/interpreter/src/instructions/bitwise.rs @@ -1,11 +1,10 @@ -use super::i256::{i256_cmp, i256_sign_compl, two_compl, Sign}; +use super::i256::i256_cmp; use crate::{ gas, primitives::{Spec, U256}, Host, Interpreter, }; use core::cmp::Ordering; -use revm_primitives::uint; pub fn lt(interpreter: &mut Interpreter, _host: &mut H) { gas!(interpreter, gas::VERYLOW); @@ -85,7 +84,12 @@ pub fn shl(interpreter: &mut Interpreter, _host: & check!(interpreter, CONSTANTINOPLE); gas!(interpreter, gas::VERYLOW); pop_top!(interpreter, op1, op2); - *op2 <<= as_usize_saturated!(op1); + let shift = as_usize_saturated!(op1); + *op2 = if shift < 256 { + *op2 << shift + } else { + U256::ZERO + } } /// EIP-145: Bitwise shifting instructions in EVM @@ -93,7 +97,12 @@ pub fn shr(interpreter: &mut Interpreter, _host: & check!(interpreter, CONSTANTINOPLE); gas!(interpreter, gas::VERYLOW); pop_top!(interpreter, op1, op2); - *op2 >>= as_usize_saturated!(op1); + let shift = as_usize_saturated!(op1); + *op2 = if shift < 256 { + *op2 >> shift + } else { + U256::ZERO + } } /// EIP-145: Bitwise shifting instructions in EVM @@ -102,33 +111,32 @@ pub fn sar(interpreter: &mut Interpreter, _host: & gas!(interpreter, gas::VERYLOW); pop_top!(interpreter, op1, op2); - let value_sign = i256_sign_compl(op2); - - // If the shift count is 255+, we can short-circuit. This is because shifting by 255 bits is the - // maximum shift that still leaves 1 bit in the original 256-bit number. Shifting by 256 bits or - // more would mean that no original bits remain. The result depends on what the highest bit of - // the value is. - *op2 = if value_sign == Sign::Zero || op1 >= U256::from(255) { - match value_sign { - // value is 0 or >=1, pushing 0 - Sign::Plus | Sign::Zero => U256::ZERO, - // value is <0, pushing -1 - Sign::Minus => U256::MAX, + let shift = as_usize_saturated!(op1); + *op2 = if shift >= 256 { + // If the shift is 256 or more, the result depends on the sign of the last bit. + if op2.bit(255) { + U256::MAX // Negative number, all bits set to one. + } else { + U256::ZERO // Non-negative number, all bits set to zero. } } else { - const ONE: U256 = uint!(1_U256); - // SAFETY: shift count is checked above; it's less than 255. - let shift = usize::try_from(op1).unwrap(); - match value_sign { - Sign::Plus | Sign::Zero => op2.wrapping_shr(shift), - Sign::Minus => two_compl(op2.wrapping_sub(ONE).wrapping_shr(shift).wrapping_add(ONE)), + // Normal shift + if op2.bit(255) { + // Check the most significant bit. + // Arithmetic right shift for negative numbers. + let shifted_value = *op2 >> shift; + let mask = U256::MAX << (256 - shift); // Mask for the sign bits. + shifted_value | mask // Apply the mask to simulate the filling of sign bits. + } else { + // Logical right shift for non-negative numbers. + *op2 >> shift } }; } #[cfg(test)] mod tests { - use crate::instructions::bitwise::{sar, shl, shr}; + use crate::instructions::bitwise::{byte, sar, shl, shr}; use crate::{Contract, DummyHost, Interpreter}; use revm_primitives::{uint, Env, LatestSpec, U256}; @@ -399,4 +407,39 @@ mod tests { assert_eq!(res, test.expected); } } + + #[test] + fn test_byte() { + struct TestCase { + input: U256, + index: usize, + expected: U256, + } + + let mut host = DummyHost::new(Env::default()); + let mut interpreter = Interpreter::new(Contract::default(), u64::MAX, false); + + let input_value = U256::from(0x1234567890abcdef1234567890abcdef_u128); + let test_cases = (0..32) + .map(|i| { + let byte_pos = 31 - i; + + let shift_amount = U256::from(byte_pos * 8); + let byte_value = (input_value >> shift_amount) & U256::from(0xFF); + TestCase { + input: input_value, + index: i, + expected: byte_value, + } + }) + .collect::>(); + + for test in test_cases.iter() { + push!(interpreter, test.input); + push!(interpreter, U256::from(test.index)); + byte(&mut interpreter, &mut host); + pop!(interpreter, res); + assert_eq!(res, test.expected, "Failed at index: {}", test.index); + } + } }