• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 19842460617

08 Oct 2025 10:03AM UTC coverage: 81.663%. Remained the same
19842460617

push

github

web-flow
Update README.md, adding reference to GIANT project

2053 of 2514 relevant lines covered (81.66%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.58
/deepsensor/active_learning/algorithms.py
1
import copy
2✔
2

3
from deepsensor.data.loader import TaskLoader
2✔
4
from deepsensor.data.processor import (
2✔
5
    xarray_to_coord_array_normalised,
6
    mask_coord_array_normalised,
7
    da1_da2_same_grid,
8
    interp_da1_to_da2,
9
    process_X_mask_for_X,
10
)
11
from deepsensor.model.model import (
2✔
12
    DeepSensorModel,
13
)
14
from deepsensor.model.pred import create_empty_spatiotemporal_xarray
2✔
15
from deepsensor.data.task import Task, append_obs_to_task
2✔
16
from deepsensor.active_learning.acquisition_fns import (
2✔
17
    AcquisitionFunction,
18
    AcquisitionFunctionParallel,
19
    AcquisitionFunctionOracle,
20
)
21

22
import numpy as np
2✔
23
import xarray as xr
2✔
24
import pandas as pd
2✔
25
from tqdm import tqdm
2✔
26

27
from typing import Union, List, Tuple, Optional
2✔
28

29

30
class GreedyAlgorithm:
2✔
31
    """Greedy active learning sensor placement algorithm.
32

33
    Given a set of :class:`~.data.task.Task` objects containing existing context data, the algorithm
34
    iteratively (i.e. 'greedily') proposes $N$ locations for new context points
35
    from a search grid, using active learning with a DeepSensorModel.
36

37
    Within each greedy iteration, the algorithm evaluates an acquisition function
38
    over the search grid. The acquisition function value at a given query location
39
    relates to the merit of a new observation at that point, and is averaged over
40
    all :class:`~.data.task.Task` objects. The algorithm then
41
    selects the context location with the 'best' (max or min) acquisition function value.
42
    A new context observation is added to each :class:`~.data.task.Task` at that location.
43
    This process is repeated until $N$ new context locations have been proposed.
44

45
    The algorithm either computes the acquisition function values
46
    in parallel over all query locations, or sequentially. This is dictated by the
47
    type of acquisition function passed to the algorithm:
48

49
        1. :class:`~.active_learning.acquisition_fns.AcquisitionFunction`:
50
        Returns a scalar acquisition function for
51
        a given query location. For example, the model's mean standard deviation
52
        over target locations (``MeanStddev``). For a given :class:`~.data.task.Task`
53
        this requires running the model *once for every query location* with a new
54
        context point at that location, so these acquisition functions can be slow.
55

56
        2. :class:`~.active_learning.acquisition_fns.AcquisitionFunctionParallel`:
57
        Returns all acquisition function values in parallel.
58
        For example, the model's standard deviation at query locations given
59
        the existing context data, which only requires running the model once for a
60
        given :class:`~.data.task.Task`. These acquisition functions are faster than
61
        their sequential counterparts but are likely less informative.
62

63
    Acquisition functions that inherit from
64
    :class:`~.active_learning.acquisition_fns.AcquisitionFunctionOracle`
65
    require ground truth target values at target locations. In this case, the algorithm
66
    must be provided with a :class:`~.data.loader.TaskLoader` object to sample these values.
67

68
    .. note::
69
        The algorithm is described in more detail in 'Environmental Sensor Placement with
70
        Convolutional Gaussian Neural Processes' (2023), https://doi.org/10.1017/eds.2023.22.
71

72
    Args:
73
        model (:class:`~.model.model.DeepSensorModel`):
74
            Model to use for proposing new context points.
75
        X_s (:class:`xarray.Dataset` | :class:`xarray.DataArray`):
76
            Xarray object containing the spatial coordinates that define the search grid.
77
        X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | pd.DataFrame):
78
            Target spatial coordinates. Can either be an xarray object containing the spatial
79
            coordinates of the target grid, or a pandas DataFrame containing a set of off-grid
80
            target locations.
81
        X_s_mask (:class:`xarray.Dataset` | :class:`xarray.DataArray`, optional):
82
            Optional 2D mask for gridded search coordinates to ignore. If provided, the acquisition
83
            function will only be computed at locations where the mask is True. Defaults to None.
84
        X_t_mask (:class:`xarray.Dataset` | :class:`xarray.DataArray`, optional):
85
            Optional 2D mask (for gridded target coordinates) to ignore.
86
            Useful e.g. if you only care about improving the model's predictions over a certain
87
            area. Defaults to None.
88
        N_new_context (int, optional):
89
            Number of new context points to propose (i.e. number of greedy iterations), defaults to 1.
90
        X_normalised (bool, optional):
91
            Whether the coordinates of the X_* arguments above have been normalised
92
            by a :class:`~.data.processor.DataProcessor`. Defaults to False.
93
        model_infill_method (str, optional):
94
            Method for generating pseudo observations from the model at search points,
95
            which are appended to Tasks when computing acquisition functions or at the
96
            end of a greedy iteration (unless overridden by ``query_infill`` or ``proposed_infill`` below).
97
            Currently, only "mean" infilling is supported. Defaults to "mean".
98
        query_infill (:class:`xarray.DataArray`, optional):
99
            Gridded xarray object containing observations to use when querying candidate context
100
            points. Must have all the same time points as the :class:`~.data.task.Task` objects
101
            the algorithm is called with. If not on the same grid as ``X_s``, it will be linearly
102
            interpolated to the same grid. Useful for providing the model with true observations
103
            rather than its own predictions. Defaults to None.
104
        proposed_infill (:class:`xarray.DataArray`, optional):
105
            Similar to ``query_infill``, but used when infilling pseudo observations at the end
106
            of a greedy iteration (rather than using model predictions). Useful e.g. to
107
            simulate the case where the model can obtain ground truth after requesting
108
            a sensor placement. Defaults to None.
109
        context_set_idx (int, optional):
110
            Context set index to run the sensor placement algorithm on. E.g. if a model
111
            ingest two context sets ["aux_data", "sensor_data"], this should be set to 1
112
            (corresponding to the sensor context set). Defaults to 0.
113
        target_set_idx (int, optional):
114
            Target set index corresponding to predictions of the context set that the
115
            algorithm is run on. Defaults to 0.
116
        progress_bar (bool, optional):
117
            Whether to display a progress bar when running the algorithm. Defaults to False.
118
        task_loader (:class:`~.data.loader.TaskLoader`, optional):
119
            If using an :class:`~.active_learning.acquisition_fns.AcquisitionFunctionOracle`,
120
            a TaskLoader object is required to sample ground truth target values at target
121
            locations. Defaults to None.
122
        verbose (bool, optional):
123
            Whether to print some status messages. Defaults to False.
124

125
    Raises:
126
        ValueError:
127
            If the ``model`` passed does not inherit from
128
            :class:`~.model.model.DeepSensorModel`.
129
    """
130

131
    def __init__(
2✔
132
        self,
133
        model: DeepSensorModel,
134
        X_s: Union[xr.Dataset, xr.DataArray],
135
        X_t: Union[xr.Dataset, xr.DataArray, pd.DataFrame],
136
        X_s_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
137
        X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
138
        N_new_context: int = 1,
139
        X_normalised: bool = False,
140
        model_infill_method: str = "mean",
141
        query_infill: Optional[xr.DataArray] = None,
142
        proposed_infill: Optional[xr.DataArray] = None,
143
        context_set_idx: int = 0,
144
        target_set_idx: int = 0,
145
        progress_bar: bool = False,
146
        task_loader: Optional[
147
            TaskLoader
148
        ] = None,  # OPTIONAL for oracle acquisition functions only
149
        verbose: bool = False,
150
    ):
151
        if not isinstance(model, DeepSensorModel):
2✔
152
            raise ValueError(
×
153
                "`model` must inherit from DeepSensorModel, but parent "
154
                f"classes are {model.__class__.__bases__}"
155
            )
156

157
        self._validate_n_new_context(X_s, N_new_context)
2✔
158

159
        self.model = model
2✔
160
        self.N_new_context = N_new_context
2✔
161
        self.progress_bar = progress_bar
2✔
162
        self.model_infill_method = model_infill_method
2✔
163
        self.context_set_idx = context_set_idx
2✔
164
        self.target_set_idx = target_set_idx
2✔
165
        self.task_loader = task_loader
2✔
166
        self.pbar = None
2✔
167

168
        self.x1_name = self.model.data_processor.config["coords"]["x1"]["name"]
2✔
169
        self.x2_name = self.model.data_processor.config["coords"]["x2"]["name"]
2✔
170

171
        # Normalised search and target coordinates
172
        if not X_normalised:
2✔
173
            X_t = model.data_processor.map_coords(X_t)
2✔
174
            X_s = model.data_processor.map_coords(X_s)
2✔
175
            if X_s_mask is not None:
2✔
176
                X_s_mask = model.data_processor.map_coords(X_s_mask)
×
177
            if X_t_mask is not None:
2✔
178
                X_t_mask = model.data_processor.map_coords(X_t_mask)
×
179

180
        self.X_s = X_s
2✔
181
        self.X_t = X_t
2✔
182
        self.X_s_mask = X_s_mask
2✔
183
        self.X_t_mask = X_t_mask
2✔
184

185
        # Interpolate masks onto search and target coords
186
        if self.X_s_mask is not None:
2✔
187
            self.X_s_mask = process_X_mask_for_X(self.X_s_mask, self.X_s)
×
188
        if self.X_t_mask is not None:
2✔
189
            self.X_t_mask = process_X_mask_for_X(self.X_t_mask, self.X_t)
×
190

191
        # Interpolate overridden infill datasets at search points if necessary
192
        if query_infill is not None and not da1_da2_same_grid(query_infill, X_s):
2✔
193
            if verbose:
×
194
                print("query_infill not on search grid, interpolating.")
×
195
            query_infill = interp_da1_to_da2(query_infill, self.X_s)
×
196
        if proposed_infill is not None and not da1_da2_same_grid(proposed_infill, X_s):
2✔
197
            if verbose:
×
198
                print("proposed_infill not on search grid, interpolating.")
×
199
            proposed_infill = interp_da1_to_da2(proposed_infill, self.X_s)
×
200
        self.query_infill = query_infill
2✔
201
        self.proposed_infill = proposed_infill
2✔
202

203
        # Convert target coords to numpy arrays and assign to tasks
204
        if isinstance(X_t, (xr.Dataset, xr.DataArray)):
2✔
205
            # Targets on grid
206
            self.X_t_arr = xarray_to_coord_array_normalised(X_t)
2✔
207
            if self.X_t_mask is not None:
2✔
208
                # Remove points that lie outside the mask
209
                self.X_t_arr = mask_coord_array_normalised(self.X_t_arr, self.X_t_mask)
×
210
        elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index)):
×
211
            # Targets off-grid
212
            self.X_t_arr = X_t.reset_index()[["x1", "x2"]].values.T
×
213
        else:
214
            raise TypeError(f"Unsupported type for X_t: {type(X_t)}")
×
215

216
        # Construct search array
217
        if isinstance(X_s, (xr.Dataset, xr.DataArray)):
2✔
218
            X_s_arr = xarray_to_coord_array_normalised(X_s)
2✔
219
            if X_s_mask is not None:
2✔
220
                X_s_arr = mask_coord_array_normalised(X_s_arr, self.X_s_mask)
×
221
        self.X_s_arr = X_s_arr
2✔
222

223
        self.X_new = []  # List of new proposed context locations
2✔
224

225
    @classmethod
2✔
226
    def _validate_n_new_context(
2✔
227
        cls, X_s: Union[xr.Dataset, xr.DataArray], N_new_context: int
228
    ):
229
        if isinstance(X_s, (xr.Dataset, xr.DataArray)):
2✔
230
            if isinstance(X_s, xr.Dataset):
2✔
231
                X_s = X_s.to_array()
×
232
            N_s = X_s.shape[-2] * X_s.shape[-1]
2✔
233
        elif isinstance(X_s, (pd.DataFrame, pd.Series, pd.Index)):
×
234
            N_s = len(X_s)
×
235

236
        if not 0 < N_new_context < N_s:
2✔
237
            raise ValueError(
2✔
238
                f"Number of new context ({N_new_context}) must be greater "
239
                f"than zero and less than the number of search points ({N_s})"
240
            )
241

242
    def _get_times_from_tasks(self):
2✔
243
        """Get times from tasks."""
244
        times = [task["time"] for task in self.tasks]
2✔
245
        # Check for any repeats
246
        if len(times) != len(set(times)):
2✔
247
            # TODO unit test this
248
            raise ValueError(
×
249
                f"The {len(times)} tasks have duplicate times ({len(set(times))} "
250
                f"unique times)"
251
            )
252
        return times
2✔
253

254
    def _model_infill_at_search_points(
2✔
255
        self,
256
        X_s: Union[xr.Dataset, xr.DataArray, pd.DataFrame, pd.Series, pd.Index],
257
    ):
258
        """Computes and model infill y-values over whole search grid."""
259
        if self.model_infill_method == "mean":
2✔
260
            pred = self.model.predict(
2✔
261
                self.tasks,
262
                X_s,
263
                X_t_is_normalised=True,
264
                unnormalise=False,
265
            )
266
            infill_ds = pred[self.target_set_idx]["mean"]
2✔
267

268
        elif self.model_infill_method == "sample":
×
269
            # pred = self.model.predict(
270
            #     self.tasks, X_s, X_t_normalised=True, unnormalise=False,
271
            #     n_samples=self.model_infill_samples,
272
            # )
273
            # infill_ds = pred[self.target_set_idx]["samples"]
274
            raise NotImplementedError("TODO")
×
275

276
        elif self.model_infill_method == "zeros":
×
277
            # TODO generate empty prediction xarray
278
            raise NotImplementedError("TODO")
×
279

280
        else:
281
            raise ValueError(
×
282
                f"Unsupported model_infill_method: {self.model_infill_method}"
283
            )
284

285
        return infill_ds
2✔
286

287
    def _sample_y_infill(self, infill_ds, time, x1, x2):
2✔
288
        """Sample infill values at a single location."""
289
        assert isinstance(infill_ds, (xr.Dataset, xr.DataArray))
2✔
290
        y = infill_ds.sel(time=time, x1=x1, x2=x2)
2✔
291
        if isinstance(y, xr.Dataset):
2✔
292
            y = y.to_array()
×
293
        y = y.data
2✔
294
        if "sample" not in infill_ds.dims:
2✔
295
            return y.reshape(1, y.size)  # 1 observation with N_target_dims
2✔
296
        else:
297
            # TODO confirm or force that dim ordering is (N_samples, N_target_dims)
298
            return y
×
299

300
    def _build_acquisition_fn_ds(self, X_s: Union[xr.Dataset, xr.DataArray]):
2✔
301
        """Initialise xr.DataArray for storing acquisition function values on
302
        search grid.
303
        """
304
        prepend_dims = ["iteration"]  # , "sample"]  # MC sample TODO
2✔
305
        prepend_coords = {
2✔
306
            "iteration": range(self.N_new_context),
307
            # "sample": range(self.n_samples_or_1),  # MC sample TODO
308
        }
309
        acquisition_fn_ds = create_empty_spatiotemporal_xarray(
2✔
310
            X=X_s,
311
            dates=self._get_times_from_tasks(),
312
            coord_names={"x1": self.x1_name, "x2": self.x2_name},
313
            data_vars=["acquisition_fn"],
314
            prepend_dims=prepend_dims,
315
            prepend_coords=prepend_coords,
316
        )["acquisition_fn"]
317
        acquisition_fn_ds.data[:] = np.nan
2✔
318

319
        return acquisition_fn_ds
2✔
320

321
    def _init_acquisition_fn_object(self, X_s: xr.Dataset):
2✔
322
        """Instantiate acquisition function object."""
323
        # Unnormalise before instantiating
324
        X_s = self.model.data_processor.map_coords(X_s, unnorm=True)
2✔
325
        if isinstance(X_s, (xr.Dataset, xr.DataArray)):
2✔
326
            # xr.Dataset storing acquisition function values
327
            self.acquisition_fn_ds = self._build_acquisition_fn_ds(X_s)
2✔
328
        elif isinstance(X_s, (pd.DataFrame, pd.Series, pd.Index)):
×
329
            raise NotImplementedError(
×
330
                "Pandas support for active learning search points X_s not yet "
331
                "implemented."
332
            )
333
        else:
334
            raise TypeError(f"Unsupported type for X_s: {type(X_s)}")
×
335

336
    def _search(self, acquisition_fn: AcquisitionFunction):
2✔
337
        """Run one greedy pass by looping over each point in ``X_s`` and
338
        computing the acquisition function.
339
        """
340
        importances_list = []
2✔
341

342
        for task in self.tasks:
2✔
343
            # Parallel computation
344
            if isinstance(acquisition_fn, AcquisitionFunctionParallel):
2✔
345
                importances = acquisition_fn(task, self.X_s_arr)
2✔
346
                if self.pbar:
2✔
347
                    self.pbar.update(1)
2✔
348

349
            # Sequential computation
350
            elif isinstance(acquisition_fn, AcquisitionFunction):
2✔
351
                importances = []
2✔
352

353
                if self.diff:
2✔
354
                    importance_bef = acquisition_fn(task)
×
355

356
                # Add size-1 dim after row dim to preserve row dim for passing to
357
                #   acquisition_fn. Also roll final axis to first axis for looping over search points.
358
                for x_query in np.rollaxis(self.X_s_arr[:, np.newaxis], 2):
2✔
359
                    y_query = self._sample_y_infill(
2✔
360
                        self.query_infill,
361
                        time=task["time"],
362
                        x1=x_query[0],
363
                        x2=x_query[1],
364
                    )
365
                    task_with_new = append_obs_to_task(
2✔
366
                        task, x_query, y_query, self.context_set_idx
367
                    )
368
                    # TODO this is a hack to add the auxiliary variable to the context set
369
                    if (
2✔
370
                        self.task_loader is not None
371
                        and self.task_loader.aux_at_contexts
372
                    ):
373
                        # Add auxiliary variable sampled at context set as a new context variable
374
                        X_c = task_with_new["X_c"][self.task_loader.aux_at_contexts[0]]
×
375
                        Y_c_aux = self.task_loader.sample_offgrid_aux(
×
376
                            X_c, self.task_loader.aux_at_contexts[1]
377
                        )
378
                        task_with_new["X_c"][-1] = X_c
×
379
                        task_with_new["Y_c"][-1] = Y_c_aux
×
380

381
                    importance = acquisition_fn(task_with_new)
2✔
382

383
                    if self.diff:
2✔
384
                        importance = importance - importance_bef
×
385

386
                    importances.append(importance)
2✔
387

388
                    if self.pbar:
2✔
389
                        self.pbar.update(1)
2✔
390

391
            else:
392
                allowed_classes = [
×
393
                    AcquisitionFunction,
394
                    AcquisitionFunctionParallel,
395
                    AcquisitionFunctionOracle,
396
                ]
397
                raise ValueError(
×
398
                    f"Acquisition function needs to inherit from one of {allowed_classes}."
399
                )
400

401
            importances = np.array(importances)
2✔
402
            importances_list.append(importances)
2✔
403

404
            if self.X_s_mask is not None:
2✔
405
                self.acquisition_fn_ds.loc[self.iteration, task["time"]].data[
×
406
                    self.X_s_mask.data
407
                ] = importances
408
            else:
409
                self.acquisition_fn_ds.loc[self.iteration, task["time"]] = (
2✔
410
                    importances.reshape(self.acquisition_fn_ds.shape[-2:])
411
                )
412

413
        return np.mean(importances_list, axis=0)
2✔
414

415
    def _select_best(self, importances, X_s_arr):
2✔
416
        """Select context location corresponding to the best importance value.
417

418
        Appends the chosen search index to a list of chosen search indexes.
419
        """
420
        if self.min_or_max == "min":
2✔
421
            best_idx = np.argmin(importances)
2✔
422
        elif self.min_or_max == "max":
2✔
423
            best_idx = np.argmax(importances)
2✔
424

425
        best_x_query = X_s_arr[:, best_idx : best_idx + 1]
2✔
426

427
        # Index into original search space of chosen context location
428
        self.best_idxs_all.append(
2✔
429
            np.where((self.X_s_arr == best_x_query).all(axis=0))[0][0]
430
        )
431

432
        return best_x_query
2✔
433

434
    def _single_greedy_iteration(self, acquisition_fn: AcquisitionFunction):
2✔
435
        """Run a single greedy grid search iteration and append the optimal
436
        context location to self.X_new.
437
        """
438
        importances = self._search(acquisition_fn)
2✔
439
        best_x_query = self._select_best(importances, self.X_s_arr)
2✔
440

441
        self.X_new.append(best_x_query)
2✔
442

443
        return best_x_query
2✔
444

445
    def __call__(
2✔
446
        self,
447
        acquisition_fn: AcquisitionFunction,
448
        tasks: Union[List[Task], Task],
449
        diff: bool = False,
450
    ) -> Tuple[pd.DataFrame, xr.Dataset]:
451
        """Iteratively propose new context points using the greedy sensor placement algorithm.
452

453
        Args:
454
            acquisition_fn (:class:`~.active_learning.acquisition_fns.AcquisitionFunction`):
455
                The acquisition function to optimise.
456
            tasks (List[:class:`~.data.task.Task`] | :class:`~.data.task.Task`):
457
                Tasks containing existing context data. If a list of Tasks, the acquisition
458
                function will be averaged over Tasks.
459
            diff (bool, optional):
460
                For sequential acquisition functions only: Whether to compute the *change* in
461
                acquisition function value after adding the new context point, i.e.
462
                ``acquisition_fn(task_with_new) - acquisition_fn(task)``. Can be useful
463
                for making the acquisition function values more interpretable, or for
464
                comparing with the change in metric that the acquisition function targets
465
                (see https://doi.org/10.1017/eds.2023.22). Defaults to False.
466

467
        Returns:
468
            Tuple[:class:`pandas.DataFrame`, :class:`xarray.DataArray`]:
469
                A tuple containing two objects:
470

471
                - **X_new_df** (:class:`pandas.DataFrame`):
472
                Proposed sensor placements. Columns are the x1 and x2 coordinates of the
473
                sensor placements, and the index is the index of the greedy iteration
474
                at which the sensor placement was proposed (which can be interpreted as
475
                a priority order, with iteration 0 being the highest priority).
476

477
                - **acquisition_fn_ds** (:class:`xarray.DataArray`):
478
                Gridded acquisition function values at each search point. Dimensions
479
                are ``iteration``, ``time`` (inferred from the input ``tasks``), followed
480
                by the x1 and x2 coordinates of the spatial grid.
481

482
        Raises:
483
            ValueError:
484
                If ``acquisition_fn`` is an
485
                :class:`~.active_learning.acquisition_fns.AcquisitionFunctionOracle`
486
                and ``task_loader`` is None.
487
            ValueError:
488
                If ``min_or_max`` is not ``"min"`` or ``"max"``.
489
            ValueError:
490
                If ``Y_t_aux`` is in ``tasks`` but ``task_loader`` is None.
491
        """
492
        if (
2✔
493
            isinstance(acquisition_fn, AcquisitionFunctionOracle)
494
            and self.task_loader is None
495
        ):
496
            raise ValueError(
2✔
497
                "AcquisitionFunctionOracle requires a task_loader function to "
498
                "be passed to the GreedyOptimal constructor."
499
            )
500

501
        self.min_or_max = acquisition_fn.min_or_max
2✔
502
        if self.min_or_max not in ["min", "max"]:
2✔
503
            raise ValueError(
2✔
504
                f"min_or_max must be either 'min' or 'max', got " f"{self.min_or_max}."
505
            )
506

507
        if diff and isinstance(acquisition_fn, AcquisitionFunctionParallel):
2✔
508
            raise ValueError(
2✔
509
                "diff=True is not valid for parallel acquisition functions."
510
            )
511
        self.diff = diff
2✔
512

513
        if isinstance(tasks, Task):
2✔
514
            tasks = [tasks]
2✔
515

516
        # Make deepcopys so that original tasks are not modified
517
        tasks = copy.deepcopy(tasks)
2✔
518

519
        # Add target set to tasks
520
        for i, task in enumerate(tasks):
2✔
521
            tasks[i]["X_t"] = [self.X_t_arr]
2✔
522
            if isinstance(acquisition_fn, AcquisitionFunctionOracle):
2✔
523
                # Sample ground truth y-values at target points `self.X_t_arr` using `self.task_loader`
524
                date = tasks[i]["time"]
2✔
525
                task_with_Y_t = self.task_loader(
2✔
526
                    date, context_sampling=0, target_sampling=self.X_t_arr
527
                )
528
                tasks[i]["Y_t"] = task_with_Y_t["Y_t"]
2✔
529

530
            if "Y_t_aux" in tasks[i] and self.task_loader is None:
2✔
531
                raise ValueError(
2✔
532
                    "Model expects Y_t_aux data but a TaskLoader isn't "
533
                    "provided to GreedyAlgorithm."
534
                )
535
            if self.task_loader is not None and self.task_loader.aux_at_target_dims > 0:
2✔
536
                tasks[i]["Y_t_aux"] = self.task_loader.sample_offgrid_aux(
2✔
537
                    self.X_t_arr, self.task_loader.aux_at_targets
538
                )
539

540
        self.tasks = tasks
2✔
541

542
        # Generate infill values at search points if not overridden
543
        if self.query_infill is None or self.proposed_infill is None:
2✔
544
            model_infill = self._model_infill_at_search_points(self.X_s)
2✔
545
            if self.query_infill is None:
2✔
546
                self.query_infill = model_infill
2✔
547
            if self.proposed_infill is None:
2✔
548
                self.proposed_infill = model_infill
2✔
549

550
        # Instantiate empty acquisition function object
551
        self._init_acquisition_fn_object(self.X_s)
2✔
552

553
        # Dataframe for storing proposed context locations
554
        self.X_new_df = pd.DataFrame(columns=[self.x1_name, self.x2_name])
2✔
555
        self.X_new_df.index.name = "iteration"
2✔
556

557
        # List to track indexes into original search grid of chosen sensor locations
558
        #   as optimisation progresses. Used for filling y-values at chosen
559
        #   sensor locations, `self.X_new`
560
        self.best_idxs_all = []
2✔
561

562
        # Total iterations are number of new context points * number of tasks * number of search
563
        #   points (if not parallel) * number of Monte Carlo samples (if using MC)
564
        total_iterations = self.N_new_context * len(self.tasks)
2✔
565
        if not isinstance(acquisition_fn, AcquisitionFunctionParallel):
2✔
566
            total_iterations *= self.X_s_arr.shape[-1]
2✔
567
        # TODO make class attribute for list of sample-based infill methods
568
        if self.model_infill_method in ["sample", "ar_sample"]:
2✔
569
            total_iterations *= self.n_samples
×
570

571
        with tqdm(total=total_iterations, disable=not self.progress_bar) as self.pbar:
2✔
572
            for iteration in range(self.N_new_context):
2✔
573
                self.iteration = iteration
2✔
574
                x_new = self._single_greedy_iteration(acquisition_fn)
2✔
575

576
                # Append new proposed context points to each task
577
                for i, task in enumerate(self.tasks):
2✔
578
                    y_new = self._sample_y_infill(
2✔
579
                        self.proposed_infill,
580
                        time=task["time"],
581
                        x1=x_new[0],
582
                        x2=x_new[1],
583
                    )
584
                    self.tasks[i] = append_obs_to_task(
2✔
585
                        task, x_new, y_new, self.context_set_idx
586
                    )
587

588
                # Append new proposed context points to dataframe
589
                x_new_unnorm = self.model.data_processor.map_coord_array(
2✔
590
                    x_new, unnorm=True
591
                )
592
                self.X_new_df.loc[self.iteration] = x_new_unnorm.ravel()
2✔
593

594
        return self.X_new_df, self.acquisition_fn_ds
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc