• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ghiggi / ximage / 23169211772

16 Mar 2026 10:37PM UTC coverage: 94.953%. Remained the same
23169211772

Pull #21

github

web-flow
Merge 31815bcf3 into f10b9569e
Pull Request #21: [pre-commit.ci] pre-commit autoupdate

8 of 8 new or added lines in 8 files covered. (100.0%)

21 existing lines in 4 files now uncovered.

809 of 852 relevant lines covered (94.95%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.79
/ximage/patch/labels_patch.py
1
# -----------------------------------------------------------------------------.
2
# MIT License
3

4
# Copyright (c) 2024-2026 ximage developers
5
#
6
# This file is part of ximage.
7

8
# Permission is hereby granted, free of charge, to any person obtaining a copy
9
# of this software and associated documentation files (the "Software"), to deal
10
# in the Software without restriction, including without limitation the rights
11
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
# copies of the Software, and to permit persons to whom the Software is
13
# furnished to do so, subject to the following conditions:
14
#
15
# The above copyright notice and this permission notice shall be included in all
16
# copies or substantial portions of the Software.
17
#
18
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25

26
# -----------------------------------------------------------------------------.
27
"""Functions to extract patch around labels."""
28
import random
1✔
29
import warnings
1✔
30
from collections.abc import Callable
1✔
31

32
import matplotlib.pyplot as plt
1✔
33
import numpy as np
1✔
34
import xarray as xr
1✔
35

36
from ximage.labels.labels import highlight_label
1✔
37
from ximage.patch.checks import (
1✔
38
    are_all_natural_numbers,
39
    check_buffer,
40
    check_kernel_size,
41
    check_padding,
42
    check_partitioning_method,
43
    check_patch_size,
44
    check_stride,
45
)
46
from ximage.patch.plot2d import plot_label_patch_extraction_areas
1✔
47
from ximage.patch.slices import (
1✔
48
    enlarge_slices,
49
    get_nd_partitions_list_slices,
50
    get_slice_around_index,
51
    get_slice_from_idx_bounds,
52
    pad_slices,
53
)
54

55
# -----------------------------------------------------------------------------.
56
#### TODOs
57
## Partitioning
58
# - Option to bound min_start and max_stop to labels bbox
59
# - Option to define min_start and max_stop to be divisible by patch_size + stride
60
# - When tiling ... define start so to center tiles around label_bbox, instead of starting at label bbox start
61
# - Option: partition_only_when_label_bbox_exceed_patch_size
62

63
# - Add option that returns a flag if the point center is the actual identified one,
64
#   or was close to the boundary !
65

66
# -----------------------------------------------------------------------------.
67
# - Implement dilate option (to subset pixel within partitions).
68
#   --> slice(start, stop, step=dilate) ... with patch_size redefined at start to patch_size*dilate
69
#   --> Need updates of enlarge slcies, pad_slices utilities (but first test current usage !)
70

71
# -----------------------------------------------------------------------------.
72

73
## Image sliding/tiling reconstruction
74
# - get_index_overlapping_slices
75
# - trim: bool, keyword only
76
#   Whether or not to trim stride elements from each block after calling the map function.
77
#   Set this to False if your mapping function already does this for you.
78
#   This for when merging !
79

80
####--------------------------------------------------------------------------.
81

82

83
def _check_label_arr(label_arr):
1✔
84
    """Check label_arr."""
85
    # Note: If label array is all zero or nan, labels_id will be []
86

87
    # Put label array in memory
88
    label_arr = np.asanyarray(label_arr)
1✔
89

90
    # Set 0 label to nan
91
    label_arr = label_arr.astype(float)  # otherwise if int throw an error when assigning nan
1✔
92
    label_arr[label_arr == 0] = np.nan
1✔
93

94
    # Check labels_id are natural number >= 1
95
    valid_labels = np.unique(label_arr[~np.isnan(label_arr)])
1✔
96
    if not are_all_natural_numbers(valid_labels):
1✔
97
        raise ValueError("The label array contains non positive natural numbers.")
1✔
98

99
    return label_arr
1✔
100

101

102
def _check_labels_id(labels_id, label_arr):
1✔
103
    """Check labels_id."""
104
    # Check labels_id type
105
    if not isinstance(labels_id, (type(None), int, list, np.ndarray)):
1✔
106
        raise TypeError("labels_id must be None or a list or a np.array.")
1✔
107
    if isinstance(labels_id, int):
1✔
108
        labels_id = [labels_id]
1✔
109
    # Get list of valid labels
110
    valid_labels = np.unique(label_arr[~np.isnan(label_arr)]).astype(int)
1✔
111
    # If labels_id is None, assign the valid_labels
112
    if isinstance(labels_id, type(None)):
1✔
113
        return valid_labels
1✔
114
    # If input labels_id is a list, make it a np.array
115
    labels_id = np.array(labels_id).astype(int)
1✔
116
    # Check labels_id are natural number >= 1
117
    if np.any(labels_id == 0):
1✔
118
        raise ValueError("labels id must not contain the 0 value.")
1✔
119
    if not are_all_natural_numbers(labels_id):
1✔
120
        raise ValueError("labels id must be positive natural numbers.")
1✔
121
    # Check labels_id are number present in the label_arr
122
    invalid_labels = labels_id[~np.isin(labels_id, valid_labels)]
1✔
123
    if invalid_labels.size != 0:
1✔
124
        invalid_labels = invalid_labels.astype(int)
1✔
125
        raise ValueError(f"The following labels id are not valid: {invalid_labels}")
1✔
126
    # If no labels, no patch to extract
127
    n_labels = len(labels_id)
1✔
128
    if n_labels == 0:
1✔
129
        raise ValueError("No labels available.")
1✔
130
    return labels_id
1✔
131

132

133
def _check_n_patches_per_partition(n_patches_per_partition, centered_on):
1✔
134
    """
135
    Check the number of patches to extract from each partition.
136

137
    It is used only if centered_on is a callable or 'random'
138

139
    Parameters
140
    ----------
141
    n_patches_per_partition : int
142
        Number of patches to extract from each partition.
143
    centered_on : str or callable
144
        Method to extract the patch around a label point.
145

146
    Returns
147
    -------
148
    n_patches_per_partition: int
149
       The number of patches to extract from each partition.
150
    """
151
    if n_patches_per_partition < 1:
1✔
152
        raise ValueError("n_patches_per_partitions must be a positive integer.")
1✔
153
    if isinstance(centered_on, str) and centered_on not in ["random"] and n_patches_per_partition > 1:
1✔
154
        raise ValueError(
1✔
155
            "Only the pre-implemented centered_on='random' method allow n_patches_per_partition values > 1.",
156
        )
157
    return n_patches_per_partition
1✔
158

159

160
def _check_n_patches(n_patches):
1✔
161
    if n_patches is None:
1✔
162
        n_patches = np.inf
1✔
163
    if n_patches <= 0:
1✔
UNCOV
164
        raise ValueError("n_patches must be a positive integer.")
×
165
    return n_patches
1✔
166

167

168
def _check_n_patches_per_label(n_patches_per_label, n_patches_per_partition):
1✔
169
    if n_patches_per_label is None:
1✔
170
        n_patches_per_label = np.inf
1✔
171
    if n_patches_per_label <= 0:
1✔
UNCOV
172
        raise ValueError("n_patches_per_label must be a positive integer.")
×
173
    if n_patches_per_label < n_patches_per_partition:
1✔
174
        raise ValueError("n_patches_per_label must be equal or larger to n_patches_per_partition.")
1✔
175
    return n_patches_per_label
1✔
176

177

178
def _check_callable_centered_on(centered_on):
1✔
179
    """Check validity of callable centered_on."""
180
    input_shape = (2, 3)
1✔
181
    arr = np.zeros(input_shape)
1✔
182
    point = centered_on(arr)
1✔
183
    if not isinstance(point, (tuple, type(None))):
1✔
184
        raise ValueError("The 'centered_on' function should return a point coordinates tuple or None.")
1✔
185
    if len(point) != len(input_shape):
1✔
186
        raise ValueError(
1✔
187
            "The 'centered_on' function should return point coordinates having same dimensions has input array.",
188
        )
189
    for c, max_value in zip(point, input_shape, strict=True):
1✔
190
        if c < 0:
1✔
191
            raise ValueError("The point coordinate must be a positive integer.")
1✔
192
        if c >= max_value:
1✔
193
            raise ValueError("The point coordinate must be inside the array shape.")
1✔
194
        if np.isnan(c):
1✔
195
            raise ValueError("The point coordinate must not be np.nan.")
1✔
196
    # Check case with nan array
197
    try:
1✔
198
        point = centered_on(arr * np.nan)
1✔
199
    except Exception as err:
1✔
200
        raise ValueError(f"The 'centered_on' function should be able to deal with a np.nan ndarray. Error is {err}.")
1✔
201
    if point is not None:
1✔
202
        raise ValueError("The 'centered_on' function should return None if the input array is a np.nan ndarray.")
1✔
203

204

205
def _check_centered_on(centered_on):
1✔
206
    """Check valid centered_on to identify a point in an array."""
207
    if not (callable(centered_on) or isinstance(centered_on, str)):
1✔
208
        raise TypeError("'centered_on' must be a string or a function.")
1✔
209
    if isinstance(centered_on, str):
1✔
210
        valid_centered_on = [
1✔
211
            "max",
212
            "min",
213
            "centroid",
214
            "center_of_mass",
215
            "random",
216
            "label_bbox",  # unfixed patch_size
217
        ]
218
        if centered_on not in valid_centered_on:
1✔
219
            raise ValueError(f"Valid 'centered_on' values are: {valid_centered_on}.")
1✔
220

221
    if callable(centered_on):
1✔
222
        _check_callable_centered_on(centered_on)
1✔
223
    return centered_on
1✔
224

225

226
def _get_variable_arr(xr_obj, variable, centered_on):
1✔
227
    """Get variable array (in memory)."""
228
    if isinstance(xr_obj, xr.DataArray):
1✔
229
        return np.asanyarray(xr_obj.data)
1✔
230
    if centered_on is not None and variable is None and (centered_on in ["max", "min"] or callable(centered_on)):
1✔
231
        raise ValueError("'variable' must be specified if 'centered_on' is specified.")
1✔
232
    return np.asanyarray(xr_obj[variable].data) if variable is not None else None
1✔
233

234

235
def _check_variable_arr(variable_arr, label_arr):
1✔
236
    """Check variable array validity."""
237
    if variable_arr is not None and variable_arr.shape != label_arr.shape:
1✔
238
        raise ValueError("Arrays corresponding to 'variable' and 'label_name' must have same shape.")
1✔
239
    return variable_arr
1✔
240

241

242
def _get_point_centroid(arr):
1✔
243
    """Get the coordinate of label bounding box center.
244

245
    It assumes that the array has been cropped around the label.
246
    It returns None if all values are non-finite (i.e. np.nan).
247
    """
248
    if np.all(~np.isfinite(arr)):
1✔
249
        return None
1✔
250
    centroid = np.array(arr.shape) / 2.0
1✔
251
    return tuple(centroid.tolist())
1✔
252

253

254
def _get_point_random(arr):
1✔
255
    """Get random point with finite value."""
256
    is_finite = np.isfinite(arr)
1✔
257
    if np.all(~is_finite):
1✔
258
        return None
1✔
259
    points = np.argwhere(is_finite)
1✔
260
    return random.choice(points)
1✔
261

262

263
def _get_point_with_max_value(arr):
1✔
264
    """Get point with maximum value."""
265
    with warnings.catch_warnings():
1✔
266
        warnings.simplefilter("ignore", category=RuntimeWarning)
1✔
267
        point = np.argwhere(arr == np.nanmax(arr))
1✔
268
    return None if len(point) == 0 else tuple(point[0].tolist())
1✔
269

270

271
def _get_point_with_min_value(arr):
1✔
272
    """Get point with minimum value."""
273
    with warnings.catch_warnings():
1✔
274
        warnings.simplefilter("ignore", category=RuntimeWarning)
1✔
275
        point = np.argwhere(arr == np.nanmin(arr))
1✔
276
    return None if len(point) == 0 else tuple(point[0].tolist())
1✔
277

278

279
def _get_point_center_of_mass(arr, integer_index=True):
1✔
280
    """Get the coordinate of the label center of mass.
281

282
    It uses all cells which have finite values.
283
    If `0` value should be a non-label area, mask before with `np.nan`.
284
    It returns `None` if all values are non-finite (i.e. ``np.nan``).
285
    """
286
    indices = np.argwhere(np.isfinite(arr))
1✔
287
    if len(indices) == 0:
1✔
288
        return None
1✔
289
    center_of_mass = np.nanmean(indices, axis=0)
1✔
290
    if integer_index:
1✔
291
        center_of_mass = center_of_mass.round().astype(int)
1✔
292
    return tuple(center_of_mass.tolist())
1✔
293

294

295
def find_point(arr, centered_on: str | Callable = "max"):
1✔
296
    """Find a specific point coordinate of the array.
297

298
    If the coordinate can't be find, return ``None``.
299
    """
300
    centered_on = _check_centered_on(centered_on)
1✔
301

302
    if centered_on == "max":
1✔
303
        point = _get_point_with_max_value(arr)
1✔
304
    elif centered_on == "min":
1✔
305
        point = _get_point_with_min_value(arr)
1✔
306
    elif centered_on == "centroid":
1✔
307
        point = _get_point_centroid(arr)
1✔
308
    elif centered_on == "center_of_mass":
1✔
309
        point = _get_point_center_of_mass(arr)
1✔
310
    elif centered_on == "random":
1✔
311
        point = _get_point_random(arr)
1✔
312
    else:  # callable centered_on
313
        point = centered_on(arr)
1✔
314
    if point is not None:
1✔
315
        point = tuple(int(p) for p in point)
1✔
316
    return point
1✔
317

318

319
def _get_labels_bbox_slices(arr):
1✔
320
    """
321
    Compute the bounding box slices of non-zero elements in a n-dimensional numpy array.
322

323
    Assume that only one unique non-zero elements values is present in the array.
324
    Assume that NaN and Inf have been replaced by zeros.
325

326
    Other implementations: scipy.ndimage.find_objects
327

328
    Parameters
329
    ----------
330
    arr : numpy.ndarray
331
        n-dimensional numpy array.
332

333
    Returns
334
    -------
335
    list_slices : list
336
        List of slices to extract the region with non-zero elements in the input array.
337
    """
338
    # Return None if all values are zeros
339
    if not np.any(arr):
1✔
340
        return None
1✔
341
    ndims = arr.ndim
1✔
342
    coords = np.nonzero(arr)
1✔
343
    return [get_slice_from_idx_bounds(np.min(coords[i]), np.max(coords[i])) for i in range(ndims)]
1✔
344

345

346
def _get_patch_list_slices_around_label_point(
1✔
347
    label_arr,
348
    label_id,
349
    variable_arr,
350
    patch_size,
351
    centered_on,
352
):
353
    """Get list_slices to extract patch around a label point.
354

355
    Assume ``label_arr`` must match ``variable_arr`` shape.
356
    Assume ``patch_size`` shape must match ``variable_arr`` shape .
357
    """
358
    # Subset variable_arr around label
359
    list_slices = _get_labels_bbox_slices(label_arr == label_id)
1✔
360
    if list_slices is None:
1✔
361
        return None
1✔
362
    label_subset_arr = label_arr[tuple(list_slices)]
1✔
363
    variable_subset_arr = variable_arr[tuple(list_slices)]
1✔
364
    variable_subset_arr = np.asarray(variable_subset_arr)  # if dask, make numpy
1✔
365
    # Mask variable arr outside the label
366
    variable_subset_arr[label_subset_arr != label_id] = np.nan
1✔
367
    # Find point of subset array
368
    point_subset_arr = find_point(arr=variable_subset_arr, centered_on=centered_on)
1✔
369
    # Define patch list_slices
370
    if point_subset_arr is not None:
1✔
371
        # Find point in original array
372
        point = [slc.start + c for slc, c in zip(list_slices, point_subset_arr, strict=True)]
1✔
373
        # Find patch list slices
374
        patch_list_slices = [
1✔
375
            get_slice_around_index(p, size=size, min_start=0, max_stop=shape)
376
            for p, size, shape in zip(point, patch_size, variable_arr.shape, strict=True)
377
        ]
378
        # TODO: also return a flag if the p midpoint is conserved (by +/- 1) or not
379
    else:
380
        patch_list_slices = None
1✔
381
    return patch_list_slices
1✔
382

383

384
def _get_patch_list_slices_around_label(label_arr, label_id, padding, min_patch_size):
1✔
385
    """Get list_slices to extract patch around a label."""
386
    # Get label bounding box slices
387
    list_slices = _get_labels_bbox_slices(label_arr == label_id)
1✔
388
    if list_slices is None:
1✔
389
        return None
1✔
390
    # Apply padding to the slices
391
    list_slices = pad_slices(list_slices, padding=padding, valid_shape=label_arr.shape)
1✔
392
    # Increase slices to match min_patch_size
393
    return enlarge_slices(list_slices, min_size=min_patch_size, valid_shape=label_arr.shape)
1✔
394

395

396
def _get_patch_list_slices(label_arr, label_id, variable_arr, patch_size, centered_on, padding):
1✔
397
    """Get patch n-dimensional list slices."""
398
    if not callable(centered_on) and centered_on == "label_bbox":
1✔
399
        list_slices = _get_patch_list_slices_around_label(
1✔
400
            label_arr=label_arr,
401
            label_id=label_id,
402
            padding=padding,
403
            min_patch_size=patch_size,
404
        )
405
    else:
406
        list_slices = _get_patch_list_slices_around_label_point(
1✔
407
            label_arr=label_arr,
408
            label_id=label_id,
409
            variable_arr=variable_arr,
410
            patch_size=patch_size,
411
            centered_on=centered_on,
412
        )
413
    return list_slices
1✔
414

415

416
def _get_masked_arrays(label_arr, variable_arr, partition_list_slices):
1✔
417
    """Mask labels and variable arrays outside the partitions area."""
418
    masked_partition_label_arr = np.zeros(label_arr.shape) * np.nan
1✔
419
    masked_partition_label_arr[tuple(partition_list_slices)] = label_arr[tuple(partition_list_slices)]
1✔
420
    if variable_arr is not None:
1✔
421
        masked_partition_variable_arr = np.zeros(variable_arr.shape) * np.nan
1✔
422
        masked_partition_variable_arr[tuple(partition_list_slices)] = variable_arr[tuple(partition_list_slices)]
1✔
423
    else:
UNCOV
424
        masked_partition_variable_arr = None
×
425
    return masked_partition_label_arr, masked_partition_variable_arr
1✔
426

427

428
def _get_patches_from_partitions_list_slices(
1✔
429
    partitions_list_slices,
430
    label_arr,
431
    variable_arr,
432
    label_id,
433
    patch_size,
434
    centered_on,
435
    n_patches_per_partition,
436
    padding,
437
    verbose=False,
438
):
439
    """Return patches list slices from list of partitions `list_slices`.
440

441
    ``n_patches_per_partition`` is 1 unless ``centered_on`` is 'random' or a callable.
442
    """
443
    patches_list_slices = []
1✔
444
    for partition_list_slices in partitions_list_slices:
1✔
445
        if verbose:
1✔
446
            print(f" -  partition: {partition_list_slices}")
1✔
447
        masked_label_arr, masked_variable_arr = _get_masked_arrays(
1✔
448
            label_arr=label_arr,
449
            variable_arr=variable_arr,
450
            partition_list_slices=partition_list_slices,
451
        )
452
        n = 0
1✔
453
        for n in range(n_patches_per_partition):
1✔
454
            patch_list_slices = _get_patch_list_slices(
1✔
455
                label_arr=masked_label_arr,
456
                variable_arr=masked_variable_arr,
457
                label_id=label_id,
458
                patch_size=patch_size,
459
                centered_on=centered_on,
460
                padding=padding,
461
            )
462
            if patch_list_slices is not None and patch_list_slices not in patches_list_slices:
1✔
463
                n += 1  # noqa PLW2901
1✔
464
                patches_list_slices.append(patch_list_slices)
1✔
465
    return patches_list_slices
1✔
466

467

468
def _get_list_isel_dicts(patches_list_slices, dims):
1✔
469
    """Return a list with isel dictionaries."""
470
    return [dict(zip(dims, patch_list_slices, strict=True)) for patch_list_slices in patches_list_slices]
1✔
471

472

473
def _extract_xr_patch(xr_obj, isel_dict, label_name, label_id, highlight_label_id):
1✔
474
    """Extract a xarray patch."""
475
    # Extract xarray patch around label
476
    xr_obj_patch = xr_obj.isel(isel_dict)
1✔
477

478
    # If asked, set label array to 0 except for label_id
479
    if highlight_label_id:
1✔
480
        xr_obj_patch = highlight_label(xr_obj_patch, label_name=label_name, label_id=label_id)
1✔
481
    return xr_obj_patch
1✔
482

483

484
def _get_patches_isel_dict_generator(
1✔
485
    xr_obj,
486
    label_name,
487
    patch_size,
488
    variable=None,
489
    # Output options
490
    n_patches=None,
491
    n_labels=None,
492
    labels_id=None,
493
    grouped_by_labels_id=False,
494
    # (Tile) label patch extraction
495
    padding=0,
496
    centered_on="max",
497
    n_patches_per_label=None,
498
    n_patches_per_partition=1,
499
    debug=False,
500
    # Label Tiling/Sliding Options
501
    partitioning_method=None,
502
    n_partitions_per_label=None,
503
    kernel_size=None,
504
    buffer=0,
505
    stride=None,
506
    include_last=True,
507
    ensure_slice_size=True,
508
    verbose=False,
509
):
510
    # Get label array information
511
    label_arr = xr_obj[label_name].data
1✔
512
    dims = xr_obj[label_name].dims
1✔
513
    shape = label_arr.shape
1✔
514

515
    # Check input arguments
516
    if n_labels is not None and labels_id is not None:
1✔
517
        raise ValueError("Specify either n_labels or labels_id.")
1✔
518
    if kernel_size is None:
1✔
519
        kernel_size = patch_size
1✔
520

521
    patch_size = check_patch_size(patch_size, dims, shape)
1✔
522
    buffer = check_buffer(buffer, dims, shape)
1✔
523
    padding = check_padding(padding, dims, shape)
1✔
524

525
    partitioning_method = check_partitioning_method(partitioning_method)
1✔
526
    stride = check_stride(stride, dims, shape, partitioning_method)
1✔
527
    kernel_size = check_kernel_size(kernel_size, dims, shape)
1✔
528

529
    centered_on = _check_centered_on(centered_on)
1✔
530
    n_patches = _check_n_patches(n_patches)
1✔
531
    n_patches_per_partition = _check_n_patches_per_partition(n_patches_per_partition, centered_on)
1✔
532
    n_patches_per_label = _check_n_patches_per_label(n_patches_per_label, n_patches_per_partition)
1✔
533

534
    label_arr = _check_label_arr(label_arr)  # output is np.array !
1✔
535
    labels_id = _check_labels_id(labels_id=labels_id, label_arr=label_arr)
1✔
536
    variable_arr = _get_variable_arr(xr_obj, variable, centered_on)  # if required
1✔
537
    variable_arr = _check_variable_arr(variable_arr, label_arr)
1✔
538

539
    # Define number of labels from which to extract patches
540
    available_n_labels = len(labels_id)
1✔
541
    n_labels = min(available_n_labels, n_labels) if n_labels else available_n_labels
1✔
542
    if verbose:
1✔
543
        print(f"Extracting patches from {n_labels} labels.")
1✔
544
    # -------------------------------------------------------------------------.
545
    # Extract patch(es) around the label
546
    patch_counter = 0
1✔
547
    break_flag = False
1✔
548
    for i, label_id in enumerate(labels_id[0:n_labels]):
1✔
549
        if verbose:
1✔
550
            print(f"Label ID: {label_id} ({i}/{n_labels})")
1✔
551

552
        # Subset label_arr around the given label
553
        label_bbox_slices = _get_labels_bbox_slices(label_arr == label_id)
1✔
554

555
        # Apply padding to the label bounding box
556
        label_bbox_slices = pad_slices(label_bbox_slices, padding=padding.values(), valid_shape=label_arr.shape)
1✔
557

558
        # --------------------------------------------------------------------.
559
        # Retrieve partitions list_slices
560
        if partitioning_method is not None:
1✔
561
            partitions_list_slices = get_nd_partitions_list_slices(
1✔
562
                label_bbox_slices,
563
                arr_shape=label_arr.shape,
564
                method=partitioning_method,
565
                kernel_size=list(kernel_size.values()),
566
                stride=list(stride.values()),
567
                buffer=list(buffer.values()),
568
                include_last=include_last,
569
                ensure_slice_size=ensure_slice_size,
570
            )
571
            if n_partitions_per_label is not None:
1✔
572
                n_to_select = min(len(partitions_list_slices), n_partitions_per_label)
1✔
573
                partitions_list_slices = partitions_list_slices[0:n_to_select]
1✔
574
        else:
575
            partitions_list_slices = [label_bbox_slices]
1✔
576

577
        # --------------------------------------------------------------------.
578
        # Retrieve patches list_slices from partitions list slices
579
        patches_list_slices = _get_patches_from_partitions_list_slices(
1✔
580
            partitions_list_slices=partitions_list_slices,
581
            label_arr=label_arr,
582
            variable_arr=variable_arr,
583
            label_id=label_id,
584
            patch_size=list(patch_size.values()),
585
            centered_on=centered_on,
586
            n_patches_per_partition=n_patches_per_partition,
587
            padding=list(padding.values()),
588
            verbose=verbose,
589
        )
590

591
        # ---------------------------------------------------------------------.
592
        # Retrieve patches isel_dictionaries
593
        partitions_isel_dicts = _get_list_isel_dicts(partitions_list_slices, dims=dims)
1✔
594
        patches_isel_dicts = _get_list_isel_dicts(patches_list_slices, dims=dims)
1✔
595

596
        n_to_select = min(len(patches_isel_dicts), n_patches_per_label)
1✔
597
        patches_isel_dicts = patches_isel_dicts[0:n_to_select]
1✔
598

599
        # --------------------------------------------------------------------.
600
        # If debug=True, plot patches boundaries
601
        if debug and label_arr.ndim == 2:
1✔
602
            _ = plot_label_patch_extraction_areas(
1✔
603
                xr_obj,
604
                label_name=label_name,
605
                patches_isel_dicts=patches_isel_dicts,
606
                partitions_isel_dicts=partitions_isel_dicts,
607
            )
608
            plt.show()
1✔
609

610
        # ---------------------------------------------------------------------.
611
        # Return isel_dicts
612
        if grouped_by_labels_id:
1✔
613
            patch_counter += 1
1✔
614
            if patch_counter > n_patches:
1✔
UNCOV
615
                break_flag = True
×
616
            else:
617
                yield label_id, patches_isel_dicts
1✔
618
        else:
619
            for isel_dict in patches_isel_dicts:
1✔
620
                patch_counter += 1
1✔
621
                if patch_counter > n_patches:
1✔
UNCOV
622
                    break_flag = True
×
623
                else:
624
                    yield label_id, isel_dict
1✔
625
        if break_flag:
1✔
UNCOV
626
            break
×
627
    # ---------------------------------------------------------------------.
628

629

630
def get_patches_isel_dict_from_labels(
1✔
631
    xr_obj,
632
    label_name,
633
    patch_size,
634
    variable=None,
635
    # Output options
636
    n_patches=None,
637
    n_labels=None,
638
    labels_id=None,
639
    # Label Patch Extraction Settings
640
    centered_on="max",
641
    padding=0,
642
    n_patches_per_label=None,
643
    n_patches_per_partition=1,
644
    # Label Tiling/Sliding Options
645
    partitioning_method=None,
646
    n_partitions_per_label=None,
647
    kernel_size=None,
648
    buffer=0,
649
    stride=None,
650
    include_last=True,
651
    ensure_slice_size=True,
652
    debug=False,
653
    verbose=False,
654
):
655
    """
656
    Returnisel-dictionaries to extract patches around labels.
657

658
    The isel-dictionaries are grouped by ``label_id`` and returned in a
659
    dictionary.
660

661
    Please refer to ``ximage.patch.get_patches_from_labels`` for a detailed description of
662
    the function arguments.
663

664
    Return
665
    ------
666
    dict
667
        A dictionary of the form: ``{label_id: list_isel_dicts}``.
668

669
    """
670
    gen = _get_patches_isel_dict_generator(
1✔
671
        xr_obj=xr_obj,
672
        label_name=label_name,
673
        patch_size=patch_size,
674
        variable=variable,
675
        n_patches=n_patches,
676
        n_labels=n_labels,
677
        labels_id=labels_id,
678
        grouped_by_labels_id=True,
679
        # Patch extraction options
680
        centered_on=centered_on,
681
        padding=padding,
682
        n_patches_per_label=n_patches_per_label,
683
        n_patches_per_partition=n_patches_per_partition,
684
        # Tiling/Sliding settings
685
        partitioning_method=partitioning_method,
686
        n_partitions_per_label=n_partitions_per_label,
687
        kernel_size=kernel_size,
688
        buffer=buffer,
689
        stride=stride,
690
        include_last=include_last,
691
        ensure_slice_size=ensure_slice_size,
692
        debug=debug,
693
        verbose=verbose,
694
    )
695
    return {int(label_id): list_isel_dicts for label_id, list_isel_dicts in gen}
1✔
696

697

698
def get_patches_from_labels(
1✔
699
    xr_obj,
700
    label_name,
701
    patch_size,
702
    variable=None,
703
    # Output options
704
    n_patches=None,
705
    n_labels=None,
706
    labels_id=None,
707
    highlight_label_id=True,
708
    # Label Patch Extraction Options
709
    centered_on="max",
710
    padding=0,
711
    n_patches_per_label=None,
712
    n_patches_per_partition=1,
713
    # Label Tiling/Sliding Options
714
    partitioning_method=None,
715
    n_partitions_per_label=None,
716
    kernel_size=None,
717
    buffer=0,
718
    stride=None,
719
    include_last=True,
720
    ensure_slice_size=True,
721
    debug=False,
722
    verbose=False,
723
):
724
    """
725
    Routines to extract patches around labels.
726

727
    Create a generator extracting (from a prelabeled xarray.Dataset) a patch around:
728

729
    - a label point
730
    - a label bounding box
731

732
    If ``centered_on`` is specified, output patches are guaranteed to have equal shape !
733
    If ``centered_on`` is not specified, output patches are guaranteed to have only have a minimum shape !
734

735
    If you want to extract the patch around the label bounding box, ``centered_on``
736
    must not be specified.
737

738
    If you want to extract the patch around a label point, the ``centered_on``
739
    method must be specified. If the identified point is close to an array boundary,
740
    the patch is expanded toward the valid directions.
741

742
    Tiling or sliding enables to split/slide over each label and extract multiple patch
743
    for each tile.
744

745
    ``tiling=True``
746
    - ``centered_on = "centroid"`` (tiling around labels bbox)
747
    - ``centered_on = "center_of_mass"`` (better coverage around label)
748

749
    ``sliding=True``
750
    - ``centered_on = "center_of_mass"`` (better coverage around label) (further data coverage)
751

752
    Only one parameter between ``n_patches`` and ``labels_id`` can be specified.
753

754
    Parameters
755
    ----------
756
    xr_obj : xarray.Dataset
757
        xarray.Dataset with a label array named ``label_name``.
758
    label_name : str
759
        Name of the variable/coordinate representing the label array.
760
    patch_size : int or tuple
761
        The dimensions of the n-dimensional patch to extract.
762
        Only positive values (>1) are allowed.
763
        The value -1 can be used to specify the full array dimension shape.
764
        If the ``centered_on`` method is not ``'label_bbox'``, all output patches
765
        are ensured to have the same shape.
766
        Otherwise, if ``centered_on='label_bbox'``, the ``patch_size`` argument defines
767
        defined the minimum n-dimensional shape of the output patches.
768
        If ``int``, the value is applied to all label array dimensions.
769
        If ``list`` or ``tuple``, the length must match the number of dimensions of the array.
770
        If a ``dict``, the dictionary must have has keys the label array dimensions.
771
    n_patches : int, optional
772
        Maximum number of patches to extract.
773
        The default (``None``) enable to extract all available patches allowed by the
774
        specified patch extraction criteria.
775
    labels_id : list, optional
776
        List of labels for which to extract the patch.
777
        If ``None``, it extracts the patches by label order ``(1, 2, 3, ...)``
778
        The default is ``None``.
779
    n_labels : int, optional
780
        The number of labels for which extract patches.
781
        If ``None`` (the default), it extract patches for all labels.
782
        This argument can be specified only if ``labels_id`` is unspecified !
783
    highlight_label_id : bool, optional
784
        If ``True``, the ``label_name`` array of each patch is modified to contain only
785
        the ``label_id`` used to select the patch.
786
    variable : str, optional
787
        Dataset variable to use to identify the patch center when centered_on is defined.
788
        This is required only for ``centered_on='max'``, ``centered_on='min'`` or the custom function.
789
    centered_on : str or callable, optional
790
        The centered_on method characterize the point around which the patch is extracted.
791
        Valid pre-implemented centered_on methods are ``'label_bbox'``, ``'max'``, ``'min'``,
792
        ``'centroid'``, ``'center_of_mass'``, ``'random'``.
793
        The default method is ``'max'``.
794

795
        If ``label_bbox`` it extract the patches around the (padded) bounding box of the label.
796
        If ``label_bbox``, the output patch sizes are only ensured to have a minimum ``patch_size``,
797
        and will likely be of different size.
798
        Otherwise, the other methods guarantee that the output patches have a common shape.
799

800
        If ``centered_on`` is ``'max'``, ``'min'`` or a custom function,
801
        the ``variable`` argument must be specified.
802
        If ``centered_on`` is a custom function, it must:
803
        - return ``None`` if all array values are non-finite (i.e ``np.nan``)
804
        - return a tuple with same length as the array shape.
805
    padding : int, tuple or dict, optional
806
        The padding to apply in each direction around a label prior to
807
        partitioning (tiling/sliding) or direct patch extraction.
808
        The default, 0, applies 0 padding in every dimension.
809
        Negative padding values are allowed !
810
        If ``int``, the value is applied to all label array dimensions.
811
        If ``list`` or ``tuple``, the length must match the number of dimensions of the array.
812
        If a ``dict``, the dictionary must have has keys the label array dimensions.
813
    n_patches_per_label: int, optional
814
        The maximum number of patches to extract for each label.
815
        The default (``None``) enables to extract all the available patches per label.
816
        If specified, ``n_patches_per_label`` must be larger than ``n_patches_per_partition`` !
817
    n_patches_per_partition, int, optional
818
        The maximum number of patches to extract from each label partition.
819
        The default values is 1.
820
        This method can be specified only if ``centered_on='random'`` or a callable.
821
    partitioning_method : str
822
        Whether to retrieve ``'tiling'`` or ``'sliding'`` slices.
823
        If ``'tiling'``, partition start slices are separated by ``stride`` + ``kernel_size``.
824
        If ``'sliding'``, partition start slices are separated by stride.
825
    n_partitions_per_label : int, optional
826
        The maximum number of partitions to extract for each label.
827
        The default (``None``) enables to extract all the available partitions per label.
828
    stride : int, tuple or dict, optional
829
        If ``partitioning_method = 'sliding'``, default ``stride`` is set to 1.
830
        If ``partitioning_method = 'tiling'``, default ``stride`` is set to 0.
831
        Step size between slices.
832
        When ``partitioning_method='tiling'``, a positive stride make partition slices to not overlap and not touch,
833
        while a negative stride make partition slices to overlap by ``stride`` amount.
834
        If ``stride=0``, the partition slices are contiguous (no spacing between partitions).
835
        When ``partitioning_method='sliding'``, only a positive stride (>= 1) is allowed.
836
        If ``int``, the value is applied to all label array dimensions.
837
        If ``list`` or ``tuple``, the length must match the number of dimensions of the array.
838
        If a ``dict``, the dictionary must have has keys the label array dimensions.
839
    kernel_size: int, tuple or dict, optional
840
        The shape of the desired partitions.
841
        Only positive values (>1) are allowed.
842
        The value ``-1`` can be used to specify the full array dimension shape.
843
        If ``int``, the value is applied to all label array dimensions.
844
        If ``list`` or ``tuple``, the length must match the number of dimensions of the array.
845
        If a ``dict``, the dictionary must have has keys the label array dimensions.
846
    buffer: int, tuple or dict, optional
847
        The default is ``0``.
848
        Value by which to enlarge a partition on each side.
849
        The final partition size should be ``kernel_size`` + ``buffer``.
850
        If ``partitioning_method='tiling'`` and ``stride=0``, a positive buffer value corresponds to
851
        the amount of overlap between each partition.
852
        Depending on ``min_start`` and ``max_stop`` values, buffering might cause
853
        border partitions to not have same sizes.
854
        If ``int``, the value is applied to all label array dimensions.
855
        If ``list`` or ``tuple``, the length must match the number of dimensions of the array.
856
        If a ``dict``, the dictionary must have has keys the label array dimensions.
857
    include_last : bool, optional
858
        Whether to include the last partition if it does not match the ``kernel_size``.
859
        The default is ``True``.
860
    ensure_slice_size : bool, optional
861
        Used only if include_last is ``True``.
862
        If ``False``, the last partition will not have the specified ``kernel_size``.
863
        If ``True``,  the last partition is enlarged to the specified ``kernel_size`` by
864
        tentatively expanding it on both sides (accounting for ``min_start`` and ``max_stop``).
865

866
    Yields
867
    ------
868
    (xarray.Dataset or xarray.DataArray)
869
        A xarray object patch.
870

871
    """
872
    # Define patches isel dictionary generator
873
    patches_isel_dicts_gen = _get_patches_isel_dict_generator(
1✔
874
        xr_obj=xr_obj,
875
        label_name=label_name,
876
        patch_size=patch_size,
877
        variable=variable,
878
        n_patches=n_patches,
879
        n_labels=n_labels,
880
        labels_id=labels_id,
881
        grouped_by_labels_id=False,
882
        # Label Patch Extraction Options
883
        centered_on=centered_on,
884
        padding=padding,
885
        n_patches_per_label=n_patches_per_label,
886
        n_patches_per_partition=n_patches_per_partition,
887
        # Tiling/Sliding Options
888
        partitioning_method=partitioning_method,
889
        n_partitions_per_label=n_partitions_per_label,
890
        kernel_size=kernel_size,
891
        buffer=buffer,
892
        stride=stride,
893
        include_last=include_last,
894
        ensure_slice_size=ensure_slice_size,
895
        debug=debug,
896
        verbose=verbose,
897
    )
898

899
    # Extract the patches
900
    for label_id, isel_dict in patches_isel_dicts_gen:
1✔
901
        xr_obj_patch = _extract_xr_patch(
1✔
902
            xr_obj=xr_obj,
903
            label_name=label_name,
904
            isel_dict=isel_dict,
905
            label_id=label_id,
906
            highlight_label_id=highlight_label_id,
907
        )
908

909
        # Return the patch around the label
910
        yield label_id, xr_obj_patch
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc