• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

vortex-data / vortex / 16606068518

29 Jul 2025 07:55PM UTC coverage: 82.699% (+0.02%) from 82.684%
16606068518

Pull #4057

github

web-flow
Merge c0724f3d4 into 6fb0f3e49
Pull Request #4057: feat: `ArrayEquals` kernel

226 of 265 new or added lines in 2 files covered. (85.28%)

52 existing lines in 1 file now uncovered.

45442 of 54949 relevant lines covered (82.7%)

185640.1 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.98
/vortex-array/src/compute/array_equals.rs
1
// SPDX-License-Identifier: Apache-2.0
2
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3

4
use std::any::Any;
5
use std::sync::LazyLock;
6

7
use arcref::ArcRef;
8
use vortex_dtype::{DType, Nullability};
9
use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
10
use vortex_scalar::Scalar;
11

12
use crate::Array;
13
use crate::compute::{
14
    ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Options, Output, compare,
15
};
16
use crate::stats::{Precision, Stat, StatsProvider};
17
use crate::vtable::VTable;
18

19
pub fn array_equals(left: &dyn Array, right: &dyn Array) -> VortexResult<bool> {
17✔
20
    array_equals_opts(left, right, false)
17✔
21
}
17✔
22

23
pub fn array_equals_opts(
19✔
24
    left: &dyn Array,
19✔
25
    right: &dyn Array,
19✔
26
    ignore_nullability: bool,
19✔
27
) -> VortexResult<bool> {
19✔
28
    Ok(ARRAY_EQUALS_FN
19✔
29
        .invoke(&InvocationArgs {
19✔
30
            inputs: &[left.into(), right.into()],
19✔
31
            options: &ArrayEqualsOptions { ignore_nullability },
19✔
32
        })?
19✔
33
        .unwrap_scalar()?
19✔
34
        .as_bool()
19✔
35
        .value()
19✔
36
        .vortex_expect("non-nullable"))
19✔
37
}
19✔
38

39
#[derive(Clone, Copy)]
40
struct ArrayEqualsOptions {
41
    ignore_nullability: bool,
42
}
43

44
impl Options for ArrayEqualsOptions {
45
    fn as_any(&self) -> &dyn Any {
19✔
46
        self
19✔
47
    }
19✔
48
}
49

50
pub static ARRAY_EQUALS_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
11✔
51
    let compute = ComputeFn::new("array_equals".into(), ArcRef::new_ref(&ArrayEquals));
11✔
52
    for kernel in inventory::iter::<ArrayEqualsKernelRef> {
11✔
NEW
53
        compute.register_kernel(kernel.0.clone());
×
NEW
54
    }
×
55
    compute
11✔
56
});
11✔
57

58
struct ArrayEquals;
59
impl ComputeFnVTable for ArrayEquals {
60
    fn invoke(
19✔
61
        &self,
19✔
62
        args: &InvocationArgs,
19✔
63
        kernels: &[ArcRef<dyn Kernel>],
19✔
64
    ) -> VortexResult<Output> {
19✔
65
        let ArrayEqualsArgs {
66
            left,
19✔
67
            right,
19✔
68
            ignore_nullability,
19✔
69
        } = ArrayEqualsArgs::try_from(args)?;
19✔
70

71
        if ignore_nullability && !left.dtype().eq_ignore_nullability(right.dtype()) {
19✔
NEW
72
            return Ok(Scalar::from(false).into());
×
73
        }
19✔
74

75
        if !ignore_nullability && !left.dtype().eq(right.dtype()) {
19✔
76
            return Ok(Scalar::from(false).into());
1✔
77
        }
18✔
78

79
        if left.len() != right.len() {
18✔
80
            return Ok(Scalar::from(false).into());
1✔
81
        }
17✔
82

83
        if let Some(l_scalar) = left.as_constant()
17✔
84
            && let Some(r_scalar) = right.as_constant()
3✔
85
        {
86
            return Ok(Scalar::from(l_scalar.eq(&r_scalar)).into());
3✔
87
        }
14✔
88

89
        if left.is_empty() && right.is_empty() {
14✔
90
            return Ok(Scalar::from(true).into());
1✔
91
        }
13✔
92

93
        for stat in [
104✔
94
            Stat::IsConstant,
13✔
95
            Stat::IsSorted,
13✔
96
            Stat::IsStrictSorted,
13✔
97
            Stat::Max, // todo: can we do that with e.g. float errors?
13✔
98
            Stat::Min,
13✔
99
            Stat::Sum,
13✔
100
            Stat::NullCount,
13✔
101
            Stat::NaNCount,
13✔
102
            // No Stat::UncompressedSizeInBytes because arrays may physically differ and has a different metric
103
        ] {
104
            let Some(Precision::Exact(left_v)) = left.statistics().get(stat) else {
104✔
105
                continue;
91✔
106
            };
107

108
            let Some(Precision::Exact(right_v)) = right.statistics().get(stat) else {
13✔
109
                continue;
13✔
110
            };
111

NEW
112
            if !left_v.eq(&right_v) {
×
NEW
113
                return Ok(Scalar::from(false).into());
×
NEW
114
            }
×
115
        }
116

117
        let args = InvocationArgs {
13✔
118
            inputs: &[left.into(), right.into()],
13✔
119
            options: &ArrayEqualsOptions { ignore_nullability },
13✔
120
        };
13✔
121

122
        for kernel in kernels {
13✔
NEW
123
            if let Some(output) = kernel.invoke(&args)? {
×
NEW
124
                return Ok(output);
×
NEW
125
            }
×
126
        }
127

128
        if let Some(output) = left.invoke(&ARRAY_EQUALS_FN, &args)? {
13✔
NEW
129
            return Ok(output);
×
130
        }
13✔
131

132
        // Try swapping arguments
133
        let swapped_args = InvocationArgs {
13✔
134
            inputs: &[right.into(), left.into()],
13✔
135
            options: &ArrayEqualsOptions { ignore_nullability },
13✔
136
        };
13✔
137
        if let Some(output) = right.invoke(&ARRAY_EQUALS_FN, &swapped_args)? {
13✔
NEW
138
            return Ok(output);
×
139
        }
13✔
140

141
        // Try canonical arrays if not already canonical
142
        let canonical_equals = if !left.is_canonical() || !right.is_canonical() {
13✔
143
            let left_canonical = left.to_canonical()?;
2✔
144
            let right_canonical = right.to_canonical()?;
2✔
145

146
            array_equals_opts(
2✔
147
                left_canonical.as_ref(),
2✔
148
                right_canonical.as_ref(),
2✔
149
                ignore_nullability,
2✔
NEW
150
            )?
×
151
        } else {
152
            // Fallback to chunked comparison
153
            const BATCH_SIZE: usize = 65536; // 64K elements per batch
154

155
            let mut offset = 0;
11✔
156
            while offset < left.len() {
18✔
157
                let end = (offset + BATCH_SIZE).min(left.len());
13✔
158

159
                let left_slice = left.slice(offset, end)?;
13✔
160
                let right_slice = right.slice(offset, end)?;
13✔
161

162
                let compare_result = compare(&left_slice, &right_slice, Operator::Eq)?;
13✔
163

164
                // For array equality, we need to check if all values are equal
165
                // This includes treating NULL == NULL as true
166
                let all_equal = if let Some(constant_scalar) = compare_result.as_constant() {
13✔
167
                    // If constant is true, all are equal
168
                    constant_scalar.is_valid() && constant_scalar.as_bool().value() == Some(true)
7✔
169
                } else {
170
                    // Not constant - need to check each value
171
                    let mut found_inequality = false;
6✔
172
                    for i in 0..compare_result.len() {
34,480✔
173
                        let cmp_scalar = compare_result.scalar_at(i)?;
34,480✔
174
                        if cmp_scalar.is_valid() && cmp_scalar.as_bool().value() == Some(false) {
34,480✔
175
                            // Found a definite inequality
176
                            found_inequality = true;
4✔
177
                            break;
4✔
178
                        }
34,476✔
179
                        // For null comparison results, we need to check the original values
180
                        if cmp_scalar.is_null() {
34,476✔
181
                            let left_val = left_slice.scalar_at(i)?;
2✔
182
                            let right_val = right_slice.scalar_at(i)?;
2✔
183
                            // If both are null, they're equal; if only one is null, they're not
184
                            if left_val.is_null() != right_val.is_null() {
2✔
185
                                found_inequality = true;
1✔
186
                                break;
1✔
187
                            }
1✔
188
                        }
34,474✔
189
                    }
190
                    !found_inequality
6✔
191
                };
192

193
                if !all_equal {
13✔
194
                    return Ok(Scalar::from(false).into());
6✔
195
                }
7✔
196

197
                offset = end;
7✔
198
            }
199

200
            true
5✔
201
        };
202

203
        Ok(Scalar::from(canonical_equals).into())
7✔
204
    }
19✔
205

206
    fn return_dtype(&self, _args: &InvocationArgs) -> VortexResult<DType> {
19✔
207
        Ok(DType::Bool(Nullability::NonNullable))
19✔
208
    }
19✔
209

210
    fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
19✔
211
        Ok(1)
19✔
212
    }
19✔
213

214
    fn is_elementwise(&self) -> bool {
19✔
215
        false
19✔
216
    }
19✔
217
}
218

219
// todo: statistics
220
pub trait ArrayEqualsKernel: VTable {
221
    fn compare_array(
222
        &self,
223
        array: &Self::Array,
224
        other: &dyn Array,
225
        ignore_nullability: bool,
226
    ) -> VortexResult<Option<bool>>;
227
}
228

229
struct ArrayEqualsArgs<'a> {
230
    left: &'a dyn Array,
231
    right: &'a dyn Array,
232
    ignore_nullability: bool,
233
}
234

235
impl<'a> TryFrom<&InvocationArgs<'a>> for ArrayEqualsArgs<'a> {
236
    type Error = VortexError;
237

238
    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
19✔
239
        if value.inputs.len() != 2 {
19✔
NEW
240
            vortex_bail!(
×
NEW
241
                "ArrayEquals function requires two arguments, got {}",
×
NEW
242
                value.inputs.len()
×
243
            );
244
        }
19✔
245
        let left = value.inputs[0]
19✔
246
            .array()
19✔
247
            .ok_or_else(|| vortex_err!("First argument must be an array"))?;
19✔
248

249
        let right = value.inputs[1]
19✔
250
            .array()
19✔
251
            .ok_or_else(|| vortex_err!("Second argument must be an array"))?;
19✔
252

253
        let options = value
19✔
254
            .options
19✔
255
            .as_any()
19✔
256
            .downcast_ref::<ArrayEqualsOptions>()
19✔
257
            .ok_or_else(|| vortex_err!("Invalid options type for array equals function"))?;
19✔
258

259
        Ok(ArrayEqualsArgs {
19✔
260
            left,
19✔
261
            right,
19✔
262
            ignore_nullability: options.ignore_nullability,
19✔
263
        })
19✔
264
    }
19✔
265
}
266

267
#[derive(Debug)]
268
pub struct ArrayEqualsKernelAdapter<V: VTable>(pub V);
269

270
pub struct ArrayEqualsKernelRef(ArcRef<dyn Kernel>);
271
inventory::collect!(ArrayEqualsKernelRef);
272

273
impl<V: VTable + ArrayEqualsKernel> ArrayEqualsKernelAdapter<V> {
NEW
274
    pub const fn lift(&'static self) -> ArrayEqualsKernelRef {
×
NEW
275
        ArrayEqualsKernelRef(ArcRef::new_ref(self))
×
NEW
276
    }
×
277
}
278

279
impl<V: VTable + ArrayEqualsKernel> Kernel for ArrayEqualsKernelAdapter<V> {
NEW
280
    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
×
281
        let ArrayEqualsArgs {
NEW
282
            left,
×
NEW
283
            right,
×
NEW
284
            ignore_nullability,
×
NEW
285
        } = ArrayEqualsArgs::try_from(args)?;
×
286

NEW
287
        let Some(left) = left.as_opt::<V>() else {
×
NEW
288
            return Ok(None);
×
289
        };
290

NEW
291
        let is_equal = V::compare_array(&self.0, left, right, ignore_nullability)?;
×
NEW
292
        Ok(is_equal.map(|b| Scalar::from(b).into()))
×
NEW
293
    }
×
294
}
295

296
#[cfg(test)]
297
mod tests {
298
    use super::*;
299
    use crate::arrays::{BoolArray, ConstantArray, PrimitiveArray, VarBinArray};
300
    use crate::validity::Validity;
301
    use vortex_dtype::{DType, Nullability};
302

303
    #[test]
304
    fn test_simple_equals() {
1✔
305
        let arr1 = PrimitiveArray::from_iter(vec![1i32, 2, 3, 4, 5]);
1✔
306
        let arr2 = PrimitiveArray::from_iter(vec![1i32, 2, 3, 4, 5]);
1✔
307
        let arr3 = PrimitiveArray::from_iter(vec![1i32, 2, 3, 4, 6]);
1✔
308

309
        assert!(array_equals(arr1.as_ref(), arr2.as_ref()).unwrap());
1✔
310
        assert!(!array_equals(arr1.as_ref(), arr3.as_ref()).unwrap());
1✔
311
    }
1✔
312

313
    #[test]
314
    fn test_stats_comparison() {
1✔
315
        // Arrays with different stats should be detected as different early
316
        let arr1 = PrimitiveArray::from_iter(vec![1i32, 2, 3, 4, 5]);
1✔
317
        let arr2 = PrimitiveArray::from_iter(vec![10i32, 20, 30, 40, 50]);
1✔
318

319
        assert!(!array_equals(arr1.as_ref(), arr2.as_ref()).unwrap());
1✔
320
    }
1✔
321

322
    #[test]
323
    fn test_constant_arrays() {
1✔
324
        let const1 = ConstantArray::new(Scalar::from(42i32), 100);
1✔
325
        let const2 = ConstantArray::new(Scalar::from(42i32), 100);
1✔
326
        let const3 = ConstantArray::new(Scalar::from(43i32), 100);
1✔
327

328
        assert!(array_equals(const1.as_ref(), const2.as_ref()).unwrap());
1✔
329
        assert!(!array_equals(const1.as_ref(), const3.as_ref()).unwrap());
1✔
330
    }
1✔
331

332
    #[test]
333
    fn test_different_types() {
1✔
334
        let int_arr = PrimitiveArray::from_iter(vec![1i32, 2, 3]);
1✔
335
        let float_arr = PrimitiveArray::from_iter(vec![1.0f32, 2.0, 3.0]);
1✔
336

337
        assert!(!array_equals(int_arr.as_ref(), float_arr.as_ref()).unwrap());
1✔
338
    }
1✔
339

340
    #[test]
341
    fn test_with_nulls() {
1✔
342
        let arr1 = PrimitiveArray::from_option_iter(vec![Some(1i32), None, Some(3), Some(4)]);
1✔
343
        let arr2 = PrimitiveArray::from_option_iter(vec![Some(1i32), None, Some(3), Some(4)]);
1✔
344
        let arr3 = PrimitiveArray::from_option_iter(vec![Some(1i32), Some(2), Some(3), Some(4)]);
1✔
345

346
        assert!(array_equals(arr1.as_ref(), arr2.as_ref()).unwrap());
1✔
347
        assert!(!array_equals(arr1.as_ref(), arr3.as_ref()).unwrap());
1✔
348
    }
1✔
349

350
    #[test]
351
    fn test_null_arrays() {
1✔
352
        let arr1 = PrimitiveArray::from_option_iter(vec![None::<i32>, None, None]);
1✔
353
        let arr2 = PrimitiveArray::from_option_iter(vec![None::<i32>, None, None]);
1✔
354

355
        assert!(array_equals(arr1.as_ref(), arr2.as_ref()).unwrap());
1✔
356
    }
1✔
357

358
    #[test]
359
    fn test_bool_arrays() {
1✔
360
        use arrow_buffer::BooleanBuffer;
361

362
        let arr1 = BoolArray::new(
1✔
363
            BooleanBuffer::from_iter([true, false, true, false]),
1✔
364
            Validity::AllValid,
1✔
365
        );
366
        let arr2 = BoolArray::new(
1✔
367
            BooleanBuffer::from_iter([true, false, true, false]),
1✔
368
            Validity::AllValid,
1✔
369
        );
370
        let arr3 = BoolArray::new(
1✔
371
            BooleanBuffer::from_iter([true, false, false, false]),
1✔
372
            Validity::AllValid,
1✔
373
        );
374

375
        assert!(array_equals(arr1.as_ref(), arr2.as_ref()).unwrap());
1✔
376
        assert!(!array_equals(arr1.as_ref(), arr3.as_ref()).unwrap());
1✔
377
    }
1✔
378

379
    #[test]
380
    fn test_empty_arrays() {
1✔
381
        let empty1 = PrimitiveArray::from_iter(Vec::<i32>::new());
1✔
382
        let empty2 = PrimitiveArray::from_iter(Vec::<i32>::new());
1✔
383

384
        assert!(array_equals(empty1.as_ref(), empty2.as_ref()).unwrap());
1✔
385
    }
1✔
386

387
    #[test]
388
    fn test_different_lengths() {
1✔
389
        let arr1 = PrimitiveArray::from_iter(vec![1i32, 2, 3]);
1✔
390
        let arr2 = PrimitiveArray::from_iter(vec![1i32, 2, 3, 4]);
1✔
391

392
        assert!(!array_equals(arr1.as_ref(), arr2.as_ref()).unwrap());
1✔
393
    }
1✔
394

395
    #[test]
396
    fn test_large_arrays() {
1✔
397
        // Test arrays larger than BATCH_SIZE
398
        let data1: Vec<i64> = (0..100_000).collect();
1✔
399
        let data2: Vec<i64> = (0..100_000).collect();
1✔
400
        let mut data3 = data1.clone();
1✔
401
        data3[99_999] = 999_999;
1✔
402

403
        let arr1 = PrimitiveArray::from_iter(data1);
1✔
404
        let arr2 = PrimitiveArray::from_iter(data2);
1✔
405
        let arr3 = PrimitiveArray::from_iter(data3);
1✔
406

407
        assert!(array_equals(arr1.as_ref(), arr2.as_ref()).unwrap());
1✔
408
        assert!(!array_equals(arr1.as_ref(), arr3.as_ref()).unwrap());
1✔
409
    }
1✔
410

411
    #[test]
412
    fn test_non_canonical_arrays() {
1✔
413
        let varbin1 = VarBinArray::from_vec(
1✔
414
            vec!["hello".as_bytes(), "world".as_bytes()],
1✔
415
            DType::Utf8(Nullability::NonNullable),
1✔
416
        );
417
        let varbin2 = VarBinArray::from_vec(
1✔
418
            vec!["hello".as_bytes(), "world".as_bytes()],
1✔
419
            DType::Utf8(Nullability::NonNullable),
1✔
420
        );
421
        let varbin3 = VarBinArray::from_vec(
1✔
422
            vec!["hello".as_bytes(), "earth".as_bytes()],
1✔
423
            DType::Utf8(Nullability::NonNullable),
1✔
424
        );
425

426
        assert!(array_equals(varbin1.as_ref(), varbin2.as_ref()).unwrap());
1✔
427
        assert!(!array_equals(varbin1.as_ref(), varbin3.as_ref()).unwrap());
1✔
428
    }
1✔
429
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc