diff --git a/rust/arrow/Cargo.toml b/rust/arrow/Cargo.toml index a29583e5550..fe1baa90109 100644 --- a/rust/arrow/Cargo.toml +++ b/rust/arrow/Cargo.toml @@ -133,3 +133,7 @@ harness = false [[bench]] name = "concatenate_kernel" harness = false + +[[bench]] +name = "mutable_array" +harness = false diff --git a/rust/arrow/benches/filter_kernels.rs b/rust/arrow/benches/filter_kernels.rs index 1348238b074..a7f4e405570 100644 --- a/rust/arrow/benches/filter_kernels.rs +++ b/rust/arrow/benches/filter_kernels.rs @@ -14,128 +14,144 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +extern crate arrow; + +use arrow::{compute::Filter, util::test_util::seedable_rng}; +use rand::{ + distributions::{Alphanumeric, Standard}, + prelude::Distribution, + Rng, +}; use arrow::array::*; -use arrow::compute::{filter, FilterContext}; +use arrow::compute::{build_filter, filter}; use arrow::datatypes::ArrowNumericType; +use arrow::datatypes::{Float32Type, UInt8Type}; + use criterion::{criterion_group, criterion_main, Criterion}; -fn create_primitive_array(size: usize, value_fn: F) -> PrimitiveArray +fn create_primitive_array(size: usize, null_density: f32) -> PrimitiveArray where T: ArrowNumericType, - F: Fn(usize) -> T::Native, + Standard: Distribution, { + // use random numbers to avoid spurious compiler optimizations wrt to branching + let mut rng = seedable_rng(); let mut builder = PrimitiveArray::::builder(size); - for i in 0..size { - builder.append_value(value_fn(i)).unwrap(); + + for _ in 0..size { + if rng.gen::() < null_density { + builder.append_null().unwrap(); + } else { + builder.append_value(rng.gen()).unwrap(); + } } builder.finish() } -fn create_u8_array_with_nulls(size: usize) -> UInt8Array { - let mut builder = UInt8Builder::new(size); - for i in 0..size { - if i % 2 == 0 { - builder.append_value(1).unwrap(); - } else { +fn create_string_array(size: usize, null_density: f32) -> StringArray { + // use random numbers to avoid spurious compiler optimizations wrt to branching + let mut rng = seedable_rng(); + let mut builder = StringBuilder::new(size); + + for _ in 0..size { + if rng.gen::() < null_density { builder.append_null().unwrap(); + } else { + let value = (&mut rng) + .sample_iter(&Alphanumeric) + .take(10) + .collect::(); + builder.append_value(&value).unwrap(); } } builder.finish() } -fn create_bool_array(size: usize, value_fn: F) -> BooleanArray -where - F: Fn(usize) -> bool, -{ +fn create_bool_array(size: usize, trues_density: f32) -> BooleanArray { + let mut rng = seedable_rng(); let mut builder = BooleanBuilder::new(size); - for i in 0..size { - builder.append_value(value_fn(i)).unwrap(); + for _ in 0..size { + let value = rng.gen::() < trues_density; + builder.append_value(value).unwrap(); } builder.finish() } -fn bench_filter_u8(data_array: &UInt8Array, filter_array: &BooleanArray) { - filter( - criterion::black_box(data_array), - criterion::black_box(filter_array), - ) - .unwrap(); -} - -// fn bench_filter_f32(data_array: &Float32Array, filter_array: &BooleanArray) { -// filter(criterion::black_box(data_array), criterion::black_box(filter_array)).unwrap(); -// } - -fn bench_filter_context_u8(data_array: &UInt8Array, filter_context: &FilterContext) { - filter_context - .filter(criterion::black_box(data_array)) - .unwrap(); +fn bench_filter(data_array: &dyn Array, filter_array: &BooleanArray) { + criterion::black_box(filter(data_array, filter_array).unwrap()); } -fn bench_filter_context_f32(data_array: &Float32Array, filter_context: &FilterContext) { - filter_context - .filter(criterion::black_box(data_array)) - .unwrap(); +fn bench_built_filter<'a>(filter: &Filter<'a>, data: &impl Array) { + criterion::black_box(filter(&data.data())); } fn add_benchmark(c: &mut Criterion) { let size = 65536; - let filter_array = create_bool_array(size, |i| matches!(i % 2, 0)); - let sparse_filter_array = create_bool_array(size, |i| matches!(i % 8000, 0)); - let dense_filter_array = create_bool_array(size, |i| !matches!(i % 8000, 0)); + let filter_array = create_bool_array(size, 0.5); + let dense_filter_array = create_bool_array(size, 1.0 - 1.0 / 1024.0); + let sparse_filter_array = create_bool_array(size, 1.0 / 1024.0); - let filter_context = FilterContext::new(&filter_array).unwrap(); - let sparse_filter_context = FilterContext::new(&sparse_filter_array).unwrap(); - let dense_filter_context = FilterContext::new(&dense_filter_array).unwrap(); + let filter = build_filter(&filter_array).unwrap(); + let dense_filter = build_filter(&dense_filter_array).unwrap(); + let sparse_filter = build_filter(&sparse_filter_array).unwrap(); - let data_array = create_primitive_array(size, |i| match i % 2 { - 0 => 1, - _ => 0, - }); - c.bench_function("filter u8 low selectivity", |b| { - b.iter(|| bench_filter_u8(&data_array, &filter_array)) + let data_array = create_primitive_array::(size, 0.0); + + c.bench_function("filter u8", |b| { + b.iter(|| bench_filter(&data_array, &filter_array)) }); c.bench_function("filter u8 high selectivity", |b| { - b.iter(|| bench_filter_u8(&data_array, &sparse_filter_array)) + b.iter(|| bench_filter(&data_array, &dense_filter_array)) }); - c.bench_function("filter u8 very low selectivity", |b| { - b.iter(|| bench_filter_u8(&data_array, &dense_filter_array)) + c.bench_function("filter u8 low selectivity", |b| { + b.iter(|| bench_filter(&data_array, &sparse_filter_array)) }); - c.bench_function("filter context u8 low selectivity", |b| { - b.iter(|| bench_filter_context_u8(&data_array, &filter_context)) + c.bench_function("filter context u8", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) }); c.bench_function("filter context u8 high selectivity", |b| { - b.iter(|| bench_filter_context_u8(&data_array, &sparse_filter_context)) + b.iter(|| bench_built_filter(&dense_filter, &data_array)) }); - c.bench_function("filter context u8 very low selectivity", |b| { - b.iter(|| bench_filter_context_u8(&data_array, &dense_filter_context)) + c.bench_function("filter context u8 low selectivity", |b| { + b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); - let data_array = create_u8_array_with_nulls(size); - c.bench_function("filter context u8 w NULLs low selectivity", |b| { - b.iter(|| bench_filter_context_u8(&data_array, &filter_context)) + let data_array = create_primitive_array::(size, 0.5); + c.bench_function("filter context u8 w NULLs", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) }); c.bench_function("filter context u8 w NULLs high selectivity", |b| { - b.iter(|| bench_filter_context_u8(&data_array, &sparse_filter_context)) + b.iter(|| bench_built_filter(&dense_filter, &data_array)) }); - c.bench_function("filter context u8 w NULLs very low selectivity", |b| { - b.iter(|| bench_filter_context_u8(&data_array, &dense_filter_context)) + c.bench_function("filter context u8 w NULLs low selectivity", |b| { + b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); - let data_array = create_primitive_array(size, |i| match i % 2 { - 0 => 1.0, - _ => 0.0, + let data_array = create_primitive_array::(size, 0.5); + c.bench_function("filter f32", |b| { + b.iter(|| bench_filter(&data_array, &filter_array)) }); - c.bench_function("filter context f32 low selectivity", |b| { - b.iter(|| bench_filter_context_f32(&data_array, &filter_context)) + c.bench_function("filter context f32", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) }); c.bench_function("filter context f32 high selectivity", |b| { - b.iter(|| bench_filter_context_f32(&data_array, &sparse_filter_context)) + b.iter(|| bench_built_filter(&dense_filter, &data_array)) + }); + c.bench_function("filter context f32 low selectivity", |b| { + b.iter(|| bench_built_filter(&sparse_filter, &data_array)) + }); + + let data_array = create_string_array(size, 0.5); + c.bench_function("filter context string", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) + }); + c.bench_function("filter context string high selectivity", |b| { + b.iter(|| bench_built_filter(&dense_filter, &data_array)) }); - c.bench_function("filter context f32 very low selectivity", |b| { - b.iter(|| bench_filter_context_f32(&data_array, &dense_filter_context)) + c.bench_function("filter context string low selectivity", |b| { + b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); } diff --git a/rust/arrow/benches/mutable_array.rs b/rust/arrow/benches/mutable_array.rs new file mode 100644 index 00000000000..df067169a39 --- /dev/null +++ b/rust/arrow/benches/mutable_array.rs @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +use criterion::Criterion; + +use rand::distributions::Alphanumeric; +use rand::Rng; + +use std::sync::Arc; + +extern crate arrow; + +use arrow::array::*; +use arrow::util::test_util::seedable_rng; + +fn create_strings(size: usize, null_density: f32) -> ArrayRef { + let rng = &mut seedable_rng(); + + let mut builder = StringBuilder::new(size); + for _ in 0..size { + let x = rng.gen::(); + if x < null_density { + let value = rng.sample_iter(&Alphanumeric).take(4).collect::(); + builder.append_value(&value).unwrap(); + } else { + builder.append_null().unwrap() + } + } + Arc::new(builder.finish()) +} + +fn create_slices(size: usize) -> Vec<(usize, usize)> { + let rng = &mut seedable_rng(); + + (0..size) + .map(|_| { + let start = rng.gen_range(0, size / 2); + let end = rng.gen_range(start + 1, size); + (start, end) + }) + .collect() +} + +fn bench(v1: &ArrayRef, slices: &[(usize, usize)]) { + let mut mutable = MutableArrayData::new(vec![v1.data_ref()], false, 5); + for (start, end) in slices { + mutable.extend(0, *start, *end) + } + mutable.freeze(); +} + +fn add_benchmark(c: &mut Criterion) { + let v1 = create_strings(1024, 0.0); + let v2 = create_slices(1024); + c.bench_function("mutable str 1024", |b| b.iter(|| bench(&v1, &v2))); + + let v1 = create_strings(1024, 0.5); + let v2 = create_slices(1024); + c.bench_function("mutable str nulls 1024", |b| b.iter(|| bench(&v1, &v2))); +} + +criterion_group!(benches, add_benchmark); +criterion_main!(benches); diff --git a/rust/arrow/src/array/transform/boolean.rs b/rust/arrow/src/array/transform/boolean.rs index 660d4cde310..cfe485b7b70 100644 --- a/rust/arrow/src/array/transform/boolean.rs +++ b/rust/arrow/src/array/transform/boolean.rs @@ -26,7 +26,7 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { let values = array.buffers()[0].data(); Box::new( move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { - let buffer = &mut mutable.buffers[0]; + let buffer = &mut mutable.buffer1; reserve_for_bits(buffer, mutable.len + len); set_bits( &mut buffer.data_mut(), @@ -40,6 +40,6 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { } pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { - let buffer = &mut mutable.buffers[0]; + let buffer = &mut mutable.buffer1; reserve_for_bits(buffer, mutable.len + len); } diff --git a/rust/arrow/src/array/transform/fixed_binary.rs b/rust/arrow/src/array/transform/fixed_binary.rs index 84cef62ef95..8899113ede7 100644 --- a/rust/arrow/src/array/transform/fixed_binary.rs +++ b/rust/arrow/src/array/transform/fixed_binary.rs @@ -30,7 +30,7 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { // fast case where we can copy regions without null issues Box::new( move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { - let buffer = &mut mutable.buffers[0]; + let buffer = &mut mutable.buffer1; buffer.extend_from_slice(&values[start * size..(start + len) * size]); }, ) @@ -38,7 +38,7 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { Box::new( move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { // nulls present: append item by item, ignoring null entries - let values_buffer = &mut mutable.buffers[0]; + let values_buffer = &mut mutable.buffer1; (start..start + len).for_each(|i| { if array.is_valid(i) { @@ -60,6 +60,6 @@ pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { _ => unreachable!(), }; - let values_buffer = &mut mutable.buffers[0]; + let values_buffer = &mut mutable.buffer1; values_buffer.extend(len * size); } diff --git a/rust/arrow/src/array/transform/list.rs b/rust/arrow/src/array/transform/list.rs index 43c7287fcf2..300afa94bf4 100644 --- a/rust/arrow/src/array/transform/list.rs +++ b/rust/arrow/src/array/transform/list.rs @@ -20,7 +20,10 @@ use crate::{ datatypes::ToByteSlice, }; -use super::{Extend, _MutableArrayData, utils::extend_offsets}; +use super::{ + Extend, _MutableArrayData, + utils::{extend_offsets, get_last_offset}, +}; pub(super) fn build_extend(array: &ArrayData) -> Extend { let offsets = array.buffer::(0); @@ -31,11 +34,14 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { index: usize, start: usize, len: usize| { - let mutable_offsets = mutable.buffer::(0); - let last_offset = mutable_offsets[mutable_offsets.len() - 1]; + let offset_buffer = &mut mutable.buffer1; + + // this is safe due to how offset is built. See details on `get_last_offset` + let last_offset: T = unsafe { get_last_offset(offset_buffer) }; + // offsets extend_offsets::( - &mut mutable.buffers[0], + offset_buffer, last_offset, &offsets[start..start + len + 1], ); @@ -54,12 +60,14 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { index: usize, start: usize, len: usize| { - let mutable_offsets = mutable.buffer::(0); - let mut last_offset = mutable_offsets[mutable_offsets.len() - 1]; + let offset_buffer = &mut mutable.buffer1; + + // this is safe due to how offset is built. See details on `get_last_offset` + let mut last_offset: T = unsafe { get_last_offset(offset_buffer) }; - let buffer = &mut mutable.buffers[0]; let delta_len = array.len() - array.null_count(); - buffer.reserve(buffer.len() + delta_len * std::mem::size_of::()); + offset_buffer + .reserve(offset_buffer.len() + delta_len * std::mem::size_of::()); let child = &mut mutable.child_data[0]; (start..start + len).for_each(|i| { @@ -75,7 +83,7 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { ); } // append offset - buffer.extend_from_slice(last_offset.to_byte_slice()); + offset_buffer.extend_from_slice(last_offset.to_byte_slice()); }) }, ) @@ -86,10 +94,10 @@ pub(super) fn extend_nulls( mutable: &mut _MutableArrayData, len: usize, ) { - let mutable_offsets = mutable.buffer::(0); - let last_offset = mutable_offsets[mutable_offsets.len() - 1]; + let offset_buffer = &mut mutable.buffer1; - let offset_buffer = &mut mutable.buffers[0]; + // this is safe due to how offset is built. See details on `get_last_offset` + let last_offset: T = unsafe { get_last_offset(offset_buffer) }; let offsets = vec![last_offset; len]; offset_buffer.extend_from_slice(offsets.to_byte_slice()); diff --git a/rust/arrow/src/array/transform/mod.rs b/rust/arrow/src/array/transform/mod.rs index 3fe76d63f54..28be14eee7b 100644 --- a/rust/arrow/src/array/transform/mod.rs +++ b/rust/arrow/src/array/transform/mod.rs @@ -46,16 +46,23 @@ struct _MutableArrayData<'a> { pub len: usize, pub null_buffer: MutableBuffer, - pub buffers: Vec, + // arrow specification only allows up to 3 buffers (2 ignoring the nulls above). + // Thus, we place them in the stack to avoid bound checks and greater data locality. + pub buffer1: MutableBuffer, + pub buffer2: MutableBuffer, pub child_data: Vec>, } impl<'a> _MutableArrayData<'a> { fn freeze(self, dictionary: Option) -> ArrayData { - let mut buffers = Vec::with_capacity(self.buffers.len()); - for buffer in self.buffers { - buffers.push(buffer.freeze()); - } + let buffers = match self.data_type { + DataType::Struct(_) => vec![], + DataType::Utf8 + | DataType::Binary + | DataType::LargeUtf8 + | DataType::LargeBinary => vec![self.buffer1.freeze(), self.buffer2.freeze()], + _ => vec![self.buffer1.freeze()], + }; let child_data = match self.data_type { DataType::Dictionary(_, _) => vec![dictionary.unwrap()], @@ -81,18 +88,6 @@ impl<'a> _MutableArrayData<'a> { child_data, ) } - - /// Returns the buffer `buffer` as a slice of type `T`. When the expected buffer is bit-packed, - /// the slice is not offset. - #[inline] - pub(super) fn buffer(&self, buffer: usize) -> &[T] { - let values = unsafe { self.buffers[buffer].data().align_to::() }; - if !values.0.is_empty() || !values.2.is_empty() { - // this is unreachable because - unreachable!("The buffer is not byte-aligned with its interpretation") - }; - &values.1 - } } fn build_extend_null_bits(array: &ArrayData, use_nulls: bool) -> ExtendNullBits { @@ -225,7 +220,6 @@ fn build_extend(array: &ArrayData) -> Extend { /* DataType::Null => {} DataType::FixedSizeList(_, _) => {} - DataType::Struct(_) => {} DataType::Union(_) => {} */ _ => todo!("Take and filter operations still not supported for this datatype"), @@ -298,75 +292,132 @@ impl<'a> MutableArrayData<'a> { use_nulls = true; }; - let buffers = match &data_type { + let empty_buffer = MutableBuffer::new(0); + let [buffer1, buffer2] = match &data_type { DataType::Boolean => { let bytes = bit_util::ceil(capacity, 8); let buffer = MutableBuffer::new(bytes).with_bitset(bytes, false); - vec![buffer] + [buffer, empty_buffer] } - DataType::UInt8 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::UInt16 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::UInt32 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::UInt64 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::Int8 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::Int16 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::Int32 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::Int64 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::Float32 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::Float64 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::Date32(_) | DataType::Time32(_) => { - vec![MutableBuffer::new(capacity * size_of::())] + DataType::UInt8 => { + [MutableBuffer::new(capacity * size_of::()), empty_buffer] } + DataType::UInt16 => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::UInt32 => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::UInt64 => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::Int8 => { + [MutableBuffer::new(capacity * size_of::()), empty_buffer] + } + DataType::Int16 => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::Int32 => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::Int64 => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::Float32 => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::Float64 => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::Date32(_) | DataType::Time32(_) => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], DataType::Date64(_) | DataType::Time64(_) | DataType::Duration(_) - | DataType::Timestamp(_, _) => { - vec![MutableBuffer::new(capacity * size_of::())] - } - DataType::Interval(IntervalUnit::YearMonth) => { - vec![MutableBuffer::new(capacity * size_of::())] - } - DataType::Interval(IntervalUnit::DayTime) => { - vec![MutableBuffer::new(capacity * size_of::())] - } + | DataType::Timestamp(_, _) => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::Interval(IntervalUnit::YearMonth) => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::Interval(IntervalUnit::DayTime) => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], DataType::Utf8 | DataType::Binary => { let mut buffer = MutableBuffer::new((1 + capacity) * size_of::()); + // safety: `unsafe` code assumes that this buffer is initialized with one element buffer.extend_from_slice(&[0i32].to_byte_slice()); - vec![buffer, MutableBuffer::new(capacity * size_of::())] + [buffer, MutableBuffer::new(capacity * size_of::())] } DataType::LargeUtf8 | DataType::LargeBinary => { let mut buffer = MutableBuffer::new((1 + capacity) * size_of::()); + // safety: `unsafe` code assumes that this buffer is initialized with one element buffer.extend_from_slice(&[0i64].to_byte_slice()); - vec![buffer, MutableBuffer::new(capacity * size_of::())] + [buffer, MutableBuffer::new(capacity * size_of::())] } DataType::List(_) => { // offset buffer always starts with a zero let mut buffer = MutableBuffer::new((1 + capacity) * size_of::()); buffer.extend_from_slice(0i32.to_byte_slice()); - vec![buffer] + [buffer, empty_buffer] } DataType::LargeList(_) => { // offset buffer always starts with a zero let mut buffer = MutableBuffer::new((1 + capacity) * size_of::()); buffer.extend_from_slice(&[0i64].to_byte_slice()); - vec![buffer] + [buffer, empty_buffer] } DataType::FixedSizeBinary(size) => { - vec![MutableBuffer::new(capacity * *size as usize)] + [MutableBuffer::new(capacity * *size as usize), empty_buffer] } DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { - DataType::UInt8 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::UInt16 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::UInt32 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::UInt64 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::Int8 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::Int16 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::Int32 => vec![MutableBuffer::new(capacity * size_of::())], - DataType::Int64 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::UInt8 => { + [MutableBuffer::new(capacity * size_of::()), empty_buffer] + } + DataType::UInt16 => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::UInt32 => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::UInt64 => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::Int8 => { + [MutableBuffer::new(capacity * size_of::()), empty_buffer] + } + DataType::Int16 => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::Int32 => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], + DataType::Int64 => [ + MutableBuffer::new(capacity * size_of::()), + empty_buffer, + ], _ => unreachable!(), }, DataType::Float16 => unreachable!(), - DataType::Struct(_) => vec![], + DataType::Struct(_) => [empty_buffer, MutableBuffer::new(0)], _ => { todo!("Take and filter operations still not supported for this datatype") } @@ -443,7 +494,8 @@ impl<'a> MutableArrayData<'a> { len: 0, null_count: 0, null_buffer, - buffers, + buffer1, + buffer2, child_data, }; Self { diff --git a/rust/arrow/src/array/transform/primitive.rs b/rust/arrow/src/array/transform/primitive.rs index a00ae4e91fe..01bbd1a4788 100644 --- a/rust/arrow/src/array/transform/primitive.rs +++ b/rust/arrow/src/array/transform/primitive.rs @@ -28,8 +28,7 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { let start = start * size_of::(); let len = len * size_of::(); let bytes = &values[start..start + len]; - let buffer = &mut mutable.buffers[0]; - buffer.extend_from_slice(bytes); + mutable.buffer1.extend_from_slice(bytes); }, ) } @@ -38,7 +37,5 @@ pub(super) fn extend_nulls( mutable: &mut _MutableArrayData, len: usize, ) { - let buffer = &mut mutable.buffers[0]; - let bytes = vec![0u8; len * size_of::()]; - buffer.extend_from_slice(&bytes); + mutable.buffer1.extend(len * size_of::()); } diff --git a/rust/arrow/src/array/transform/utils.rs b/rust/arrow/src/array/transform/utils.rs index df9ce2453be..933ec0da1c6 100644 --- a/rust/arrow/src/array/transform/utils.rs +++ b/rust/arrow/src/array/transform/utils.rs @@ -61,3 +61,18 @@ pub(super) fn extend_offsets( buffer.extend_from_slice(last_offset.to_byte_slice()); }); } + +#[inline] +pub(super) unsafe fn get_last_offset( + offset_buffer: &MutableBuffer, +) -> T { + // JUSTIFICATION + // Benefit + // 20% performance improvement extend of variable sized arrays (see bench `mutable_array`) + // Soundness + // * offset buffer is always extended in slices of T and aligned accordingly. + // * Buffer[0] is initialized with one element, 0, and thus `mutable_offsets.len() - 1` is always valid. + let (prefix, offsets, suffix) = offset_buffer.data().align_to::(); + debug_assert!(prefix.is_empty() && suffix.is_empty()); + *offsets.get_unchecked(offsets.len() - 1) +} diff --git a/rust/arrow/src/array/transform/variable_size.rs b/rust/arrow/src/array/transform/variable_size.rs index 6735c8471ff..3a18b6fe5ee 100644 --- a/rust/arrow/src/array/transform/variable_size.rs +++ b/rust/arrow/src/array/transform/variable_size.rs @@ -21,8 +21,12 @@ use crate::{ datatypes::ToByteSlice, }; -use super::{Extend, _MutableArrayData, utils::extend_offsets}; +use super::{ + Extend, _MutableArrayData, + utils::{extend_offsets, get_last_offset}, +}; +#[inline] fn extend_offset_values( buffer: &mut MutableBuffer, offsets: &[T], @@ -43,33 +47,33 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { // fast case where we can copy regions without null issues Box::new( move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { - let mutable_offsets = mutable.buffer::(0); - let last_offset = mutable_offsets[mutable_offsets.len() - 1]; - // offsets - let buffer = &mut mutable.buffers[0]; + let offset_buffer = &mut mutable.buffer1; + let values_buffer = &mut mutable.buffer2; + + // this is safe due to how offset is built. See details on `get_last_offset` + let last_offset = unsafe { get_last_offset(offset_buffer) }; + extend_offsets::( - buffer, + offset_buffer, last_offset, &offsets[start..start + len + 1], ); // values - let buffer = &mut mutable.buffers[1]; - extend_offset_values::(buffer, offsets, values, start, len); + extend_offset_values::(values_buffer, offsets, values, start, len); }, ) } else { Box::new( move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { - let mutable_offsets = mutable.buffer::(0); - let mut last_offset = mutable_offsets[mutable_offsets.len() - 1]; + let offset_buffer = &mut mutable.buffer1; + let values_buffer = &mut mutable.buffer2; + + // this is safe due to how offset is built. See details on `get_last_offset` + let mut last_offset: T = unsafe { get_last_offset(offset_buffer) }; // nulls present: append item by item, ignoring null entries - let (offset_buffer, values_buffer) = mutable.buffers.split_at_mut(1); - let offset_buffer = &mut offset_buffer[0]; - let values_buffer = &mut values_buffer[0]; - offset_buffer.reserve( - offset_buffer.len() + array.len() * std::mem::size_of::(), - ); + offset_buffer + .reserve(offset_buffer.len() + len * std::mem::size_of::()); (start..start + len).for_each(|i| { if array.is_valid(i) { @@ -96,10 +100,10 @@ pub(super) fn extend_nulls( mutable: &mut _MutableArrayData, len: usize, ) { - let mutable_offsets = mutable.buffer::(0); - let last_offset = mutable_offsets[mutable_offsets.len() - 1]; + let offset_buffer = &mut mutable.buffer1; - let offset_buffer = &mut mutable.buffers[0]; + // this is safe due to how offset is built. See details on `get_last_offset` + let last_offset: T = unsafe { get_last_offset(offset_buffer) }; let offsets = vec![last_offset; len]; offset_buffer.extend_from_slice(offsets.to_byte_slice()); diff --git a/rust/arrow/src/buffer.rs b/rust/arrow/src/buffer.rs index 8af93960efa..2cd79b070ca 100644 --- a/rust/arrow/src/buffer.rs +++ b/rust/arrow/src/buffer.rs @@ -183,7 +183,7 @@ impl Buffer { /// in larger chunks and starting at arbitrary bit offsets. /// Note that both `offset` and `length` are measured in bits. pub fn bit_chunks(&self, offset: usize, len: usize) -> BitChunks { - BitChunks::new(&self, offset, len) + BitChunks::new(&self.data.as_slice()[self.offset..], offset, len) } /// Returns the number of 1-bits in this buffer. diff --git a/rust/arrow/src/compute/kernels/filter.rs b/rust/arrow/src/compute/kernels/filter.rs index 31d3a1a18ae..c0f6299f0d1 100644 --- a/rust/arrow/src/compute/kernels/filter.rs +++ b/rust/arrow/src/compute/kernels/filter.rs @@ -17,841 +17,243 @@ //! Defines miscellaneous array kernels. -use crate::array::*; -use crate::datatypes::*; -use crate::error::{ArrowError, Result}; +use crate::error::Result; use crate::record_batch::RecordBatch; -use crate::{ - bitmap::Bitmap, - buffer::{Buffer, MutableBuffer}, - util::bit_util, -}; -use std::{mem, sync::Arc}; - -/// trait for copying filtered null bitmap bits -trait CopyNullBit { - fn copy_null_bit(&mut self, source_index: usize); - fn copy_null_bits(&mut self, source_index: usize, count: usize); - fn null_count(&self) -> usize; - fn null_buffer(&mut self) -> Buffer; -} - -/// no-op null bitmap copy implementation, -/// used when the filtered data array doesn't have a null bitmap -struct NullBitNoop {} - -impl NullBitNoop { - fn new() -> Self { - NullBitNoop {} - } -} - -impl CopyNullBit for NullBitNoop { - #[inline] - fn copy_null_bit(&mut self, _source_index: usize) { - // do nothing - } - - #[inline] - fn copy_null_bits(&mut self, _source_index: usize, _count: usize) { - // do nothing - } - - fn null_count(&self) -> usize { - 0 - } - - fn null_buffer(&mut self) -> Buffer { - Buffer::from([0u8; 0]) - } +use crate::{array::*, util::bit_chunk_iterator::BitChunkIterator}; +use std::{iter::Enumerate, sync::Arc}; + +/// Function that can filter arbitrary arrays +pub type Filter<'a> = Box ArrayData + 'a>; + +/// Internal state of [SlicesIterator] +#[derive(Debug, PartialEq)] +enum State { + // it is iterating over bits of a mask (`u64`, steps of size of 1 slot) + Bits(u64), + // it is iterating over chunks (steps of size of 64 slots) + Chunks, + // it is iterating over the remainding bits (steps of size of 1 slot) + Remainder, + // nothing more to iterate. + Finish, } -/// null bitmap copy implementation, -/// used when the filtered data array has a null bitmap -struct NullBitSetter<'a> { - target_buffer: MutableBuffer, - source_bytes: &'a [u8], - target_index: usize, - null_count: usize, +/// An iterator of `(usize, usize)` each representing an interval `[start,end[` whose +/// slots of a [BooleanArray] are true. Each interval corresponds to a contiguous region of memory to be +/// "taken" from an array to be filtered. +#[derive(Debug)] +struct SlicesIterator<'a> { + iter: Enumerate>, + state: State, + filter_count: usize, + remainder_mask: u64, + remainder_len: usize, + chunk_len: usize, + len: usize, + start: usize, + on_region: bool, + current_chunk: usize, + current_bit: usize, } -impl<'a> NullBitSetter<'a> { - fn new(null_bitmap: &'a Bitmap) -> Self { - let null_bytes = null_bitmap.buffer_ref().data(); - // create null bitmap buffer with same length and initialize null bitmap buffer to 1s - let null_buffer = - MutableBuffer::new(null_bytes.len()).with_bitset(null_bytes.len(), true); - NullBitSetter { - source_bytes: null_bytes, - target_buffer: null_buffer, - target_index: 0, - null_count: 0, +impl<'a> SlicesIterator<'a> { + fn new(filter: &'a BooleanArray) -> Self { + let values = &filter.data_ref().buffers()[0]; + + // this operation is performed before iteration + // because it is fast and allows reserving all the needed memory + let filter_count = values.count_set_bits_offset(filter.offset(), filter.len()); + + let chunks = values.bit_chunks(filter.offset(), filter.len()); + + Self { + iter: chunks.iter().enumerate(), + state: State::Chunks, + filter_count, + remainder_len: chunks.remainder_len(), + chunk_len: chunks.chunk_len(), + remainder_mask: chunks.remainder_bits(), + len: 0, + start: 0, + on_region: false, + current_chunk: 0, + current_bit: 0, } } -} -impl<'a> CopyNullBit for NullBitSetter<'a> { #[inline] - fn copy_null_bit(&mut self, source_index: usize) { - if !bit_util::get_bit(self.source_bytes, source_index) { - bit_util::unset_bit(self.target_buffer.data_mut(), self.target_index); - self.null_count += 1; - } - self.target_index += 1; + fn current_start(&self) -> usize { + self.current_chunk * 64 + self.current_bit } #[inline] - fn copy_null_bits(&mut self, source_index: usize, count: usize) { - for i in 0..count { - self.copy_null_bit(source_index + i); + fn iterate_bits(&mut self, mask: u64, max: usize) -> Option<(usize, usize)> { + while self.current_bit < max { + if (mask & (1 << self.current_bit)) != 0 { + if !self.on_region { + self.start = self.current_start(); + self.on_region = true; + } + self.len += 1; + } else if self.on_region { + let result = (self.start, self.start + self.len); + self.len = 0; + self.on_region = false; + self.current_bit += 1; + return Some(result); + } + self.current_bit += 1; } + self.current_bit = 0; + None } - fn null_count(&self) -> usize { - self.null_count - } - - fn null_buffer(&mut self) -> Buffer { - self.target_buffer.resize(self.target_index); - // use mem::replace to detach self.target_buffer from self so that it can be returned - let target_buffer = mem::replace(&mut self.target_buffer, MutableBuffer::new(0)); - target_buffer.freeze() - } -} - -fn get_null_bit_setter<'a>(data_array: &'a impl Array) -> Box { - if let Some(null_bitmap) = data_array.data_ref().null_bitmap() { - // only return an actual null bit copy implementation if null_bitmap is set - Box::new(NullBitSetter::new(null_bitmap)) - } else { - // otherwise return a no-op copy null bit implementation - // for improved performance when the filtered array doesn't contain NULLs - Box::new(NullBitNoop::new()) - } -} - -// transmute filter array to u64 -// - optimize filtering with highly selective filters by skipping entire batches of 64 filter bits -// - if the data array being filtered doesn't have a null bitmap, no time is wasted to copy a null bitmap -fn filter_array_impl( - filter_context: &FilterContext, - data_array: &impl Array, - array_type: DataType, - value_size: usize, -) -> Result { - if filter_context.filter_len > data_array.len() { - return Err(ArrowError::ComputeError( - "Filter array cannot be larger than data array".to_string(), - )); - } - let filtered_count = filter_context.filtered_count; - let filter_mask = &filter_context.filter_mask; - let filter_u64 = &filter_context.filter_u64; - let data_bytes = data_array.data_ref().buffers()[0].data(); - let mut target_buffer = MutableBuffer::new(filtered_count * value_size); - target_buffer.resize(filtered_count * value_size); - let target_bytes = target_buffer.data_mut(); - let mut target_byte_index: usize = 0; - let mut null_bit_setter = get_null_bit_setter(data_array); - let null_bit_setter = null_bit_setter.as_mut(); - let all_ones_batch = !0u64; - let data_array_offset = data_array.offset(); - - for (i, filter_batch) in filter_u64.iter().enumerate() { - // foreach u64 batch - let filter_batch = *filter_batch; - if filter_batch == 0 { - // if batch == 0, all items are filtered out, so skip entire batch - continue; - } else if filter_batch == all_ones_batch { - // if batch == all 1s: copy all 64 values in one go - let data_index = (i * 64) + data_array_offset; - null_bit_setter.copy_null_bits(data_index, 64); - let data_byte_index = data_index * value_size; - let data_len = value_size * 64; - target_bytes[target_byte_index..(target_byte_index + data_len)] - .copy_from_slice( - &data_bytes[data_byte_index..(data_byte_index + data_len)], - ); - target_byte_index += data_len; - continue; - } - for (j, filter_mask) in filter_mask.iter().enumerate() { - // foreach bit in batch: - if (filter_batch & *filter_mask) != 0 { - let data_index = (i * 64) + j + data_array_offset; - null_bit_setter.copy_null_bit(data_index); - // if filter bit == 1: copy data value bytes - let data_byte_index = data_index * value_size; - target_bytes[target_byte_index..(target_byte_index + value_size)] - .copy_from_slice( - &data_bytes[data_byte_index..(data_byte_index + value_size)], - ); - target_byte_index += value_size; + /// iterates over chunks. + #[inline] + fn iterate_chunks(&mut self) -> Option<(usize, usize)> { + while let Some((i, mask)) = self.iter.next() { + self.current_chunk = i; + if mask == 0 { + if self.on_region { + let result = (self.start, self.start + self.len); + self.len = 0; + self.on_region = false; + return Some(result); + } + } else if mask == 18446744073709551615u64 { + // = !0u64 + if !self.on_region { + self.start = self.current_start(); + self.on_region = true; + } + self.len += 64; + } else { + // there is a chunk that has a non-trivial mask => iterate over bits. + self.state = State::Bits(mask); + return None; } } + // no more chunks => start iterating over the remainder + self.current_chunk = self.chunk_len; + self.state = State::Remainder; + None } - - let mut array_data_builder = ArrayDataBuilder::new(array_type) - .len(filtered_count) - .add_buffer(target_buffer.freeze()); - if null_bit_setter.null_count() > 0 { - array_data_builder = array_data_builder - .null_count(null_bit_setter.null_count()) - .null_bit_buffer(null_bit_setter.null_buffer()); - } - - Ok(array_data_builder) -} - -/// FilterContext can be used to improve performance when -/// filtering multiple data arrays with the same filter array. -#[derive(Debug)] -pub struct FilterContext { - filter_u64: Vec, - filter_len: usize, - filtered_count: usize, - filter_mask: Vec, -} - -macro_rules! filter_primitive_array { - ($context:expr, $array:expr, $array_type:ident) => {{ - let input_array = $array.as_any().downcast_ref::<$array_type>().unwrap(); - let output_array = $context.filter_primitive_array(input_array)?; - Ok(Arc::new(output_array)) - }}; -} - -macro_rules! filter_dictionary_array { - ($context:expr, $array:expr, $array_type:ident) => {{ - let input_array = $array.as_any().downcast_ref::<$array_type>().unwrap(); - let output_array = $context.filter_dictionary_array(input_array)?; - Ok(Arc::new(output_array)) - }}; } -macro_rules! filter_boolean_item_list_array { - ($context:expr, $array:expr, $list_type:ident, $list_builder_type:ident) => {{ - let input_array = $array.as_any().downcast_ref::<$list_type>().unwrap(); - let values_builder = BooleanBuilder::new($context.filtered_count); - let mut builder = $list_builder_type::new(values_builder); - for i in 0..$context.filter_u64.len() { - // foreach u64 batch - let filter_batch = $context.filter_u64[i]; - if filter_batch == 0 { - // if batch == 0, all items are filtered out, so skip entire batch - continue; - } - for j in 0..64 { - // foreach bit in batch: - if (filter_batch & $context.filter_mask[j]) != 0 { - let data_index = (i * 64) + j; - if input_array.is_null(data_index) { - builder.append(false)?; - } else { - let this_inner_list = input_array.value(data_index); - let inner_list = this_inner_list - .as_any() - .downcast_ref::() - .unwrap(); - for k in 0..inner_list.len() { - if inner_list.is_null(k) { - builder.values().append_null()?; - } else { - builder.values().append_value(inner_list.value(k))?; - } - } - builder.append(true)?; +impl<'a> Iterator for SlicesIterator<'a> { + type Item = (usize, usize); + + fn next(&mut self) -> Option { + match self.state { + State::Chunks => { + match self.iterate_chunks() { + None => { + // iterating over chunks does not yield any new slice => continue to the next + self.current_bit = 0; + self.next() } + other => other, } } - } - Ok(Arc::new(builder.finish())) - }}; -} - -macro_rules! filter_primitive_item_list_array { - ($context:expr, $array:expr, $item_type:ident, $list_type:ident, $list_builder_type:ident) => {{ - let input_array = $array.as_any().downcast_ref::<$list_type>().unwrap(); - let values_builder = PrimitiveBuilder::<$item_type>::new($context.filtered_count); - let mut builder = $list_builder_type::new(values_builder); - for i in 0..$context.filter_u64.len() { - // foreach u64 batch - let filter_batch = $context.filter_u64[i]; - if filter_batch == 0 { - // if batch == 0, all items are filtered out, so skip entire batch - continue; - } - for j in 0..64 { - // foreach bit in batch: - if (filter_batch & $context.filter_mask[j]) != 0 { - let data_index = (i * 64) + j; - if input_array.is_null(data_index) { - builder.append(false)?; - } else { - let this_inner_list = input_array.value(data_index); - let inner_list = this_inner_list - .as_any() - .downcast_ref::>() - .unwrap(); - for k in 0..inner_list.len() { - if inner_list.is_null(k) { - builder.values().append_null()?; - } else { - builder.values().append_value(inner_list.value(k))?; - } - } - builder.append(true)?; + State::Bits(mask) => { + match self.iterate_bits(mask, 64) { + None => { + // iterating over bits does not yield any new slice => change back + // to chunks and continue to the next + self.state = State::Chunks; + self.next() } + other => other, } } - } - Ok(Arc::new(builder.finish())) - }}; -} - -macro_rules! filter_non_primitive_item_list_array { - ($context:expr, $array:expr, $item_array_type:ident, $item_builder:ident, $list_type:ident, $list_builder_type:ident) => {{ - let input_array = $array.as_any().downcast_ref::<$list_type>().unwrap(); - let values_builder = $item_builder::new($context.filtered_count); - let mut builder = $list_builder_type::new(values_builder); - for i in 0..$context.filter_u64.len() { - // foreach u64 batch - let filter_batch = $context.filter_u64[i]; - if filter_batch == 0 { - // if batch == 0, all items are filtered out, so skip entire batch - continue; - } - for j in 0..64 { - // foreach bit in batch: - if (filter_batch & $context.filter_mask[j]) != 0 { - let data_index = (i * 64) + j; - if input_array.is_null(data_index) { - builder.append(false)?; - } else { - let this_inner_list = input_array.value(data_index); - let inner_list = this_inner_list - .as_any() - .downcast_ref::<$item_array_type>() - .unwrap(); - for k in 0..inner_list.len() { - if inner_list.is_null(k) { - builder.values().append_null()?; - } else { - builder.values().append_value(inner_list.value(k))?; - } + State::Remainder => { + match self.iterate_bits(self.remainder_mask, self.remainder_len) { + None => { + self.state = State::Finish; + if self.on_region { + Some((self.start, self.start + self.len)) + } else { + None } - builder.append(true)?; } + other => other, } } + State::Finish => None, } - Ok(Arc::new(builder.finish())) - }}; -} - -impl FilterContext { - /// Returns a new instance of FilterContext - pub fn new(filter_array: &BooleanArray) -> Result { - if filter_array.offset() > 0 { - return Err(ArrowError::ComputeError( - "Filter array cannot have offset > 0".to_string(), - )); - } - let filter_mask: Vec = (0..64).map(|x| 1u64 << x).collect(); - let filter_buffer = &filter_array.data_ref().buffers()[0]; - let filtered_count = filter_buffer.count_set_bits_offset(0, filter_array.len()); - - let filter_bytes = filter_buffer.data(); - - // add to the resulting len so is is a multiple of the size of u64 - let pad_addional_len = 8 - filter_bytes.len() % 8; - - // transmute filter_bytes to &[u64] - let mut u64_buffer = MutableBuffer::new(filter_bytes.len() + pad_addional_len); - - u64_buffer.extend_from_slice(filter_bytes); - u64_buffer.extend_from_slice(&vec![0; pad_addional_len]); - let mut filter_u64 = u64_buffer.typed_data_mut::().to_owned(); - - // mask of any bits outside of the given len - if filter_array.len() % 64 != 0 { - let last_idx = filter_u64.len() - 1; - let mask = u64::MAX >> (64 - filter_array.len() % 64); - filter_u64[last_idx] &= mask; - } - - Ok(FilterContext { - filter_u64, - filter_len: filter_array.len(), - filtered_count, - filter_mask, - }) - } - - /// Returns a new array, containing only the elements matching the filter - pub fn filter(&self, array: &Array) -> Result { - match array.data_type() { - DataType::UInt8 => filter_primitive_array!(self, array, UInt8Array), - DataType::UInt16 => filter_primitive_array!(self, array, UInt16Array), - DataType::UInt32 => filter_primitive_array!(self, array, UInt32Array), - DataType::UInt64 => filter_primitive_array!(self, array, UInt64Array), - DataType::Int8 => filter_primitive_array!(self, array, Int8Array), - DataType::Int16 => filter_primitive_array!(self, array, Int16Array), - DataType::Int32 => filter_primitive_array!(self, array, Int32Array), - DataType::Int64 => filter_primitive_array!(self, array, Int64Array), - DataType::Float32 => filter_primitive_array!(self, array, Float32Array), - DataType::Float64 => filter_primitive_array!(self, array, Float64Array), - DataType::Boolean => { - let input_array = array.as_any().downcast_ref::().unwrap(); - let mut builder = BooleanArray::builder(self.filtered_count); - for i in 0..self.filter_u64.len() { - // foreach u64 batch - let filter_batch = self.filter_u64[i]; - if filter_batch == 0 { - // if batch == 0, all items are filtered out, so skip entire batch - continue; - } - for j in 0..64 { - // foreach bit in batch: - if (filter_batch & self.filter_mask[j]) != 0 { - let data_index = (i * 64) + j; - if input_array.is_null(data_index) { - builder.append_null()?; - } else { - builder.append_value(input_array.value(data_index))?; - } - } - } - } - Ok(Arc::new(builder.finish())) - }, - DataType::Date32(_) => filter_primitive_array!(self, array, Date32Array), - DataType::Date64(_) => filter_primitive_array!(self, array, Date64Array), - DataType::Time32(TimeUnit::Second) => { - filter_primitive_array!(self, array, Time32SecondArray) - } - DataType::Time32(TimeUnit::Millisecond) => { - filter_primitive_array!(self, array, Time32MillisecondArray) - } - DataType::Time64(TimeUnit::Microsecond) => { - filter_primitive_array!(self, array, Time64MicrosecondArray) - } - DataType::Time64(TimeUnit::Nanosecond) => { - filter_primitive_array!(self, array, Time64NanosecondArray) - } - DataType::Duration(TimeUnit::Second) => { - filter_primitive_array!(self, array, DurationSecondArray) - } - DataType::Duration(TimeUnit::Millisecond) => { - filter_primitive_array!(self, array, DurationMillisecondArray) - } - DataType::Duration(TimeUnit::Microsecond) => { - filter_primitive_array!(self, array, DurationMicrosecondArray) - } - DataType::Duration(TimeUnit::Nanosecond) => { - filter_primitive_array!(self, array, DurationNanosecondArray) - } - DataType::Timestamp(TimeUnit::Second, _) => { - filter_primitive_array!(self, array, TimestampSecondArray) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - filter_primitive_array!(self, array, TimestampMillisecondArray) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - filter_primitive_array!(self, array, TimestampMicrosecondArray) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - filter_primitive_array!(self, array, TimestampNanosecondArray) - } - DataType::Binary => { - let input_array = array.as_any().downcast_ref::().unwrap(); - let mut values: Vec> = Vec::with_capacity(self.filtered_count); - for i in 0..self.filter_u64.len() { - // foreach u64 batch - let filter_batch = self.filter_u64[i]; - if filter_batch == 0 { - // if batch == 0, all items are filtered out, so skip entire batch - continue; - } - for j in 0..64 { - // foreach bit in batch: - if (filter_batch & self.filter_mask[j]) != 0 { - let data_index = (i * 64) + j; - if input_array.is_null(data_index) { - values.push(None) - } else { - values.push(Some(input_array.value(data_index))) - } - } - } - } - Ok(Arc::new(BinaryArray::from(values))) - } - DataType::Utf8 => { - let input_array = array.as_any().downcast_ref::().unwrap(); - let mut values: Vec> = Vec::with_capacity(self.filtered_count); - for i in 0..self.filter_u64.len() { - // foreach u64 batch - let filter_batch = self.filter_u64[i]; - if filter_batch == 0 { - // if batch == 0, all items are filtered out, so skip entire batch - continue; - } - for j in 0..64 { - // foreach bit in batch: - if (filter_batch & self.filter_mask[j]) != 0 { - let data_index = (i * 64) + j; - if input_array.is_null(data_index) { - values.push(None) - } else { - values.push(Some(input_array.value(data_index))) - } - } - } - } - Ok(Arc::new(StringArray::from(values))) - } - DataType::Dictionary(ref key_type, ref value_type) => match (key_type.as_ref(), value_type.as_ref()) { - (key_type, DataType::Utf8) => match key_type { - DataType::UInt8 => filter_dictionary_array!(self, array, UInt8DictionaryArray), - DataType::UInt16 => filter_dictionary_array!(self, array, UInt16DictionaryArray), - DataType::UInt32 => filter_dictionary_array!(self, array, UInt32DictionaryArray), - DataType::UInt64 => filter_dictionary_array!(self, array, UInt64DictionaryArray), - DataType::Int8 => filter_dictionary_array!(self, array, Int8DictionaryArray), - DataType::Int16 => filter_dictionary_array!(self, array, Int16DictionaryArray), - DataType::Int32 => filter_dictionary_array!(self, array, Int32DictionaryArray), - DataType::Int64 => filter_dictionary_array!(self, array, Int64DictionaryArray), - other => Err(ArrowError::ComputeError(format!( - "filter not supported for string dictionary with key of type {:?}", - other - ))) - } - (key_type, value_type) => Err(ArrowError::ComputeError(format!( - "filter not supported for Dictionary({:?}, {:?})", - key_type, value_type - ))) - } - DataType::List(dt) => match dt.data_type() { - DataType::UInt8 => { - filter_primitive_item_list_array!(self, array, UInt8Type, ListArray, ListBuilder) - } - DataType::UInt16 => { - filter_primitive_item_list_array!(self, array, UInt16Type, ListArray, ListBuilder) - } - DataType::UInt32 => { - filter_primitive_item_list_array!(self, array, UInt32Type, ListArray, ListBuilder) - } - DataType::UInt64 => { - filter_primitive_item_list_array!(self, array, UInt64Type, ListArray, ListBuilder) - } - DataType::Int8 => filter_primitive_item_list_array!(self, array, Int8Type, ListArray, ListBuilder), - DataType::Int16 => { - filter_primitive_item_list_array!(self, array, Int16Type, ListArray, ListBuilder) - } - DataType::Int32 => { - filter_primitive_item_list_array!(self, array, Int32Type, ListArray, ListBuilder) - } - DataType::Int64 => { - filter_primitive_item_list_array!(self, array, Int64Type, ListArray, ListBuilder) - } - DataType::Float32 => { - filter_primitive_item_list_array!(self, array, Float32Type, ListArray, ListBuilder) - } - DataType::Float64 => { - filter_primitive_item_list_array!(self, array, Float64Type, ListArray, ListBuilder) - } - DataType::Boolean => { - filter_boolean_item_list_array!(self, array, ListArray, ListBuilder) - } - DataType::Date32(_) => { - filter_primitive_item_list_array!(self, array, Date32Type, ListArray, ListBuilder) - } - DataType::Date64(_) => { - filter_primitive_item_list_array!(self, array, Date64Type, ListArray, ListBuilder) - } - DataType::Time32(TimeUnit::Second) => { - filter_primitive_item_list_array!(self, array, Time32SecondType, ListArray, ListBuilder) - } - DataType::Time32(TimeUnit::Millisecond) => { - filter_primitive_item_list_array!(self, array, Time32MillisecondType, ListArray, ListBuilder) - } - DataType::Time64(TimeUnit::Microsecond) => { - filter_primitive_item_list_array!(self, array, Time64MicrosecondType, ListArray, ListBuilder) - } - DataType::Time64(TimeUnit::Nanosecond) => { - filter_primitive_item_list_array!(self, array, Time64NanosecondType, ListArray, ListBuilder) - } - DataType::Duration(TimeUnit::Second) => { - filter_primitive_item_list_array!(self, array, DurationSecondType, ListArray, ListBuilder) - } - DataType::Duration(TimeUnit::Millisecond) => { - filter_primitive_item_list_array!(self, array, DurationMillisecondType, ListArray, ListBuilder) - } - DataType::Duration(TimeUnit::Microsecond) => { - filter_primitive_item_list_array!(self, array, DurationMicrosecondType, ListArray, ListBuilder) - } - DataType::Duration(TimeUnit::Nanosecond) => { - filter_primitive_item_list_array!(self, array, DurationNanosecondType, ListArray, ListBuilder) - } - DataType::Timestamp(TimeUnit::Second, _) => { - filter_primitive_item_list_array!(self, array, TimestampSecondType, ListArray, ListBuilder) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - filter_primitive_item_list_array!(self, array, TimestampMillisecondType, ListArray, ListBuilder) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - filter_primitive_item_list_array!(self, array, TimestampMicrosecondType, ListArray, ListBuilder) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - filter_primitive_item_list_array!(self, array, TimestampNanosecondType, ListArray, ListBuilder) - } - DataType::Binary => filter_non_primitive_item_list_array!( - self, - array, - BinaryArray, - BinaryBuilder, - ListArray, - ListBuilder - ), - DataType::LargeBinary => filter_non_primitive_item_list_array!( - self, - array, - LargeBinaryArray, - LargeBinaryBuilder, - ListArray, - ListBuilder - ), - DataType::Utf8 => filter_non_primitive_item_list_array!( - self, - array, - StringArray, - StringBuilder, - ListArray - ,ListBuilder - ), - DataType::LargeUtf8 => filter_non_primitive_item_list_array!( - self, - array, - LargeStringArray, - LargeStringBuilder, - ListArray, - ListBuilder - ), - other => { - Err(ArrowError::ComputeError(format!( - "filter not supported for List({:?})", - other - ))) - } - } - DataType::LargeList(dt) => match dt.data_type() { - DataType::UInt8 => { - filter_primitive_item_list_array!(self, array, UInt8Type, LargeListArray, LargeListBuilder) - } - DataType::UInt16 => { - filter_primitive_item_list_array!(self, array, UInt16Type, LargeListArray, LargeListBuilder) - } - DataType::UInt32 => { - filter_primitive_item_list_array!(self, array, UInt32Type, LargeListArray, LargeListBuilder) - } - DataType::UInt64 => { - filter_primitive_item_list_array!(self, array, UInt64Type, LargeListArray, LargeListBuilder) - } - DataType::Int8 => filter_primitive_item_list_array!(self, array, Int8Type, LargeListArray, LargeListBuilder), - DataType::Int16 => { - filter_primitive_item_list_array!(self, array, Int16Type, LargeListArray, LargeListBuilder) - } - DataType::Int32 => { - filter_primitive_item_list_array!(self, array, Int32Type, LargeListArray, LargeListBuilder) - } - DataType::Int64 => { - filter_primitive_item_list_array!(self, array, Int64Type, LargeListArray, LargeListBuilder) - } - DataType::Float32 => { - filter_primitive_item_list_array!(self, array, Float32Type, LargeListArray, LargeListBuilder) - } - DataType::Float64 => { - filter_primitive_item_list_array!(self, array, Float64Type, LargeListArray, LargeListBuilder) - } - DataType::Boolean => { - filter_boolean_item_list_array!(self, array, LargeListArray, LargeListBuilder) - } - DataType::Date32(_) => { - filter_primitive_item_list_array!(self, array, Date32Type, LargeListArray, LargeListBuilder) - } - DataType::Date64(_) => { - filter_primitive_item_list_array!(self, array, Date64Type, LargeListArray, LargeListBuilder) - } - DataType::Time32(TimeUnit::Second) => { - filter_primitive_item_list_array!(self, array, Time32SecondType, LargeListArray, LargeListBuilder) - } - DataType::Time32(TimeUnit::Millisecond) => { - filter_primitive_item_list_array!(self, array, Time32MillisecondType, LargeListArray, LargeListBuilder) - } - DataType::Time64(TimeUnit::Microsecond) => { - filter_primitive_item_list_array!(self, array, Time64MicrosecondType, LargeListArray, LargeListBuilder) - } - DataType::Time64(TimeUnit::Nanosecond) => { - filter_primitive_item_list_array!(self, array, Time64NanosecondType, LargeListArray, LargeListBuilder) - } - DataType::Duration(TimeUnit::Second) => { - filter_primitive_item_list_array!(self, array, DurationSecondType, LargeListArray, LargeListBuilder) - } - DataType::Duration(TimeUnit::Millisecond) => { - filter_primitive_item_list_array!(self, array, DurationMillisecondType, LargeListArray, LargeListBuilder) - } - DataType::Duration(TimeUnit::Microsecond) => { - filter_primitive_item_list_array!(self, array, DurationMicrosecondType, LargeListArray, LargeListBuilder) - } - DataType::Duration(TimeUnit::Nanosecond) => { - filter_primitive_item_list_array!(self, array, DurationNanosecondType, LargeListArray, LargeListBuilder) - } - DataType::Timestamp(TimeUnit::Second, _) => { - filter_primitive_item_list_array!(self, array, TimestampSecondType, LargeListArray, LargeListBuilder) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - filter_primitive_item_list_array!(self, array, TimestampMillisecondType, LargeListArray, LargeListBuilder) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - filter_primitive_item_list_array!(self, array, TimestampMicrosecondType, LargeListArray, LargeListBuilder) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - filter_primitive_item_list_array!(self, array, TimestampNanosecondType, LargeListArray, LargeListBuilder) - } - DataType::Binary => filter_non_primitive_item_list_array!( - self, - array, - BinaryArray, - BinaryBuilder, - LargeListArray, - LargeListBuilder - ), - DataType::LargeBinary => filter_non_primitive_item_list_array!( - self, - array, - LargeBinaryArray, - LargeBinaryBuilder, - LargeListArray, - LargeListBuilder - ), - DataType::Utf8 => filter_non_primitive_item_list_array!( - self, - array, - StringArray, - StringBuilder, - LargeListArray, - LargeListBuilder - ), - DataType::LargeUtf8 => filter_non_primitive_item_list_array!( - self, - array, - LargeStringArray, - LargeStringBuilder, - LargeListArray, - LargeListBuilder - ), - other => { - Err(ArrowError::ComputeError(format!( - "filter not supported for LargeList({:?})", - other - ))) - } - } - other => Err(ArrowError::ComputeError(format!( - "filter not supported for {:?}", - other - ))), - } - } - - /// Returns a new PrimitiveArray containing only those values from the array passed as the data_array parameter, - /// selected by the BooleanArray passed as the filter_array parameter - pub fn filter_primitive_array( - &self, - data_array: &PrimitiveArray, - ) -> Result> - where - T: ArrowNumericType, - { - let array_type = T::DATA_TYPE; - let value_size = mem::size_of::(); - let array_data_builder = - filter_array_impl(self, data_array, array_type, value_size)?; - let data = array_data_builder.build(); - Ok(PrimitiveArray::::from(data)) - } - - /// Returns a new DictionaryArray containing only those keys from the array passed as the data_array parameter, - /// selected by the BooleanArray passed as the filter_array parameter. The values are cloned from the data_array. - pub fn filter_dictionary_array( - &self, - data_array: &DictionaryArray, - ) -> Result> - where - T: ArrowNumericType, - { - let array_type = data_array.data_type().clone(); - let value_size = mem::size_of::(); - let mut array_data_builder = - filter_array_impl(self, data_array, array_type, value_size)?; - // copy dictionary values from input array - array_data_builder = - array_data_builder.add_child_data(data_array.values().data()); - let data = array_data_builder.build(); - Ok(DictionaryArray::::from(data)) } } -/// Returns a new array, containing only the elements matching the filter. -pub fn filter(array: &Array, filter: &BooleanArray) -> Result { - FilterContext::new(filter)?.filter(array) +/// Returns a prepared function optimized to filter multiple arrays. +/// Creating this function requires time, but using it is faster than [filter] when the +/// same filter needs to be applied to multiple arrays (e.g. a multi-column `RecordBatch`). +/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered. +/// Therefore, it is considered undefined behavior to pass `filter` with null values. +pub fn build_filter(filter: &BooleanArray) -> Result { + let iter = SlicesIterator::new(filter); + let filter_count = iter.filter_count; + let chunks = iter.collect::>(); + + Ok(Box::new(move |array: &ArrayData| { + let mut mutable = MutableArrayData::new(vec![array], false, filter_count); + chunks + .iter() + .for_each(|(start, end)| mutable.extend(0, *start, *end)); + mutable.freeze() + })) } -/// Returns a new PrimitiveArray containing only those values from the array passed as the data_array parameter, -/// selected by the BooleanArray passed as the filter_array parameter -pub fn filter_primitive_array( - data_array: &PrimitiveArray, - filter_array: &BooleanArray, -) -> Result> -where - T: ArrowNumericType, -{ - FilterContext::new(filter_array)?.filter_primitive_array(data_array) -} +/// Filters an [Array], returning elements matching the filter (i.e. where the values are true). +/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered. +/// Therefore, it is considered undefined behavior to pass `filter` with null values. +/// # Example +/// ```rust +/// # use arrow::array::{Int32Array, BooleanArray}; +/// # use arrow::error::Result; +/// # use arrow::compute::kernels::filter::filter; +/// # fn main() -> Result<()> { +/// let array = Int32Array::from(vec![5, 6, 7, 8, 9]); +/// let filter_array = BooleanArray::from(vec![true, false, false, true, false]); +/// let c = filter(&array, &filter_array)?; +/// let c = c.as_any().downcast_ref::().unwrap(); +/// assert_eq!(c, &Int32Array::from(vec![5, 8])); +/// # Ok(()) +/// # } +/// ``` +pub fn filter(array: &Array, filter: &BooleanArray) -> Result { + let iter = SlicesIterator::new(filter); -/// Returns a new DictionaryArray containing only those keys from the array passed as the data_array parameter, -/// selected by the BooleanArray passed as the filter_array parameter. The values are cloned from the data_array. -pub fn filter_dictionary_array( - data_array: &DictionaryArray, - filter_array: &BooleanArray, -) -> Result> -where - T: ArrowNumericType, -{ - FilterContext::new(filter_array)?.filter_dictionary_array(data_array) + let mut mutable = + MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count); + iter.for_each(|(start, end)| mutable.extend(0, start, end)); + let data = mutable.freeze(); + Ok(make_array(Arc::new(data))) } -/// Returns a new RecordBatch with arrays containing only values matching the filter. -/// The same FilterContext is re-used when filtering arrays in the RecordBatch for better performance. +/// Returns a new [RecordBatch] with arrays containing only values matching the filter. +/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered. +/// Therefore, it is considered undefined behavior to pass `filter` with null values. pub fn filter_record_batch( record_batch: &RecordBatch, - filter_array: &BooleanArray, + filter: &BooleanArray, ) -> Result { - let filter_context = FilterContext::new(filter_array)?; + let filter = build_filter(filter)?; let filtered_arrays = record_batch .columns() .iter() - .map(|a| filter_context.filter(a.as_ref())) - .collect::>>()?; + .map(|a| make_array(Arc::new(filter(&a.data())))) + .collect(); RecordBatch::try_new(record_batch.schema(), filtered_arrays) } #[cfg(test)] mod tests { use super::*; - use crate::buffer::Buffer; use crate::datatypes::ToByteSlice; + use crate::{ + buffer::Buffer, + datatypes::{DataType, Field}, + }; macro_rules! def_temporal_test { ($test:ident, $array_type: ident, $data: expr) => { @@ -939,17 +341,6 @@ mod tests { TimestampNanosecondArray::from_vec(vec![1, 2, 3, 4], None) ); - #[test] - fn test_filter_array() { - let a = Int32Array::from(vec![5, 6, 7, 8, 9]); - let b = BooleanArray::from(vec![true, false, false, true, false]); - let c = filter(&a, &b).unwrap(); - let d = c.as_ref().as_any().downcast_ref::().unwrap(); - assert_eq!(2, d.len()); - assert_eq!(5, d.value(0)); - assert_eq!(8, d.value(1)); - } - #[test] fn test_filter_array_slice() { let a_slice = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4); @@ -1009,7 +400,7 @@ mod tests { } #[test] - fn test_filter_string_array() { + fn test_filter_string_array_simple() { let a = StringArray::from(vec!["hello", " ", "world", "!"]); let b = BooleanArray::from(vec![true, false, true, false]); let c = filter(&a, &b).unwrap(); @@ -1131,36 +522,64 @@ mod tests { // a = [[0, 1, 2], [3, 4, 5], [6, 7], null] let a = LargeListArray::from(list_data); let b = BooleanArray::from(vec![false, true, false, true]); - let c = filter(&a, &b).unwrap(); - let d = c - .as_ref() - .as_any() - .downcast_ref::() - .unwrap(); + let result = filter(&a, &b).unwrap(); - assert_eq!(DataType::Int32, d.value_type()); + // expected: [[3, 4, 5], null] + let value_data = ArrayData::builder(DataType::Int32) + .len(3) + .add_buffer(Buffer::from(&[3, 4, 5].to_byte_slice())) + .build(); - // result should be [[3, 4, 5], null] - assert_eq!(2, d.len()); - assert_eq!(1, d.null_count()); - assert_eq!(true, d.is_null(1)); + let value_offsets = Buffer::from(&[0i64, 3, 3].to_byte_slice()); + + let list_data_type = + DataType::LargeList(Box::new(Field::new("item", DataType::Int32, false))); + let expected = ArrayData::builder(list_data_type) + .len(2) + .add_buffer(value_offsets) + .add_child_data(value_data) + .null_bit_buffer(Buffer::from([0b00000001])) + .build(); + + assert_eq!(&make_array(expected), &result); + } + + #[test] + fn test_slice_iterator_bits() { + let filter_values = (0..64).map(|i| i == 1).collect::>(); + let filter = BooleanArray::from(filter_values); + + let iter = SlicesIterator::new(&filter); + let filter_count = iter.filter_count; + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(1, 2)]); + assert_eq!(filter_count, 1); + } + + #[test] + fn test_slice_iterator_bits1() { + let filter_values = (0..64).map(|i| i != 1).collect::>(); + let filter = BooleanArray::from(filter_values); + + let iter = SlicesIterator::new(&filter); + let filter_count = iter.filter_count; + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(0, 1), (2, 64)]); + assert_eq!(filter_count, 64 - 1); + } + + #[test] + fn test_slice_iterator_chunk_and_bits() { + let filter_values = (0..130).map(|i| i % 62 != 0).collect::>(); + let filter = BooleanArray::from(filter_values); + + let iter = SlicesIterator::new(&filter); + let filter_count = iter.filter_count; + let chunks = iter.collect::>(); - assert_eq!(0, d.value_offset(0)); - assert_eq!(3, d.value_length(0)); - assert_eq!(3, d.value_offset(1)); - assert_eq!(0, d.value_length(1)); - assert_eq!( - Buffer::from(&[3, 4, 5].to_byte_slice()), - d.values().data().buffers()[0].clone() - ); - assert_eq!( - Buffer::from(&[0i64, 3, 3].to_byte_slice()), - d.data().buffers()[0].clone() - ); - let inner_list = d.value(0); - let inner_list = inner_list.as_any().downcast_ref::().unwrap(); - assert_eq!(3, inner_list.len()); - assert_eq!(0, inner_list.null_count()); - assert_eq!(inner_list, &Int32Array::from(vec![3, 4, 5])); + assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]); + assert_eq!(filter_count, 61 + 61 + 5); } } diff --git a/rust/arrow/src/util/bit_chunk_iterator.rs b/rust/arrow/src/util/bit_chunk_iterator.rs index 801c38a243f..b9145b7af86 100644 --- a/rust/arrow/src/util/bit_chunk_iterator.rs +++ b/rust/arrow/src/util/bit_chunk_iterator.rs @@ -14,14 +14,12 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -use crate::buffer::Buffer; use crate::util::bit_util::ceil; use std::fmt::Debug; #[derive(Debug)] pub struct BitChunks<'a> { - buffer: &'a Buffer, - raw_data: *const u8, + buffer: &'a [u8], /// offset inside a byte, guaranteed to be between 0 and 7 (inclusive) bit_offset: usize, /// number of complete u64 chunks @@ -31,22 +29,19 @@ pub struct BitChunks<'a> { } impl<'a> BitChunks<'a> { - pub fn new(buffer: &'a Buffer, offset: usize, len: usize) -> Self { + pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self { assert!(ceil(offset + len, 8) <= buffer.len() * 8); let byte_offset = offset / 8; let bit_offset = offset % 8; - let raw_data = unsafe { buffer.raw_data().add(byte_offset) }; - let chunk_bits = 8 * std::mem::size_of::(); let chunk_len = len / chunk_bits; let remainder_len = len & (chunk_bits - 1); BitChunks::<'a> { - buffer: &buffer, - raw_data, + buffer: &buffer[byte_offset..], bit_offset, chunk_len, remainder_len, @@ -56,8 +51,7 @@ impl<'a> BitChunks<'a> { #[derive(Debug)] pub struct BitChunkIterator<'a> { - buffer: &'a Buffer, - raw_data: *const u8, + buffer: &'a [u8], bit_offset: usize, chunk_len: usize, index: usize, @@ -70,6 +64,12 @@ impl<'a> BitChunks<'a> { self.remainder_len } + /// Returns the number of chunks + #[inline] + pub const fn chunk_len(&self) -> usize { + self.chunk_len + } + /// Returns the bitmask of remaining bits #[inline] pub fn remainder_bits(&self) -> u64 { @@ -83,7 +83,8 @@ impl<'a> BitChunks<'a> { let byte_len = ceil(bit_len + bit_offset, 8); // pointer to remainder bytes after all complete chunks let base = unsafe { - self.raw_data + self.buffer + .as_ptr() .add(self.chunk_len * std::mem::size_of::()) }; @@ -102,7 +103,6 @@ impl<'a> BitChunks<'a> { pub const fn iter(&self) -> BitChunkIterator<'a> { BitChunkIterator::<'a> { buffer: self.buffer, - raw_data: self.raw_data, bit_offset: self.bit_offset, chunk_len: self.chunk_len, index: 0, @@ -131,7 +131,7 @@ impl Iterator for BitChunkIterator<'_> { // cast to *const u64 should be fine since we are using read_unaligned below #[allow(clippy::cast_ptr_alignment)] - let raw_data = self.raw_data as *const u64; + let raw_data = self.buffer.as_ptr() as *const u64; // bit-packed buffers are stored starting with the least-significant byte first // so when reading as u64 on a big-endian machine, the bytes need to be swapped