diff --git a/src/index.rs b/src/index.rs index a6114572..0c1aac2c 100644 --- a/src/index.rs +++ b/src/index.rs @@ -11,7 +11,11 @@ use crate::{ table::{key::TableKey, SIZE_TIERS_BITS}, Key, }; -use std::convert::TryInto; +#[cfg(target_arch = "x86")] +use std::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use std::arch::x86_64::*; +use std::{cmp::max, convert::TryInto}; // Index chunk consists of 8 64-bit entries. const CHUNK_LEN: usize = CHUNK_ENTRIES * ENTRY_BYTES; // 512 bytes @@ -231,8 +235,56 @@ impl IndexTable { Ok(try_io!(Ok(&map[offset..offset + CHUNK_LEN]))) } - #[inline(never)] fn find_entry(&self, key_prefix: u64, sub_index: usize, chunk: &[u8]) -> (Entry, usize) { + if cfg!(target_feature = "sse2") { + self.find_entry_sse2(key_prefix, sub_index, chunk) + } else { + self.find_entry_base(key_prefix, sub_index, chunk) + } + } + + #[cfg(target_feature = "sse2")] + fn find_entry_sse2(&self, key_prefix: u64, sub_index: usize, chunk: &[u8]) -> (Entry, usize) { + assert!(chunk.len() >= CHUNK_ENTRIES * 8); // Bound checking (not done by SIMD instructions) + const _: () = assert!( + CHUNK_ENTRIES % 4 == 0, + "We assume here we got buffer with a number of elements that is a multiple of 4" + ); + + let shift = max(32, Entry::address_bits(self.id.index_bits())); + unsafe { + let target = _mm_set1_epi32(((key_prefix << self.id.index_bits()) >> shift) as i32); + let shift_mask = _mm_set_epi64x(0, shift.into()); + let mut i = (sub_index >> 2) << 2; // We keep an alignment of 4 + while i + 4 <= CHUNK_ENTRIES { + // We load the value 2 by 2 + // Then we remove the address by shifting such that the partial key is in the low + // part + let first_two = _mm_shuffle_epi32::<0b11011000>(_mm_srl_epi64( + _mm_loadu_si128(chunk[i * 8..].as_ptr() as *const __m128i), + shift_mask, + )); + let last_two = _mm_shuffle_epi32::<0b11011000>(_mm_srl_epi64( + _mm_loadu_si128(chunk[(i + 2) * 8..].as_ptr() as *const __m128i), + shift_mask, + )); + // We set into current the input low parts + let current = _mm_unpacklo_epi64(first_two, last_two); + let cmp = _mm_movemask_epi8(_mm_cmpeq_epi32(current, target)); + if cmp != 0 { + let position = i + (cmp.trailing_zeros() as usize) / 4; + if position >= sub_index { + // We need to check we are not reading again the same input + return (Self::read_entry(chunk, position), position) + } + } + i += 4; + } + } + (Entry::empty(), 0) + } + + fn find_entry_base(&self, key_prefix: u64, sub_index: usize, chunk: &[u8]) -> (Entry, usize) { assert!(chunk.len() >= CHUNK_ENTRIES * 8); let partial_key = Entry::extract_key(key_prefix, self.id.index_bits()); for i in sub_index..CHUNK_ENTRIES { @@ -532,6 +584,7 @@ impl IndexTable { #[cfg(test)] mod test { use super::*; + use std::path::PathBuf; #[test] fn test_entries() { @@ -552,4 +605,37 @@ mod test { assert!(IndexTable::transmute_chunk(chunk2) == chunk); } + + #[test] + fn test_find_entries() { + let partial_keys = [1, 1 << 10, 1 << 20]; + for index_bits in [16, 18, 20, 22] { + let index_table = IndexTable { + id: TableId(index_bits.into()), + map: RwLock::new(None), + path: PathBuf::new(), + }; + + let data_address = Address::from_u64((1 << index_bits) - 1); + + let mut chunk = [0; CHUNK_ENTRIES * 8]; + for (i, partial_key) in partial_keys.iter().enumerate() { + chunk[i * 8..(i + 1) * 8].copy_from_slice( + &Entry::new(data_address, *partial_key, index_bits).as_u64().to_le_bytes(), + ); + } + + for partial_key in &partial_keys { + let key_prefix = *partial_key << (CHUNK_ENTRIES_BITS + SIZE_TIERS_BITS); + assert_eq!( + index_table.find_entry_sse2(key_prefix, 0, &chunk).0.partial_key(index_bits), + *partial_key + ); + assert_eq!( + index_table.find_entry_base(key_prefix, 0, &chunk).0.partial_key(index_bits), + *partial_key + ); + } + } + } }