• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nci / scores / 19917845555

04 Dec 2025 04:37AM UTC coverage: 99.882% (-0.1%) from 100.0%
19917845555

Pull #932

github

web-flow
Merge 345275a40 into 40baf9150
Pull Request #932: Issue 632 rev

500 of 500 branches covered (100.0%)

Branch coverage included in aggregate %.

281 of 285 new or added lines in 5 files covered. (98.6%)

3 existing lines in 1 file now uncovered.

2876 of 2880 relevant lines covered (99.86%)

3.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.0
/src/scores/processing/aggregation.py
1
"""
2
Functions related to aggregating data
3
"""
4

5
import warnings
4✔
6
from typing import Optional
4✔
7

8
import xarray as xr
4✔
9

10
from scores.processing.matching import broadcast_and_match_nan
4✔
11
from scores.typing import FlexibleDimensionTypes, XarrayLike
4✔
12
from scores.utils import HAS_DASK, check_weights
4✔
13

14
if HAS_DASK:
4✔
15
    import dask.array as da
4✔
16

17

18
def _add_assertion_dependency(result: xr.DataArray, assertion_graph: xr.DataArray) -> xr.DataArray:
4✔
19
    """
20
    Creates a new DataArray that depends on both the original result computation
21
    and the deferred Dask assertion graph, ensuring the check runs on compute.
22
    """
23
    if HAS_DASK and hasattr(result.data, "dask"):
4✔
24

25
        def identity_with_check(x, check):
4✔
26
            # Force dependency by performing operation that doesn't change x
27
            # but requires check to be computed
28
            return x * 1.0 + check * 0.0
4✔
29

30
        # Get dimension labels for blockwise
31
        result_dims = tuple(range(result.ndim))
4✔
32

33
        combined = da.blockwise(
4✔
34
            identity_with_check,
35
            result_dims,  # Output has same dimensions as result
36
            result.data,
37
            result_dims,  # First input
38
            assertion_graph.data,
39
            (),  # Second input is scalar (no dims)
40
            dtype=result.dtype,
41
        )
42

43
        return result.copy(data=combined)
4✔
NEW
44
    return result
×
45

46

47
def aggregate(
4✔
48
    values: XarrayLike,
49
    *,
50
    reduce_dims: FlexibleDimensionTypes | None,
51
    weights: Optional[XarrayLike] = None,
52
    method: str = "mean",
53
) -> XarrayLike:
54
    """
55
    Computes a weighted or unweighted aggregation of the input data across specified dimensions.
56
    The input data is typically the "score" at each point.
57

58
    This function applies a mean reduction or a sum over the dimensions given by ``reduce_dims`` on
59
    the input ``values``, optionally using weights to compute a weighted mean or sum.
60
    The ``method`` arg specifies if you want to produce a weighted mean or weighted sum.
61

62
    If `reduce_dims` is None, no aggregation is performed and the original ``values`` are
63
    returned unchanged.
64

65
    If ``weights`` is None, an unweighted mean or sum is computed. If weights are provided, negative
66
    weights are not allowed and will raise a ``ValueError``.
67

68
    If weights are provided but ``reduce_dims`` is None (i.e., no reduction), a ``UserWarning``
69
    is emitted since the weights will be ignored.
70

71
    Weights must not contain NaN values. Missing values can be filled by ``weights.fillna(0)``
72
    if you would like to assign a weight of zero to those points (e.g., masking).
73

74
    Args:
75
        values: Input data to be aggregated.
76
        reduce_dims: Dimensions over which to apply the mean. Can be a string, list of
77
            strings, or None. If None, no reduction is performed.
78
        weights: Weights to apply for weighted averaging.
79
            Must be broadcastable to ``values`` and contain no negative values. If None,
80
            an unweighted mean is calculated. Defaults to None.
81
        method: Aggregation method to use. Either "mean" or "sum". Defaults to "mean".
82

83
    Returns:
84
        An xarray object (same type as the input) with (un)weighted mean or sum of ``values``
85

86
    Raises:
87
        ValueError: If ``weights`` contains any negative values.
88
        ValueError: if ``weights`` contains any NaN values
89
        ValueError: if ``method`` is not 'mean' or 'sum'
90
        ValueError: if ``weights`` is an xr.Dataset when ``values`` is an xr.DataArray
91

92
    Warnings:
93
        UserWarning: If weights are provided but no reduction is performed (``reduce_dims`` is None),
94
        a warning is issued since weights are ignored.
95

96
    Examples:
97
        >>> import xarray as xr
98
        >>> import numpy as np
99
        >>> da = xr.DataArray(np.arange(6).reshape(2, 3), dims=['x', 'y'])
100
        >>> weights = xr.DataArray([1, 2], dims=['x'])
101
        >>> apply_weighted_mean(da, reduce_dims=['x'], weights=weights)
102
        <xarray.DataArray (y: 3)>
103
        array([2., 3., 4.])
104
        Dimensions without coordinates: y
105

106
    """
107
    assertion_graph = _check_aggregate_inputs(values, reduce_dims, weights, method)
4✔
108

109
    if reduce_dims is None:
4✔
110
        return values
4✔
111

112
    # Perform the actual aggregation
113
    match method:
4✔
114
        case "mean":
4✔
115
            if weights is not None:
4✔
116
                result = _weighted_mean(values, weights, reduce_dims)
4✔
117
            else:
118
                result = values.mean(reduce_dims)
4✔
119
        case "sum":
4✔
120
            if weights is not None:
4✔
121
                result = _weighted_sum(values, weights, reduce_dims)
4✔
122
            else:
123
                result = values.sum(reduce_dims)
4✔
124
        case _:  # pragma: no cover
125
            raise ValueError(f"Unsupported method {method}. Expected 'mean' or 'sum'.")
126

127
    # Add the deferred assertion dependency if it exists
128
    if assertion_graph is not None:
4✔
129
        if isinstance(result, xr.DataArray):
4✔
130
            result = _add_assertion_dependency(result, assertion_graph)
4✔
131
        else:  # xr.Dataset case
132
            new_vars = {}
4✔
133
            for name, da in result.data_vars.items():
4✔
134
                new_vars[name] = _add_assertion_dependency(da, assertion_graph[name])
4✔
135
            result = xr.Dataset(new_vars)
4✔
136

137
    return result
4✔
138

139

140
def _weighted_mean(
4✔
141
    values: XarrayLike,
142
    weights: XarrayLike,
143
    reduce_dims: FlexibleDimensionTypes,
144
) -> XarrayLike:
145
    """
146
    Calculates the weighted mean of `values` using `weights` over specified dimensions.
147

148
    xarray doesn't allow ``.weighted`` to take ``xr.Dataset`` as weights, so we need to do it ourselves
149
    """
150
    if isinstance(weights, xr.Dataset):
4✔
151
        w_results = {}
4✔
152
        for name, da in values.data_vars.items():
4✔
153
            w = weights[name]
4✔
154
            da_aligned, w_aligned = broadcast_and_match_nan(da, w)
4✔
155

156
            # `check_weights` in `_check_aggregate_inputs` ensures that `weights`
157
            # has at least one positive value and will raise an error.
158
            # However, if a value in w_aligned.sum(dim=reduce_dims) is zero,
159
            # a NaN will be produced for that point.
160
            w_results[name] = (da_aligned * w_aligned).sum(dim=reduce_dims) / w_aligned.sum(dim=reduce_dims)
4✔
161

162
        return xr.Dataset(w_results)
4✔
163

164
    values = values.weighted(weights)
4✔
165

166
    return values.mean(reduce_dims)
4✔
167

168

169
def _weighted_sum(
4✔
170
    values: XarrayLike,
171
    weights: XarrayLike,
172
    reduce_dims: FlexibleDimensionTypes,
173
) -> XarrayLike:
174
    """
175
    Calculated the weighted sum of `values` using `weights` over specified dimensions.
176
    """
177
    if isinstance(weights, xr.Dataset):
4✔
178
        w_results = {}
4✔
179
        for name, da in values.data_vars.items():
4✔
180
            w = weights[name]
4✔
181
            da_aligned, w_aligned = broadcast_and_match_nan(da, w)
4✔
182
            summed = (da_aligned * w_aligned).sum(dim=reduce_dims)
4✔
183
            # If weights sum to zero for a point that has been aggregated over reduce_dims,
184
            # we want the result to be NaN, not zero.
185
            summed = summed.where(w_aligned.sum(dim=reduce_dims) != 0)
4✔
186
            w_results[name] = summed
4✔
187

188
        return xr.Dataset(w_results)
4✔
189
    values = values.weighted(weights)
4✔
190
    summed_values = values.sum(reduce_dims)
4✔
191
    # Handle NaNs in `values`
192
    summed_values = summed_values.where(~xr.ufuncs.isnan(values.mean(reduce_dims)))
4✔
193
    return summed_values
4✔
194

195

196
def _check_aggregate_inputs(
4✔
197
    values: XarrayLike, reduce_dims: FlexibleDimensionTypes | None, weights: XarrayLike | None, method: str
198
):
199
    """
200
    This function checks the inputs to the aggregate function.
201

202
    It checks that:
203
    - `method` is either 'mean' or 'sum'
204
    - `weights` does not contain negative values
205
    - `weights` does not contain NaN values
206
    - `weights` were provided, `reduce_dims` is not None
207
    - `weights` is not an xr.Dataset when `values` is an xr.DataArray
208
    - if `values is an xr.Dataset`, and `weights` is an xr.Dataset, it must have the same variables
209

210
    Args:
211
        values: The input data to be reduced in :py:func:`aggregate`.
212
        reduce_dims: The dimensions over which to apply the mean in :py:func:`aggregate`.
213
        weights: The weights to apply for weighted averaging in :py:func:`aggregate`.
214
        method: The aggregation method to use, either "mean" or "sum" in :py:func:`aggregate`.
215

216
    """
217
    if method not in ["mean", "sum"]:
4✔
218
        raise ValueError(f"Method must be either 'mean' or 'sum', got '{method}'")
4✔
219

220
    assertion_graph = None
4✔
221

222
    if weights is not None:
4✔
223
        assertion_graph = check_weights(weights)
4✔
224

225
    if reduce_dims is None and weights is not None:
4✔
226
        warnings.warn(
4✔
227
            """
228
            Weights were provided but the point-wise score across all dimensions is being preserved. 
229
            Weights will be ignored.
230
            """,
231
            UserWarning,
232
        )
233
    if reduce_dims is not None:
4✔
234
        if weights is not None:
4✔
235
            if isinstance(weights, xr.Dataset):
4✔
236
                if isinstance(values, xr.DataArray):
4✔
237
                    raise ValueError("`weights` cannot be an xr.Dataset when `values` is an xr.DataArray")
4✔
238
                for name in values.data_vars:
4✔
239
                    if name not in weights:
4✔
240
                        raise KeyError(f"No weights provided for variable '{name}'")
4✔
241

242
    return assertion_graph
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc