• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

vortex-data / vortex / 16483072032

23 Jul 2025 10:22PM UTC coverage: 81.63% (+0.6%) from 81.07%
16483072032

Pull #3973

github

web-flow
Merge 1079410ed into 2ddcfbf30
Pull Request #3973: fix: Pruning expressions check NanCount where appropriate

254 of 267 new or added lines in 16 files covered. (95.13%)

32 existing lines in 4 files now uncovered.

42713 of 52325 relevant lines covered (81.63%)

172810.71 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.51
/vortex-array/src/compute/mod.rs
1
// SPDX-License-Identifier: Apache-2.0
2
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3

4
//! Compute kernels on top of Vortex Arrays.
5
//!
6
//! We aim to provide a basic set of compute kernels that can be used to efficiently index, slice,
7
//! and filter Vortex Arrays in their encoded forms.
8
//!
9
//! Every array encoding has the ability to implement their own efficient implementations of these
10
//! operators, else we will decode, and perform the equivalent operator from Arrow.
11

12
use std::any::{Any, type_name};
13
use std::fmt::{Debug, Formatter};
14

15
use arcref::ArcRef;
16
pub use between::*;
17
pub use boolean::*;
18
pub use cast::*;
19
pub use compare::*;
20
pub use fill_null::*;
21
pub use filter::*;
22
pub use invert::*;
23
pub use is_constant::*;
24
pub use is_sorted::*;
25
use itertools::Itertools;
26
pub use like::*;
27
pub use list_contains::*;
28
pub use mask::*;
29
pub use min_max::*;
30
pub use nan_count::*;
31
pub use numeric::*;
32
use parking_lot::RwLock;
33
pub use sum::*;
34
pub use take::*;
35
use vortex_dtype::DType;
36
use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
37
use vortex_mask::Mask;
38
use vortex_scalar::Scalar;
39

40
use crate::builders::ArrayBuilder;
41
use crate::{Array, ArrayRef};
42

43
#[cfg(feature = "arbitrary")]
44
mod arbitrary;
45
mod between;
46
mod boolean;
47
mod cast;
48
mod compare;
49
#[cfg(feature = "test-harness")]
50
pub mod conformance;
51
mod fill_null;
52
mod filter;
53
mod invert;
54
mod is_constant;
55
mod is_sorted;
56
mod like;
57
mod list_contains;
58
mod mask;
59
mod min_max;
60
mod nan_count;
61
mod numeric;
62
mod sum;
63
mod take;
64

65
/// An instance of a compute function holding the implementation vtable and a set of registered
66
/// compute kernels.
67
pub struct ComputeFn {
68
    id: ArcRef<str>,
69
    vtable: ArcRef<dyn ComputeFnVTable>,
70
    kernels: RwLock<Vec<ArcRef<dyn Kernel>>>,
71
}
72

73
impl ComputeFn {
74
    /// Create a new compute function from the given [`ComputeFnVTable`].
75
    pub fn new(id: ArcRef<str>, vtable: ArcRef<dyn ComputeFnVTable>) -> Self {
37,129✔
76
        Self {
37,129✔
77
            id,
37,129✔
78
            vtable,
37,129✔
79
            kernels: Default::default(),
37,129✔
80
        }
37,129✔
81
    }
37,129✔
82

83
    /// Returns the string identifier of the compute function.
84
    pub fn id(&self) -> &ArcRef<str> {
×
85
        &self.id
×
86
    }
×
87

88
    /// Register a kernel for the compute function.
89
    pub fn register_kernel(&self, kernel: ArcRef<dyn Kernel>) {
328,457✔
90
        self.kernels.write().push(kernel);
328,457✔
91
    }
328,457✔
92

93
    /// Invokes the compute function with the given arguments.
94
    pub fn invoke(&self, args: &InvocationArgs) -> VortexResult<Output> {
264,994✔
95
        // Perform some pre-condition checks against the arguments and the function properties.
189,728✔
96
        if self.is_elementwise() {
454,722✔
97
            // For element-wise functions, all input arrays must be the same length.
98
            if !args
234,511✔
99
                .inputs
68,207✔
100
                .iter()
44,783✔
101
                .filter_map(|input| input.array())
103,871✔
102
                .map(|array| array.len())
82,988✔
103
                .all_equal()
68,207✔
104
            {
45,358✔
105
                vortex_bail!(
36,204✔
106
                    "Compute function {} is elementwise but input arrays have different lengths",
23,424✔
107
                    self.id
108
                );
109
            }
44,783✔
110
        }
220,211✔
111

112
        let expected_dtype = self.vtable.return_dtype(args)?;
288,418✔
113
        let expected_len = self.vtable.return_len(args)?;
431,297✔
114

115
        let output = self.vtable.invoke(args, &self.kernels.read())?;
454,721✔
116

189,728✔
117
        if output.dtype() != &expected_dtype {
453,189✔
118
            vortex_bail!(
189,728✔
119
                "Internal error: compute function {} returned a result of type {} but expected {}\n{}",
120
                self.id,
189,728✔
121
                output.dtype(),
189,728✔
122
                &expected_dtype,
123
                args.inputs
189,728✔
124
                    .iter()
×
125
                    .filter_map(|input| input.array())
×
126
                    .format_with(",", |array, f| f(&array.display_tree()))
127
            );
128
        }
263,461✔
129
        if output.len() != expected_len {
263,461✔
130
            vortex_bail!(
×
131
                "Internal error: compute function {} returned a result of length {} but expected {}",
×
132
                self.id,
133
                output.len(),
134
                expected_len
189,728✔
135
            );
189,728✔
136
        }
263,461✔
137

138
        Ok(output)
263,461✔
139
    }
264,994✔
140

141
    /// Compute the return type of the function given the input arguments.
142
    pub fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
189,743✔
143
        self.vtable.return_dtype(args)
15✔
144
    }
189,743✔
145

189,728✔
146
    /// Compute the return length of the function given the input arguments.
147
    pub fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
UNCOV
148
        self.vtable.return_len(args)
×
UNCOV
149
    }
×
150

151
    /// Returns whether the compute function is elementwise, i.e. the output is the same shape as
152
    pub fn is_elementwise(&self) -> bool {
265,024✔
153
        // TODO(ngates): should this just be a constant passed in the constructor?
154
        self.vtable.is_elementwise()
265,024✔
155
    }
265,024✔
156

157
    /// Returns the compute function's kernels.
158
    pub fn kernels(&self) -> Vec<ArcRef<dyn Kernel>> {
193,763✔
159
        self.kernels.read().to_vec()
4,035✔
160
    }
193,763✔
161
}
189,728✔
162

163
/// VTable for the implementation of a compute function.
164
pub trait ComputeFnVTable: 'static + Send + Sync {
526✔
165
    /// Invokes the compute function entry-point with the given input arguments and options.
526✔
166
    ///
526✔
167
    /// The entry-point logic can short-circuit compute using statistics, update result array
168
    /// statistics, search for relevant compute kernels, and canonicalize the inputs in order
169
    /// to successfully compute a result.
170
    fn invoke(&self, args: &InvocationArgs, kernels: &[ArcRef<dyn Kernel>])
171
    -> VortexResult<Output>;
172

173
    /// Computes the return type of the function given the input arguments.
174
    ///
175
    /// All kernel implementations will be validated to return the [`DType`] as computed here.
176
    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType>;
177

178
    /// Computes the return length of the function given the input arguments.
179
    ///
180
    /// All kernel implementations will be validated to return the len as computed here.
181
    /// Scalars are considered to have length 1.
182
    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize>;
183

184
    /// Returns whether the function operates elementwise, i.e. the output is the same shape as the
185
    /// input and no information is shared between elements.
186
    ///
187
    /// Examples include `add`, `subtract`, `and`, `cast`, `fill_null` etc.
188
    /// Examples that are not elementwise include `sum`, `count`, `min`, `fill_forward` etc.
189
    ///
190
    /// All input arrays to an elementwise function *must* have the same length.
191
    fn is_elementwise(&self) -> bool;
192
}
193

194
/// Arguments to a compute function invocation.
195
#[derive(Clone)]
196
pub struct InvocationArgs<'a> {
197
    pub inputs: &'a [Input<'a>],
198
    pub options: &'a dyn Options,
199
}
200

201
/// For unary compute functions, it's useful to just have this short-cut.
202
pub struct UnaryArgs<'a, O: Options> {
203
    pub array: &'a dyn Array,
204
    pub options: &'a O,
205
}
206

207
impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for UnaryArgs<'a, O> {
208
    type Error = VortexError;
209

210
    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
410,407✔
211
        if value.inputs.len() != 1 {
410,407✔
212
            vortex_bail!("Expected 1 input, found {}", value.inputs.len());
213
        }
410,407✔
214
        let array = value.inputs[0]
410,407✔
215
            .array()
410,407✔
216
            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
503,235✔
217
        let options =
503,235✔
218
            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
410,407✔
219
                vortex_err!("Expected options to be of type {}", type_name::<O>())
92,828✔
220
            })?;
92,828✔
221
        Ok(UnaryArgs { array, options })
503,235✔
222
    }
503,235✔
223
}
92,828✔
224

92,828✔
225
/// For binary compute functions, it's useful to just have this short-cut.
226
pub struct BinaryArgs<'a, O: Options> {
227
    pub lhs: &'a dyn Array,
92,828✔
228
    pub rhs: &'a dyn Array,
92,828✔
229
    pub options: &'a O,
230
}
231

232
impl<'a, O: Options> TryFrom<&InvocationArgs<'a>> for BinaryArgs<'a, O> {
233
    type Error = VortexError;
234

235
    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
1,106✔
236
        if value.inputs.len() != 2 {
1,106✔
237
            vortex_bail!("Expected 2 input, found {}", value.inputs.len());
238
        }
1,106✔
239
        let lhs = value.inputs[0]
1,106✔
240
            .array()
1,106✔
241
            .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
1,338✔
242
        let rhs = value.inputs[1]
1,338✔
243
            .array()
1,106✔
244
            .ok_or_else(|| vortex_err!("Expected input 1 to be an array"))?;
1,338✔
245
        let options =
1,338✔
246
            value.options.as_any().downcast_ref::<O>().ok_or_else(|| {
1,338✔
247
                vortex_err!("Expected options to be of type {}", type_name::<O>())
232✔
248
            })?;
232✔
249
        Ok(BinaryArgs { lhs, rhs, options })
1,338✔
250
    }
1,338✔
251
}
232✔
252

232✔
253
/// Input to a compute function.
254
pub enum Input<'a> {
255
    Scalar(&'a Scalar),
232✔
256
    Array(&'a dyn Array),
232✔
257
    Mask(&'a Mask),
258
    Builder(&'a mut dyn ArrayBuilder),
259
    DType(&'a DType),
260
}
261

262
impl Debug for Input<'_> {
263
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
264
        let mut f = f.debug_struct("Input");
265
        match self {
266
            Input::Scalar(scalar) => f.field("Scalar", scalar),
267
            Input::Array(array) => f.field("Array", array),
268
            Input::Mask(mask) => f.field("Mask", mask),
269
            Input::Builder(builder) => f.field("Builder", &builder.len()),
×
270
            Input::DType(dtype) => f.field("DType", dtype),
×
271
        };
272
        f.finish()
×
273
    }
×
274
}
275

276
impl<'a> From<&'a dyn Array> for Input<'a> {
277
    fn from(value: &'a dyn Array) -> Self {
406,045✔
278
        Input::Array(value)
406,045✔
279
    }
406,045✔
280
}
281

282
impl<'a> From<&'a Scalar> for Input<'a> {
283
    fn from(value: &'a Scalar) -> Self {
270,155✔
284
        Input::Scalar(value)
270,155✔
285
    }
270,155✔
286
}
287

288
impl<'a> From<&'a Mask> for Input<'a> {
289
    fn from(value: &'a Mask) -> Self {
9,996✔
290
        Input::Mask(value)
9,996✔
291
    }
9,996✔
292
}
293

294
impl<'a> From<&'a DType> for Input<'a> {
295
    fn from(value: &'a DType) -> Self {
26,047✔
296
        Input::DType(value)
26,047✔
297
    }
26,047✔
298
}
299

300
impl<'a> Input<'a> {
301
    pub fn scalar(&self) -> Option<&'a Scalar> {
12,073✔
302
        match self {
12,073✔
303
            Input::Scalar(scalar) => Some(*scalar),
12,073✔
304
            _ => None,
305
        }
306
    }
6,727✔
307

11,424✔
308
    pub fn array(&self) -> Option<&'a dyn Array> {
1,609,524✔
309
        match self {
1,609,524✔
310
            Input::Array(array) => Some(*array),
1,577,217✔
311
            _ => None,
20,883✔
312
        }
11,424✔
313
    }
1,598,100✔
314

1,177,094✔
315
    pub fn mask(&self) -> Option<&'a Mask> {
1,231,960✔
316
        match self {
1,222,806✔
317
            Input::Mask(mask) => Some(*mask),
64,020✔
318
            _ => None,
319
        }
1,177,094✔
320
    }
54,866✔
321

131,006✔
322
    pub fn builder(&'a mut self) -> Option<&'a mut dyn ArrayBuilder> {
131,006✔
323
        match self {
131,006✔
324
            Input::Builder(builder) => Some(*builder),
×
325
            _ => None,
326
        }
131,006✔
327
    }
328

329
    pub fn dtype(&self) -> Option<&'a DType> {
71,473✔
330
        match self {
71,473✔
331
            Input::DType(dtype) => Some(*dtype),
71,473✔
332
            _ => None,
333
        }
334
    }
71,473✔
335
}
18,212✔
336

18,212✔
337
/// Output from a compute function.
18,212✔
338
#[derive(Debug)]
339
pub enum Output {
340
    Scalar(Scalar),
18,212✔
341
    Array(ArrayRef),
342
}
343

344
#[allow(clippy::len_without_is_empty)]
345
impl Output {
346
    pub fn dtype(&self) -> &DType {
263,461✔
347
        match self {
263,461✔
348
            Output::Scalar(scalar) => scalar.dtype(),
200,015✔
349
            Output::Array(array) => array.dtype(),
63,446✔
350
        }
351
    }
263,461✔
352

189,728✔
353
    pub fn len(&self) -> usize {
453,189✔
354
        match self {
397,239✔
355
            Output::Scalar(_) => 1,
255,965✔
356
            Output::Array(array) => array.len(),
63,446✔
357
        }
189,728✔
358
    }
263,461✔
359

189,728✔
360
    pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
483,104✔
361
        match self {
427,154✔
362
            Output::Array(_) => vortex_bail!("Expected array output, got Array"),
55,950✔
363
            Output::Scalar(scalar) => Ok(scalar),
293,376✔
364
        }
189,728✔
365
    }
293,376✔
366

189,160✔
367
    pub fn unwrap_array(self) -> VortexResult<ArrayRef> {
256,604✔
368
        match self {
67,444✔
369
            Output::Array(array) => Ok(array),
256,604✔
370
            Output::Scalar(_) => vortex_bail!("Expected array output, got Scalar"),
371
        }
189,160✔
372
    }
67,444✔
373
}
56,464✔
374

56,464✔
375
impl From<ArrayRef> for Output {
56,464✔
376
    fn from(value: ArrayRef) -> Self {
67,137✔
377
        Output::Array(value)
67,137✔
378
    }
123,601✔
379
}
380

381
impl From<Scalar> for Output {
382
    fn from(value: Scalar) -> Self {
349,782✔
383
        Output::Scalar(value)
349,782✔
384
    }
349,782✔
385
}
386

387
/// Options for a compute function invocation.
388
pub trait Options: 'static {
189,160✔
389
    fn as_any(&self) -> &dyn Any;
189,160✔
390
}
189,160✔
391

392
impl Options for () {
393
    fn as_any(&self) -> &dyn Any {
411,513✔
394
        self
411,513✔
395
    }
411,513✔
396
}
397

398
/// Compute functions can ask arrays for compute kernels for a given invocation.
399
///
93,060✔
400
/// The kernel is invoked with the input arguments and options, and can return `None` if it is
93,060✔
401
/// unable to compute the result for the given inputs due to missing implementation logic.
93,060✔
402
/// For example, if kernel doesn't support the `LTE` operator. By returning `None`, the kernel
403
/// is indicating that it cannot compute the result for the given inputs, and another kernel should
404
/// be tried. *Not* that the given inputs are invalid for the compute function.
405
///
406
/// If the kernel fails to compute a result, it should return a `Some` with the error.
407
pub trait Kernel: 'static + Send + Sync + Debug {
408
    /// Invokes the kernel with the given input arguments and options.
409
    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>>;
410
}
411

412
/// Register a kernel for a compute function.
413
/// See each compute function for the correct type of kernel to register.
414
#[macro_export]
415
macro_rules! register_kernel {
416
    ($T:expr) => {
417
        $crate::aliases::inventory::submit!($T);
418
    };
419
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc