diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index b339479b35e9d..6f3a2a99600e6 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -25,9 +25,8 @@ use std::sync::Arc; use arrow::array::{ Array, ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, BooleanBufferBuilder, - PrimitiveArray, }; -use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::buffer::BooleanBuffer; use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions}; use arrow::datatypes::{ DataType, Date32Type, Date64Type, Decimal32Type, Decimal64Type, Decimal128Type, @@ -53,6 +52,10 @@ use datafusion_functions_aggregate_common::utils::get_sort_options; use datafusion_macros::user_doc; use datafusion_physical_expr_common::sort_expr::LexOrdering; +mod state; + +use state::{BytesValueState, PrimitiveValueState, ValueState}; + create_func!(FirstValue, first_value_udaf); create_func!(LastValue, last_value_udaf); @@ -76,6 +79,185 @@ pub fn last_value(expression: Expr, order_by: Vec) -> Expr { .unwrap() } +fn create_groups_primitive_accumulator( + args: &AccumulatorArgs, + is_first: bool, +) -> Result> { + let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else { + return internal_err!("Groups accumulator must have an ordering."); + }; + + let ordering_dtypes = ordering + .iter() + .map(|e| e.expr.data_type(args.schema)) + .collect::>>()?; + + Ok(Box::new(FirstLastGroupsAccumulator::try_new( + PrimitiveValueState::::new(args.return_field.data_type().clone()), + ordering, + args.ignore_nulls, + &ordering_dtypes, + is_first, + )?)) +} + +fn create_groups_bytes_accumulator( + args: &AccumulatorArgs, + is_first: bool, +) -> Result> { + let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else { + return internal_err!("Groups accumulator must have an ordering."); + }; + + let ordering_dtypes = ordering + .iter() + .map(|e| e.expr.data_type(args.schema)) + .collect::>>()?; + + Ok(Box::new(FirstLastGroupsAccumulator::try_new( + BytesValueState::try_new(args.return_field.data_type().clone())?, + ordering, + args.ignore_nulls, + &ordering_dtypes, + is_first, + )?)) +} + +fn create_groups_accumulator( + args: &AccumulatorArgs, + is_first: bool, + function_name: &str, +) -> Result> { + match args.return_field.data_type() { + DataType::Int8 => create_groups_primitive_accumulator::(args, is_first), + DataType::Int16 => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::Int32 => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::Int64 => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::UInt8 => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::UInt16 => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::UInt32 => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::UInt64 => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::Float16 => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::Float32 => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::Float64 => { + create_groups_primitive_accumulator::(args, is_first) + } + + DataType::Decimal32(_, _) => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::Decimal64(_, _) => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::Decimal128(_, _) => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::Decimal256(_, _) => { + create_groups_primitive_accumulator::(args, is_first) + } + + DataType::Timestamp(TimeUnit::Second, _) => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + create_groups_primitive_accumulator::( + args, is_first, + ) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + create_groups_primitive_accumulator::( + args, is_first, + ) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + create_groups_primitive_accumulator::(args, is_first) + } + + DataType::Date32 => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::Date64 => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::Time32(TimeUnit::Second) => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::Time32(TimeUnit::Millisecond) => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::Time64(TimeUnit::Microsecond) => { + create_groups_primitive_accumulator::(args, is_first) + } + DataType::Time64(TimeUnit::Nanosecond) => { + create_groups_primitive_accumulator::(args, is_first) + } + + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Utf8View + | DataType::Binary + | DataType::LargeBinary + | DataType::BinaryView => create_groups_bytes_accumulator(args, is_first), + + _ => internal_err!( + "GroupsAccumulator not supported for {}({})", + function_name, + args.return_field.data_type() + ), + } +} + +fn groups_accumulator_supported(args: &AccumulatorArgs) -> bool { + use DataType::*; + !args.order_bys.is_empty() + && matches!( + args.return_field.data_type(), + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + | Utf8 + | LargeUtf8 + | Utf8View + | Binary + | LargeBinary + | BinaryView + ) +} + #[user_doc( doc_section(label = "General Functions"), description = "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.", @@ -183,110 +365,14 @@ impl AggregateUDFImpl for FirstValue { } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { - use DataType::*; - !args.order_bys.is_empty() - && matches!( - args.return_field.data_type(), - Int8 | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float16 - | Float32 - | Float64 - | Decimal32(_, _) - | Decimal64(_, _) - | Decimal128(_, _) - | Decimal256(_, _) - | Date32 - | Date64 - | Time32(_) - | Time64(_) - | Timestamp(_, _) - ) + groups_accumulator_supported(&args) } fn create_groups_accumulator( &self, args: AccumulatorArgs, ) -> Result> { - fn create_accumulator( - args: &AccumulatorArgs, - ) -> Result> { - let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else { - return internal_err!("Groups accumulator must have an ordering."); - }; - - let ordering_dtypes = ordering - .iter() - .map(|e| e.expr.data_type(args.schema)) - .collect::>>()?; - - FirstPrimitiveGroupsAccumulator::::try_new( - ordering, - args.ignore_nulls, - args.return_field.data_type(), - &ordering_dtypes, - true, - ) - .map(|acc| Box::new(acc) as _) - } - - match args.return_field.data_type() { - DataType::Int8 => create_accumulator::(&args), - DataType::Int16 => create_accumulator::(&args), - DataType::Int32 => create_accumulator::(&args), - DataType::Int64 => create_accumulator::(&args), - DataType::UInt8 => create_accumulator::(&args), - DataType::UInt16 => create_accumulator::(&args), - DataType::UInt32 => create_accumulator::(&args), - DataType::UInt64 => create_accumulator::(&args), - DataType::Float16 => create_accumulator::(&args), - DataType::Float32 => create_accumulator::(&args), - DataType::Float64 => create_accumulator::(&args), - - DataType::Decimal32(_, _) => create_accumulator::(&args), - DataType::Decimal64(_, _) => create_accumulator::(&args), - DataType::Decimal128(_, _) => create_accumulator::(&args), - DataType::Decimal256(_, _) => create_accumulator::(&args), - - DataType::Timestamp(TimeUnit::Second, _) => { - create_accumulator::(&args) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - create_accumulator::(&args) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - create_accumulator::(&args) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - create_accumulator::(&args) - } - - DataType::Date32 => create_accumulator::(&args), - DataType::Date64 => create_accumulator::(&args), - DataType::Time32(TimeUnit::Second) => { - create_accumulator::(&args) - } - DataType::Time32(TimeUnit::Millisecond) => { - create_accumulator::(&args) - } - - DataType::Time64(TimeUnit::Microsecond) => { - create_accumulator::(&args) - } - DataType::Time64(TimeUnit::Nanosecond) => { - create_accumulator::(&args) - } - - _ => internal_err!( - "GroupsAccumulator not supported for first_value({})", - args.return_field.data_type() - ), - } + create_groups_accumulator(&args, true, self.name()) } fn with_beneficial_ordering( @@ -316,13 +402,9 @@ impl AggregateUDFImpl for FirstValue { } } -// TODO: rename to PrimitiveGroupsAccumulator -struct FirstPrimitiveGroupsAccumulator -where - T: ArrowPrimitiveType + Send, -{ +struct FirstLastGroupsAccumulator { // ================ state =========== - vals: Vec, + state: S, // Stores ordering values, of the aggregator requirement corresponding to first value // of the aggregator. // The `orderings` are stored row-wise, meaning that `orderings[group_idx]` @@ -331,8 +413,6 @@ where // At the beginning, `is_sets[group_idx]` is false, which means `first` is not seen yet. // Once we see the first value, we set the `is_sets[group_idx]` flag is_sets: BooleanBufferBuilder, - // null_builder[group_idx] == false => vals[group_idx] is null - null_builder: BooleanBufferBuilder, // size of `self.orderings` // Calculating the memory usage of `self.orderings` using `ScalarValue::size_of_vec` is quite costly. // Therefore, we cache it and compute `size_of` only after each update @@ -342,8 +422,7 @@ where // buffer for `get_filtered_min_of_each_group` // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val // only valid if filter_min_of_each_group_buf.1[group_idx] == true - // TODO: rename to extreme_of_each_group_buf - min_of_each_group_buf: (Vec, BooleanBufferBuilder), + extreme_of_each_group_buf: (Vec, BooleanBufferBuilder), // =========== option ============ @@ -356,19 +435,14 @@ where sort_options: Vec, // Ignore null values. ignore_nulls: bool, - /// The output type - data_type: DataType, default_orderings: Vec, } -impl FirstPrimitiveGroupsAccumulator -where - T: ArrowPrimitiveType + Send, -{ +impl FirstLastGroupsAccumulator { fn try_new( + state: S, ordering_req: LexOrdering, ignore_nulls: bool, - data_type: &DataType, ordering_dtypes: &[DataType], pick_first_in_group: bool, ) -> Result { @@ -380,17 +454,15 @@ where let sort_options = get_sort_options(&ordering_req); Ok(Self { - null_builder: BooleanBufferBuilder::new(0), ordering_req, sort_options, ignore_nulls, default_orderings, - data_type: data_type.clone(), - vals: Vec::new(), + state, orderings: Vec::new(), is_sets: BooleanBufferBuilder::new(0), size_of_orderings: 0, - min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)), + extreme_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)), pick_first_in_group, }) } @@ -429,32 +501,8 @@ where result } - fn take_need( - bool_buf_builder: &mut BooleanBufferBuilder, - emit_to: EmitTo, - ) -> BooleanBuffer { - let bool_buf = bool_buf_builder.finish(); - match emit_to { - EmitTo::All => bool_buf, - EmitTo::First(n) => { - // split off the first N values in seen_values - // - // TODO make this more efficient rather than two - // copies and bitwise manipulation - let first_n: BooleanBuffer = bool_buf.iter().take(n).collect(); - // reset the existing buffer - for b in bool_buf.iter().skip(n) { - bool_buf_builder.append(b); - } - first_n - } - } - } - fn resize_states(&mut self, new_size: usize) { - self.vals.resize(new_size, T::default_value()); - - self.null_builder.resize(new_size); + self.state.resize(new_size); if self.orderings.len() < new_size { let current_len = self.orderings.len(); @@ -473,44 +521,43 @@ where self.is_sets.resize(new_size); - self.min_of_each_group_buf.0.resize(new_size, 0); - self.min_of_each_group_buf.1.resize(new_size); + self.extreme_of_each_group_buf.0.resize(new_size, 0); + self.extreme_of_each_group_buf.1.resize(new_size); } fn update_state( &mut self, group_idx: usize, orderings: &[ScalarValue], - new_val: T::Native, - is_null: bool, - ) { - self.vals[group_idx] = new_val; + array: &ArrayRef, + idx: usize, + ) -> Result<()> { + self.state.update(group_idx, array, idx)?; self.is_sets.set_bit(group_idx, true); - self.null_builder.set_bit(group_idx, !is_null); - assert!(orderings.len() == self.ordering_req.len()); let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]); self.orderings[group_idx].clear(); self.orderings[group_idx].extend_from_slice(orderings); let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]); self.size_of_orderings = self.size_of_orderings - old_size + new_size; + Ok(()) } fn take_state( &mut self, emit_to: EmitTo, - ) -> (ArrayRef, Vec>, BooleanBuffer) { - emit_to.take_needed(&mut self.min_of_each_group_buf.0); - self.min_of_each_group_buf + ) -> Result<(ArrayRef, Vec>, BooleanBuffer)> { + emit_to.take_needed(&mut self.extreme_of_each_group_buf.0); + self.extreme_of_each_group_buf .1 - .truncate(self.min_of_each_group_buf.0.len()); + .truncate(self.extreme_of_each_group_buf.0.len()); - ( - self.take_vals_and_null_buf(emit_to), + Ok(( + self.state.take(emit_to)?, self.take_orderings(emit_to), - Self::take_need(&mut self.is_sets, emit_to), - ) + state::take_need(&mut self.is_sets, emit_to), + )) } // should be used in test only @@ -524,20 +571,19 @@ where /// Returns a vector of tuples `(group_idx, idx_in_val)` representing the index of the /// minimum value in `orderings` for each group, using lexicographical comparison. /// Values are filtered using `opt_filter` and `is_set_arr` if provided. - /// TODO: rename to get_filtered_extreme_of_each_group - fn get_filtered_min_of_each_group( + fn get_filtered_extreme_of_each_group( &mut self, orderings: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, - vals: &PrimitiveArray, + vals: &ArrayRef, is_set_arr: Option<&BooleanArray>, ) -> Result> { // Set all values in min_of_each_group_buf.1 to false. - self.min_of_each_group_buf.1.truncate(0); - self.min_of_each_group_buf + self.extreme_of_each_group_buf.1.truncate(0); + self.extreme_of_each_group_buf .1 - .append_n(self.vals.len(), false); + .append_n(self.is_sets.len(), false); // No need to call `clear` since `self.min_of_each_group_buf.0[group_idx]` // is only valid when `self.min_of_each_group_buf.1[group_idx] == true`. @@ -570,48 +616,35 @@ where continue; } - let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx); + let is_valid = self.extreme_of_each_group_buf.1.get_bit(group_idx); if !is_valid { - self.min_of_each_group_buf.1.set_bit(group_idx, true); - self.min_of_each_group_buf.0[group_idx] = idx_in_val; + self.extreme_of_each_group_buf.1.set_bit(group_idx, true); + self.extreme_of_each_group_buf.0[group_idx] = idx_in_val; } else { let ordering = comparator - .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val); + .compare(self.extreme_of_each_group_buf.0[group_idx], idx_in_val); if (ordering.is_gt() && self.pick_first_in_group) || (ordering.is_lt() && !self.pick_first_in_group) { - self.min_of_each_group_buf.0[group_idx] = idx_in_val; + self.extreme_of_each_group_buf.0[group_idx] = idx_in_val; } } } Ok(self - .min_of_each_group_buf + .extreme_of_each_group_buf .0 .iter() .enumerate() - .filter(|(group_idx, _)| self.min_of_each_group_buf.1.get_bit(*group_idx)) + .filter(|(group_idx, _)| self.extreme_of_each_group_buf.1.get_bit(*group_idx)) .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val)) .collect::>()) } - - fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef { - let r = emit_to.take_needed(&mut self.vals); - - let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, emit_to)); - - let values = PrimitiveArray::::new(r.into(), Some(null_buf)) // no copy - .with_data_type(self.data_type.clone()); - Arc::new(values) - } } -impl GroupsAccumulator for FirstPrimitiveGroupsAccumulator -where - T: ArrowPrimitiveType + Send, -{ +impl GroupsAccumulator for FirstLastGroupsAccumulator { fn update_batch( &mut self, // e.g. first_value(a order by b): values_and_order_cols will be [a, b] @@ -622,13 +655,13 @@ where ) -> Result<()> { self.resize_states(total_num_groups); - let vals = values_and_order_cols[0].as_primitive::(); + let vals = &values_and_order_cols[0]; let mut ordering_buf = Vec::with_capacity(self.ordering_req.len()); // The overhead of calling `extract_row_at_idx_to_buf` is somewhat high, so we need to minimize its calls as much as possible. for (group_idx, idx) in self - .get_filtered_min_of_each_group( + .get_filtered_extreme_of_each_group( &values_and_order_cols[1..], group_indices, opt_filter, @@ -644,12 +677,7 @@ where )?; if self.should_update_state(group_idx, &ordering_buf)? { - self.update_state( - group_idx, - &ordering_buf, - vals.value(idx), - vals.is_null(idx), - ); + self.update_state(group_idx, &ordering_buf, vals, idx)?; } } @@ -657,11 +685,11 @@ where } fn evaluate(&mut self, emit_to: EmitTo) -> Result { - Ok(self.take_state(emit_to).0) + Ok(self.take_state(emit_to)?.0) } fn state(&mut self, emit_to: EmitTo) -> Result> { - let (val_arr, orderings, is_sets) = self.take_state(emit_to); + let (val_arr, orderings, is_sets) = self.take_state(emit_to)?; let mut result = Vec::with_capacity(self.orderings.len() + 2); result.push(val_arr); @@ -707,9 +735,9 @@ where let is_set_arr = as_boolean_array(is_set_arr)?; - let vals = values[0].as_primitive::(); + let vals = &values[0]; // The overhead of calling `extract_row_at_idx_to_buf` is somewhat high, so we need to minimize its calls as much as possible. - let groups = self.get_filtered_min_of_each_group( + let groups = self.get_filtered_extreme_of_each_group( &val_and_order_cols[1..], group_indices, opt_filter, @@ -721,12 +749,7 @@ where extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut ordering_buf)?; if self.should_update_state(group_idx, &ordering_buf)? { - self.update_state( - group_idx, - &ordering_buf, - vals.value(idx), - vals.is_null(idx), - ); + self.update_state(group_idx, &ordering_buf, vals, idx)?; } } @@ -734,12 +757,11 @@ where } fn size(&self) -> usize { - self.vals.capacity() * size_of::() - + self.null_builder.capacity() / 8 // capacity is in bits, so convert to bytes - + self.is_sets.capacity() / 8 + self.state.size() + + self.is_sets.capacity() / 8 // capacity is in bits, so convert to bytes + self.size_of_orderings - + self.min_of_each_group_buf.0.capacity() * size_of::() - + self.min_of_each_group_buf.1.capacity() / 8 + + self.extreme_of_each_group_buf.0.capacity() * size_of::() + + self.extreme_of_each_group_buf.1.capacity() / 8 } fn supports_convert_to_state(&self) -> bool { @@ -1149,114 +1171,14 @@ impl AggregateUDFImpl for LastValue { } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { - use DataType::*; - !args.order_bys.is_empty() - && matches!( - args.return_field.data_type(), - Int8 | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float16 - | Float32 - | Float64 - | Decimal32(_, _) - | Decimal64(_, _) - | Decimal128(_, _) - | Decimal256(_, _) - | Date32 - | Date64 - | Time32(_) - | Time64(_) - | Timestamp(_, _) - ) + groups_accumulator_supported(&args) } fn create_groups_accumulator( &self, args: AccumulatorArgs, ) -> Result> { - fn create_accumulator( - args: &AccumulatorArgs, - ) -> Result> - where - T: ArrowPrimitiveType + Send, - { - let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else { - return internal_err!("Groups accumulator must have an ordering."); - }; - - let ordering_dtypes = ordering - .iter() - .map(|e| e.expr.data_type(args.schema)) - .collect::>>()?; - - Ok(Box::new(FirstPrimitiveGroupsAccumulator::::try_new( - ordering, - args.ignore_nulls, - args.return_field.data_type(), - &ordering_dtypes, - false, - )?)) - } - - match args.return_field.data_type() { - DataType::Int8 => create_accumulator::(&args), - DataType::Int16 => create_accumulator::(&args), - DataType::Int32 => create_accumulator::(&args), - DataType::Int64 => create_accumulator::(&args), - DataType::UInt8 => create_accumulator::(&args), - DataType::UInt16 => create_accumulator::(&args), - DataType::UInt32 => create_accumulator::(&args), - DataType::UInt64 => create_accumulator::(&args), - DataType::Float16 => create_accumulator::(&args), - DataType::Float32 => create_accumulator::(&args), - DataType::Float64 => create_accumulator::(&args), - - DataType::Decimal32(_, _) => create_accumulator::(&args), - DataType::Decimal64(_, _) => create_accumulator::(&args), - DataType::Decimal128(_, _) => create_accumulator::(&args), - DataType::Decimal256(_, _) => create_accumulator::(&args), - - DataType::Timestamp(TimeUnit::Second, _) => { - create_accumulator::(&args) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - create_accumulator::(&args) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - create_accumulator::(&args) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - create_accumulator::(&args) - } - - DataType::Date32 => create_accumulator::(&args), - DataType::Date64 => create_accumulator::(&args), - DataType::Time32(TimeUnit::Second) => { - create_accumulator::(&args) - } - DataType::Time32(TimeUnit::Millisecond) => { - create_accumulator::(&args) - } - - DataType::Time64(TimeUnit::Microsecond) => { - create_accumulator::(&args) - } - DataType::Time64(TimeUnit::Nanosecond) => { - create_accumulator::(&args) - } - - _ => { - internal_err!( - "GroupsAccumulator not supported for last_value({})", - args.return_field.data_type() - ) - } - } + create_groups_accumulator(&args, false, self.name()) } } @@ -1550,7 +1472,7 @@ mod tests { use std::iter::repeat_with; use arrow::{ - array::{BooleanArray, Int64Array, ListArray, StringArray}, + array::{BooleanArray, Int64Array, ListArray, PrimitiveArray, StringArray}, compute::SortOptions, datatypes::Schema, }; @@ -1677,10 +1599,10 @@ mod tests { options: SortOptions::default(), }]; - let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( + let mut group_acc = FirstLastGroupsAccumulator::try_new( + PrimitiveValueState::::new(DataType::Int64), sort_keys.into(), true, - &DataType::Int64, &[DataType::Int64], true, )?; @@ -1771,10 +1693,10 @@ mod tests { options: SortOptions::default(), }]; - let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( + let mut group_acc = FirstLastGroupsAccumulator::try_new( + PrimitiveValueState::::new(DataType::Int64), sort_keys.into(), true, - &DataType::Int64, &[DataType::Int64], true, )?; @@ -1852,10 +1774,10 @@ mod tests { options: SortOptions::default(), }]; - let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( + let mut group_acc = FirstLastGroupsAccumulator::try_new( + PrimitiveValueState::::new(DataType::Int64), sort_keys.into(), true, - &DataType::Int64, &[DataType::Int64], false, )?; diff --git a/datafusion/functions-aggregate/src/first_last/state.rs b/datafusion/functions-aggregate/src/first_last/state.rs new file mode 100644 index 0000000000000..d4b405276fdfc --- /dev/null +++ b/datafusion/functions-aggregate/src/first_last/state.rs @@ -0,0 +1,439 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::mem::size_of; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, AsArray, BinaryBuilder, BinaryViewBuilder, + BooleanBufferBuilder, LargeBinaryBuilder, LargeStringBuilder, PrimitiveArray, + StringBuilder, StringViewBuilder, +}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::datatypes::DataType; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::EmitTo; + +pub(crate) trait ValueState: Send + Sync { + /// Resizes the state to accommodate `new_size` groups. + fn resize(&mut self, new_size: usize); + /// Updates the state for the specified `group_idx` using the value at `idx` from the provided `array`. + /// + /// Note: While this is not a batch interface, it is not a performance bottleneck. + /// In heavy aggregation benchmarks, the overhead of this method is typically less than 1%. + /// + /// Benchmarked queries with < 1% `update` overhead: + /// ```sql + /// -- TPC-H SF10 + /// select l_shipmode, last_value(l_partkey order by l_orderkey, l_linenumber, l_comment, l_suppkey, l_tax) + /// from 'benchmarks/data/tpch_sf10/lineitem' + /// group by l_shipmode; + /// + /// -- H2O G1_1e8 + /// select t.id1, first_value(t.id3 order by t.id2, t.id4) as r2 + /// from 'benchmarks/data/h2o/G1_1e8_1e8_100_0.parquet' as t + /// group by t.id1, t.v1; + /// ``` + fn update(&mut self, group_idx: usize, array: &ArrayRef, idx: usize) -> Result<()>; + /// Takes the accumulated state and returns it as an [`ArrayRef`], respecting the `emit_to` strategy. + fn take(&mut self, emit_to: EmitTo) -> Result; + /// Returns the estimated memory size of the state in bytes. + fn size(&self) -> usize; +} + +pub(crate) struct PrimitiveValueState { + /// Values data + vals: Vec, + nulls: BooleanBufferBuilder, + data_type: DataType, +} + +impl PrimitiveValueState { + pub(crate) fn new(data_type: DataType) -> Self { + Self { + vals: vec![], + nulls: BooleanBufferBuilder::new(0), + data_type, + } + } +} + +impl ValueState for PrimitiveValueState { + fn resize(&mut self, new_size: usize) { + self.vals.resize(new_size, T::default_value()); + self.nulls.resize(new_size); + } + + fn update(&mut self, group_idx: usize, array: &ArrayRef, idx: usize) -> Result<()> { + let array = array.as_primitive::(); + self.vals[group_idx] = array.value(idx); + self.nulls.set_bit(group_idx, !array.is_null(idx)); + Ok(()) + } + + fn take(&mut self, emit_to: EmitTo) -> Result { + let values = emit_to.take_needed(&mut self.vals); + let null_buf = NullBuffer::new(take_need(&mut self.nulls, emit_to)); + let array: PrimitiveArray = + PrimitiveArray::::new(values.into(), Some(null_buf)) + .with_data_type(self.data_type.clone()); + Ok(Arc::new(array)) + } + + fn size(&self) -> usize { + self.vals.capacity() * size_of::() + self.nulls.capacity() / 8 + } +} + +/// Stores internal state for "bytes" types (Utf8, Binary, etc.). +/// +/// This implementation is similar to `MinMaxBytesState` in `min_max_bytes.rs`, but +/// it does not reuse it for two main reasons: +/// +/// 1. **Direct Overwrite**: `MinMaxBytesState::update_batch` is tightly coupled +/// with min/max comparison logic, whereas `FirstLast` performs its own comparisons +/// externally (using ordering columns) and only needs a simple interface to +/// unconditionally set/overwrite values for specific groups. +/// 2. **Different NULL Handling**: `MinMaxBytesState` always ignores `NULL` values +/// in the input, while `BytesValueState` needs to support setting `NULL` values +/// to correctly implement `RESPECT NULLS` behavior. +/// +pub(crate) struct BytesValueState { + vals: Vec>>, + data_type: DataType, + /// The sum of the capacities of all vectors in `vals`. + total_capacity: usize, +} + +impl BytesValueState { + pub(crate) fn try_new(data_type: DataType) -> Result { + if !matches!( + data_type, + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Utf8View + | DataType::Binary + | DataType::LargeBinary + | DataType::BinaryView + ) { + return internal_err!("BytesValueState does not support {}", data_type); + } + Ok(Self { + vals: vec![], + data_type, + total_capacity: 0, + }) + } +} + +impl ValueState for BytesValueState { + fn resize(&mut self, new_size: usize) { + if new_size < self.vals.len() { + for v in self.vals[new_size..].iter().flatten() { + self.total_capacity -= v.capacity(); + } + } + self.vals.resize(new_size, None); + } + + fn update(&mut self, group_idx: usize, array: &ArrayRef, idx: usize) -> Result<()> { + if array.is_null(idx) { + self.vals[group_idx] = None; + } else { + let val = match self.data_type { + DataType::Utf8 => array.as_string::().value(idx).as_bytes(), + DataType::LargeUtf8 => array.as_string::().value(idx).as_bytes(), + DataType::Utf8View => array.as_string_view().value(idx).as_bytes(), + DataType::Binary => array.as_binary::().value(idx), + DataType::LargeBinary => array.as_binary::().value(idx), + DataType::BinaryView => array.as_binary_view().value(idx), + _ => { + return internal_err!( + "Unsupported data type for BytesValueState: {}", + self.data_type + ); + } + }; + + if let Some(v) = &mut self.vals[group_idx] { + self.total_capacity -= v.capacity(); + v.clear(); + v.extend_from_slice(val); + + self.total_capacity += v.capacity(); + } else { + let v = val.to_vec(); + self.total_capacity += v.capacity(); + self.vals[group_idx] = Some(v); + } + } + Ok(()) + } + + fn take(&mut self, emit_to: EmitTo) -> Result { + let values = emit_to.take_needed(&mut self.vals); + + let (total_len, taken_capacity) = values + .iter() + .flatten() + .fold((0, 0), |(len_acc, cap_acc), v| { + (len_acc + v.len(), cap_acc + v.capacity()) + }); + self.total_capacity -= taken_capacity; + + match self.data_type { + DataType::Utf8 => { + let mut builder = StringBuilder::with_capacity(values.len(), total_len); + for val in values { + match val { + Some(v) => builder.append_value( + // SAFETY: The bytes were originally from a valid UTF-8 array in `update` + unsafe { std::str::from_utf8_unchecked(&v) }, + ), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + DataType::LargeUtf8 => { + let mut builder = + LargeStringBuilder::with_capacity(values.len(), total_len); + for val in values { + match val { + Some(v) => builder.append_value( + // SAFETY: The bytes were originally from a valid UTF-8 array in `update` + unsafe { std::str::from_utf8_unchecked(&v) }, + ), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Utf8View => { + let mut builder = StringViewBuilder::with_capacity(values.len()); + for val in values { + match val { + Some(v) => builder.append_value( + // SAFETY: The bytes were originally from a valid UTF-8 array in `update` + unsafe { std::str::from_utf8_unchecked(&v) }, + ), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Binary => { + let mut builder = BinaryBuilder::with_capacity(values.len(), total_len); + for val in values { + match val { + Some(v) => builder.append_value(&v), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + DataType::LargeBinary => { + let mut builder = + LargeBinaryBuilder::with_capacity(values.len(), total_len); + for val in values { + match val { + Some(v) => builder.append_value(&v), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + DataType::BinaryView => { + let mut builder = BinaryViewBuilder::with_capacity(values.len()); + for val in values { + match val { + Some(v) => builder.append_value(&v), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + _ => internal_err!( + "Unsupported data type for BytesValueState: {}", + self.data_type + ), + } + } + + fn size(&self) -> usize { + self.vals.capacity() * size_of::>>() + self.total_capacity + } +} + +impl BytesValueState { + #[cfg(test)] + /// For testing only: strictly calculate the sum of capacities of all vectors in `vals`. + fn total_capacity_calculated(&self) -> usize { + self.vals.iter().flatten().map(|v| v.capacity()).sum() + } +} + +pub(crate) fn take_need( + bool_buf_builder: &mut BooleanBufferBuilder, + emit_to: EmitTo, +) -> BooleanBuffer { + let bool_buf = bool_buf_builder.finish(); + match emit_to { + EmitTo::All => bool_buf, + EmitTo::First(n) => { + // split off the first N values in seen_values + // + // TODO make this more efficient rather than two + // copies and bitwise manipulation + let first_n: BooleanBuffer = bool_buf.iter().take(n).collect(); + // reset the existing buffer + for b in bool_buf.iter().skip(n) { + bool_buf_builder.append(b); + } + first_n + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + BinaryArray, BinaryViewArray, LargeBinaryArray, LargeStringArray, StringArray, + StringViewArray, + }; + + #[test] + fn test_bytes_value_state_utf8() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::Utf8)?; + state.resize(2); + + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("hello"), + Some("world"), + Some("longer_string_than_hello"), + ])); + + state.update(0, &array, 0)?; // group 0 = "hello" + state.update(1, &array, 1)?; // group 1 = "world" + + assert_eq!(state.total_capacity, state.total_capacity_calculated()); + + // Overwrite group 0 with a longer string (checks capacity update logic) + state.update(0, &array, 2)?; + assert_eq!(state.total_capacity, state.total_capacity_calculated()); + + let result = state.take(EmitTo::All)?; + let result = result.as_string::(); + assert_eq!(result.len(), 2); + assert_eq!(result.value(0), "longer_string_than_hello"); + assert_eq!(result.value(1), "world"); + + // After take all, size should be 0 (excluding vals vector capacity) + assert_eq!(state.total_capacity, 0); + assert_eq!(state.total_capacity, state.total_capacity_calculated()); + + Ok(()) + } + + #[test] + fn test_bytes_value_state_large_utf8() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::LargeUtf8)?; + state.resize(1); + let array: ArrayRef = Arc::new(LargeStringArray::from(vec!["large_utf8"])); + state.update(0, &array, 0)?; + let result = state.take(EmitTo::All)?; + assert_eq!(result.as_string::().value(0), "large_utf8"); + Ok(()) + } + + #[test] + fn test_bytes_value_state_utf8_view() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::Utf8View)?; + state.resize(1); + let array: ArrayRef = Arc::new(StringViewArray::from(vec!["Utf8View"])); + state.update(0, &array, 0)?; + let result = state.take(EmitTo::All)?; + assert_eq!(result.as_string_view().value(0), "Utf8View"); + Ok(()) + } + + #[test] + fn test_bytes_value_state_binary() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::Binary)?; + state.resize(1); + let array: ArrayRef = Arc::new(BinaryArray::from(vec![b"binary" as &[u8]])); + state.update(0, &array, 0)?; + let result = state.take(EmitTo::All)?; + assert_eq!(result.as_binary::().value(0), b"binary"); + Ok(()) + } + + #[test] + fn test_bytes_value_state_large_binary() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::LargeBinary)?; + state.resize(1); + let array: ArrayRef = + Arc::new(LargeBinaryArray::from(vec![b"large_binary" as &[u8]])); + state.update(0, &array, 0)?; + let result = state.take(EmitTo::All)?; + assert_eq!(result.as_binary::().value(0), b"large_binary"); + Ok(()) + } + + #[test] + fn test_bytes_value_state_binary_view() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::BinaryView)?; + state.resize(1); + + let data: Vec> = vec![Some(b"long_binary_value_to_test_view")]; + let array: ArrayRef = Arc::new(BinaryViewArray::from(data)); + + state.update(0, &array, 0)?; + + let result = state.take(EmitTo::All)?; + let result = result.as_binary_view(); + assert_eq!(result.value(0), b"long_binary_value_to_test_view"); + + Ok(()) + } + + #[test] + fn test_bytes_value_state_emit_first() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::Utf8)?; + state.resize(3); + + let array: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + state.update(0, &array, 0)?; + state.update(1, &array, 1)?; + state.update(2, &array, 2)?; + + let result = state.take(EmitTo::First(2))?; + let result = result.as_string::(); + assert_eq!(result.len(), 2); + assert_eq!(result.value(0), "a"); + assert_eq!(result.value(1), "b"); + + // Remaining should be "c" + let result = state.take(EmitTo::All)?; + let result = result.as_string::(); + assert_eq!(result.len(), 1); + assert_eq!(result.value(0), "c"); + + Ok(()) + } +} diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index cf894a494ad90..3027927e1d1ff 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -8784,3 +8784,74 @@ DROP TABLE stream_test; # Restore default target partitions statement ok set datafusion.execution.target_partitions = 4; + + + +################# +# first_value on strings/binary with groups and ordering +################# + +statement ok +CREATE TABLE first_value_str_tests(id INT, sort_key INT, val TEXT) AS VALUES +(1, 10, 'apple'), +(1, 2, 'banana'), +(1, 5, 'cherry'), +(2, 100, 'dog'), +(2, 200, 'elephant'); + +# Utf8 +query IT +SELECT id, first_value(val ORDER BY sort_key) +FROM (SELECT id, sort_key, arrow_cast(val, 'Utf8') as val FROM first_value_str_tests) +GROUP BY id ORDER BY id; +---- +1 banana +2 dog + +# LargeUtf8 +query IT +SELECT id, first_value(val ORDER BY sort_key) +FROM (SELECT id, sort_key, arrow_cast(val, 'LargeUtf8') as val FROM first_value_str_tests) +GROUP BY id ORDER BY id; +---- +1 banana +2 dog + +# Utf8View +query IT +SELECT id, first_value(val ORDER BY sort_key) +FROM (SELECT id, sort_key, arrow_cast(val, 'Utf8View') as val FROM first_value_str_tests) +GROUP BY id ORDER BY id; +---- +1 banana +2 dog + +# Binary +query I? +SELECT id, first_value(val ORDER BY sort_key) +FROM (SELECT id, sort_key, arrow_cast(val, 'Binary') as val FROM first_value_str_tests) +GROUP BY id ORDER BY id; +---- +1 62616e616e61 +2 646f67 + +# LargeBinary +query I? +SELECT id, first_value(val ORDER BY sort_key) +FROM (SELECT id, sort_key, arrow_cast(val, 'LargeBinary') as val FROM first_value_str_tests) +GROUP BY id ORDER BY id; +---- +1 62616e616e61 +2 646f67 + +# BinaryView +query I? +SELECT id, first_value(val ORDER BY sort_key) +FROM (SELECT id, sort_key, arrow_cast(arrow_cast(val, 'Binary'), 'BinaryView') as val FROM first_value_str_tests) +GROUP BY id ORDER BY id; +---- +1 62616e616e61 +2 646f67 + +statement ok +DROP TABLE first_value_str_tests;