• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 9923346244

13 Jul 2024 10:38PM UTC coverage: 81.333%. Remained the same
9923346244

push

github

tom-andersson
Update GreedyAlgorithm API reference

9 of 11 new or added lines in 1 file covered. (81.82%)

2 existing lines in 1 file now uncovered.

1965 of 2416 relevant lines covered (81.33%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.58
/deepsensor/active_learning/algorithms.py
1
import copy
2✔
2

3
from deepsensor.data.loader import TaskLoader
2✔
4
from deepsensor.data.processor import (
2✔
5
    xarray_to_coord_array_normalised,
6
    mask_coord_array_normalised,
7
    da1_da2_same_grid,
8
    interp_da1_to_da2,
9
    process_X_mask_for_X,
10
)
11
from deepsensor.model.model import (
2✔
12
    DeepSensorModel,
13
)
14
from deepsensor.model.pred import create_empty_spatiotemporal_xarray
2✔
15
from deepsensor.data.task import Task, append_obs_to_task
2✔
16
from deepsensor.active_learning.acquisition_fns import (
2✔
17
    AcquisitionFunction,
18
    AcquisitionFunctionParallel,
19
    AcquisitionFunctionOracle,
20
)
21

22
import numpy as np
2✔
23
import xarray as xr
2✔
24
import pandas as pd
2✔
25
from tqdm import tqdm
2✔
26

27
from typing import Union, List, Tuple, Optional
2✔
28

29

30
class GreedyAlgorithm:
2✔
31
    """Greedy active learning sensor placement algorithm.
32

33
    Given a set of :class:`~.data.task.Task` objects containing existing context data, the algorithm
34
    iteratively (i.e. 'greedily') proposes $N$ locations for new context points
35
    from a search grid, using active learning with a DeepSensorModel.
36

37
    Within each greedy iteration, the algorithm evaluates an acquisition function
38
    over the search grid. The acquisition function value at a given query location
39
    relates to the merit of a new observation at that point, and is averaged over
40
    all :class:`~.data.task.Task` objects. The algorithm then
41
    selects the context location with the 'best' (max or min) acquisition function value.
42
    A new context observation is added to each :class:`~.data.task.Task` at that location.
43
    This process is repeated until $N$ new context locations have been proposed.
44

45
    The algorithm either computes the acquisition function values
46
    in parallel over all query locations, or sequentially. This is dictated by the
47
    type of acquisition function passed to the algorithm:
48

49
        1. :class:`~.active_learning.acquisition_fns.AcquisitionFunction`:
50
        Returns a scalar acquisition function for
51
        a given query location. For example, the model's mean standard deviation
52
        over target locations (``MeanStddev``). For a given :class:`~.data.task.Task`
53
        this requires running the model *once for every query location* with a new
54
        context point at that location, so these acquisition functions can be slow.
55

56
        2. :class:`~.active_learning.acquisition_fns.AcquisitionFunctionParallel`:
57
        Returns all acquisition function values in parallel.
58
        For example, the model's standard deviation at query locations given
59
        the existing context data, which only requires running the model once for a
60
        given :class:`~.data.task.Task`. These acquisition functions are faster than
61
        their sequential counterparts but are likely less informative.
62

63
    Acquisition functions that inherit from
64
    :class:`~.active_learning.acquisition_fns.AcquisitionFunctionOracle`
65
    require ground truth target values at target locations. In this case, the algorithm
66
    must be provided with a :class:`~.data.loader.TaskLoader` object to sample these values.
67

68
    .. note::
69
        The algorithm is described in more detail in 'Environmental Sensor Placement with
70
        Convolutional Gaussian Neural Processes' (2023), https://doi.org/10.1017/eds.2023.22.
71

72
    Args:
73
        model (:class:`~.model.model.DeepSensorModel`):
74
            Model to use for proposing new context points.
75
        X_s (:class:`xarray.Dataset` | :class:`xarray.DataArray`):
76
            Xarray object containing the spatial coordinates that define the search grid.
77
        X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | pd.DataFrame):
78
            Target spatial coordinates. Can either be an xarray object containing the spatial
79
            coordinates of the target grid, or a pandas DataFrame containing a set of off-grid
80
            target locations.
81
        X_s_mask (:class:`xarray.Dataset` | :class:`xarray.DataArray`, optional):
82
            Optional 2D mask for gridded search coordinates to ignore. If provided, the acquisition
83
            function will only be computed at locations where the mask is True. Defaults to None.
84
        X_t_mask (:class:`xarray.Dataset` | :class:`xarray.DataArray`, optional):
85
            Optional 2D mask (for gridded target coordinates) to ignore.
86
            Useful e.g. if you only care about improving the model's predictions over a certain
87
            area. Defaults to None.
88
        N_new_context (int, optional):
89
            Number of new context points to propose (i.e. number of greedy iterations), defaults to 1.
90
        X_normalised (bool, optional):
91
            Whether the coordinates of the X_* arguments above have been normalised
92
            by a :class:`~.data.processor.DataProcessor`. Defaults to False.
93
        model_infill_method (str, optional):
94
            Method for generating pseudo observations from the model at search points,
95
            which are appended to Tasks when computing acquisition functions or at the
96
            end of a greedy iteration (unless overridden by ``query_infill`` or ``proposed_infill`` below).
97
            Currently, only "mean" infilling is supported. Defaults to "mean".
98
        query_infill (:class:`xarray.DataArray`, optional):
99
            Gridded xarray object containing observations to use when querying candidate context
100
            points. Must have all the same time points as the :class:`~.data.task.Task` objects
101
            the algorithm is called with. If not on the same grid as ``X_s``, it will be linearly
102
            interpolated to the same grid. Useful for providing the model with true observations
103
            rather than its own predictions. Defaults to None.
104
        proposed_infill (:class:`xarray.DataArray`, optional):
105
            Similar to ``query_infill``, but used when infilling pseudo observations at the end
106
            of a greedy iteration (rather than using model predictions). Useful e.g. to
107
            simulate the case where the model can obtain ground truth after requesting
108
            a sensor placement. Defaults to None.
109
        context_set_idx (int, optional):
110
            Context set index to run the sensor placement algorithm on. E.g. if a model
111
            ingest two context sets ["aux_data", "sensor_data"], this should be set to 1
112
            (corresponding to the sensor context set). Defaults to 0.
113
        target_set_idx (int, optional):
114
            Target set index corresponding to predictions of the context set that the
115
            algorithm is run on. Defaults to 0.
116
        progress_bar (bool, optional):
117
            Whether to display a progress bar when running the algorithm. Defaults to False.
118
        task_loader (:class:`~.data.loader.TaskLoader`, optional):
119
            If using an :class:`~.active_learning.acquisition_fns.AcquisitionFunctionOracle`,
120
            a TaskLoader object is required to sample ground truth target values at target
121
            locations. Defaults to None.
122
        verbose (bool, optional):
123
            Whether to print some status messages. Defaults to False.
124

125
    Raises:
126
        ValueError:
127
            If the ``model`` passed does not inherit from
128
            :class:`~.model.model.DeepSensorModel`.
129
    """
130

131
    def __init__(
2✔
132
        self,
133
        model: DeepSensorModel,
134
        X_s: Union[xr.Dataset, xr.DataArray],
135
        X_t: Union[xr.Dataset, xr.DataArray, pd.DataFrame],
136
        X_s_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
137
        X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
138
        N_new_context: int = 1,
139
        X_normalised: bool = False,
140
        model_infill_method: str = "mean",
141
        query_infill: Optional[xr.DataArray] = None,
142
        proposed_infill: Optional[xr.DataArray] = None,
143
        context_set_idx: int = 0,
144
        target_set_idx: int = 0,
145
        progress_bar: bool = False,
146
        task_loader: Optional[
147
            TaskLoader
148
        ] = None,  # OPTIONAL for oracle acquisition functions only
149
        verbose: bool = False,
150
    ):
151
        if not isinstance(model, DeepSensorModel):
2✔
152
            raise ValueError(
×
153
                "`model` must inherit from DeepSensorModel, but parent "
154
                f"classes are {model.__class__.__bases__}"
155
            )
156

157
        self._validate_n_new_context(X_s, N_new_context)
2✔
158

159
        self.model = model
2✔
160
        self.N_new_context = N_new_context
2✔
161
        self.progress_bar = progress_bar
2✔
162
        self.model_infill_method = model_infill_method
2✔
163
        self.context_set_idx = context_set_idx
2✔
164
        self.target_set_idx = target_set_idx
2✔
165
        self.task_loader = task_loader
2✔
166
        self.pbar = None
2✔
167

168
        self.x1_name = self.model.data_processor.config["coords"]["x1"]["name"]
2✔
169
        self.x2_name = self.model.data_processor.config["coords"]["x2"]["name"]
2✔
170

171
        # Normalised search and target coordinates
172
        if not X_normalised:
2✔
173
            X_t = model.data_processor.map_coords(X_t)
2✔
174
            X_s = model.data_processor.map_coords(X_s)
2✔
175
            if X_s_mask is not None:
2✔
176
                X_s_mask = model.data_processor.map_coords(X_s_mask)
×
177
            if X_t_mask is not None:
2✔
178
                X_t_mask = model.data_processor.map_coords(X_t_mask)
×
179

180
        self.X_s = X_s
2✔
181
        self.X_t = X_t
2✔
182
        self.X_s_mask = X_s_mask
2✔
183
        self.X_t_mask = X_t_mask
2✔
184

185
        # Interpolate masks onto search and target coords
186
        if self.X_s_mask is not None:
2✔
187
            self.X_s_mask = process_X_mask_for_X(self.X_s_mask, self.X_s)
×
188
        if self.X_t_mask is not None:
2✔
189
            self.X_t_mask = process_X_mask_for_X(self.X_t_mask, self.X_t)
×
190

191
        # Interpolate overridden infill datasets at search points if necessary
192
        if query_infill is not None and not da1_da2_same_grid(query_infill, X_s):
2✔
193
            if verbose:
×
194
                print("query_infill not on search grid, interpolating.")
×
195
            query_infill = interp_da1_to_da2(query_infill, self.X_s)
×
196
        if proposed_infill is not None and not da1_da2_same_grid(proposed_infill, X_s):
2✔
197
            if verbose:
×
198
                print("proposed_infill not on search grid, interpolating.")
×
199
            proposed_infill = interp_da1_to_da2(proposed_infill, self.X_s)
×
200
        self.query_infill = query_infill
2✔
201
        self.proposed_infill = proposed_infill
2✔
202

203
        # Convert target coords to numpy arrays and assign to tasks
204
        if isinstance(X_t, (xr.Dataset, xr.DataArray)):
2✔
205
            # Targets on grid
206
            self.X_t_arr = xarray_to_coord_array_normalised(X_t)
2✔
207
            if self.X_t_mask is not None:
2✔
208
                # Remove points that lie outside the mask
209
                self.X_t_arr = mask_coord_array_normalised(self.X_t_arr, self.X_t_mask)
×
210
        elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index)):
×
211
            # Targets off-grid
212
            self.X_t_arr = X_t.reset_index()[["x1", "x2"]].values.T
×
213
        else:
214
            raise TypeError(f"Unsupported type for X_t: {type(X_t)}")
×
215

216
        # Construct search array
217
        if isinstance(X_s, (xr.Dataset, xr.DataArray)):
2✔
218
            X_s_arr = xarray_to_coord_array_normalised(X_s)
2✔
219
            if X_s_mask is not None:
2✔
220
                X_s_arr = mask_coord_array_normalised(X_s_arr, self.X_s_mask)
×
221
        self.X_s_arr = X_s_arr
2✔
222

223
        self.X_new = []  # List of new proposed context locations
2✔
224

225
    @classmethod
2✔
226
    def _validate_n_new_context(
2✔
227
        cls, X_s: Union[xr.Dataset, xr.DataArray], N_new_context: int
228
    ):
229
        if isinstance(X_s, (xr.Dataset, xr.DataArray)):
2✔
230
            if isinstance(X_s, xr.Dataset):
2✔
231
                X_s = X_s.to_array()
×
232
            N_s = X_s.shape[-2] * X_s.shape[-1]
2✔
233
        elif isinstance(X_s, (pd.DataFrame, pd.Series, pd.Index)):
×
234
            N_s = len(X_s)
×
235

236
        if not 0 < N_new_context < N_s:
2✔
237
            raise ValueError(
2✔
238
                f"Number of new context ({N_new_context}) must be greater "
239
                f"than zero and less than the number of search points ({N_s})"
240
            )
241

242
    def _get_times_from_tasks(self):
2✔
243
        """Get times from tasks"""
244
        times = [task["time"] for task in self.tasks]
2✔
245
        # Check for any repeats
246
        if len(times) != len(set(times)):
2✔
247
            # TODO unit test this
248
            raise ValueError(
×
249
                f"The {len(times)} tasks have duplicate times ({len(set(times))} "
250
                f"unique times)"
251
            )
252
        return times
2✔
253

254
    def _model_infill_at_search_points(
2✔
255
        self,
256
        X_s: Union[xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index],
257
    ):
258
        """Computes and model infill y-values over whole search grid."""
259
        if self.model_infill_method == "mean":
2✔
260
            pred = self.model.predict(
2✔
261
                self.tasks,
262
                X_s,
263
                X_t_is_normalised=True,
264
                unnormalise=False,
265
            )
266
            infill_ds = pred[self.target_set_idx]["mean"]
2✔
267

268
        elif self.model_infill_method == "sample":
×
269
            # pred = self.model.predict(
270
            #     self.tasks, X_s, X_t_normalised=True, unnormalise=False,
271
            #     n_samples=self.model_infill_samples,
272
            # )
273
            # infill_ds = pred[self.target_set_idx]["samples"]
UNCOV
274
            raise NotImplementedError("TODO")
×
275

276
        elif self.model_infill_method == "zeros":
×
277
            # TODO generate empty prediction xarray
278
            raise NotImplementedError("TODO")
×
279

280
        else:
281
            raise ValueError(
×
282
                f"Unsupported model_infill_method: {self.model_infill_method}"
283
            )
284

285
        return infill_ds
2✔
286

287
    def _sample_y_infill(self, infill_ds, time, x1, x2):
2✔
288
        """Sample infill values at a single location"""
289
        assert isinstance(infill_ds, (xr.Dataset, xr.DataArray))
2✔
290
        y = infill_ds.sel(time=time, x1=x1, x2=x2)
2✔
291
        if isinstance(y, xr.Dataset):
2✔
NEW
292
            y = y.to_array()
×
293
        y = y.data
2✔
294
        if "sample" not in infill_ds.dims:
2✔
295
            return y.reshape(1, y.size)  # 1 observation with N_target_dims
2✔
296
        else:
297
            # TODO confirm or force that dim ordering is (N_samples, N_target_dims)
NEW
298
            return y
×
299

300
    def _build_acquisition_fn_ds(self, X_s: Union[xr.Dataset, xr.DataArray]):
2✔
301
        """
302
        Initialise xr.DataArray for storing acquisition function values on
303
        search grid
304
        """
305
        prepend_dims = ["iteration"]  # , "sample"]  # MC sample TODO
2✔
306
        prepend_coords = {
2✔
307
            "iteration": range(self.N_new_context),
308
            # "sample": range(self.n_samples_or_1),  # MC sample TODO
309
        }
310
        acquisition_fn_ds = create_empty_spatiotemporal_xarray(
2✔
311
            X=X_s,
312
            dates=self._get_times_from_tasks(),
313
            coord_names={"x1": self.x1_name, "x2": self.x2_name},
314
            data_vars=["acquisition_fn"],
315
            prepend_dims=prepend_dims,
316
            prepend_coords=prepend_coords,
317
        )["acquisition_fn"]
318
        acquisition_fn_ds.data[:] = np.nan
2✔
319

320
        return acquisition_fn_ds
2✔
321

322
    def _init_acquisition_fn_object(self, X_s: xr.Dataset):
2✔
323
        """Instantiate acquisition function object"""
324
        # Unnormalise before instantiating
325
        X_s = self.model.data_processor.map_coords(X_s, unnorm=True)
2✔
326
        if isinstance(X_s, (xr.Dataset, xr.DataArray)):
2✔
327
            # xr.Dataset storing acquisition function values
328
            self.acquisition_fn_ds = self._build_acquisition_fn_ds(X_s)
2✔
329
        elif isinstance(X_s, (pd.DataFrame, pd.Series, pd.Index)):
×
330
            raise NotImplementedError(
×
331
                "Pandas support for active learning search points X_s not yet "
332
                "implemented."
333
            )
334
        else:
335
            raise TypeError(f"Unsupported type for X_s: {type(X_s)}")
×
336

337
    def _search(self, acquisition_fn: AcquisitionFunction):
2✔
338
        """
339
        Run one greedy pass by looping over each point in ``X_s`` and
340
        computing the acquisition function.
341
        """
342
        importances_list = []
2✔
343

344
        for task in self.tasks:
2✔
345
            # Parallel computation
346
            if isinstance(acquisition_fn, AcquisitionFunctionParallel):
2✔
347
                importances = acquisition_fn(task, self.X_s_arr)
2✔
348
                if self.pbar:
2✔
349
                    self.pbar.update(1)
2✔
350

351
            # Sequential computation
352
            elif isinstance(acquisition_fn, AcquisitionFunction):
2✔
353
                importances = []
2✔
354

355
                if self.diff:
2✔
356
                    importance_bef = acquisition_fn(task)
×
357

358
                # Add size-1 dim after row dim to preserve row dim for passing to
359
                #   acquisition_fn. Also roll final axis to first axis for looping over search points.
360
                for x_query in np.rollaxis(self.X_s_arr[:, np.newaxis], 2):
2✔
361
                    y_query = self._sample_y_infill(
2✔
362
                        self.query_infill,
363
                        time=task["time"],
364
                        x1=x_query[0],
365
                        x2=x_query[1],
366
                    )
367
                    task_with_new = append_obs_to_task(
2✔
368
                        task, x_query, y_query, self.context_set_idx
369
                    )
370
                    # TODO this is a hack to add the auxiliary variable to the context set
371
                    if (
2✔
372
                        self.task_loader is not None
373
                        and self.task_loader.aux_at_contexts
374
                    ):
375
                        # Add auxiliary variable sampled at context set as a new context variable
376
                        X_c = task_with_new["X_c"][self.task_loader.aux_at_contexts[0]]
×
377
                        Y_c_aux = self.task_loader.sample_offgrid_aux(
×
378
                            X_c, self.task_loader.aux_at_contexts[1]
379
                        )
380
                        task_with_new["X_c"][-1] = X_c
×
381
                        task_with_new["Y_c"][-1] = Y_c_aux
×
382

383
                    importance = acquisition_fn(task_with_new)
2✔
384

385
                    if self.diff:
2✔
386
                        importance = importance - importance_bef
×
387

388
                    importances.append(importance)
2✔
389

390
                    if self.pbar:
2✔
391
                        self.pbar.update(1)
2✔
392

393
            else:
394
                allowed_classes = [
×
395
                    AcquisitionFunction,
396
                    AcquisitionFunctionParallel,
397
                    AcquisitionFunctionOracle,
398
                ]
399
                raise ValueError(
×
400
                    f"Acquisition function needs to inherit from one of {allowed_classes}."
401
                )
402

403
            importances = np.array(importances)
2✔
404
            importances_list.append(importances)
2✔
405

406
            if self.X_s_mask is not None:
2✔
407
                self.acquisition_fn_ds.loc[self.iteration, task["time"]].data[
×
408
                    self.X_s_mask.data
409
                ] = importances
410
            else:
411
                self.acquisition_fn_ds.loc[self.iteration, task["time"]] = (
2✔
412
                    importances.reshape(self.acquisition_fn_ds.shape[-2:])
413
                )
414

415
        return np.mean(importances_list, axis=0)
2✔
416

417
    def _select_best(self, importances, X_s_arr):
2✔
418
        """Select context location corresponding to the best importance value.
419

420
        Appends the chosen search index to a list of chosen search indexes.
421
        """
422
        if self.min_or_max == "min":
2✔
423
            best_idx = np.argmin(importances)
2✔
424
        elif self.min_or_max == "max":
2✔
425
            best_idx = np.argmax(importances)
2✔
426

427
        best_x_query = X_s_arr[:, best_idx : best_idx + 1]
2✔
428

429
        # Index into original search space of chosen context location
430
        self.best_idxs_all.append(
2✔
431
            np.where((self.X_s_arr == best_x_query).all(axis=0))[0][0]
432
        )
433

434
        return best_x_query
2✔
435

436
    def _single_greedy_iteration(self, acquisition_fn: AcquisitionFunction):
2✔
437
        """
438
        Run a single greedy grid search iteration and append the optimal
439
        context location to self.X_new.
440
        """
441
        importances = self._search(acquisition_fn)
2✔
442
        best_x_query = self._select_best(importances, self.X_s_arr)
2✔
443

444
        self.X_new.append(best_x_query)
2✔
445

446
        return best_x_query
2✔
447

448
    def __call__(
2✔
449
        self,
450
        acquisition_fn: AcquisitionFunction,
451
        tasks: Union[List[Task], Task],
452
        diff: bool = False,
453
    ) -> Tuple[pd.DataFrame, xr.Dataset]:
454
        """
455
        Iteratively propose new context points using the greedy sensor placement algorithm.
456

457
        Args:
458
            acquisition_fn (:class:`~.active_learning.acquisition_fns.AcquisitionFunction`):
459
                The acquisition function to optimise.
460
            tasks (List[:class:`~.data.task.Task`] | :class:`~.data.task.Task`):
461
                Tasks containing existing context data. If a list of Tasks, the acquisition
462
                function will be averaged over Tasks.
463
            diff (bool, optional):
464
                For sequential acquisition functions only: Whether to compute the *change* in
465
                acquisition function value after adding the new context point, i.e.
466
                ``acquisition_fn(task_with_new) - acquisition_fn(task)``. Can be useful
467
                for making the acquisition function values more interpretable, or for
468
                comparing with the change in metric that the acquisition function targets
469
                (see https://doi.org/10.1017/eds.2023.22). Defaults to False.
470

471
        Returns:
472
            Tuple[:class:`pandas.DataFrame`, :class:`xarray.DataArray`]:
473
                A tuple containing two objects:
474

475
                - **X_new_df** (:class:`pandas.DataFrame`):
476
                Proposed sensor placements. Columns are the x1 and x2 coordinates of the
477
                sensor placements, and the index is the index of the greedy iteration
478
                at which the sensor placement was proposed (which can be interpreted as
479
                a priority order, with iteration 0 being the highest priority).
480

481
                - **acquisition_fn_ds** (:class:`xarray.DataArray`):
482
                Gridded acquisition function values at each search point. Dimensions
483
                are ``iteration``, ``time`` (inferred from the input ``tasks``), followed
484
                by the x1 and x2 coordinates of the spatial grid.
485

486
        Raises:
487
            ValueError:
488
                If ``acquisition_fn`` is an
489
                :class:`~.active_learning.acquisition_fns.AcquisitionFunctionOracle`
490
                and ``task_loader`` is None.
491
            ValueError:
492
                If ``min_or_max`` is not ``"min"`` or ``"max"``.
493
            ValueError:
494
                If ``Y_t_aux`` is in ``tasks`` but ``task_loader`` is None.
495
        """
496
        if (
2✔
497
            isinstance(acquisition_fn, AcquisitionFunctionOracle)
498
            and self.task_loader is None
499
        ):
500
            raise ValueError(
2✔
501
                "AcquisitionFunctionOracle requires a task_loader function to "
502
                "be passed to the GreedyOptimal constructor."
503
            )
504

505
        self.min_or_max = acquisition_fn.min_or_max
2✔
506
        if self.min_or_max not in ["min", "max"]:
2✔
507
            raise ValueError(
2✔
508
                f"min_or_max must be either 'min' or 'max', got " f"{self.min_or_max}."
509
            )
510

511
        if diff and isinstance(acquisition_fn, AcquisitionFunctionParallel):
2✔
512
            raise ValueError(
2✔
513
                "diff=True is not valid for parallel acquisition functions."
514
            )
515
        self.diff = diff
2✔
516

517
        if isinstance(tasks, Task):
2✔
518
            tasks = [tasks]
2✔
519

520
        # Make deepcopys so that original tasks are not modified
521
        tasks = copy.deepcopy(tasks)
2✔
522

523
        # Add target set to tasks
524
        for i, task in enumerate(tasks):
2✔
525
            tasks[i]["X_t"] = [self.X_t_arr]
2✔
526
            if isinstance(acquisition_fn, AcquisitionFunctionOracle):
2✔
527
                # Sample ground truth y-values at target points `self.X_t_arr` using `self.task_loader`
528
                date = tasks[i]["time"]
2✔
529
                task_with_Y_t = self.task_loader(
2✔
530
                    date, context_sampling=0, target_sampling=self.X_t_arr
531
                )
532
                tasks[i]["Y_t"] = task_with_Y_t["Y_t"]
2✔
533

534
            if "Y_t_aux" in tasks[i] and self.task_loader is None:
2✔
535
                raise ValueError(
2✔
536
                    "Model expects Y_t_aux data but a TaskLoader isn't "
537
                    "provided to GreedyAlgorithm."
538
                )
539
            if self.task_loader is not None and self.task_loader.aux_at_target_dims > 0:
2✔
540
                tasks[i]["Y_t_aux"] = self.task_loader.sample_offgrid_aux(
2✔
541
                    self.X_t_arr, self.task_loader.aux_at_targets
542
                )
543

544
        self.tasks = tasks
2✔
545

546
        # Generate infill values at search points if not overridden
547
        if self.query_infill is None or self.proposed_infill is None:
2✔
548
            model_infill = self._model_infill_at_search_points(self.X_s)
2✔
549
            if self.query_infill is None:
2✔
550
                self.query_infill = model_infill
2✔
551
            if self.proposed_infill is None:
2✔
552
                self.proposed_infill = model_infill
2✔
553

554
        # Instantiate empty acquisition function object
555
        self._init_acquisition_fn_object(self.X_s)
2✔
556

557
        # Dataframe for storing proposed context locations
558
        self.X_new_df = pd.DataFrame(columns=[self.x1_name, self.x2_name])
2✔
559
        self.X_new_df.index.name = "iteration"
2✔
560

561
        # List to track indexes into original search grid of chosen sensor locations
562
        #   as optimisation progresses. Used for filling y-values at chosen
563
        #   sensor locations, `self.X_new`
564
        self.best_idxs_all = []
2✔
565

566
        # Total iterations are number of new context points * number of tasks * number of search
567
        #   points (if not parallel) * number of Monte Carlo samples (if using MC)
568
        total_iterations = self.N_new_context * len(self.tasks)
2✔
569
        if not isinstance(acquisition_fn, AcquisitionFunctionParallel):
2✔
570
            total_iterations *= self.X_s_arr.shape[-1]
2✔
571
        # TODO make class attribute for list of sample-based infill methods
572
        if self.model_infill_method in ["sample", "ar_sample"]:
2✔
UNCOV
573
            total_iterations *= self.n_samples
×
574

575
        with tqdm(total=total_iterations, disable=not self.progress_bar) as self.pbar:
2✔
576
            for iteration in range(self.N_new_context):
2✔
577
                self.iteration = iteration
2✔
578
                x_new = self._single_greedy_iteration(acquisition_fn)
2✔
579

580
                # Append new proposed context points to each task
581
                for i, task in enumerate(self.tasks):
2✔
582
                    y_new = self._sample_y_infill(
2✔
583
                        self.proposed_infill,
584
                        time=task["time"],
585
                        x1=x_new[0],
586
                        x2=x_new[1],
587
                    )
588
                    self.tasks[i] = append_obs_to_task(
2✔
589
                        task, x_new, y_new, self.context_set_idx
590
                    )
591

592
                # Append new proposed context points to dataframe
593
                x_new_unnorm = self.model.data_processor.map_coord_array(
2✔
594
                    x_new, unnorm=True
595
                )
596
                self.X_new_df.loc[self.iteration] = x_new_unnorm.ravel()
2✔
597

598
        return self.X_new_df, self.acquisition_fn_ds
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc