Skip to content

Commit 1f43f0b

Browse files
committed
Aggregate Fns: Mean
Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 9970bc9 commit 1f43f0b

1 file changed

Lines changed: 309 additions & 0 deletions

File tree

  • vortex-array/src/aggregate_fn/fns
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
use vortex_error::vortex_bail;
6+
use vortex_mask::Mask;
7+
8+
use crate::ArrayRef;
9+
use crate::IntoArray;
10+
use crate::aggregate_fn::AggregateFnId;
11+
use crate::aggregate_fn::AggregateFnVTable;
12+
use crate::aggregate_fn::accumulator::Accumulator;
13+
use crate::arrays::PrimitiveArray;
14+
use crate::canonical::ToCanonical;
15+
use crate::dtype::DType;
16+
use crate::dtype::NativePType;
17+
use crate::dtype::Nullability;
18+
use crate::dtype::PType;
19+
use crate::dtype::StructFields;
20+
use crate::match_each_native_ptype;
21+
use crate::scalar::Scalar;
22+
use crate::scalar_fn::EmptyOptions;
23+
24+
/// Computes the arithmetic mean of numeric values.
25+
#[derive(Clone)]
26+
pub struct Mean;
27+
28+
impl AggregateFnVTable for Mean {
29+
type Options = EmptyOptions;
30+
31+
fn id(&self) -> AggregateFnId {
32+
AggregateFnId::new_ref("vortex.mean")
33+
}
34+
35+
fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
36+
if !input_dtype.is_int() && !input_dtype.is_float() {
37+
vortex_bail!("Mean requires numeric input, got {}", input_dtype);
38+
}
39+
Ok(DType::Primitive(PType::F64, Nullability::Nullable))
40+
}
41+
42+
fn state_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> VortexResult<DType> {
43+
if !input_dtype.is_int() && !input_dtype.is_float() {
44+
vortex_bail!("Mean requires numeric input, got {}", input_dtype);
45+
}
46+
Ok(DType::Struct(
47+
StructFields::from_iter([
48+
(
49+
"sum",
50+
DType::Primitive(PType::F64, Nullability::NonNullable),
51+
),
52+
(
53+
"count",
54+
DType::Primitive(PType::U64, Nullability::NonNullable),
55+
),
56+
]),
57+
Nullability::Nullable,
58+
))
59+
}
60+
61+
fn accumulator(
62+
&self,
63+
_options: &Self::Options,
64+
input_dtype: &DType,
65+
) -> VortexResult<Box<dyn Accumulator>> {
66+
if !input_dtype.is_int() && !input_dtype.is_float() {
67+
vortex_bail!("Mean requires numeric input, got {}", input_dtype);
68+
}
69+
Ok(Box::new(MeanAccumulator::new()))
70+
}
71+
}
72+
73+
struct MeanAccumulator {
74+
sum: f64,
75+
count: u64,
76+
results: Vec<Option<f64>>,
77+
}
78+
79+
impl MeanAccumulator {
80+
fn new() -> Self {
81+
Self {
82+
sum: 0.0,
83+
count: 0,
84+
results: Vec::new(),
85+
}
86+
}
87+
}
88+
89+
/// Accumulate all-valid values of type `T` into `sum` and `count`.
90+
fn accumulate_all_valid<T: NativePType>(values: &[T], sum: &mut f64, count: &mut u64) {
91+
for v in values {
92+
*sum += v.to_f64().unwrap_or(0.0);
93+
*count += 1;
94+
}
95+
}
96+
97+
/// Accumulate partially-valid values of type `T` into `sum` and `count`.
98+
fn accumulate_with_mask<T: NativePType>(
99+
values: &[T],
100+
mask: &vortex_mask::MaskValues,
101+
sum: &mut f64,
102+
count: &mut u64,
103+
) {
104+
for (val, valid) in values.iter().zip(mask.bit_buffer().iter()) {
105+
if valid {
106+
*sum += val.to_f64().unwrap_or(0.0);
107+
*count += 1;
108+
}
109+
}
110+
}
111+
112+
impl Accumulator for MeanAccumulator {
113+
fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> {
114+
let primitive = batch.to_primitive();
115+
let validity = primitive.validity_mask()?;
116+
117+
match_each_native_ptype!(primitive.ptype(), |T| {
118+
let values = primitive.as_slice::<T>();
119+
match &validity {
120+
Mask::AllTrue(_) => accumulate_all_valid(values, &mut self.sum, &mut self.count),
121+
Mask::AllFalse(_) => {}
122+
Mask::Values(v) => accumulate_with_mask(values, v, &mut self.sum, &mut self.count),
123+
}
124+
});
125+
126+
Ok(())
127+
}
128+
129+
fn merge(&mut self, state: &Scalar) -> VortexResult<()> {
130+
if state.is_null() {
131+
return Ok(());
132+
}
133+
134+
let s = state.as_struct();
135+
let Some(sum_scalar) = s.field_by_idx(0) else {
136+
vortex_bail!("Mean state struct missing sum field at index 0");
137+
};
138+
let Some(count_scalar) = s.field_by_idx(1) else {
139+
vortex_bail!("Mean state struct missing count field at index 1");
140+
};
141+
142+
self.sum += sum_scalar
143+
.as_primitive()
144+
.typed_value::<f64>()
145+
.unwrap_or(0.0);
146+
self.count += count_scalar
147+
.as_primitive()
148+
.typed_value::<u64>()
149+
.unwrap_or(0);
150+
Ok(())
151+
}
152+
153+
fn flush(&mut self) -> VortexResult<()> {
154+
if self.count == 0 {
155+
self.results.push(None);
156+
} else {
157+
self.results.push(Some(self.sum / self.count as f64));
158+
}
159+
self.sum = 0.0;
160+
self.count = 0;
161+
Ok(())
162+
}
163+
164+
fn finish(self: Box<Self>) -> VortexResult<ArrayRef> {
165+
Ok(PrimitiveArray::from_option_iter(self.results).into_array())
166+
}
167+
}
168+
169+
#[cfg(test)]
170+
mod tests {
171+
use vortex_buffer::buffer;
172+
use vortex_error::VortexResult;
173+
174+
use crate::ArrayRef;
175+
use crate::IntoArray;
176+
use crate::aggregate_fn::AggregateFnVTable;
177+
use crate::aggregate_fn::fns::mean::Mean;
178+
use crate::arrays::PrimitiveArray;
179+
use crate::dtype::DType;
180+
use crate::dtype::Nullability;
181+
use crate::dtype::PType;
182+
use crate::dtype::StructFields;
183+
use crate::scalar::Scalar;
184+
use crate::scalar_fn::EmptyOptions;
185+
use crate::validity::Validity;
186+
187+
fn run_mean(batch: &ArrayRef) -> VortexResult<ArrayRef> {
188+
let mut acc = Mean.accumulator(&EmptyOptions, batch.dtype())?;
189+
acc.accumulate(batch)?;
190+
acc.flush()?;
191+
acc.finish()
192+
}
193+
194+
fn get_f64_value(array: &ArrayRef, idx: usize) -> VortexResult<Option<f64>> {
195+
let scalar = array.scalar_at(idx)?;
196+
Ok(scalar.as_primitive().typed_value::<f64>())
197+
}
198+
199+
#[test]
200+
fn mean_i32() -> VortexResult<()> {
201+
let arr = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array();
202+
let result = run_mean(&arr)?;
203+
assert_eq!(get_f64_value(&result, 0)?, Some(2.5));
204+
Ok(())
205+
}
206+
207+
#[test]
208+
fn mean_f64() -> VortexResult<()> {
209+
let arr =
210+
PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array();
211+
let result = run_mean(&arr)?;
212+
assert_eq!(get_f64_value(&result, 0)?, Some(2.0));
213+
Ok(())
214+
}
215+
216+
#[test]
217+
fn mean_with_nulls() -> VortexResult<()> {
218+
let arr = PrimitiveArray::from_option_iter([Some(2i32), None, Some(4)]).into_array();
219+
let result = run_mean(&arr)?;
220+
assert_eq!(get_f64_value(&result, 0)?, Some(3.0));
221+
Ok(())
222+
}
223+
224+
#[test]
225+
fn mean_all_null() -> VortexResult<()> {
226+
let arr = PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array();
227+
let result = run_mean(&arr)?;
228+
assert_eq!(get_f64_value(&result, 0)?, None);
229+
Ok(())
230+
}
231+
232+
#[test]
233+
fn mean_empty_flush() -> VortexResult<()> {
234+
let mut acc = Mean.accumulator(
235+
&EmptyOptions,
236+
&DType::Primitive(PType::I32, Nullability::NonNullable),
237+
)?;
238+
acc.flush()?;
239+
let result = acc.finish()?;
240+
assert_eq!(get_f64_value(&result, 0)?, None);
241+
Ok(())
242+
}
243+
244+
#[test]
245+
fn mean_multi_group() -> VortexResult<()> {
246+
let mut acc = Mean.accumulator(
247+
&EmptyOptions,
248+
&DType::Primitive(PType::I32, Nullability::NonNullable),
249+
)?;
250+
251+
let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array();
252+
acc.accumulate(&batch1)?;
253+
acc.flush()?;
254+
255+
let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array();
256+
acc.accumulate(&batch2)?;
257+
acc.flush()?;
258+
259+
let result = acc.finish()?;
260+
assert_eq!(get_f64_value(&result, 0)?, Some(15.0));
261+
assert_eq!(get_f64_value(&result, 1)?, Some(6.0));
262+
Ok(())
263+
}
264+
265+
#[test]
266+
fn mean_merge() -> VortexResult<()> {
267+
let mut acc = Mean.accumulator(
268+
&EmptyOptions,
269+
&DType::Primitive(PType::I32, Nullability::NonNullable),
270+
)?;
271+
272+
let state_dtype = DType::Struct(
273+
StructFields::from_iter([
274+
(
275+
"sum",
276+
DType::Primitive(PType::F64, Nullability::NonNullable),
277+
),
278+
(
279+
"count",
280+
DType::Primitive(PType::U64, Nullability::NonNullable),
281+
),
282+
]),
283+
Nullability::Nullable,
284+
);
285+
286+
let state = Scalar::struct_(
287+
state_dtype.clone(),
288+
vec![
289+
Scalar::primitive(30.0f64, Nullability::NonNullable),
290+
Scalar::primitive(3u64, Nullability::NonNullable),
291+
],
292+
);
293+
acc.merge(&state)?;
294+
295+
let state2 = Scalar::struct_(
296+
state_dtype,
297+
vec![
298+
Scalar::primitive(20.0f64, Nullability::NonNullable),
299+
Scalar::primitive(2u64, Nullability::NonNullable),
300+
],
301+
);
302+
acc.merge(&state2)?;
303+
304+
acc.flush()?;
305+
let result = acc.finish()?;
306+
assert_eq!(get_f64_value(&result, 0)?, Some(10.0));
307+
Ok(())
308+
}
309+
}

0 commit comments

Comments
 (0)