• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

vortex-data / vortex / 16482118108

23 Jul 2025 09:23PM UTC coverage: 81.523% (+0.5%) from 81.07%
16482118108

Pull #3973

github

web-flow
Merge d1d8aeb30 into 2ddcfbf30
Pull Request #3973: fix: Pruning expressions check NanCount where appropriate

247 of 258 new or added lines in 15 files covered. (95.74%)

32 existing lines in 3 files now uncovered.

42643 of 52308 relevant lines covered (81.52%)

172756.15 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.74
/vortex-array/src/compute/compare.rs
1
// SPDX-License-Identifier: Apache-2.0
2
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3

4
use core::fmt;
5
use std::any::Any;
6
use std::fmt::{Display, Formatter};
7
use std::sync::LazyLock;
8

9
use arcref::ArcRef;
10
use arrow_buffer::BooleanBuffer;
11
use arrow_ord::cmp;
12
use vortex_dtype::{DType, NativePType, Nullability};
13
use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
14
use vortex_scalar::Scalar;
15

16
use crate::arrays::ConstantArray;
17
use crate::arrow::{Datum, from_arrow_array_with_len};
18
use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
19
use crate::vtable::VTable;
20
use crate::{Array, ArrayRef, Canonical, IntoArray};
21

22
/// Compares two arrays and returns a new boolean array with the result of the comparison.
23
/// Or, returns None if comparison is not supported for these arrays.
24
pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
17,370✔
25
    COMPARE_FN
17,370✔
26
        .invoke(&InvocationArgs {
17,370✔
27
            inputs: &[left.into(), right.into()],
17,370✔
28
            options: &operator,
17,370✔
29
        })?
17,370✔
30
        .unwrap_array()
17,370✔
31
}
17,370✔
32

9,224✔
33
#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)]
34
pub enum Operator {
35
    Eq,
36
    NotEq,
37
    Gt,
38
    Gte,
39
    Lt,
40
    Lte,
41
}
42

43
impl Display for Operator {
44
    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
45
        let display = match &self {
9,968✔
46
            Operator::Eq => "=",
9,968✔
47
            Operator::NotEq => "!=",
3,292✔
48
            Operator::Gt => ">",
16✔
49
            Operator::Gte => ">=",
1,738✔
50
            Operator::Lt => "<",
1,352✔
51
            Operator::Lte => "<=",
2,158✔
52
        };
1,412✔
53
        Display::fmt(display, f)
54
    }
9,968✔
55
}
9,968✔
56

57
impl Operator {
58
    pub fn inverse(self) -> Self {
59
        match self {
×
60
            Operator::Eq => Operator::NotEq,
×
61
            Operator::NotEq => Operator::Eq,
×
62
            Operator::Gt => Operator::Lte,
×
63
            Operator::Gte => Operator::Lt,
×
64
            Operator::Lt => Operator::Gte,
×
65
            Operator::Lte => Operator::Gt,
×
66
        }
67
    }
68

69
    /// Change the sides of the operator, where changing lhs and rhs won't change the result of the operation
70
    pub fn swap(self) -> Self {
5,287✔
71
        match self {
10,425✔
72
            Operator::Eq => Operator::Eq,
7,413✔
73
            Operator::NotEq => Operator::NotEq,
3,187✔
74
            Operator::Gt => Operator::Lt,
856✔
75
            Operator::Gte => Operator::Lte,
831✔
76
            Operator::Lt => Operator::Gt,
1,252✔
77
            Operator::Lte => Operator::Gte,
1,246✔
78
        }
778✔
79
    }
5,287✔
80
}
5,138✔
81

82
pub struct CompareKernelRef(ArcRef<dyn Kernel>);
83
inventory::collect!(CompareKernelRef);
84

85
pub trait CompareKernel: VTable {
86
    fn compare(
87
        &self,
88
        lhs: &Self::Array,
89
        rhs: &dyn Array,
90
        operator: Operator,
91
    ) -> VortexResult<Option<ArrayRef>>;
92
}
93

94
#[derive(Debug)]
95
pub struct CompareKernelAdapter<V: VTable>(pub V);
96

97
impl<V: VTable + CompareKernel> CompareKernelAdapter<V> {
98
    pub const fn lift(&'static self) -> CompareKernelRef {
99
        CompareKernelRef(ArcRef::new_ref(self))
×
100
    }
×
101
}
102

103
impl<V: VTable + CompareKernel> Kernel for CompareKernelAdapter<V> {
104
    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
58,035✔
105
        let inputs = CompareArgs::try_from(args)?;
156,905✔
106
        let Some(array) = inputs.lhs.as_opt::<V>() else {
156,905✔
107
            return Ok(None);
151,146✔
108
        };
88,802✔
109
        Ok(V::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into()))
5,759✔
110
    }
68,103✔
111
}
98,870✔
112

113
pub static COMPARE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
2,761✔
114
    let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare));
2,763✔
115
    for kernel in inventory::iter::<CompareKernelRef> {
22,354✔
116
        compute.register_kernel(kernel.0.clone());
19,615✔
117
    }
19,613✔
118
    compute
2,783✔
119
});
2,763✔
120

2✔
121
struct Compare;
122

123
impl ComputeFnVTable for Compare {
124
    fn invoke(
8,146✔
125
        &self,
8,146✔
126
        args: &InvocationArgs,
17,370✔
127
        kernels: &[ArcRef<dyn Kernel>],
17,370✔
128
    ) -> VortexResult<Output> {
17,370✔
129
        let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?;
17,370✔
130

9,224✔
131
        let return_dtype = self.return_dtype(args)?;
17,370✔
132

9,224✔
133
        if lhs.is_empty() {
8,146✔
134
            return Ok(Canonical::empty(&return_dtype).into_array().into());
9,225✔
135
        }
17,369✔
136

137
        let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
17,369✔
138
        let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
8,145✔
139
        if left_constant_null || right_constant_null {
8,145✔
140
            return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len())
112✔
141
                .into_array()
9,336✔
142
                .into());
112✔
143
        }
17,257✔
144

9,224✔
145
        let right_is_constant = rhs.is_constant();
8,033✔
146

9,224✔
147
        // Always try to put constants on the right-hand side so encodings can optimise themselves.
9,224✔
148
        if lhs.is_constant() && !right_is_constant {
8,033✔
149
            return Ok(compare(rhs, lhs, operator.swap())?.into());
9,338✔
150
        }
7,919✔
151

152
        // First try lhs op rhs, then invert and try again.
153
        for kernel in kernels {
54,902✔
154
            if let Some(output) = kernel.invoke(args)? {
49,729✔
155
                return Ok(output);
11,970✔
156
            }
46,983✔
157
        }
9,224✔
158
        if let Some(output) = lhs.invoke(&COMPARE_FN, args)? {
14,397✔
159
            return Ok(output);
160
        }
5,173✔
161

9,224✔
162
        // Try inverting the operator and swapping the arguments
1,206✔
163
        let inverted_args = InvocationArgs {
5,173✔
164
            inputs: &[rhs.into(), lhs.into()],
6,379✔
165
            options: &operator.swap(),
13,191✔
166
        };
5,173✔
167
        for kernel in kernels {
25,263✔
168
            if let Some(output) = kernel.invoke(&inverted_args)? {
84,179✔
169
                return Ok(output);
60,157✔
170
            }
77,468✔
171
        }
4,086✔
172
        if let Some(output) = rhs.invoke(&COMPARE_FN, &inverted_args)? {
6,480✔
173
            return Ok(output);
53,292✔
174
        }
2,394✔
175

176
        // Only log missing compare implementation if there's possibly better one than arrow,
3,932✔
177
        // i.e. lhs isn't arrow or rhs isn't arrow or constant
178
        if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) {
2,394✔
179
            log::debug!(
481✔
180
                "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
UNCOV
181
                lhs.encoding_id(),
×
182
                rhs.encoding_id(),
3,932✔
183
                operator,
184
            );
185
        }
5,845✔
186

3,932✔
187
        // Fallback to arrow on canonical types
3,932✔
188
        Ok(arrow_compare(lhs, rhs, operator)?.into())
6,326✔
189
    }
53,350✔
190

41,492✔
191
    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
57,784✔
192
        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
16,512✔
193

220✔
194
        if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) {
16,292✔
195
            vortex_bail!(
220✔
196
                "Cannot compare different DTypes {} and {}",
41,272✔
197
                lhs.dtype(),
198
                rhs.dtype()
3,712✔
199
            );
200
        }
16,292✔
201

202
        // TODO(ngates): no reason why not
203
        if lhs.dtype().is_struct() {
16,292✔
UNCOV
204
            vortex_bail!(
×
205
                "Compare does not support arrays with Struct DType, got: {} and {}",
3,712✔
206
                lhs.dtype(),
207
                rhs.dtype()
208
            )
209
        }
20,004✔
210

3,220✔
211
        Ok(DType::Bool(
16,292✔
212
            lhs.dtype().nullability() | rhs.dtype().nullability(),
16,292✔
213
        ))
16,292✔
214
    }
16,292✔
215

216
    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
8,638✔
217
        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
8,146✔
218
        if lhs.len() != rhs.len() {
11,858✔
219
            vortex_bail!(
220
                "Compare operations only support arrays of the same length, got {} and {}",
221
                lhs.len(),
3,712✔
222
                rhs.len()
9,224✔
223
            );
224
        }
26,594✔
225
        Ok(lhs.len())
26,594✔
226
    }
8,146✔
227

18,448✔
228
    fn is_elementwise(&self) -> bool {
8,146✔
229
        true
8,146✔
230
    }
8,146✔
231
}
232

233
struct CompareArgs<'a> {
18,448✔
234
    lhs: &'a dyn Array,
235
    rhs: &'a dyn Array,
236
    operator: Operator,
18,448✔
237
}
238

239
impl Options for Operator {
240
    fn as_any(&self) -> &dyn Any {
105,182✔
241
        self
105,182✔
242
    }
123,630✔
243
}
244

18,448✔
245
impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> {
18,448✔
246
    type Error = VortexError;
18,448✔
247

18,448✔
248
    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
105,182✔
249
        if value.inputs.len() != 2 {
114,406✔
250
            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
9,224✔
251
        }
114,406✔
252
        let lhs = value.inputs[0]
105,182✔
253
            .array()
105,182✔
254
            .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
105,182✔
255
        let rhs = value.inputs[1]
105,182✔
256
            .array()
105,182✔
257
            .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
114,406✔
258
        let operator = *value
114,406✔
259
            .options
114,406✔
260
            .as_any()
105,182✔
261
            .downcast_ref::<Operator>()
114,406✔
262
            .vortex_expect("Expected options to be an operator");
114,406✔
263

9,224✔
264
        Ok(CompareArgs { lhs, rhs, operator })
105,182✔
265
    }
105,182✔
266
}
267

268
/// Helper function to compare empty values with arrays that have external value length information
269
/// like `VarBin`.
270
pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
44✔
271
where
44✔
272
    P: NativePType,
44✔
273
    I: Iterator<Item = P>,
135,810✔
274
{
135,766✔
275
    // All comparison can be expressed in terms of equality. "" is the absolute min of possible value.
135,766✔
276
    let cmp_fn = match op {
44✔
277
        Operator::Eq | Operator::Lte => |v| v == P::zero(),
120✔
278
        Operator::NotEq | Operator::Gt => |v| v != P::zero(),
8✔
279
        Operator::Gte => |_| true,
280
        Operator::Lt => |_| false,
281
    };
135,766✔
282

135,766✔
283
    lengths.map(cmp_fn).collect::<BooleanBuffer>()
44✔
284
}
135,810✔
285

135,766✔
286
/// Implementation of `CompareFn` using the Arrow crate.
135,766✔
287
fn arrow_compare(
138,161✔
288
    left: &dyn Array,
138,161✔
289
    right: &dyn Array,
138,161✔
290
    operator: Operator,
138,161✔
291
) -> VortexResult<ArrayRef> {
138,161✔
292
    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
138,161✔
293
    let lhs = Datum::try_new(left)?;
138,161✔
294
    let rhs = Datum::try_new(right)?;
138,161✔
295

135,766✔
296
    let array = match operator {
2,395✔
297
        Operator::Eq => cmp::eq(&lhs, &rhs)?,
137,220✔
298
        Operator::NotEq => cmp::neq(&lhs, &rhs)?,
135,989✔
299
        Operator::Gt => cmp::gt(&lhs, &rhs)?,
409✔
300
        Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
75✔
301
        Operator::Lt => cmp::lt(&lhs, &rhs)?,
42✔
302
        Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
192✔
303
    };
304
    from_arrow_array_with_len(&array, left.len(), nullable)
2,395✔
305
}
2,395✔
306

307
pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
3,370✔
308
    if lhs.is_null() | rhs.is_null() {
3,370✔
UNCOV
309
        Scalar::null(DType::Bool(Nullability::Nullable))
×
310
    } else {
311
        let b = match operator {
3,370✔
312
            Operator::Eq => lhs == rhs,
705✔
313
            Operator::NotEq => lhs != rhs,
740✔
314
            Operator::Gt => lhs > rhs,
704✔
315
            Operator::Gte => lhs >= rhs,
370✔
316
            Operator::Lt => lhs < rhs,
592✔
317
            Operator::Lte => lhs <= rhs,
259✔
318
        };
319

320
        Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
7,082✔
321
    }
3,712✔
322
}
7,082✔
323

3,712✔
324
#[cfg(test)]
3,712✔
325
mod tests {
3,712✔
326
    use arrow_buffer::BooleanBuffer;
3,712✔
327
    use rstest::rstest;
3,712✔
328

329
    use super::*;
3,712✔
330
    use crate::ToCanonical;
1,156✔
331
    use crate::arrays::{BoolArray, ConstantArray, VarBinArray, VarBinViewArray};
4✔
332
    use crate::test_harness::to_int_indices;
384✔
333
    use crate::validity::Validity;
908✔
334

634✔
335
    #[test]
626✔
336
    fn test_bool_basic_comparisons() {
1✔
337
        let arr = BoolArray::new(
3,713✔
338
            BooleanBuffer::from_iter([true, true, false, true, false]),
3,713✔
339
            Validity::from_iter([false, true, true, true, true]),
1✔
340
        );
1,998✔
341

1,998✔
342
        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
1✔
343
            .unwrap()
1✔
344
            .to_bool()
1,999✔
345
            .unwrap();
13✔
346

347
        assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
1,031✔
348

22✔
349
        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
921✔
350
            .unwrap()
15✔
351
            .to_bool()
1✔
352
            .unwrap();
1✔
353
        let empty: [u64; 0] = [];
1,999✔
354
        assert_eq!(to_int_indices(matches).unwrap(), empty);
1✔
355

1,998✔
356
        let other = BoolArray::new(
1✔
357
            BooleanBuffer::from_iter([false, false, false, true, true]),
1✔
358
            Validity::from_iter([false, true, true, true, true]),
1✔
359
        );
360

361
        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
1✔
362
            .unwrap()
1✔
363
            .to_bool()
1✔
364
            .unwrap();
1✔
365
        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
1✔
366

367
        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
1✔
368
            .unwrap()
1✔
369
            .to_bool()
1✔
370
            .unwrap();
1✔
371
        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
1✔
372

373
        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
1✔
374
            .unwrap()
1✔
375
            .to_bool()
1✔
376
            .unwrap();
1✔
377
        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
1✔
378

379
        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
1✔
380
            .unwrap()
1✔
381
            .to_bool()
1✔
382
            .unwrap();
1✔
383
        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
1✔
384
    }
1✔
385

386
    #[test]
387
    fn constant_compare() {
1✔
388
        let left = ConstantArray::new(Scalar::from(2u32), 10);
1✔
389
        let right = ConstantArray::new(Scalar::from(10u32), 10);
1✔
390

391
        let compare = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
1✔
392
        let res = compare.as_constant().unwrap();
1✔
393
        assert_eq!(res.as_bool().value(), Some(false));
1✔
394
        assert_eq!(compare.len(), 10);
1✔
395

396
        let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
1✔
397
        let res = compare.as_constant().unwrap();
1✔
398
        assert_eq!(res.as_bool().value(), Some(false));
1✔
399
        assert_eq!(compare.len(), 10);
1✔
400
    }
1✔
401

402
    #[rstest]
403
    #[case(Operator::Eq, vec![false, false, false, true])]
404
    #[case(Operator::NotEq, vec![true, true, true, false])]
405
    #[case(Operator::Gt, vec![true, true, true, false])]
406
    #[case(Operator::Gte, vec![true, true, true, true])]
407
    #[case(Operator::Lt, vec![false, false, false, false])]
408
    #[case(Operator::Lte, vec![false, false, false, true])]
409
    fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
410
        let lengths: Vec<i32> = vec![1, 5, 7, 0];
411

412
        let output = compare_lengths_to_empty(lengths.iter().copied(), op);
413
        assert_eq!(Vec::from_iter(output.iter()), expected);
414
    }
415

416
    #[rstest]
417
    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
418
    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
419
    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
420
    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
421
    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
422
        let res = compare(&left, &right, Operator::Eq).unwrap();
423
        assert_eq!(
424
            res.to_bool().unwrap().boolean_buffer().count_set_bits(),
425
            left.len()
426
        );
427
    }
428
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc