• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ghiggi / ximage / 8525291162

02 Apr 2024 03:18PM UTC coverage: 96.341% (+0.2%) from 96.149%
8525291162

push

github

ghiggi
Fix typo

790 of 820 relevant lines covered (96.34%)

0.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.47
/ximage/patch/labels_patch.py
1
# -----------------------------------------------------------------------------.
2
# MIT License
3

4
# Copyright (c) 2024 ximage developers
5
#
6
# This file is part of ximage.
7

8
# Permission is hereby granted, free of charge, to any person obtaining a copy
9
# of this software and associated documentation files (the "Software"), to deal
10
# in the Software without restriction, including without limitation the rights
11
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
# copies of the Software, and to permit persons to whom the Software is
13
# furnished to do so, subject to the following conditions:
14
#
15
# The above copyright notice and this permission notice shall be included in all
16
# copies or substantial portions of the Software.
17
#
18
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25

26
# -----------------------------------------------------------------------------.
27
"""Functions to extract patch around labels."""
1✔
28
import random
1✔
29
import warnings
1✔
30
from typing import Callable, Union
1✔
31

32
import matplotlib.pyplot as plt
1✔
33
import numpy as np
1✔
34
import xarray as xr
1✔
35

36
from ximage.labels.labels import highlight_label
1✔
37
from ximage.patch.checks import (
1✔
38
    are_all_natural_numbers,
39
    check_buffer,
40
    check_kernel_size,
41
    check_padding,
42
    check_partitioning_method,
43
    check_patch_size,
44
    check_stride,
45
)
46
from ximage.patch.plot2d import plot_label_patch_extraction_areas
1✔
47
from ximage.patch.slices import (
1✔
48
    enlarge_slices,
49
    get_nd_partitions_list_slices,
50
    get_slice_around_index,
51
    get_slice_from_idx_bounds,
52
    pad_slices,
53
)
54

55
# -----------------------------------------------------------------------------.
56
#### TODOs
57
## Partitioning
58
# - Option to bound min_start and max_stop to labels bbox
59
# - Option to define min_start and max_stop to be divisible by patch_size + stride
60
# - When tiling ... define start so to center tiles around label_bbox, instead of starting at label bbox start
61
# - Option: partition_only_when_label_bbox_exceed_patch_size
62

63
# - Add option that returns a flag if the point center is the actual identified one,
64
#   or was close to the boundary !
65

66
# -----------------------------------------------------------------------------.
67
# - Implement dilate option (to subset pixel within partitions).
68
#   --> slice(start, stop, step=dilate) ... with patch_size redefined at start to patch_size*dilate
69
#   --> Need updates of enlarge slcies, pad_slices utilities (but first test current usage !)
70

71
# -----------------------------------------------------------------------------.
72

73
## Image sliding/tiling reconstruction
74
# - get_index_overlapping_slices
75
# - trim: bool, keyword only
76
#   Whether or not to trim stride elements from each block after calling the map function.
77
#   Set this to False if your mapping function already does this for you.
78
#   This for when merging !
79

80
####--------------------------------------------------------------------------.
81

82

83
def _check_label_arr(label_arr):
1✔
84
    """Check label_arr."""
85
    # Note: If label array is all zero or nan, labels_id will be []
86

87
    # Put label array in memory
88
    label_arr = np.asanyarray(label_arr)
1✔
89

90
    # Set 0 label to nan
91
    label_arr = label_arr.astype(float)  # otherwise if int throw an error when assigning nan
1✔
92
    label_arr[label_arr == 0] = np.nan
1✔
93

94
    # Check labels_id are natural number >= 1
95
    valid_labels = np.unique(label_arr[~np.isnan(label_arr)])
1✔
96
    if not are_all_natural_numbers(valid_labels):
1✔
97
        raise ValueError("The label array contains non positive natural numbers.")
1✔
98

99
    return label_arr
1✔
100

101

102
def _check_labels_id(labels_id, label_arr):
1✔
103
    """Check labels_id."""
104
    # Check labels_id type
105
    if not isinstance(labels_id, (type(None), int, list, np.ndarray)):
1✔
106
        raise TypeError("labels_id must be None or a list or a np.array.")
1✔
107
    if isinstance(labels_id, int):
1✔
108
        labels_id = [labels_id]
1✔
109
    # Get list of valid labels
110
    valid_labels = np.unique(label_arr[~np.isnan(label_arr)]).astype(int)
1✔
111
    # If labels_id is None, assign the valid_labels
112
    if isinstance(labels_id, type(None)):
1✔
113
        return valid_labels
1✔
114
    # If input labels_id is a list, make it a np.array
115
    labels_id = np.array(labels_id).astype(int)
1✔
116
    # Check labels_id are natural number >= 1
117
    if np.any(labels_id == 0):
1✔
118
        raise ValueError("labels id must not contain the 0 value.")
1✔
119
    if not are_all_natural_numbers(labels_id):
1✔
120
        raise ValueError("labels id must be positive natural numbers.")
1✔
121
    # Check labels_id are number present in the label_arr
122
    invalid_labels = labels_id[~np.isin(labels_id, valid_labels)]
1✔
123
    if invalid_labels.size != 0:
1✔
124
        invalid_labels = invalid_labels.astype(int)
1✔
125
        raise ValueError(f"The following labels id are not valid: {invalid_labels}")
1✔
126
    # If no labels, no patch to extract
127
    n_labels = len(labels_id)
1✔
128
    if n_labels == 0:
1✔
129
        raise ValueError("No labels available.")
1✔
130
    return labels_id
1✔
131

132

133
def _check_n_patches_per_partition(n_patches_per_partition, centered_on):
1✔
134
    """
135
    Check the number of patches to extract from each partition.
136

137
    It is used only if centered_on is a callable or 'random'
138

139
    Parameters
140
    ----------
141
    n_patches_per_partition : int
142
        Number of patches to extract from each partition.
143
    centered_on : (str, callable)
144
        Method to extract the patch around a label point.
145

146
    Returns
147
    -------
148
    n_patches_per_partition: int
149
       The number of patches to extract from each partition.
150
    """
151
    if n_patches_per_partition < 1:
1✔
152
        raise ValueError("n_patches_per_partitions must be a positive integer.")
1✔
153
    if isinstance(centered_on, str) and centered_on not in ["random"] and n_patches_per_partition > 1:
1✔
154
        raise ValueError(
1✔
155
            "Only the pre-implemented centered_on='random' method allow n_patches_per_partition values > 1.",
156
        )
157
    return n_patches_per_partition
1✔
158

159

160
def _check_n_patches_per_label(n_patches_per_label, n_patches_per_partition):
1✔
161
    if n_patches_per_label < n_patches_per_partition:
1✔
162
        raise ValueError("n_patches_per_label must be equal or larger to n_patches_per_partition.")
1✔
163
    return n_patches_per_label
1✔
164

165

166
def _check_callable_centered_on(centered_on):
1✔
167
    """Check validity of callable centered_on."""
168
    input_shape = (2, 3)
1✔
169
    arr = np.zeros(input_shape)
1✔
170
    point = centered_on(arr)
1✔
171
    if not isinstance(point, (tuple, type(None))):
1✔
172
        raise ValueError("The 'centered_on' function should return a point coordinates tuple or None.")
1✔
173
    if len(point) != len(input_shape):
1✔
174
        raise ValueError(
1✔
175
            "The 'centered_on' function should return point coordinates having same dimensions has input array.",
176
        )
177
    for c, max_value in zip(point, input_shape):
1✔
178
        if c < 0:
1✔
179
            raise ValueError("The point coordinate must be a positive integer.")
1✔
180
        if c >= max_value:
1✔
181
            raise ValueError("The point coordinate must be inside the array shape.")
1✔
182
        if np.isnan(c):
1✔
183
            raise ValueError("The point coordinate must not be np.nan.")
1✔
184
    # Check case with nan array
185
    try:
1✔
186
        point = centered_on(arr * np.nan)
1✔
187
    except Exception as err:
1✔
188
        raise ValueError(f"The 'centered_on' function should be able to deal with a np.nan ndarray. Error is {err}.")
1✔
189
    if point is not None:
1✔
190
        raise ValueError("The 'centered_on' function should return None if the input array is a np.nan ndarray.")
1✔
191

192

193
def _check_centered_on(centered_on):
1✔
194
    """Check valid centered_on to identify a point in an array."""
195
    if not (callable(centered_on) or isinstance(centered_on, str)):
1✔
196
        raise TypeError("'centered_on' must be a string or a function.")
1✔
197
    if isinstance(centered_on, str):
1✔
198
        valid_centered_on = [
1✔
199
            "max",
200
            "min",
201
            "centroid",
202
            "center_of_mass",
203
            "random",
204
            "label_bbox",  # unfixed patch_size
205
        ]
206
        if centered_on not in valid_centered_on:
1✔
207
            raise ValueError(f"Valid 'centered_on' values are: {valid_centered_on}.")
1✔
208

209
    if callable(centered_on):
1✔
210
        _check_callable_centered_on(centered_on)
1✔
211
    return centered_on
1✔
212

213

214
def _get_variable_arr(xr_obj, variable, centered_on):
1✔
215
    """Get variable array (in memory)."""
216
    if isinstance(xr_obj, xr.DataArray):
1✔
217
        return np.asanyarray(xr_obj.data)
1✔
218
    if centered_on is not None and variable is None and (centered_on in ["max", "min"] or callable(centered_on)):
1✔
219
        raise ValueError("'variable' must be specified if 'centered_on' is specified.")
1✔
220
    return np.asanyarray(xr_obj[variable].data) if variable is not None else None
1✔
221

222

223
def _check_variable_arr(variable_arr, label_arr):
1✔
224
    """Check variable array validity."""
225
    if variable_arr is not None and variable_arr.shape != label_arr.shape:
1✔
226
        raise ValueError("Arrays corresponding to 'variable' and 'label_name' must have same shape.")
1✔
227
    return variable_arr
1✔
228

229

230
def _get_point_centroid(arr):
1✔
231
    """Get the coordinate of label bounding box center.
232

233
    It assumes that the array has been cropped around the label.
234
    It returns None if all values are non-finite (i.e. np.nan).
235
    """
236
    if np.all(~np.isfinite(arr)):
1✔
237
        return None
1✔
238
    centroid = np.array(arr.shape) / 2.0
1✔
239
    return tuple(centroid.tolist())
1✔
240

241

242
def _get_point_random(arr):
1✔
243
    """Get random point with finite value."""
244
    is_finite = np.isfinite(arr)
1✔
245
    if np.all(~is_finite):
1✔
246
        return None
1✔
247
    points = np.argwhere(is_finite)
1✔
248
    return random.choice(points)
1✔
249

250

251
def _get_point_with_max_value(arr):
1✔
252
    """Get point with maximum value."""
253
    with warnings.catch_warnings():
1✔
254
        warnings.simplefilter("ignore", category=RuntimeWarning)
1✔
255
        point = np.argwhere(arr == np.nanmax(arr))
1✔
256
    return None if len(point) == 0 else tuple(point[0].tolist())
1✔
257

258

259
def _get_point_with_min_value(arr):
1✔
260
    """Get point with minimum value."""
261
    with warnings.catch_warnings():
1✔
262
        warnings.simplefilter("ignore", category=RuntimeWarning)
1✔
263
        point = np.argwhere(arr == np.nanmin(arr))
1✔
264
    return None if len(point) == 0 else tuple(point[0].tolist())
1✔
265

266

267
def _get_point_center_of_mass(arr, integer_index=True):
1✔
268
    """Get the coordinate of the label center of mass.
269

270
    It uses all cells which have finite values.
271
    If 0 value should be a non-label area, mask before with np.nan.
272
    It returns None if all values are non-finite (i.e. np.nan).
273
    """
274
    indices = np.argwhere(np.isfinite(arr))
1✔
275
    if len(indices) == 0:
1✔
276
        return None
1✔
277
    center_of_mass = np.nanmean(indices, axis=0)
1✔
278
    if integer_index:
1✔
279
        center_of_mass = center_of_mass.round().astype(int)
1✔
280
    return tuple(center_of_mass.tolist())
1✔
281

282

283
def find_point(arr, centered_on: Union[str, Callable] = "max"):
1✔
284
    """Find a specific point coordinate of the array.
285

286
    If the coordinate can't be find, return None.
287
    """
288
    centered_on = _check_centered_on(centered_on)
1✔
289

290
    if centered_on == "max":
1✔
291
        point = _get_point_with_max_value(arr)
1✔
292
    elif centered_on == "min":
1✔
293
        point = _get_point_with_min_value(arr)
1✔
294
    elif centered_on == "centroid":
1✔
295
        point = _get_point_centroid(arr)
1✔
296
    elif centered_on == "center_of_mass":
1✔
297
        point = _get_point_center_of_mass(arr)
1✔
298
    elif centered_on == "random":
1✔
299
        point = _get_point_random(arr)
1✔
300
    else:  # callable centered_on
301
        point = centered_on(arr)
1✔
302
    if point is not None:
1✔
303
        point = tuple(int(p) for p in point)
1✔
304
    return point
1✔
305

306

307
def _get_labels_bbox_slices(arr):
1✔
308
    """
309
    Compute the bounding box slices of non-zero elements in a n-dimensional numpy array.
310

311
    Assume that only one unique non-zero elements values is present in the array.
312
    Assume that NaN and Inf have been replaced by zeros.
313

314
    Other implementations: scipy.ndimage.find_objects
315

316
    Parameters
317
    ----------
318
    arr : np.ndarray
319
        n-dimensional numpy array.
320

321
    Returns
322
    -------
323
    list_slices : list
324
        List of slices to extract the region with non-zero elements in the input array.
325
    """
326
    # Return None if all values are zeros
327
    if not np.any(arr):
1✔
328
        return None
1✔
329
    ndims = arr.ndim
1✔
330
    coords = np.nonzero(arr)
1✔
331
    return [get_slice_from_idx_bounds(np.min(coords[i]), np.max(coords[i])) for i in range(ndims)]
1✔
332

333

334
def _get_patch_list_slices_around_label_point(
1✔
335
    label_arr,
336
    label_id,
337
    variable_arr,
338
    patch_size,
339
    centered_on,
340
):
341
    """Get list_slices to extract patch around a label point.
342

343
    Assume label_arr must match variable_arr shape.
344
    Assume patch_size shape must match variable_arr shape .
345
    """
346
    # Subset variable_arr around label
347
    list_slices = _get_labels_bbox_slices(label_arr == label_id)
1✔
348
    if list_slices is None:
1✔
349
        return None
1✔
350
    label_subset_arr = label_arr[tuple(list_slices)]
1✔
351
    variable_subset_arr = variable_arr[tuple(list_slices)]
1✔
352
    variable_subset_arr = np.asarray(variable_subset_arr)  # if dask, make numpy
1✔
353
    # Mask variable arr outside the label
354
    variable_subset_arr[label_subset_arr != label_id] = np.nan
1✔
355
    # Find point of subset array
356
    point_subset_arr = find_point(arr=variable_subset_arr, centered_on=centered_on)
1✔
357
    # Define patch list_slices
358
    if point_subset_arr is not None:
1✔
359
        # Find point in original array
360
        point = [slc.start + c for slc, c in zip(list_slices, point_subset_arr)]
1✔
361
        # Find patch list slices
362
        patch_list_slices = [
1✔
363
            get_slice_around_index(p, size=size, min_start=0, max_stop=shape)
364
            for p, size, shape in zip(point, patch_size, variable_arr.shape)
365
        ]
366
        # TODO: also return a flag if the p midpoint is conserved (by +/- 1) or not
367
    else:
368
        patch_list_slices = None
1✔
369
    return patch_list_slices
1✔
370

371

372
def _get_patch_list_slices_around_label(label_arr, label_id, padding, min_patch_size):
1✔
373
    """Get list_slices to extract patch around a label."""
374
    # Get label bounding box slices
375
    list_slices = _get_labels_bbox_slices(label_arr == label_id)
1✔
376
    if list_slices is None:
1✔
377
        return None
1✔
378
    # Apply padding to the slices
379
    list_slices = pad_slices(list_slices, padding=padding, valid_shape=label_arr.shape)
1✔
380
    # Increase slices to match min_patch_size
381
    return enlarge_slices(list_slices, min_size=min_patch_size, valid_shape=label_arr.shape)
1✔
382

383

384
def _get_patch_list_slices(label_arr, label_id, variable_arr, patch_size, centered_on, padding):
1✔
385
    """Get patch n-dimensional list slices."""
386
    if not callable(centered_on) and centered_on == "label_bbox":
1✔
387
        list_slices = _get_patch_list_slices_around_label(
1✔
388
            label_arr=label_arr,
389
            label_id=label_id,
390
            padding=padding,
391
            min_patch_size=patch_size,
392
        )
393
    else:
394
        list_slices = _get_patch_list_slices_around_label_point(
1✔
395
            label_arr=label_arr,
396
            label_id=label_id,
397
            variable_arr=variable_arr,
398
            patch_size=patch_size,
399
            centered_on=centered_on,
400
        )
401
    return list_slices
1✔
402

403

404
def _get_masked_arrays(label_arr, variable_arr, partition_list_slices):
1✔
405
    """Mask labels and variable arrays outside the partitions area."""
406
    masked_partition_label_arr = np.zeros(label_arr.shape) * np.nan
1✔
407
    masked_partition_label_arr[tuple(partition_list_slices)] = label_arr[tuple(partition_list_slices)]
1✔
408
    if variable_arr is not None:
1✔
409
        masked_partition_variable_arr = np.zeros(variable_arr.shape) * np.nan
1✔
410
        masked_partition_variable_arr[tuple(partition_list_slices)] = variable_arr[tuple(partition_list_slices)]
1✔
411
    else:
412
        masked_partition_variable_arr = None
×
413
    return masked_partition_label_arr, masked_partition_variable_arr
1✔
414

415

416
def _get_patches_from_partitions_list_slices(
1✔
417
    partitions_list_slices,
418
    label_arr,
419
    variable_arr,
420
    label_id,
421
    patch_size,
422
    centered_on,
423
    n_patches_per_partition,
424
    padding,
425
    verbose=False,
426
):
427
    """Return patches list slices from list of partitions list_slices.
428

429
    n_patches_per_partition is 1 unless centered_on is 'random' or a callable.
430
    """
431
    patches_list_slices = []
1✔
432
    for partition_list_slices in partitions_list_slices:
1✔
433
        if verbose:
1✔
434
            print(f" -  partition: {partition_list_slices}")
1✔
435
        masked_label_arr, masked_variable_arr = _get_masked_arrays(
1✔
436
            label_arr=label_arr,
437
            variable_arr=variable_arr,
438
            partition_list_slices=partition_list_slices,
439
        )
440
        n = 0
1✔
441
        for n in range(n_patches_per_partition):
1✔
442
            patch_list_slices = _get_patch_list_slices(
1✔
443
                label_arr=masked_label_arr,
444
                variable_arr=masked_variable_arr,
445
                label_id=label_id,
446
                patch_size=patch_size,
447
                centered_on=centered_on,
448
                padding=padding,
449
            )
450
            if patch_list_slices is not None and patch_list_slices not in patches_list_slices:
1✔
451
                n += 1  # noqa PLW2901
1✔
452
                patches_list_slices.append(patch_list_slices)
1✔
453
    return patches_list_slices
1✔
454

455

456
def _get_list_isel_dicts(patches_list_slices, dims):
1✔
457
    """Return a list with isel dictionaries."""
458
    return [dict(zip(dims, patch_list_slices)) for patch_list_slices in patches_list_slices]
1✔
459

460

461
def _extract_xr_patch(xr_obj, isel_dict, label_name, label_id, highlight_label_id):
1✔
462
    """Extract a xarray patch."""
463
    # Extract xarray patch around label
464
    xr_obj_patch = xr_obj.isel(isel_dict)
1✔
465

466
    # If asked, set label array to 0 except for label_id
467
    if highlight_label_id:
1✔
468
        xr_obj_patch = highlight_label(xr_obj_patch, label_name=label_name, label_id=label_id)
1✔
469
    return xr_obj_patch
1✔
470

471

472
def _get_patches_isel_dict_generator(
1✔
473
    xr_obj,
474
    label_name,
475
    patch_size,
476
    variable=None,
477
    # Output options
478
    n_patches=np.Inf,
479
    n_labels=None,
480
    labels_id=None,
481
    grouped_by_labels_id=False,
482
    # (Tile) label patch extraction
483
    padding=0,
484
    centered_on="max",
485
    n_patches_per_label=np.Inf,
486
    n_patches_per_partition=1,
487
    debug=False,
488
    # Label Tiling/Sliding Options
489
    partitioning_method=None,
490
    n_partitions_per_label=None,
491
    kernel_size=None,
492
    buffer=0,
493
    stride=None,
494
    include_last=True,
495
    ensure_slice_size=True,
496
    verbose=False,
497
):
498
    # Get label array information
499
    label_arr = xr_obj[label_name].data
1✔
500
    dims = xr_obj[label_name].dims
1✔
501
    shape = label_arr.shape
1✔
502

503
    # Check input arguments
504
    if n_labels is not None and labels_id is not None:
1✔
505
        raise ValueError("Specify either n_labels or labels_id.")
1✔
506
    if kernel_size is None:
1✔
507
        kernel_size = patch_size
1✔
508
    patch_size = check_patch_size(patch_size, dims, shape)
1✔
509
    buffer = check_buffer(buffer, dims, shape)
1✔
510
    padding = check_padding(padding, dims, shape)
1✔
511

512
    partitioning_method = check_partitioning_method(partitioning_method)
1✔
513
    stride = check_stride(stride, dims, shape, partitioning_method)
1✔
514
    kernel_size = check_kernel_size(kernel_size, dims, shape)
1✔
515

516
    centered_on = _check_centered_on(centered_on)
1✔
517
    n_patches_per_partition = _check_n_patches_per_partition(n_patches_per_partition, centered_on)
1✔
518
    n_patches_per_label = _check_n_patches_per_label(n_patches_per_label, n_patches_per_partition)
1✔
519

520
    label_arr = _check_label_arr(label_arr)  # output is np.array !
1✔
521
    labels_id = _check_labels_id(labels_id=labels_id, label_arr=label_arr)
1✔
522
    variable_arr = _get_variable_arr(xr_obj, variable, centered_on)  # if required
1✔
523
    variable_arr = _check_variable_arr(variable_arr, label_arr)
1✔
524

525
    # Define number of labels from which to extract patches
526
    available_n_labels = len(labels_id)
1✔
527
    n_labels = min(available_n_labels, n_labels) if n_labels else available_n_labels
1✔
528
    if verbose:
1✔
529
        print(f"Extracting patches from {n_labels} labels.")
1✔
530
    # -------------------------------------------------------------------------.
531
    # Extract patch(es) around the label
532
    patch_counter = 0
1✔
533
    break_flag = False
1✔
534
    for i, label_id in enumerate(labels_id[0:n_labels]):
1✔
535
        if verbose:
1✔
536
            print(f"Label ID: {label_id} ({i}/{n_labels})")
1✔
537

538
        # Subset label_arr around the given label
539
        label_bbox_slices = _get_labels_bbox_slices(label_arr == label_id)
1✔
540

541
        # Apply padding to the label bounding box
542
        label_bbox_slices = pad_slices(label_bbox_slices, padding=padding.values(), valid_shape=label_arr.shape)
1✔
543

544
        # --------------------------------------------------------------------.
545
        # Retrieve partitions list_slices
546
        if partitioning_method is not None:
1✔
547
            partitions_list_slices = get_nd_partitions_list_slices(
1✔
548
                label_bbox_slices,
549
                arr_shape=label_arr.shape,
550
                method=partitioning_method,
551
                kernel_size=list(kernel_size.values()),
552
                stride=list(stride.values()),
553
                buffer=list(buffer.values()),
554
                include_last=include_last,
555
                ensure_slice_size=ensure_slice_size,
556
            )
557
            if n_partitions_per_label is not None:
1✔
558
                n_to_select = min(len(partitions_list_slices), n_partitions_per_label)
1✔
559
                partitions_list_slices = partitions_list_slices[0:n_to_select]
1✔
560
        else:
561
            partitions_list_slices = [label_bbox_slices]
1✔
562

563
        # --------------------------------------------------------------------.
564
        # Retrieve patches list_slices from partitions list slices
565
        patches_list_slices = _get_patches_from_partitions_list_slices(
1✔
566
            partitions_list_slices=partitions_list_slices,
567
            label_arr=label_arr,
568
            variable_arr=variable_arr,
569
            label_id=label_id,
570
            patch_size=list(patch_size.values()),
571
            centered_on=centered_on,
572
            n_patches_per_partition=n_patches_per_partition,
573
            padding=list(padding.values()),
574
            verbose=verbose,
575
        )
576

577
        # ---------------------------------------------------------------------.
578
        # Retrieve patches isel_dictionaries
579
        partitions_isel_dicts = _get_list_isel_dicts(partitions_list_slices, dims=dims)
1✔
580
        patches_isel_dicts = _get_list_isel_dicts(patches_list_slices, dims=dims)
1✔
581

582
        n_to_select = min(len(patches_isel_dicts), n_patches_per_label)
1✔
583
        patches_isel_dicts = patches_isel_dicts[0:n_to_select]
1✔
584

585
        # --------------------------------------------------------------------.
586
        # If debug=True, plot patches boundaries
587
        if debug and label_arr.ndim == 2:
1✔
588
            _ = plot_label_patch_extraction_areas(
1✔
589
                xr_obj,
590
                label_name=label_name,
591
                patches_isel_dicts=patches_isel_dicts,
592
                partitions_isel_dicts=partitions_isel_dicts,
593
            )
594
            plt.show()
1✔
595

596
        # ---------------------------------------------------------------------.
597
        # Return isel_dicts
598
        if grouped_by_labels_id:
1✔
599
            patch_counter += 1
1✔
600
            if patch_counter > n_patches:
1✔
601
                break_flag = True
×
602
            else:
603
                yield label_id, patches_isel_dicts
1✔
604
        else:
605
            for isel_dict in patches_isel_dicts:
1✔
606
                patch_counter += 1
1✔
607
                if patch_counter > n_patches:
1✔
608
                    break_flag = True
×
609
                else:
610
                    yield label_id, isel_dict
1✔
611
        if break_flag:
1✔
612
            break
×
613
    # ---------------------------------------------------------------------.
614

615

616
def get_patches_isel_dict_from_labels(
1✔
617
    xr_obj,
618
    label_name,
619
    patch_size,
620
    variable=None,
621
    # Output options
622
    n_patches=np.Inf,
623
    n_labels=None,
624
    labels_id=None,
625
    # Label Patch Extraction Settings
626
    centered_on="max",
627
    padding=0,
628
    n_patches_per_label=np.Inf,
629
    n_patches_per_partition=1,
630
    # Label Tiling/Sliding Options
631
    partitioning_method=None,
632
    n_partitions_per_label=None,
633
    kernel_size=None,
634
    buffer=0,
635
    stride=None,
636
    include_last=True,
637
    ensure_slice_size=True,
638
    debug=False,
639
    verbose=False,
640
):
641
    """
642
    Returnisel-dictionaries to extract patches around labels.
643

644
    The isel-dictionaries are grouped by label_id and returned in a
645
    dictionary.
646

647
    Please refer to ``get_patches_from_labels`` for a detailed description of
648
    the function arguments.
649

650
    Return
651
    ------
652
    dict
653
        A dictionary of the form: ``{label_id: list_isel_dicts}``.
654

655
    """
656
    gen = _get_patches_isel_dict_generator(
1✔
657
        xr_obj=xr_obj,
658
        label_name=label_name,
659
        patch_size=patch_size,
660
        variable=variable,
661
        n_patches=n_patches,
662
        n_labels=n_labels,
663
        labels_id=labels_id,
664
        grouped_by_labels_id=True,
665
        # Patch extraction options
666
        centered_on=centered_on,
667
        padding=padding,
668
        n_patches_per_label=n_patches_per_label,
669
        n_patches_per_partition=n_patches_per_partition,
670
        # Tiling/Sliding settings
671
        partitioning_method=partitioning_method,
672
        n_partitions_per_label=n_partitions_per_label,
673
        kernel_size=kernel_size,
674
        buffer=buffer,
675
        stride=stride,
676
        include_last=include_last,
677
        ensure_slice_size=ensure_slice_size,
678
        debug=debug,
679
        verbose=verbose,
680
    )
681
    return {int(label_id): list_isel_dicts for label_id, list_isel_dicts in gen}
1✔
682

683

684
def get_patches_from_labels(
1✔
685
    xr_obj,
686
    label_name,
687
    patch_size,
688
    variable=None,
689
    # Output options
690
    n_patches=np.Inf,
691
    n_labels=None,
692
    labels_id=None,
693
    highlight_label_id=True,
694
    # Label Patch Extraction Options
695
    centered_on="max",
696
    padding=0,
697
    n_patches_per_label=np.Inf,
698
    n_patches_per_partition=1,
699
    # Label Tiling/Sliding Options
700
    partitioning_method=None,
701
    n_partitions_per_label=None,
702
    kernel_size=None,
703
    buffer=0,
704
    stride=None,
705
    include_last=True,
706
    ensure_slice_size=True,
707
    debug=False,
708
    verbose=False,
709
):
710
    """
711
    Routines to extract patches around labels.
712

713
    Create a generator extracting (from a prelabeled xr.Dataset) a patch around:
714

715
    - a label point
716
    - a label bounding box
717

718
    If 'centered_on' is specified, output patches are guaranteed to have equal shape !
719
    If 'centered_on' is not specified, output patches are guaranteed to have only have a minimum shape !
720

721
    If you want to extract the patch around the label bounding box, 'centered_on'
722
    must not be specified.
723

724
    If you want to extract the patch around a label point, the 'centered_on'
725
    method must be specified. If the identified point is close to an array boundary,
726
    the patch is expanded toward the valid directions.
727

728
    Tiling or sliding enables to split/slide over each label and extract multiple patch
729
    for each tile.
730

731
    tiling=True
732
    - centered_on = "centroid" (tiling around labels bbox)
733
    - centered_on = "center_of_mass" (better coverage around label)
734

735
    sliding=True
736
    - centered_on = "center_of_mass" (better coverage around label) (further data coverage)
737

738
    Only one parameter between n_patches and labels_id can be specified.
739

740
    Parameters
741
    ----------
742
    xr_obj : xr.Dataset
743
        xr.Dataset with a label array named label_name.
744
    label_name : str
745
        Name of the variable/coordinate representing the label array.
746
    patch_size : (int, tuple)
747
        The dimensions of the n-dimensional patch to extract.
748
        Only positive values (>1) are allowed.
749
        The value -1 can be used to specify the full array dimension shape.
750
        If the centered_on method is not 'label_bbox', all output patches
751
        are ensured to have the same shape.
752
        Otherwise, if 'centered_on'='label_bbox', the patch_size argument defines
753
        defined the minimum n-dimensional shape of the output patches.
754
        If int, the value is applied to all label array dimensions.
755
        If list or tuple, the length must match the number of dimensions of the array.
756
        If a dict, the dictionary must have has keys the label array dimensions.
757
    n_patches : int, optional
758
        Maximum number of patches to extract.
759
        The default (np.Inf) enable to extract all available patches allowed by the
760
        specified patch extraction criteria.
761
    labels_id : list, optional
762
        List of labels for which to extract the patch.
763
        If None, it extracts the patches by label order (1, 2, 3, ...)
764
        The default is None.
765
    n_labels : int, optional
766
        The number of labels for which extract patches.
767
        If None (the default), it extract patches for all labels.
768
        This argument can be specified only if labels_id is unspecified !
769
    highlight_label_id : (bool), optional
770
        If True, the label_name array of each patch is modified to contain only
771
        the label_id used to select the patch.
772
    variable : str, optional
773
        Dataset variable to use to identify the patch center when centered_on is defined.
774
        This is required only for centered_on='max', 'min' or the custom function.
775

776
    centered_on : (str, callable), optional
777
        The centered_on method characterize the point around which the patch is extracted.
778
        Valid pre-implemented centered_on methods are 'label_bbox', 'max', 'min',
779
        'centroid', 'center_of_mass', 'random'.
780
        The default method is 'max'.
781

782
        If 'label_bbox' it extract the patches around the (padded) bounding box of the label.
783
        If 'label_bbox',the output patch sizes are only ensured to have a minimum patch_size,
784
        and will likely be of different size.
785
        Otherwise, the other methods guarantee that the output patches have a common shape.
786

787
        If centered_on is 'max', 'min' or a custom function, the 'variable' must be specified.
788
        If centered_on is a custom function, it must:
789
        - return None if all array values are non-finite (i.e np.nan)
790
        - return a tuple with same length as the array shape.
791
    padding : (int, tuple, dict), optional
792
        The padding to apply in each direction around a label prior to
793
        partitioning (tiling/sliding) or direct patch extraction.
794
        The default, 0, applies 0 padding in every dimension.
795
        Negative padding values are allowed !
796
        If int, the value is applied to all label array dimensions.
797
        If list or tuple, the length must match the number of dimensions of the array.
798
        If a dict, the dictionary must have has keys the label array dimensions.
799
    n_patches_per_label: int, optional
800
        The maximum number of patches to extract for each label.
801
        The default (np.Inf) enables to extract all the available patches per label.
802
        n_patches_per_label must be larger than n_patches_per_partition !
803
    n_patches_per_partition, int, optional
804
        The maximum number of patches to extract from each label partition.
805
        The default values is 1.
806
        This method can be specified only if centered_on='random' or a callable.
807
    partitioning_method : str
808
        Whether to retrieve 'tiling' or 'sliding' slices.
809
        If 'tiling', partition start slices are separated by stride + kernel_size
810
        If 'sliding', partition start slices are separated by stride.
811
    n_partitions_per_label : int, optional
812
        The maximum number of partitions to extract for each label.
813
        The default (None) enables to extract all the available partitions per label.
814
    stride : (int, tuple, dict), optional
815
        If partitioning_method is 'sliding'', default stride is set to 1.
816
        If partitioning_method is 'tiling', default stride is set to 0.
817
        Step size between slices.
818
        When 'tiling', a positive stride make partition slices to not overlap and not touch,
819
        while a negative stride make partition slices to overlap by 'stride' amount.
820
        If stride is 0, the partition slices are contiguous (no spacing between partitions).
821
        When 'sliding', only a positive stride (>= 1) is allowed.
822
        If int, the value is applied to all label array dimensions.
823
        If list or tuple, the length must match the number of dimensions of the array.
824
        If a dict, the dictionary must have has keys the label array dimensions.
825
    kernel_size: (int, tuple, dict), optional
826
        The shape of the desired partitions.
827
        Only positive values (>1) are allowed.
828
        The value -1 can be used to specify the full array dimension shape.
829
        If int, the value is applied to all label array dimensions.
830
        If list or tuple, the length must match the number of dimensions of the array.
831
        If a dict, the dictionary must have has keys the label array dimensions.
832
    buffer: (int, tuple, dict), optional
833
        The default is 0.
834
        Value by which to enlarge a partition on each side.
835
        The final partition size should be kernel_size + buffer.
836
        If 'tiling' and stride=0, a positive buffer value corresponds to
837
        the amount of overlap between each partition.
838
        Depending on min_start and max_stop values, buffering might cause
839
        border partitions to not have same sizes.
840
        If int, the value is applied to all label array dimensions.
841
        If list or tuple, the length must match the number of dimensions of the array.
842
        If a dict, the dictionary must have has keys the label array dimensions.
843
    include_last : bool, optional
844
        Whether to include the last partition if it does not match the kernel_size.
845
        The default is True.
846
    ensure_slice_size : False, optional
847
        Used only if include_last is True.
848
        If False, the last partition will not have the specified kernel_size.
849
        If True,  the last partition is enlarged to the specified kernel_size by
850
        tentatively expandinf it on both sides (accounting for min_start and max_stop).
851

852
    Yields
853
    ------
854
    (xr.Dataset or xr.DataArray)
855
        A xarray object patch.
856

857
    """
858
    # Define patches isel dictionary generator
859
    patches_isel_dicts_gen = _get_patches_isel_dict_generator(
1✔
860
        xr_obj=xr_obj,
861
        label_name=label_name,
862
        patch_size=patch_size,
863
        variable=variable,
864
        n_patches=n_patches,
865
        n_labels=n_labels,
866
        labels_id=labels_id,
867
        grouped_by_labels_id=False,
868
        # Label Patch Extraction Options
869
        centered_on=centered_on,
870
        padding=padding,
871
        n_patches_per_label=n_patches_per_label,
872
        n_patches_per_partition=n_patches_per_partition,
873
        # Tiling/Sliding Options
874
        partitioning_method=partitioning_method,
875
        n_partitions_per_label=n_partitions_per_label,
876
        kernel_size=kernel_size,
877
        buffer=buffer,
878
        stride=stride,
879
        include_last=include_last,
880
        ensure_slice_size=ensure_slice_size,
881
        debug=debug,
882
        verbose=verbose,
883
    )
884

885
    # Extract the patches
886
    for label_id, isel_dict in patches_isel_dicts_gen:
1✔
887
        xr_obj_patch = _extract_xr_patch(
1✔
888
            xr_obj=xr_obj,
889
            label_name=label_name,
890
            isel_dict=isel_dict,
891
            label_id=label_id,
892
            highlight_label_id=highlight_label_id,
893
        )
894

895
        # Return the patch around the label
896
        yield label_id, xr_obj_patch
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc