diff --git a/crates/precompile/bench/blake2.rs b/crates/precompile/bench/blake2.rs index b6e0621870..053552fad4 100644 --- a/crates/precompile/bench/blake2.rs +++ b/crates/precompile/bench/blake2.rs @@ -100,14 +100,14 @@ pub fn add_benches(group: &mut BenchmarkGroup<'_, criterion::measurement::WallTi 0x1f83d9abfb41bd6bu64, 0x5be0cd19137e2179u64, ]; - let m = [0u8; 128]; + let m = [0u64; 16]; let t = [0u64, 0u64]; b.iter(|| { let mut h_copy = h; blake2::algo::compress( black_box(12), &mut h_copy, - black_box(&m), + black_box(m), black_box(t), black_box(false), ); diff --git a/crates/precompile/src/blake2.rs b/crates/precompile/src/blake2.rs index 3996b894a7..6c446152a9 100644 --- a/crates/precompile/src/blake2.rs +++ b/crates/precompile/src/blake2.rs @@ -16,44 +16,43 @@ pub fn run(input: &[u8], gas_limit: u64) -> PrecompileResult { return Err(PrecompileError::Blake2WrongLength); } - // Rounds 4 bytes + // Parse number of rounds (4 bytes) let rounds = u32::from_be_bytes(input[..4].try_into().unwrap()) as usize; let gas_used = rounds as u64 * F_ROUND; if gas_used > gas_limit { return Err(PrecompileError::OutOfGas); } + // Parse final block flag let f = match input[212] { - 1 => true, 0 => false, + 1 => true, _ => return Err(PrecompileError::Blake2WrongFinalIndicatorFlag), }; + // Parse state vector h (8 × u64) let mut h = [0u64; 8]; - //let mut m = [0u64; 16]; - - let t; - // Optimized parsing using ptr::read_unaligned for potentially better performance - - let m; - unsafe { - let ptr = input.as_ptr(); - - // Read h values - for (i, item) in h.iter_mut().enumerate() { - *item = u64::from_le_bytes(core::ptr::read_unaligned( - ptr.add(4 + i * 8) as *const [u8; 8] - )); - } - - m = input[68..68 + 16 * size_of::()].try_into().unwrap(); - - t = [ - u64::from_le_bytes(core::ptr::read_unaligned(ptr.add(196) as *const [u8; 8])), - u64::from_le_bytes(core::ptr::read_unaligned(ptr.add(204) as *const [u8; 8])), - ]; - } - algo::compress(rounds, &mut h, m, t, f); + input[4..68] + .chunks_exact(8) + .enumerate() + .for_each(|(i, chunk)| { + h[i] = u64::from_le_bytes(chunk.try_into().unwrap()); + }); + + // Parse message block m (16 × u64) + let mut m = [0u64; 16]; + input[68..196] + .chunks_exact(8) + .enumerate() + .for_each(|(i, chunk)| { + m[i] = u64::from_le_bytes(chunk.try_into().unwrap()); + }); + + // Parse offset counters + let t_0 = u64::from_le_bytes(input[196..204].try_into().unwrap()); + let t_1 = u64::from_le_bytes(input[204..212].try_into().unwrap()); + + algo::compress(rounds, &mut h, m, [t_0, t_1], f); let mut out = [0u8; 64]; for (i, h) in (0..64).step_by(8).zip(h.iter()) { @@ -94,22 +93,26 @@ pub mod algo { #[inline(always)] #[allow(clippy::many_single_char_names)] /// G function: - pub fn g(v: &mut [u64], a: usize, b: usize, c: usize, d: usize, x: u64, y: u64) { - v[a] = v[a].wrapping_add(v[b]); - v[a] = v[a].wrapping_add(x); - v[d] ^= v[a]; - v[d] = v[d].rotate_right(32); - v[c] = v[c].wrapping_add(v[d]); - v[b] ^= v[c]; - v[b] = v[b].rotate_right(24); - - v[a] = v[a].wrapping_add(v[b]); - v[a] = v[a].wrapping_add(y); - v[d] ^= v[a]; - v[d] = v[d].rotate_right(16); - v[c] = v[c].wrapping_add(v[d]); - v[b] ^= v[c]; - v[b] = v[b].rotate_right(63); + fn g(v: &mut [u64; 16], a: usize, b: usize, c: usize, d: usize, x: u64, y: u64) { + let mut va = v[a]; + let mut vb = v[b]; + let mut vc = v[c]; + let mut vd = v[d]; + + va = va.wrapping_add(vb).wrapping_add(x); + vd = (vd ^ va).rotate_right(32); + vc = vc.wrapping_add(vd); + vb = (vb ^ vc).rotate_right(24); + + va = va.wrapping_add(vb).wrapping_add(y); + vd = (vd ^ va).rotate_right(16); + vc = vc.wrapping_add(vd); + vb = (vb ^ vc).rotate_right(63); + + v[a] = va; + v[b] = vb; + v[c] = vc; + v[d] = vd; } /// Compression function F takes as an argument the state vector "h", @@ -119,15 +122,7 @@ pub mod algo { /// returns a new state vector. The number of rounds, "r", is 12 for /// BLAKE2b and 10 for BLAKE2s. Rounds are numbered from 0 to r - 1. #[allow(clippy::many_single_char_names)] - pub fn compress( - rounds: usize, - h: &mut [u64; 8], - m_slice: &[u8; 16 * size_of::()], - t: [u64; 2], - f: bool, - ) { - assert!(m_slice.len() == 16 * size_of::()); - + pub fn compress(rounds: usize, h: &mut [u64; 8], m: [u64; 16], t: [u64; 2], f: bool) { #[cfg(all(target_feature = "avx2", feature = "std"))] { // only if it is compiled with avx2 flag and it is std, we can use avx2. @@ -136,7 +131,7 @@ pub mod algo { unsafe { super::avx2::compress_block( rounds, - m_slice, + &m, h, ((t[1] as u128) << 64) | (t[0] as u128), if f { !0 } else { 0 }, @@ -149,14 +144,6 @@ pub mod algo { // if avx2 is not available, use the fallback portable implementation - // Read m values - let mut m = [0u64; 16]; - for (i, item) in m.iter_mut().enumerate() { - *item = u64::from_le_bytes(unsafe { - core::ptr::read_unaligned(m_slice.as_ptr().add(i * 8) as *const [u8; 8]) - }); - } - let mut v = [0u64; 16]; v[..h.len()].copy_from_slice(h); // First half from state. v[h.len()..].copy_from_slice(&IV); // Second half from IV. @@ -224,7 +211,7 @@ mod avx2 { #[inline(always)] pub(crate) unsafe fn compress_block( mut rounds: usize, - block: &[u8; BLOCKBYTES], + block: &[Word; 16], words: &mut [Word; 8], count: Count, last_block: Word, @@ -238,6 +225,7 @@ mod avx2 { let flags = set4(count_low(count), count_high(count), last_block, last_node); let mut d = xor(loadu(iv_high), flags); + let block: &[u8; BLOCKBYTES] = std::mem::transmute(block); let msg_chunks = array_refs!(block, 16, 16, 16, 16, 16, 16, 16, 16); let m0 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.0)); let m1 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.1));