• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

angelolab / ark-analysis / 20330656268

18 Dec 2025 08:25AM UTC coverage: 95.155% (+2.1%) from 93.078%
20330656268

push

github

web-flow
Fix preprocessing description in Pixie notebook (#1197)

* Add preprocessing description

* Update twine and packaging to support metadata 2.4

* Readd uv.lock file

* Fix missing filename in example dataset tests

* Testing save colored mask

* More testing

---------

Co-authored-by: Alex Kong <alkong@ucdavis.edu>

3889 of 4087 relevant lines covered (95.16%)

2.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.1
/src/ark/utils/plot_utils.py
1
import os
3✔
2
import pathlib
3✔
3
import shutil
3✔
4
from dataclasses import dataclass, field
3✔
5
from operator import contains
3✔
6
from typing import Dict, List, Literal, Optional, Tuple, Union
3✔
7
from matplotlib import gridspec
3✔
8
from matplotlib.axes import Axes
3✔
9
import matplotlib.colors as colors
3✔
10
from matplotlib import cm
3✔
11
from matplotlib import colormaps, patches
3✔
12
from matplotlib.figure import Figure
3✔
13
import matplotlib.pyplot as plt
3✔
14
import natsort
3✔
15
import numpy as np
3✔
16
import pandas as pd
3✔
17
from pandas.core.groupby.generic import DataFrameGroupBy
3✔
18
import skimage
3✔
19
import xarray as xr
3✔
20
from alpineer import image_utils, io_utils, load_utils, misc_utils
3✔
21
from alpineer.settings import EXTENSION_TYPES
3✔
22
from mpl_toolkits.axes_grid1 import make_axes_locatable
3✔
23
from skimage.exposure import rescale_intensity
3✔
24
from skimage import io
3✔
25

26

27
from skimage.util import img_as_ubyte
3✔
28
from tqdm.auto import tqdm
3✔
29
from ark import settings
3✔
30
from skimage.segmentation import find_boundaries
3✔
31
from ark.utils.data_utils import (
3✔
32
    ClusterMaskData,
33
    erode_mask,
34
    generate_cluster_mask,
35
    save_fov_mask,
36
    map_segmentation_labels,
37
)
38

39

40
@dataclass
3✔
41
class MetaclusterColormap:
3✔
42
    """
43
    A dataclass which contains the colormap-related information for the metaclusters.
44

45
    """
46
    cluster_type: str
3✔
47
    cluster_id_to_name_path: Union[str, pathlib.Path]
3✔
48
    metacluster_colors: Dict
3✔
49

50
    # Fields initialized after `__post_init__`
51
    unassigned_color: Tuple[float, ...] = field(init=False)
3✔
52
    unassigned_id: int = field(init=False)
3✔
53
    background_color: Tuple[float, ...] = field(init=False)
3✔
54
    metacluster_id_to_name: pd.DataFrame = field(init=False)
3✔
55
    mc_colors: np.ndarray = field(init=False)
3✔
56
    cmap: colors.ListedColormap = field(init=False)
3✔
57
    norm: colors.BoundaryNorm = field(init=False)
3✔
58

59
    def __post_init__(self) -> None:
3✔
60
        """
61
        Initializes the fields of the dataclass after the object is instantiated.
62
        """
63

64
        # A pixel with no associated metacluster (gray, #5A5A5A)
65
        self.unassigned_color: Tuple[float, ...] = (0.9, 0.9, 0.9, 1.0)
3✔
66

67
        # A pixel assigned to the background (black, #000000)
68
        self.background_color: Tuple[float, ...] = (0.0, 0.0, 0.0, 1.0)
3✔
69

70
        self._metacluster_cmap_generator()
3✔
71

72
    def _metacluster_cmap_generator(self) -> None:
3✔
73
        """
74
        A helper function which generates a colormap for the metaclusters with a given cluster ID
75
        to name mapping.
76

77
        Args:
78
            cluster_id_to_name_path (Union[str, pathlib.Path]):
79
                a path to a CSV identifying the pixel/cell cluster to manually-defined name mapping
80
                this is output by the remapping visualization found in `metacluster_remap_gui`
81
            metacluster_colors (Dict):
82
                maps each metacluster id to a color
83
            cluster_type (Literal["cell", "pixel"]):
84
                the type of clustering being done
85

86

87
        Returns:
88
            MetaclusterColormap: The Dataclass containing the colormap-related information
89
        """
90

91
        cluster_id_to_name: pd.DataFrame = pd.read_csv(self.cluster_id_to_name_path)
3✔
92

93
        # The mapping file needs to contain the following columns:
94
        # '*_som_cluster', '*_meta_cluster', '*_meta_cluster_rename', 'cluster_id'
95
        misc_utils.verify_in_list(
3✔
96
            required_cols=[
97
                f"{self.cluster_type}_som_cluster",
98
                f"{self.cluster_type}_meta_cluster",
99
                f"{self.cluster_type}_meta_cluster_rename",
100
                f"cluster_id",
101
            ],
102
            cluster_mapping_cols=cluster_id_to_name.columns.values
103
        )
104

105
        # subset on just metacluster and mc_name
106
        metacluster_id_to_name = cluster_id_to_name[
3✔
107
            [f"{self.cluster_type}_meta_cluster", f"{self.cluster_type}_meta_cluster_rename",
108
             "cluster_id"]].copy()
109

110
        unassigned_meta_cluster: int = int(
3✔
111
            metacluster_id_to_name[f"{self.cluster_type}_meta_cluster"].max() + 1)
112
        unassigned_cluster_id: int = int(
3✔
113
            metacluster_id_to_name["cluster_id"].max() + 1)
114

115
        # Extract unique pairs of (metacluster-ID,  name, cluster_id)
116
        # Set the unassigned meta cluster to be the max ID + 1
117
        # Set 0 as the Empty value
118
        metacluster_id_to_name: pd.DataFrame = pd.concat(
3✔
119
            [
120
                metacluster_id_to_name.drop_duplicates(),
121
                pd.DataFrame(
122
                    data={
123
                        f"{self.cluster_type}_meta_cluster": [unassigned_meta_cluster, 0],
124
                        f"{self.cluster_type}_meta_cluster_rename": ["Unassigned", "Empty"],
125
                        "cluster_id": [unassigned_cluster_id, 0]
126
                    }
127
                )
128
            ]
129
        )
130

131
        # add the unassigned color to the metacluster_colors dict
132
        self.metacluster_colors.update({unassigned_meta_cluster: self.unassigned_color})
3✔
133

134
        # add the no cluster color to the metacluster_colors dict
135
        self.metacluster_colors.update({0: self.background_color})
3✔
136

137
        # assert the metacluster index in colors matches with the ids in metacluster_id_to_name
138
        misc_utils.verify_same_elements(
3✔
139
            metacluster_colors_ids=list(
140
                self.metacluster_colors.keys()),
141
            metacluster_mapping_ids=metacluster_id_to_name
142
            [f'{self.cluster_type}_meta_cluster'].values)
143

144
        # use metacluster_colors to add the colors to metacluster_id_to_name
145
        metacluster_id_to_name["color"] = metacluster_id_to_name[
3✔
146
            f"{self.cluster_type}_meta_cluster"
147
        ].map(self.metacluster_colors)
148

149
        # sort by cluster_id ascending, so colors align with mask integers
150
        metacluster_id_to_name.sort_values(by="cluster_id", inplace=True)
3✔
151
        metacluster_id_to_name.reset_index(drop=True, inplace=True)
3✔
152

153
        # Convert the list of tuples to a numpy array, each index is a color
154
        mc_colors: np.ndarray = np.array(metacluster_id_to_name['color'].to_list())
3✔
155

156
        # generate the colormap
157
        cmap = colors.ListedColormap(mc_colors)
3✔
158
        norm = colors.BoundaryNorm(
3✔
159
            np.linspace(0, len(mc_colors), len(mc_colors) + 1) - 0.5,
160
            len(mc_colors)
161
        )
162

163
        # Assign created values to dataclass attributes
164
        self.metacluster_id_to_name = metacluster_id_to_name
3✔
165
        self.mc_colors = mc_colors
3✔
166
        self.cmap = cmap
3✔
167
        self.norm = norm
3✔
168

169

170
def create_cmap(cmap: Union[np.ndarray, list[str], str],
3✔
171
                n_clusters: int) -> tuple[colors.ListedColormap, colors.BoundaryNorm]:
172
    """
173
    Creates a discrete colormap and a boundary norm from the provided colors.
174

175
    Args:
176
        cmap (Union[np.ndarray, list[str], str]): The colormap, or set of colors to use.
177
        n_clusters (int): The numbe rof clusters for the colormap.
178

179
    Returns:
180
        tuple[colors.ListedColormap, colors.BoundaryNorm]:
181
            The generated colormap and boundary norm.
182
    """
183

184
    """Creates a colormap and a boundary norm from the provided colors.
2✔
185

186
    Colors can be of any format that matplotlib accepts.
187
    See here for color formats: https://matplotlib.org/stable/tutorials/colors/colors.html
188

189

190
    Args:
191
        colors_array (): The colors to use for the colormap.
192

193
    Returns:
194
        tuple[colors.ListedColormap, colors.BoundaryNorm]: The colormap and the boundary norm
195
    """
196

197
    if isinstance(cmap, np.ndarray):
3✔
198
        if cmap.ndim != 2:
3✔
199
            raise ValueError(
×
200
                f"colors_array must be a 2D array, got {cmap.ndim}D array")
201
        if cmap.shape[0] != n_clusters:
3✔
202
            raise ValueError(
×
203
                f"colors_array must have {n_clusters} colors, got {cmap.shape[0]} colors")
204
        color_map = colors.ListedColormap(colors=_cmap_add_background_unassigned(cmap))
3✔
205
    if isinstance(cmap, list):
3✔
206
        if len(cmap) != n_clusters:
3✔
207
            raise ValueError(
3✔
208
                f"colors_array must have {n_clusters} colors, got {len(cmap)} colors")
209
    if isinstance(cmap, str):
3✔
210
        try:
3✔
211
            # colorcet colormaps are also supported
212
            # cmocean colormaps are also supported
213
            color_map = colormaps[cmap]
3✔
214
        except KeyError:
×
215
            raise KeyError(f"Colormap {cmap} not found.")
×
216
        colors_rgba: np.ndarray = color_map(np.linspace(0, 1, n_clusters))
3✔
217
        color_map: colors.ListedColormap = colors.ListedColormap(
3✔
218
            colors=_cmap_add_background_unassigned(colors_rgba))
219

220
    bounds = [i-0.5 for i in np.linspace(0, color_map.N, color_map.N + 1)]
3✔
221

222
    norm = colors.BoundaryNorm(bounds, color_map.N)
3✔
223
    return color_map, norm
3✔
224

225

226
def _cmap_add_background_unassigned(cluster_colors: np.ndarray):
3✔
227
    # A pixel with no associated metacluster (gray, #5A5A5A)
228
    unassigned_color: np.ndarray = np.array([0.9, 0.9, 0.9, 1.0])
3✔
229

230
    # A pixel assigned to the background (black, #000000)
231
    background_color: np.ndarray = np.array([0.0, 0.0, 0.0, 1.0])
3✔
232

233
    return np.vstack([background_color, cluster_colors, unassigned_color])
3✔
234

235

236
def plot_cluster(
3✔
237
        image: np.ndarray,
238
        fov: str,
239
        cmap: colors.ListedColormap,
240
        norm: colors.BoundaryNorm,
241
        cbar_visible: bool = True,
242
        cbar_labels: list[str] = None,
243
        dpi: int = 300,
244
        figsize: tuple[int, int] = None) -> Figure:
245
    """
246
    Plots the cluster image with the provided colormap and norm.
247

248
    Args:
249
        image (np.ndarray):
250
            The cluster image to plot.
251
        fov (str):
252
            The name of the clustered FOV.
253
        cmap (colors.ListedColormap):
254
            A colormap to use for the cluster image.
255
        norm (colors.BoundaryNorm):
256
            A normalization to use for the cluster image.
257
        cbar_visible (bool, optional):
258
            Whether or not to display the colorbar. Defaults to True.
259
        cbar_labels (list[str], optional):
260
            Colorbar labels for the clusters. Devaults to None, where
261
            the labels will be automatically generated.
262
        dpi (int, optional):
263
            The resolution of the image to use for saving. Defaults to 300.
264
        figsize (tuple, optional):
265
            The size of the image to display. Defaults to (10, 10).
266

267
    Returns:
268
        Figure: Returns the cluster image as a matplotlib Figure.
269
    """
270
    # Default colorbar labels
271
    if cbar_labels is None:
3✔
272
        cbar_labels = [f"Cluster {x}" for x in range(1, len(cmap.colors))]
3✔
273

274
    fig: Figure = plt.figure(figsize=figsize, dpi=dpi)
3✔
275
    fig.set_layout_engine(layout="tight")
3✔
276
    gs = gridspec.GridSpec(nrows=1, ncols=1, figure=fig)
3✔
277
    fig.suptitle(f"{fov}")
3✔
278

279
    # Image axis
280
    ax: Axes = fig.add_subplot(gs[0, 0])
3✔
281
    ax.axis("off")
3✔
282
    ax.grid(visible=False)
3✔
283

284
    ax.imshow(
3✔
285
        X=image,
286
        cmap=cmap,
287
        norm=norm,
288
        origin="upper",
289
        aspect="equal",
290
        interpolation="none",
291
    )
292

293
    if cbar_visible:
3✔
294
        # # Manually set the colorbar
295
        divider = make_axes_locatable(fig.gca())
3✔
296
        cax = divider.append_axes(position="right", size="5%", pad="3%")
3✔
297
        cbar = fig.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap),
3✔
298
                            cax=cax, orientation="vertical", use_gridspec=True, pad=0.1,
299
                            shrink=0.9, drawedges=True)
300
        cbar.ax.set_yticks(
3✔
301
            ticks=np.arange(len(cbar_labels)),
302
            labels=cbar_labels
303
        )
304
        cbar.minorticks_off()
3✔
305

306
    return fig
3✔
307

308

309
def plot_neighborhood_cluster_result(img_xr: xr.DataArray,
3✔
310
                                     fovs: list[str],
311
                                     k: int,
312
                                     cmap_name: str = "tab20",
313
                                     cbar_visible: bool = True,
314
                                     save_dir: Union[str, pathlib.Path] = None,
315
                                     fov_col: str = "fovs",
316
                                     dpi: int = 300,
317
                                     figsize=(10, 10)
318
                                     ) -> None:
319
    """
320
    Plots the neighborhood clustering results for the provided FOVs.
321

322
    Args:
323
        img_xr (xr.DataArray):
324
            DataArray containing labeled cells.
325
        fovs (list[str]):
326
            A list of FOVs to plot.
327
        k (int):
328
            The number of neighborhoods / clusters.
329
        cmap_name (str, optional):
330
            The Colormap to use for clustering results. Defaults to "tab20".
331
        cbar_visible (bool, optional):
332
            Whether or not to display the colorbar. Defaults to True.
333
        save_dir (Union[str, pathlib.Path], optional):
334
            The image will be saved to this location if provided. Defaults to None.
335
        fov_col (str, optional):
336
            The column with the fov names in `img_xr`. Defaults to "fovs".
337
        dpi (int, optional):
338
            The resolution of the image to use for saving. Defaults to 300.
339
        figsize (tuple, optional):
340
            The size of the image to display. Defaults to (10, 10).
341
    """
342

343
    # verify the fovs are valid
344
    misc_utils.verify_in_list(fovs=fovs, unique_fovs=img_xr.fovs.values)
3✔
345

346
    # define the colormap
347
    my_colors = plt.get_cmap(cmap_name, k).colors
3✔
348

349
    cmap, norm = create_cmap(my_colors, n_clusters=k)
3✔
350

351
    cbar_labels = ["Empty"]
3✔
352
    cbar_labels.extend([f"Cluster {x}" for x in range(1, k+1)])
3✔
353

354
    for fov in img_xr.sel({fov_col: fovs}):
3✔
355

356
        fig: Figure = plot_cluster(
3✔
357
            image=fov.values.squeeze(),
358
            fov=fov.fovs.values,
359
            cmap=cmap,
360
            norm=norm,
361
            cbar_visible=cbar_visible,
362
            cbar_labels=cbar_labels,
363
            dpi=dpi,
364
            figsize=figsize
365
        )
366

367
        # save if specified
368
        if save_dir:
3✔
369
            fig.savefig(fname=os.path.join(save_dir, f"{fov.fovs.values}.png"), dpi=300)
3✔
370

371

372
def plot_pixel_cell_cluster(
3✔
373
        img_xr: xr.DataArray,
374
        fovs: list[str],
375
        cluster_id_to_name_path: Union[str, pathlib.Path],
376
        metacluster_colors: Dict,
377
        cluster_type: Union[Literal["pixel"], Literal["cell"]] = "pixel",
378
        cbar_visible: bool = True,
379
        save_dir=None,
380
        fov_col: str = "fovs",
381
        erode: bool = False,
382
        dpi=300,
383
        figsize=(10, 10)
384
):
385
    """Overlays the pixel and cell clusters on an image
386

387
    Args:
388
        img_xr (xr.DataArray):
389
            DataArray containing labeled pixel or cell clusters
390
        fovs (list[str]):
391
            A list of FOVs to plot.
392
        cluster_id_to_name_path (str):
393
            A path to a CSV identifying the pixel/cell cluster to manually-defined name mapping
394
            this is output by the remapping visualization found in `metacluster_remap_gui`
395
        metacluster_colors (dict):
396
            Dictionary which maps each metacluster id to a color
397
        cluster_type ("pixel" or "cell"):
398
            the type of clustering being done.
399
        cbar_visible (bool, optional):
400
            Whether or not to display the colorbar. Defaults to True.
401
        save_dir (str):
402
            If provided, the image will be saved to this location.
403
        fov_col (str):
404
            The column with the fov names in `img_xr`. Defaults to "fovs".
405
        erode (bool):
406
            Whether or not to erode the segmentation mask. Defaults to False.
407
        dpi (int):
408
            The resolution of the image to use for saving. Defaults to 300.
409
        figsize (tuple):
410
            Size of the image that will be displayed.
411
    """
412

413
    # verify the type of clustering provided is valid
414
    misc_utils.verify_in_list(
3✔
415
        provided_cluster_type=[cluster_type],
416
        valid_cluster_types=['pixel', 'cell']
417
    )
418

419
    # verify the fovs are valid
420
    misc_utils.verify_in_list(fovs=fovs, unique_fovs=img_xr.fovs.values)
3✔
421

422
    # verify cluster_id_to_name_path exists
423
    io_utils.validate_paths(cluster_id_to_name_path)
3✔
424

425
    # read the cluster to name mapping with the helper function
426
    mcc = MetaclusterColormap(cluster_type=cluster_type,
3✔
427
                              cluster_id_to_name_path=cluster_id_to_name_path,
428
                              metacluster_colors=metacluster_colors)
429

430
    for fov in img_xr.sel({fov_col: fovs}):
3✔
431
        fov_name = fov.fovs.values
3✔
432
        if erode:
3✔
433
            fov = erode_mask(seg_mask=fov, connectivity=2, mode="thick", background=0)
×
434

435
        fig: Figure = plot_cluster(
3✔
436
            image=fov,
437
            fov=fov_name,
438
            cmap=mcc.cmap,
439
            norm=mcc.norm,
440
            cbar_visible=cbar_visible,
441
            cbar_labels=mcc.metacluster_id_to_name[f'{cluster_type}_meta_cluster_rename'].values,
442
            dpi=dpi,
443
            figsize=figsize,
444
        )
445

446
        # save if specified
447
        if save_dir:
3✔
448
            fig.savefig(fname=os.path.join(save_dir, f"{fov_name}.png"), dpi=300)
3✔
449

450

451
def tif_overlay_preprocess(segmentation_labels, plotting_tif):
3✔
452
    """Validates plotting_tif and preprocesses it accordingly
453
    Args:
454
        segmentation_labels (numpy.ndarray):
455
            2D numpy array of labeled cell objects
456
        plotting_tif (numpy.ndarray):
457
            2D or 3D numpy array of imaging signal
458
    Returns:
459
        numpy.ndarray:
460
            The preprocessed image
461
    """
462

463
    if len(plotting_tif.shape) == 2:
3✔
464
        if plotting_tif.shape != segmentation_labels.shape:
3✔
465
            raise ValueError("plotting_tif and segmentation_labels array dimensions not equal.")
3✔
466
        else:
467
            # convert RGB image with the blue channel containing the plotting tif data
468
            formatted_tif = np.zeros((plotting_tif.shape[0], plotting_tif.shape[1], 3),
3✔
469
                                     dtype=plotting_tif.dtype)
470
            formatted_tif[..., 2] = plotting_tif
3✔
471
    elif len(plotting_tif.shape) == 3:
3✔
472
        # can only support up to 3 channels
473
        if plotting_tif.shape[2] > 3:
3✔
474
            raise ValueError("max 3 channels of overlay supported, got {}".
3✔
475
                             format(plotting_tif.shape))
476

477
        # set first n channels (in reverse order) of formatted_tif to plotting_tif
478
        # (n = num channels in plotting_tif)
479
        formatted_tif = np.zeros((plotting_tif.shape[0], plotting_tif.shape[1], 3),
3✔
480
                                 dtype=plotting_tif.dtype)
481
        formatted_tif[..., :plotting_tif.shape[2]] = plotting_tif
3✔
482
        formatted_tif = np.flip(formatted_tif, axis=2)
3✔
483
    else:
484
        raise ValueError("plotting tif must be 2D or 3D array, got {}".
3✔
485
                         format(plotting_tif.shape))
486

487
    return formatted_tif
3✔
488

489

490
def create_overlay(fov, segmentation_dir, data_dir,
3✔
491
                   img_overlay_chans, seg_overlay_comp, alternate_segmentation=None):
492
    """Take in labeled contour data, along with optional mibi tif and second contour,
493
    and overlay them for comparison"
494
    Generates the outline(s) of the mask(s) as well as intensity from plotting tif. Predicted
495
    contours are colored red, while alternate contours are colored white.
496

497
    Args:
498
        fov (str):
499
            The name of the fov to overlay
500
        segmentation_dir (str):
501
            The path to the directory containing the segmentation data
502
        data_dir (str):
503
            The path to the directory containing the nuclear and whole cell image data
504
        img_overlay_chans (list):
505
            List of channels the user will overlay
506
        seg_overlay_comp (str):
507
            The segmented compartment the user will overlay
508
        alternate_segmentation (numpy.ndarray):
509
            2D numpy array of labeled cell objects
510
    Returns:
511
        numpy.ndarray:
512
            The image with the channel overlay
513
    """
514

515
    # load the specified fov data in
516
    plotting_tif = load_utils.load_imgs_from_dir(
3✔
517
        data_dir=data_dir,
518
        files=[fov + '.tiff'],
519
        xr_dim_name='channels',
520
        xr_channel_names=['nuclear_channel', 'membrane_channel']
521
    )
522

523
    # verify that the provided image channels exist in plotting_tif
524
    misc_utils.verify_in_list(
3✔
525
        provided_channels=img_overlay_chans,
526
        img_channels=plotting_tif.channels.values
527
    )
528

529
    # subset the plotting tif with the provided image overlay channels
530
    plotting_tif = plotting_tif.loc[fov, :, :, img_overlay_chans].values
3✔
531

532
    # read the segmentation data in
533
    segmentation_labels_cell = load_utils.load_imgs_from_dir(data_dir=segmentation_dir,
3✔
534
                                                             files=[fov + '_whole_cell.tiff'],
535
                                                             xr_dim_name='compartments',
536
                                                             xr_channel_names=['whole_cell'],
537
                                                             trim_suffix='_whole_cell',
538
                                                             match_substring='_whole_cell')
539
    segmentation_labels_nuc = load_utils.load_imgs_from_dir(data_dir=segmentation_dir,
3✔
540
                                                            files=[fov + '_nuclear.tiff'],
541
                                                            xr_dim_name='compartments',
542
                                                            xr_channel_names=['nuclear'],
543
                                                            trim_suffix='_nuclear',
544
                                                            match_substring='_nuclear')
545

546
    segmentation_labels = xr.DataArray(np.concatenate((segmentation_labels_cell.values,
3✔
547
                                                      segmentation_labels_nuc.values),
548
                                                      axis=-1),
549
                                       coords=[segmentation_labels_cell.fovs,
550
                                               segmentation_labels_cell.rows,
551
                                               segmentation_labels_cell.cols,
552
                                               ['whole_cell', 'nuclear']],
553
                                       dims=segmentation_labels_cell.dims)
554

555
    # verify that the provided segmentation channels exist in segmentation_labels
556
    misc_utils.verify_in_list(
3✔
557
        provided_compartments=seg_overlay_comp,
558
        seg_compartments=segmentation_labels.compartments.values
559
    )
560

561
    # subset segmentation labels with the provided segmentation overlay channels
562
    segmentation_labels = segmentation_labels.loc[fov, :, :, seg_overlay_comp].values
3✔
563

564
    # overlay the segmentation labels over the image
565
    plotting_tif = tif_overlay_preprocess(segmentation_labels, plotting_tif)
3✔
566

567
    # define borders of cells in mask
568
    predicted_contour_mask = find_boundaries(segmentation_labels,
3✔
569
                                             connectivity=1, mode='inner').astype(np.uint8)
570
    predicted_contour_mask[predicted_contour_mask > 0] = 255
3✔
571

572
    # rescale each channel to go from 0 to 255
573
    rescaled = np.zeros(plotting_tif.shape, dtype='uint8')
3✔
574

575
    for idx in range(plotting_tif.shape[2]):
3✔
576
        if np.max(plotting_tif[:, :, idx]) == 0:
3✔
577
            # don't need to rescale this channel
578
            pass
3✔
579
        else:
580
            percentiles = np.percentile(plotting_tif[:, :, idx][plotting_tif[:, :, idx] > 0],
3✔
581
                                        [5, 95])
582
            rescaled_intensity = rescale_intensity(plotting_tif[:, :, idx],
3✔
583
                                                   in_range=(percentiles[0], percentiles[1]),
584
                                                   out_range='uint8')
585
            rescaled[:, :, idx] = rescaled_intensity
3✔
586

587
    # overlay first contour on all three RGB, to have it show up as white border
588
    rescaled[predicted_contour_mask > 0, :] = 255
3✔
589

590
    # overlay second contour as red outline if present
591
    if alternate_segmentation is not None:
3✔
592

593
        if segmentation_labels.shape != alternate_segmentation.shape:
3✔
594
            raise ValueError("segmentation_labels and alternate_"
3✔
595
                             "segmentation array dimensions not equal.")
596

597
        # define borders of cell in mask
598
        alternate_contour_mask = find_boundaries(alternate_segmentation, connectivity=1,
3✔
599
                                                 mode='inner').astype(np.uint8)
600
        rescaled[alternate_contour_mask > 0, 0] = 255
3✔
601
        rescaled[alternate_contour_mask > 0, 1:] = 0
3✔
602

603
    return rescaled
3✔
604

605

606
def set_minimum_color_for_colormap(cmap, default=(0, 0, 0, 1)):
3✔
607
    """ Changes minimum value in provided colormap to black (#000000) or provided color
608

609
    This is useful for instances where zero-valued regions of an image should be
610
    distinct from positive regions (i.e transparent or non-colormap member color)
611

612
    Args:
613
        cmap (matplotlib.colors.Colormap):
614
            matplotlib color map
615
        default (Iterable):
616
            RGBA color values for minimum color. Default is black, (0, 0, 0, 1).
617

618
    Returns:
619
        matplotlib.colors.Colormap:
620
            corrected colormap
621
    """
622
    cmapN = cmap.N
3✔
623
    corrected = cmap(np.arange(cmapN))
3✔
624
    corrected[0, :] = list(default)
3✔
625
    return colors.ListedColormap(corrected)
3✔
626

627

628
def create_mantis_dir(fovs: List[str], mantis_project_path: Union[str, pathlib.Path],
3✔
629
                      img_data_path: Union[str, pathlib.Path],
630
                      mask_output_dir: Union[str, pathlib.Path],
631
                      mapping: Union[str, pathlib.Path, pd.DataFrame],
632
                      seg_dir: Optional[Union[str, pathlib.Path]],
633
                      cluster_type='pixel',
634
                      mask_suffix: str = "_mask",
635
                      seg_suffix_name: Optional[str] = "_whole_cell.tiff",
636
                      img_sub_folder: str = None,
637
                      new_mask_suffix: str = None):
638
    """Creates a mantis project directory so that it can be opened by the mantis viewer.
639
    Copies fovs, segmentation files, masks, and mapping csv's into a new directory structure.
640
    Here is how the contents of the mantis project folder will look like.
641

642
    ```{code-block} sh
643
    mantis_project
644
    ├── fov0
645
    │   ├── cell_segmentation.tiff
646
    │   ├── chan0.tiff
647
    │   ├── chan1.tiff
648
    │   ├── chan2.tiff
649
    │   ├── ...
650
    │   ├── population_mask.csv
651
    │   └── population_mask.tiff
652
    └── fov1
653
    │   ├── cell_segmentation.tiff
654
    │   ├── chan0.tiff
655
    │   ├── chan1.tiff
656
    │   ├── chan2.tiff
657
    │   ├── ...
658
    │   ├── population_mask.csv
659
    │   └── population_mask.tiff
660
    └── ...
661
    ```
662

663
    Args:
664
        fovs (List[str]):
665
            A list of FOVs to create a Mantis Project for.
666
        mantis_project_path (Union[str, pathlib.Path]):
667
            The folder where the mantis project will be created.
668
        img_data_path (Union[str, pathlib.Path]):
669
            The location of the all the fovs you wish to create a project from.
670
        mask_output_dir (Union[str, pathlib.Path]):
671
            The folder containing all the masks of the fovs.
672
        mapping (Union[str, pathlib.Path, pd.DataFrame]):
673
            The location of the mapping file, or the mapping Pandas DataFrame itself.
674
        seg_dir (Union[str, pathlib.Path], optional):
675
            The location of the segmentation directory for the fovs. If None, then
676
            the segmentation file will not be copied over.
677
        cluster_type (str):
678
            the type of clustering being done
679
        mask_suffix (str, optional):
680
            The suffix used to find the mask tiffs. Defaults to "_mask".
681
        seg_suffix_name (str, optional):
682
            The suffix of the segmentation file and it's file extension. If None, then
683
            the segmentation file will not be copied over.
684
            Defaults to "_whole_cell.tiff".
685
        img_sub_folder (str, optional):
686
            The subfolder where the channels exist within the `img_data_path`.
687
            Defaults to None.
688
        new_mask_suffix (str, optional):
689
            The new suffix added to the copied mask tiffs.
690
    """
691

692
    # verify the type of clustering provided is valid
693
    misc_utils.verify_in_list(
3✔
694
        provided_cluster_type=[cluster_type],
695
        valid_cluster_types=['pixel', 'cell']
696
    )
697

698
    if not os.path.exists(mantis_project_path):
3✔
699
        os.makedirs(mantis_project_path)
3✔
700

701
    # account for non-sub folder channel file structures
702
    img_sub_folder = "" if not img_sub_folder else img_sub_folder
3✔
703

704
    # create key from cluster number to cluster name
705
    if isinstance(mapping, (pathlib.Path, str)):
3✔
706
        map_df = pd.read_csv(mapping)
3✔
707
    elif isinstance(mapping, pd.DataFrame):
3✔
708
        map_df = mapping
3✔
709
    else:
710
        ValueError("Mapping must either be a path to an already saved mapping csv, \
×
711
                   or a DataFrame that is already loaded in.")
712

713
    # Save the segmentation tiff or not
714
    save_seg_tiff: bool = all(v is not None for v in [seg_dir, seg_suffix_name])
3✔
715

716
    # if no new suffix specified, copy over with original mask name
717
    if not new_mask_suffix:
3✔
718
        new_mask_suffix = mask_suffix
3✔
719

720
    cluster_id_key = 'cluster_id'
3✔
721
    map_df = map_df.loc[:, [cluster_id_key, f'{cluster_type}_meta_cluster_rename']]
3✔
722
    # remove duplicates from df
723
    map_df = map_df.drop_duplicates()
3✔
724
    map_df = map_df.sort_values(by=[cluster_id_key])
3✔
725

726
    # rename for mantis names
727
    map_df = map_df.rename(
3✔
728
        {
729
            cluster_id_key: 'region_id',
730
            f'{cluster_type}_meta_cluster_rename': 'region_name'
731
        },
732
        axis=1
733
    )
734

735
    # get names of fovs with masks
736
    mask_names_loaded = (io_utils.list_files(mask_output_dir, mask_suffix))
3✔
737
    mask_names_delimited = io_utils.extract_delimited_names(mask_names_loaded,
3✔
738
                                                            delimiter=mask_suffix)
739
    mask_names_sorted = natsort.natsorted(mask_names_delimited)
3✔
740

741
    # use `fovs`, a subset of the FOVs in `total_fovs` which
742
    # is a list of FOVs in `img_data_path`
743
    fovs = natsort.natsorted(fovs)
3✔
744
    misc_utils.verify_in_list(fovs=fovs, img_data_fovs=mask_names_delimited)
3✔
745

746
    # Filter out the masks that do not have an associated FOV.
747
    mask_names = filter(lambda mn: any(contains(mn, f) for f in fovs), mask_names_sorted)
3✔
748

749
    # create a folder with image data, pixel masks, and segmentation mask
750
    for fov, mn in zip(fovs, mask_names):
3✔
751
        # set up paths
752
        img_source_dir = os.path.join(img_data_path, fov, img_sub_folder)
3✔
753
        output_dir = os.path.join(mantis_project_path, fov)
3✔
754

755
        # copy image data if not already copied in from previous round of clustering
756
        if not os.path.exists(output_dir):
3✔
757
            os.makedirs(output_dir)
3✔
758

759
            # copy all channels into new folder
760
            chans = io_utils.list_files(img_source_dir, substrs=EXTENSION_TYPES["IMAGE"])
3✔
761
            for chan in chans:
3✔
762
                shutil.copy(os.path.join(img_source_dir, chan), os.path.join(output_dir, chan))
3✔
763

764
        # copy mask into new folder
765
        mask_name: str = mn + mask_suffix + ".tiff"
3✔
766
        shutil.copy(os.path.join(mask_output_dir, mask_name),
3✔
767
                    os.path.join(output_dir, 'population{}.tiff'.format(new_mask_suffix)))
768

769
        # copy the segmentation files into the output directory
770
        # if `seg_dir` or `seg_name` is none, then skip copying
771
        if save_seg_tiff:
3✔
772
            if not os.path.exists(os.path.join(output_dir, 'cell_segmentation.tiff')):
3✔
773
                seg_name: str = fov + seg_suffix_name
3✔
774
                shutil.copy(os.path.join(seg_dir, seg_name),
3✔
775
                            os.path.join(output_dir, 'cell_segmentation.tiff'))
776

777
        # copy mapping into directory
778
        map_df.to_csv(os.path.join(output_dir, 'population{}.csv'.format(new_mask_suffix)),
3✔
779
                      index=False)
780

781

782
def save_colored_mask(
3✔
783
    fov: str,
784
    save_dir: str,
785
    suffix: str,
786
    data: np.ndarray,
787
    cmap: colors.ListedColormap,
788
    norm: colors.BoundaryNorm,
789
) -> None:
790
    """Saves the colored mask to the provided save directory.
791

792
    Args:
793
        fov (str):
794
            The name of the FOV.
795
        save_dir (str):
796
            The directory where the colored mask will be saved.
797
        suffix (str):
798
            The suffix to append to the FOV name.
799
        data (np.ndarray):
800
            The mask to save.
801
        cmap (colors.ListedColormap):
802
            The colormap to use for the mask.
803
        norm (colors.BoundaryNorm):
804
            The normalization to use for the mask.
805
    """
806

807
    # Create the save directory if it does not exist
808
    if not os.path.exists(save_dir):
3✔
809
        os.makedirs(save_dir)
×
810

811
    # Create the colored mask
812
    colored_mask = img_as_ubyte(cmap(norm(data)))
3✔
813

814
    # Save the image
815
    image_utils.save_image(
3✔
816
        fname=os.path.join(save_dir, f"{fov}{suffix}"),
817
        data=colored_mask,
818
    )
819

820

821
def save_colored_masks(
3✔
822
        fovs: List[str],
823
        mask_dir: Union[str, pathlib.Path],
824
        save_dir: Union[str, pathlib.Path],
825
        cluster_id_to_name_path: Union[str, pathlib.Path],
826
        metacluster_colors: Dict,
827
        cluster_type: Literal["cell", "pixel"],
828
) -> None:
829
    """
830
    Converts the pixie TIFF masks into colored TIFF masks using the provided colormap and saves
831
    them in the `save_dir`. Mainly used for visualization purposes.
832

833
    Args:
834
        fovs (List[str]): A list of FOVs to save their associated color masks for.
835
        mask_dir (Union[str, pathlib.Path]): The directory where the pixie masks are stored.
836
        save_dir (Union[str, pathlib.Path]): The directory where the colored masks will be saved.
837
        cluster_id_to_name_path (Union[str, pathlib.Path]): A path to a CSV identifying the
838
            pixel/cell cluster to manually-defined name mapping this is output by the remapping
839
            visualization found in `metacluster_remap_gui`.
840
        metacluster_colors (Dict): Maps each metacluster id to a color.
841
        cluster_type (Literal["cell", "pixel"]): The type of clustering being done.
842
    """
843

844
    # Input validation
845
    misc_utils.verify_in_list(
3✔
846
        provided_cluster_type=[cluster_type],
847
        valid_cluster_types=["pixel", "cell"])
848

849
    # Create the save directory if it does not exist, convert mask and save dirs to Path objects
850
    if isinstance(mask_dir, str):
3✔
851
        mask_dir = pathlib.Path(mask_dir)
×
852
    if isinstance(save_dir, str):
3✔
853
        save_dir = pathlib.Path(save_dir)
×
854
    save_dir.mkdir(parents=True, exist_ok=True)
3✔
855

856
    io_utils.validate_paths([mask_dir, save_dir])
3✔
857

858
    mcc = MetaclusterColormap(cluster_type=cluster_type,
3✔
859
                              cluster_id_to_name_path=cluster_id_to_name_path,
860
                              metacluster_colors=metacluster_colors)
861

862
    with tqdm(total=len(fovs), desc="Saving colored masks", unit="FOVs") as pbar:
3✔
863
        for fov in fovs:
3✔
864
            pbar.set_postfix(FOV=fov, refresh=False)
3✔
865

866
            mask: xr.DataArray = load_utils.load_imgs_from_dir(
3✔
867
                data_dir=mask_dir,
868
                files=[f"{fov}_{cluster_type}_mask.tiff"],
869
                trim_suffix=f"{cluster_type}_mask",
870
                match_substring=f"{cluster_type}_mask",
871
                xr_dim_name="pixel_mask",
872
                xr_channel_names=None,
873
            )
874

875
            # Make a new array with the actual colors, multiply by uint8 max to get 0-255 range
876
            colored_mask: np.ndarray = (mcc.mc_colors[mask.squeeze()] * 255.999).astype(np.uint8)
3✔
877

878
            image_utils.save_image(
3✔
879
                fname=save_dir / f"{fov}_{cluster_type}_mask_colored.tiff",
880
                data=colored_mask,)
881

882
            pbar.update(1)
3✔
883

884

885
def cohort_cluster_plot(
3✔
886
    fovs: List[str],
887
    seg_dir: Union[pathlib.Path, str],
888
    save_dir: Union[pathlib.Path, str],
889
    cell_data: pd.DataFrame,
890
    fov_col: str = settings.FOV_ID,
891
    label_col: str = settings.CELL_LABEL,
892
    cluster_col: str = settings.CELL_TYPE,
893
    seg_suffix: str = "_whole_cell.tiff",
894
    cmap: Union[str, pd.DataFrame] = "viridis",
895
    style: str = "seaborn-v0_8-paper",
896
    erode: bool = False,
897
    display_fig: bool = False,
898
    fig_file_type: str = "png",
899
    figsize: tuple = (10, 10),
900
    dpi: int = 300,
901
) -> None:
902
    """
903
    Saves the cluster masks for each FOV in the cohort as the following:
904
    - Cluster mask numbered 1-N, where N is the number of clusters (tiff)
905
    - Cluster mask colored by cluster with or without a colorbar (png)
906
    - Cluster mask colored by cluster (tiff).
907

908
    Args:
909
        fovs (List[str]): A list of FOVs to generate cluster masks for.
910
        seg_dir (Union[pathlib.Path, str]): The directory containing the segmentation masks.
911
        save_dir (Union[pathlib.Path, str]): The directory to save the cluster masks to.
912
        cell_data (pd.DataFrame): The cell data table containing the cluster labels.
913
        fov_col (str, optional): The column containing the FOV name. Defaults to settings.FOV_ID.
914
        label_col (str, optional): The column containing the segmentaiton label.
915
            Defaults to settings.CELL_LABEL.
916
        cluster_col (str, optional): The column containing the cluster a segmentation label
917
            belongs to. Defaults to settings.CELL_TYPE.
918
        seg_suffix (str, optional): The kind of segmentation file to read.
919
            Defaults to "_whole_cell.tiff".
920
        cmap (str, pd.DataFrame, optional): The colormap to generate clusters from,
921
            or a DataFrame, where the user can specify their own colors per cluster.
922
            The color column must be labeled "color". Defaults to "viridis".
923
        style (str, optional): Set the matplotlib style image style. Defaults to 
924
            "seaborn-v0_8-paper".
925
            View the available styles here: 
926
            https://matplotlib.org/stable/gallery/style_sheets/style_sheets_reference.html
927
            Or run matplotlib.pyplot.style.available in a notebook to view all the styles.
928
        erode (bool, optional): Option to "thicken" the cell boundary via the segmentation label
929
            for visualization purposes. Defaults to False.
930
        display_fig (bool, optional): Option to display the cluster mask plots as they are
931
            generated. Defaults to False. Displaying each figure can use a lot of memory,
932
            so it's best to try to visualize just a few FOVs, before generating the cluster masks
933
            for the entire cohort.
934
        fig_file_type (str, optional): The file type to save figures as. Defaults to 'png'.
935
        figsize (tuple, optional):
936
            The size of the figure to display. Defaults to (10, 10).
937
        dpi (int, optional):
938
            The resolution of the image to use for saving. Defaults to 300.
939
    """
940

941
    plt.style.use(style)
3✔
942

943
    if isinstance(seg_dir, str):
3✔
944
        seg_dir = pathlib.Path(seg_dir)
×
945

946
    try:
3✔
947
        io_utils.validate_paths(seg_dir)
3✔
948
    except ValueError:
×
949
        raise ValueError(f"Could not find the segmentation directory at {seg_dir.as_posix()}")
×
950

951
    if isinstance(save_dir, str):
3✔
952
        save_dir = pathlib.Path(save_dir)
×
953
        if not save_dir.exists():
×
954
            save_dir.mkdir(parents=True, exist_ok=True)
×
955
    if isinstance(fovs, str):
3✔
956
        fovs = [fovs]
×
957

958
    # Create the subdirectories for the 3 cluster mask files
959
    for sub_dir in ["cluster_masks", "cluster_masks_colored", "cluster_plots"]:
3✔
960
        (save_dir / sub_dir).mkdir(parents=True, exist_ok=True)
3✔
961

962
    cmd = ClusterMaskData(
3✔
963
        data=cell_data,
964
        fov_col=fov_col,
965
        label_col=label_col,
966
        cluster_col=cluster_col,
967
    )
968
    if isinstance(cmap, pd.DataFrame):
3✔
969
        unique_clusters: pd.DataFrame = cmd.mapping[[cmd.cluster_column,
3✔
970
                                                     cmd.cluster_id_column]].drop_duplicates()
971
        cmap_colors: pd.DataFrame = cmap.merge(
3✔
972
            right=unique_clusters,
973
            on=cmd.cluster_column
974
        ).sort_values(by="cluster_id")["color"].values
975
        colors_like: list[bool] = [colors.is_color_like(c) for c in cmap_colors]
3✔
976

977
        if not all(colors_like):
3✔
978
            bad_color_values: np.ndarray = cmap_colors[~np.array(colors_like)]
×
979
            raise ValueError(
×
980
                ("Not all colors in the provided cmap are valid colors."
981
                 f"The following colors are invalid: {bad_color_values}"))
982

983
        np_colors = colors.to_rgba_array(cmap_colors)
3✔
984

985
        color_map, norm = create_cmap(np_colors, n_clusters=cmd.n_clusters)
3✔
986

987
    if isinstance(cmap, str):
3✔
988
        color_map, norm = create_cmap(cmap, n_clusters=cmd.n_clusters)
3✔
989

990
    # create the pixel cluster masks across each fov
991
    with tqdm(total=len(fovs), desc="Cluster Mask Generation", unit="FOVs") as pbar:
3✔
992
        for fov in fovs:
3✔
993
            pbar.set_postfix(FOV=fov)
3✔
994

995
            # generate the cell mask for the FOV
996
            cluster_mask: np.ndarray = generate_cluster_mask(
3✔
997
                fov=fov,
998
                seg_dir=seg_dir,
999
                cmd=cmd,
1000
                seg_suffix=seg_suffix,
1001
                erode=erode,
1002
            )
1003

1004
            # save the cluster mask generated
1005
            save_fov_mask(
3✔
1006
                fov,
1007
                data_dir=save_dir / "cluster_masks",
1008
                mask_data=cluster_mask,
1009
                sub_dir=None,
1010
            )
1011

1012
            save_colored_mask(
3✔
1013
                fov=fov,
1014
                save_dir=save_dir / "cluster_masks_colored",
1015
                suffix=".tiff",
1016
                data=cluster_mask,
1017
                cmap=color_map,
1018
                norm=norm,
1019
            )
1020

1021
            cluster_labels = ["Background"] + cmd.cluster_names + ["Unassigned"]
3✔
1022

1023
            fig = plot_cluster(
3✔
1024
                image=cluster_mask,
1025
                fov=fov,
1026
                cmap=color_map,
1027
                norm=norm,
1028
                cbar_visible=True,
1029
                cbar_labels=cluster_labels,
1030
                figsize=figsize,
1031
                dpi=dpi,
1032
            )
1033

1034
            fig.savefig(
3✔
1035
                fname=os.path.join(save_dir, "cluster_plots", f"{fov}.{fig_file_type}"),
1036
            )
1037

1038
            if display_fig:
3✔
1039
                fig.show(warn=False)
×
1040
            else:
1041
                plt.close(fig)
3✔
1042

1043
            pbar.update(1)
3✔
1044

1045

1046
def plot_continuous_variable(
3✔
1047
    image: np.ndarray,
1048
    name: str,
1049
    stat_name: str,
1050
    cmap: Union[colors.Colormap, str],
1051
    norm: colors.Normalize = None,
1052
    cbar_visible: bool = True,
1053
    dpi: int = 300,
1054
    figsize: tuple[int, int] = (10, 10),
1055
) -> Figure:
1056
    """
1057

1058
    Plots an image measuring some type of continuous variable with a user provided colormap.
1059

1060
    Args:
1061
        image (np.ndarray):
1062
            An array representing an image to plot.
1063
        name (str):
1064
            The name of the image.
1065
        stat_name (str):
1066
            The name of the statistic to plot, this will be the colormap's label.
1067
        cmap (colors.Colormap, str, optional): A colormap to plot the array with.
1068
            Defaults to "viridis".
1069
        cbar_visible (bool, optional): A flag for setting the colorbar on or not.
1070
            Defaults to True.
1071
        norm (colors.Normalize, optional): A normalization to apply to the colormap.
1072
        dpi (int, optional):
1073
            The resolution of the image. Defaults to 300.
1074
        figsize (tuple[int, int], optional):
1075
            The size of the image. Defaults to (10, 10).
1076

1077
    Returns:
1078
        Figure : The Figure object of the image.
1079
    """
1080
    fig: Figure = plt.figure(figsize=figsize, dpi=dpi)
3✔
1081
    fig.set_layout_engine(layout="tight")
3✔
1082
    gs = gridspec.GridSpec(nrows=1, ncols=1, figure=fig)
3✔
1083
    fig.suptitle(f"{name}")
3✔
1084

1085
    # Image axis
1086
    ax: Axes = fig.add_subplot(gs[0, 0])
3✔
1087
    ax.axis("off")
3✔
1088
    ax.grid(visible=False)
3✔
1089

1090
    im = ax.imshow(
3✔
1091
        X=image,
1092
        cmap=cmap,
1093
        norm=norm,
1094
        origin="upper",
1095
        aspect="equal",
1096
        interpolation="none",
1097
    )
1098

1099
    if cbar_visible:
3✔
1100
        # Manually set the colorbar
1101
        divider = make_axes_locatable(fig.gca())
3✔
1102
        cax = divider.append_axes(position="right", size="5%", pad="3%")
3✔
1103

1104
        fig.colorbar(mappable=im, cax=cax, orientation="vertical",
3✔
1105
                     use_gridspec=True, pad=0.1, shrink=0.9, drawedges=False, label=stat_name)
1106

1107
    return fig
3✔
1108

1109

1110
def color_segmentation_by_stat(
3✔
1111
    fovs: List[str],
1112
    data_table: pd.DataFrame,
1113
    seg_dir: Union[pathlib.Path, str],
1114
    save_dir: Union[pathlib.Path, str],
1115
    fov_col: str = settings.FOV_ID,
1116
    label_col: str = settings.CELL_LABEL,
1117
    stat_name: str = settings.CELL_TYPE,
1118
    cmap: str = "viridis",
1119
    reverse: bool = False,
1120
    seg_suffix: str = "_whole_cell.tiff",
1121
    cbar_visible: bool = True,
1122
    style: str = "seaborn-v0_8-paper",
1123
    erode: bool = False,
1124
    display_fig: bool = False,
1125
    fig_file_type: str = "png",
1126
    figsize: tuple = (10, 10),
1127
    dpi: int = 300,
1128
):
1129
    """
1130
    Colors segmentation masks by a given continuous statistic.
1131

1132
    Args:
1133
        fovs: (List[str]):
1134
            A list of FOVs to plot.
1135
        data_table (pd.DataFrame):
1136
            A DataFrame containing FOV and segmentation label identifiers
1137
            as well as a collection of statistics for each label in a segmentation
1138
            mask such as:
1139

1140
                - `fov_id` (identifier)
1141
                - `label` (identifier)
1142
                - `area` (statistic)
1143
                - `fiber` (statistic)
1144
                - etc...
1145

1146
        seg_dir (Union[pathlib.Path, str]):
1147
            Path to the directory containing segmentation masks.
1148
        save_dir (Union[pathlib.Path, str]):
1149
            Path to the directory where the colored segmentation masks will be saved.
1150
        fov_col: (str, optional):
1151
            The name of the column in `data_table` containing the FOV identifiers.
1152
            Defaults to "fov".
1153
        label_col (str, optional):
1154
            The name of the column in `data_table` containing the segmentation label identifiers.
1155
            Defaults to "label".
1156
        stat_name (str):
1157
            The name of the statistic to color the segmentation masks by. This should be a column
1158
            in `data_table`.
1159
        seg_suffix (str, optional):
1160
            The suffix of the segmentation file and it's file extension. Defaults to
1161
            "_whole_cell.tiff".
1162
        cmap (str, optional): The colormap for plotting. Defaults to "viridis".
1163
        reverse (bool, optional):
1164
            A flag to reverse the colormap provided. Defaults to False.
1165
        cbar_visible (bool, optional):
1166
            A flag to display the colorbar. Defaults to True.
1167
        erode (bool, optional): Option to "thicken" the cell boundary via the segmentation label
1168
            for visualization purposes. Defaults to False.
1169
        style (str, optional): Set the matplotlib style image style. Defaults to 
1170
            "seaborn-v0_8-paper".
1171
            View the available styles here: 
1172
            https://matplotlib.org/stable/gallery/style_sheets/style_sheets_reference.html
1173
            Or run matplotlib.pyplot.style.available in a notebook to view all the styles.
1174
        display_fig: (bool, optional):
1175
            Option to display the cluster mask plots as they are generated. Defaults to False.
1176
        fig_file_type (str, optional): The file type to save figures as. Defaults to 'png'.
1177
        figsize (tuple, optional):
1178
            The size of the figure to display. Defaults to (10, 10).
1179
        dpi (int, optional):
1180
            The resolution of the image to use for saving. Defaults to 300.
1181
    """
1182
    plt.style.use(style)
3✔
1183

1184
    if not isinstance(seg_dir, pathlib.Path):
3✔
1185
        seg_dir = pathlib.Path(seg_dir)
×
1186

1187
    if not isinstance(save_dir, pathlib.Path):
3✔
1188
        save_dir = pathlib.Path(save_dir)
×
1189

1190
    io_utils.validate_paths([seg_dir])
3✔
1191

1192
    try:
3✔
1193
        io_utils.validate_paths([save_dir])
3✔
1194
    except FileNotFoundError:
×
1195
        save_dir.mkdir(parents=True, exist_ok=True)
×
1196

1197
    misc_utils.verify_in_list(
3✔
1198
        statistic_name=[fov_col, label_col, stat_name],
1199
        data_table_columns=data_table.columns,
1200
    )
1201

1202
    if not (save_dir / "continuous_plots").exists():
3✔
1203
        (save_dir / "continuous_plots").mkdir(parents=True, exist_ok=True)
3✔
1204
    if not (save_dir / "colored").exists():
3✔
1205
        (save_dir / "colored").mkdir(parents=True, exist_ok=True)
3✔
1206

1207
    # filter the data table to only include the FOVs we want to plot
1208
    data_table = data_table[data_table[fov_col].isin(fovs)]
3✔
1209

1210
    data_table_subset_groups: DataFrameGroupBy = (
3✔
1211
        data_table[[fov_col, label_col, stat_name]]
1212
        .sort_values(by=[fov_col, label_col], key=natsort.natsort_keygen())
1213
        .groupby(by=fov_col)
1214
    )
1215

1216
    # Colormap normalization across the cohort + reverse if necessary
1217
    vmin: np.float64 = data_table[stat_name].min()
3✔
1218
    vmax: np.float64 = data_table[stat_name].max()
3✔
1219
    norm = colors.Normalize(vmin=vmin, vmax=vmax)
3✔
1220

1221
    if reverse:
3✔
1222
        # Adding the suffix "_r" will reverse the colormap
1223
        cmap = f"{cmap}_r"
×
1224

1225
    # Prepend black to the colormap
1226
    color_map = set_minimum_color_for_colormap(
3✔
1227
        cmap=colormaps[cmap], default=(0, 0, 0, 1)
1228
    )
1229

1230
    with tqdm(
3✔
1231
        total=len(data_table_subset_groups),
1232
        desc=f"Generating {stat_name} Plots",
1233
        unit="FOVs",
1234
    ) as pbar:
1235
        for fov, fov_group in data_table_subset_groups:
3✔
1236
            pbar.set_postfix(FOV=fov)
3✔
1237

1238
            label_map: np.ndarray = io.imread(seg_dir / f"{fov}{seg_suffix}")
3✔
1239

1240
            if erode:
3✔
1241
                label_map = erode_mask(
×
1242
                    label_map, connectivity=2, mode="thick", background=0
1243
                )
1244

1245
            mapped_seg_image: np.ndarray = map_segmentation_labels(
3✔
1246
                labels=fov_group[label_col],
1247
                values=fov_group[stat_name],
1248
                label_map=label_map,
1249
            )
1250

1251
            fig = plot_continuous_variable(
3✔
1252
                image=mapped_seg_image,
1253
                name=fov,
1254
                stat_name=stat_name,
1255
                norm=norm,
1256
                cmap=color_map,
1257
                cbar_visible=cbar_visible,
1258
                figsize=figsize,
1259
                dpi=dpi,
1260
            )
1261
            fig.savefig(fname=os.path.join(save_dir, "continuous_plots", f"{fov}.{fig_file_type}"))
3✔
1262

1263
            save_colored_mask(
3✔
1264
                fov=fov,
1265
                save_dir=save_dir / "colored",
1266
                suffix=".tiff",
1267
                data=mapped_seg_image,
1268
                cmap=color_map,
1269
                norm=norm,
1270
            )
1271
            if display_fig:
3✔
1272
                fig.show(warn=False)
×
1273
            else:
1274
                plt.close(fig)
3✔
1275

1276
            pbar.update(1)
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc