• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

angelolab / ark-analysis / 11242350337

08 Oct 2024 07:26PM UTC coverage: 93.078% (-0.4%) from 93.47%
11242350337

push

github

web-flow
adjust mask suffix input (#1164)

1819 of 2070 branches covered (87.87%)

Branch coverage included in aggregate %.

4 of 4 new or added lines in 1 file covered. (100.0%)

8 existing lines in 1 file now uncovered.

3923 of 4099 relevant lines covered (95.71%)

2.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.55
/src/ark/utils/data_utils.py
1
import numba as nb
3✔
2
import itertools
3✔
3
import os
3✔
4
import pathlib
3✔
5
import re
3✔
6
from typing import List, Literal, Union, Sequence
3✔
7

8
from numpy.typing import ArrayLike, DTypeLike
3✔
9
from numpy import ma
3✔
10
import feather
3✔
11
import natsort as ns
3✔
12
import numpy as np
3✔
13
import pandas as pd
3✔
14
import skimage.io as io
3✔
15
from alpineer import data_utils, image_utils, io_utils, load_utils, misc_utils
3✔
16
from alpineer.settings import EXTENSION_TYPES
3✔
17
from tqdm.notebook import tqdm_notebook as tqdm
3✔
18
import xarray as xr
3✔
19
from ark import settings
3✔
20
from skimage.segmentation import find_boundaries
3✔
21
from pandas.core.groupby.generic import DataFrameGroupBy
3✔
22
from anndata import AnnData, read_zarr
3✔
23
from anndata.experimental import AnnCollection
3✔
24
from anndata.experimental.multi_files._anncollection import ConvertType
3✔
25
from typing import Optional
3✔
26
try:
3✔
27
    from typing import TypedDict, Unpack
3✔
28
except ImportError:
29
    from typing_extensions import TypedDict, Unpack
30

31

32
def save_fov_mask(fov, data_dir, mask_data, sub_dir=None, name_suffix=''):
3✔
33
    """Saves a provided cluster label mask overlay for a FOV.
34

35
    Args:
36
        fov (str):
37
            The FOV to save
38
        data_dir (str):
39
            The directory to save the cluster mask
40
        mask_data (numpy.ndarray):
41
            The cluster mask data for the FOV
42
        sub_dir (Optional[str]):
43
            The subdirectory to save the masks in. If specified images are saved to
44
            "data_dir/sub_dir". If `sub_dir = None` the images are saved to `"data_dir"`.
45
            Defaults to `None`.
46
        name_suffix (str):
47
            Specify what to append at the end of every fov.
48
    """
49

50
    # data_dir validation
51
    io_utils.validate_paths(data_dir)
3✔
52

53
    # ensure None is handled correctly in file path generation
54
    if sub_dir is None:
3✔
55
        sub_dir = ''
3✔
56

57
    save_dir = os.path.join(data_dir, sub_dir)
3✔
58

59
    # make the save_dir if it doesn't already exist
60
    if not os.path.exists(save_dir):
3✔
61
        os.makedirs(save_dir)
3✔
62

63
    # define the file name as the fov name with the name suffix appended
64
    fov_file = fov + name_suffix + '.tiff'
3✔
65

66
    # save the image to data_dir
67
    image_utils.save_image(os.path.join(save_dir, fov_file), mask_data)
3✔
68

69

70
def erode_mask(seg_mask: np.ndarray, **kwargs) -> np.ndarray:
3✔
71
    """
72
    Erodes the edges labels of a segmentation mask.
73
    Other keyword arguments get passed to `skimage.segmentation.find_boundaries`.
74

75
    Args:
76
        seg_mask (np.ndarray): The segmentation mask to erode.
77

78
    Returns:
79
        np.ndarray: The eroded segmentation mask
80
    """
81
    edges = find_boundaries(
3✔
82
        label_img=seg_mask, **kwargs)
83
    seg_mask = np.where(edges == 0, seg_mask, 0)
3✔
84
    return seg_mask
3✔
85

86

87
class ClusterMaskData:
3✔
88
    """
3✔
89
    A class containing the cell labels, cluster labels, and segmentation labels for the
90
    whole cohort. Also contains the mapping from the segmentation label to the cluster
91
    label for each FOV.
92
    """
93

94
    fov_column: str
3✔
95
    label_column: str
3✔
96
    cluster_column: str
3✔
97
    unique_fovs: List[str]
3✔
98
    cluster_id_column: str
3✔
99
    unassigned_id: int
3✔
100
    n_clusters: int
3✔
101
    mapping: pd.DataFrame
3✔
102

103
    def __init__(
3✔
104
        self, data: pd.DataFrame, fov_col: str, label_col: str, cluster_col: str
105
    ) -> None:
106
        """
107
        A class containing the cell data, cell label column, cluster column and the mapping from a
108
        cell label to a cluster.
109

110
        Args:
111
            data (pd.DataFrame):
112
                A cell table with the cell label column and the cluster column.
113
            fov_col (str):
114
                The name of the column in the cell table that contains the FOV ID.
115
            label_col (str):
116
                The name of the column in the cell table that contains the cell label.
117
            cluster_col (str):
118
                The name of the column in the cell table that contains the cluster label.
119
        """
120
        self.fov_column: str = fov_col
3✔
121
        self.label_column: str = label_col
3✔
122
        self.cluster_column: str = cluster_col
3✔
123
        self.cluster_id_column: str = "cluster_id"
3✔
124

125
        # Extract only the necessary columns: fov ID, segmentation label, cluster label
126
        mapping_data: pd.DataFrame = data[
3✔
127
            [self.fov_column, self.label_column, self.cluster_column]
128
        ].copy()
129

130
        # Add a cluster_id_column to the column in case the cluster_column is
131
        # non-numeric (i.e. string), index in ascending order of cell_meta_cluster
132
        cluster_name_id = pd.DataFrame(
3✔
133
            {self.cluster_column: mapping_data[self.cluster_column].unique()})
134
        cluster_name_id.sort_values(by=f'{self.cluster_column}', inplace=True)
3✔
135
        cluster_name_id.reset_index(drop=True, inplace=True)
3✔
136

137
        cluster_name_id[self.cluster_id_column] = (cluster_name_id.index + 1).astype(np.int32)
3✔
138

139
        self.cluster_name_id = cluster_name_id
3✔
140

141
        # merge the cluster_id_column to the mapping_data dataframe
142
        mapping_data = mapping_data.merge(right=self.cluster_name_id, on=self.cluster_column)
3✔
143

144
        mapping_data = mapping_data.astype(
3✔
145
            {
146
                self.fov_column: str,
147
                self.label_column: np.int32,
148
                self.cluster_id_column: np.int32,
149
            }
150
        )
151
        self.unique_fovs: List[str] = ns.natsorted(
3✔
152
            mapping_data[self.fov_column].unique().tolist()
153
        )
154

155
        self.unassigned_id: np.int32 = np.int32(
3✔
156
            mapping_data[self.cluster_id_column].max() + 1
157
        )
158
        self.n_clusters: int = mapping_data[self.cluster_id_column].max()
3✔
159

160
        # For each FOV map the segmentation label 0 (background) to the cluster label 0
161
        cluster0_mapping: pd.DataFrame = pd.DataFrame(
3✔
162
            data={
163
                self.fov_column: self.unique_fovs,
164
                self.label_column: np.repeat(0, repeats=len(self.unique_fovs)),
165
                self.cluster_column: np.repeat(0, repeats=len(self.unique_fovs)),
166
                self.cluster_id_column: np.repeat(0, repeats=len(self.unique_fovs)),
167
            }
168
        )
169

170
        mapping_data = pd.concat(objs=[mapping_data, cluster0_mapping]).astype(
3✔
171
            {
172
                self.fov_column: str,
173
                self.label_column: np.int32,
174
                self.cluster_id_column: np.int32,
175
            }
176
        )
177

178
        # Sort by FOV first, then by segmentation label
179
        self.mapping = mapping_data.sort_values(by=[self.fov_column, self.label_column])
3✔
180

181
    def fov_mapping(self, fov: str) -> pd.DataFrame:
3✔
182
        """Returns the mapping for a specific FOV.
183
        Args:
184
            fov (str):
185
                The FOV to get the mapping for.
186
        Returns:
187
            pd.DataFrame:
188
                The mapping for the FOV.
189
        """
190
        misc_utils.verify_in_list(requested_fov=[fov], all_fovs=self.unique_fovs)
3✔
191
        fov_data: pd.DataFrame = self.mapping[self.mapping[self.fov_column] == fov]
3✔
192
        return fov_data.reset_index(drop=True)
3✔
193

194
    @property
3✔
195
    def cluster_names(self) -> List[str]:
3✔
196
        """Returns the cluster names.
197
        Returns:
198
            List[str]:
199
                The cluster names.
200
        """
201
        return self.cluster_name_id[self.cluster_column].tolist()
3✔
202

203

204
def label_cells_by_cluster(
3✔
205
        fov: str,
206
        cmd: ClusterMaskData,
207
        label_map: Union[np.ndarray, xr.DataArray],
208
) -> np.ndarray:
209
    """Translates cell-ID labeled images according to the clustering assignment
210
    found in cell_cluster_mask_data.
211

212

213
    Args:
214
        fov (str):
215
            The FOV to relabel
216
        cmd (ClusterMaskData):
217
            A dataclass containing the cell data, cell label column, cluster column and the
218
            mapping from the segmentation label to the cluster label for a given FOV.
219
        label_map (xarray.DataArray):
220
            label map for a single FOV.
221

222
    Returns:
223
        numpy.ndarray:
224
            The image with new designated label assignments
225
    """
226

227
    # verify that fov found in all_data
228
    misc_utils.verify_in_list(
3✔
229
        fov_name=[fov],
230
        all_data_fovs=cmd.unique_fovs
231
    )
232

233
    # condense extraneous axes if label_map is a DataArray
234
    if isinstance(label_map, xr.DataArray):
3✔
235
        labeled_image = label_map.squeeze().values.astype(np.int32)
3✔
236
    else:
237
        labeled_image: np.ndarray = label_map.squeeze().astype(np.int32)
3✔
238

239
    fov_clusters: pd.DataFrame = cmd.fov_mapping(fov=fov)
3✔
240

241
    mapping: nb.typed.typeddict = nb.typed.Dict.empty(
3✔
242
        key_type=nb.types.int32,
243
        value_type=nb.types.int32,
244
    )
245

246
    for label, cluster in fov_clusters[[cmd.label_column, cmd.cluster_id_column]].itertuples(
3✔
247
            index=False):
248
        mapping[np.int32(label)] = np.int32(cluster)
3✔
249

250
    relabeled_image: np.ndarray = relabel_segmentation(
3✔
251
        mapping=mapping,
252
        unassigned_id=cmd.unassigned_id,
253
        labeled_image=labeled_image,
254
        _dtype=np.int32)
255

256
    return relabeled_image.astype(np.int16)
3✔
257

258

259
def map_segmentation_labels(
3✔
260
    labels: Union[pd.Series, np.ndarray],
261
    values: Union[pd.Series, np.ndarray],
262
    label_map: ArrayLike,
263
    unassigned_id: float = 0,
264
) -> np.ndarray:
265
    """
266
    Maps an image consisting of segmentation labels to an image consisting of a particular type of
267
    statistic, metric, or value of interest.
268

269
    Args:
270
        labels (Union[pd.Series, np.ndarray]): The segmentation labels.
271
        values (Union[pd.Series, np.ndarray]): The values to map to the segmentation labels.
272
        label_map (ArrayLike): The segmentation labels as an image to map to.
273
        unassigned_id (int | float, optional): A default value to assign there is exists no 1-to-1
274
        mapping from a label in the label_map to a label in the `labels` argument. Defaults to 0.
275

276
    Returns:
277
        np.ndarray: Returns the mapped image.
278
    """
279
    # condense extraneous axes if label_map is a DataArray
280
    if isinstance(label_map, xr.DataArray):
3✔
281
        labeled_image = label_map.squeeze().values.astype(np.int32)
3✔
282
    else:
283
        labeled_image: np.ndarray = label_map.squeeze().astype(np.int32)
3✔
284

285
    if isinstance(labels, pd.Series):
3✔
286
        labels = labels.to_numpy(dtype=np.int32)
3✔
287
    if isinstance(values, pd.Series):
3!
288
        # handle NaNs, replace with 0
289
        values = ma.fix_invalid(values.to_numpy(dtype=np.float64), fill_value=0).data
3✔
290

291
    mapping: nb.typed.typeddict = nb.typed.Dict.empty(
3!
292
        key_type=nb.types.int32, value_type=nb.types.float64
293
    )
294

295
    for label, value in zip(labels, values):
3✔
296
        mapping[label] = value
3✔
297

298
    relabeled_image: np.ndarray = relabel_segmentation(
3✔
299
        mapping=mapping,
300
        unassigned_id=unassigned_id,
301
        labeled_image=labeled_image,
302
        _dtype=np.float64,
303
    )
304

305
    return relabeled_image
3✔
306

307

308
@nb.njit(parallel=True)
3✔
309
def relabel_segmentation(
3✔
310
    mapping: nb.typed.typeddict,
311
    unassigned_id: np.int32,
312
    labeled_image: np.ndarray,
313
    _dtype: DTypeLike = np.float64,
314
) -> np.ndarray:
315
    """
316
    Relabels a labled segmentation image according to the provided values.
317

318
    Args:
319
        mapping (nb.typed.typeddict):
320
            A Numba typed dictionary mapping segmentation labels to cluster labels.
321
        unassigned_id (np.int32):
322
            The label given to a pixel with no associated cluster.
323
        labeled_image (np.ndarray):
324
            The labeled segmentation image.
325
        _dtype (DTypeLike, optional):
326
            The data type of the relabeled image. Defaults to `np.float64`.
327

328
    Returns:
329
        np.ndarray: The relabeled segmentation image.
330
    """
331
    relabeled_image: np.ndarray = np.empty(shape=labeled_image.shape, dtype=_dtype)
3✔
332
    for i in nb.prange(labeled_image.shape[0]):
3✔
333
        for j in nb.prange(labeled_image.shape[1]):
3✔
334
            relabeled_image[i, j] = mapping.get(labeled_image[i, j], unassigned_id)
3✔
335
    return relabeled_image
3✔
336

337

338
def generate_cluster_mask(
3✔
339
        fov: str,
340
        seg_dir: Union[str, pathlib.Path],
341
        cmd: ClusterMaskData,
342
        seg_suffix: str = "_whole_cell.tiff",
343
        erode: bool = True,
344
        **kwargs) -> np.ndarray:
345
    """For a fov, create a mask labeling each cell with their SOM or meta cluster label
346

347
    Args:
348
        fov (str):
349
            The fov to relabel
350
        seg_dir (str):
351
            The path to the segmentation data
352
        cmd (ClusterMaskData):
353
            A dataclass containing the cell data, cell label column, cluster column and the
354
            mapping from the segmentation label to the cluster label for a given FOV.
355
        seg_suffix (str):
356
            The suffix that the segmentation images use. Defaults to `'_whole_cell.tiff'`.
357
        erode (bool):
358
            Whether to erode the edges of the segmentation mask. Defaults to `True`.
359

360
    Returns:
361
        numpy.ndarray:
362
            The image where values represent cell cluster labels.
363
    """
364

365
    # path checking
366
    io_utils.validate_paths([seg_dir])
3✔
367

368
    # define the file for whole cell
369
    whole_cell_files = [fov + seg_suffix]
3✔
370

371
    # load the segmentation labels in for the FOV
372
    label_map = load_utils.load_imgs_from_dir(
3✔
373
        data_dir=seg_dir, files=whole_cell_files, xr_dim_name='compartments',
374
        xr_channel_names=['whole_cell'], trim_suffix=seg_suffix.split('.')[0]
375
    ).loc[fov, ...]
376

377
    if erode:
3✔
378
        label_map = erode_mask(label_map, connectivity=2, mode="thick", background=0)
3✔
379

380
    # use label_cells_by_cluster to create cell masks
381
    img_data: np.ndarray = label_cells_by_cluster(
3✔
382
        fov=fov,
383
        cmd=cmd,
384
        label_map=label_map
385
    )
386

387
    return img_data
3✔
388

389

390
def generate_and_save_cell_cluster_masks(
3✔
391
    fovs: List[str],
392
    save_dir: Union[pathlib.Path, str],
393
    seg_dir: Union[pathlib.Path, str],
394
    cell_data: pd.DataFrame,
395
    cluster_id_to_name_path: Union[pathlib.Path, str],
396
    fov_col: str = settings.FOV_ID,
397
    label_col: str = settings.CELL_LABEL,
398
    cell_cluster_col: str = settings.CELL_TYPE,
399
    seg_suffix: str = "_whole_cell.tiff",
400
    sub_dir: str = None,
401
    name_suffix: str = "",
402
):
403
    """Generates cell cluster masks and saves them for downstream analysis.
404

405
    Args:
406
        fovs (List[str]):
407
            A list of fovs to generate and save pixel masks for.
408
        save_dir (Union[pathlib.Path, str]):
409
            The directory to save the generated cell cluster masks.
410
        seg_dir (Union[pathlib.Path, str]):
411
            The path to the segmentation data.
412
        cell_data (pd.DataFrame):
413
            The cell data with both cell SOM and meta cluster assignments.
414
        cluster_id_to_name_path (Union[str, pathlib.Path]): A path to a CSV identifying the
415
            cell cluster to manually-defined name mapping this is output by the remapping
416
            visualization found in `metacluster_remap_gui`.
417
        fov_col (str, optional):
418
            The column name containing the FOV IDs . Defaults to `settings.FOV_ID` (`"fov"`).
419
        label_col (str, optional):
420
            The column name containing the cell label. Defaults to
421
            `settings.CELL_LABEL` (`"label"`).
422
        cell_cluster_col (str, optional):
423
            Whether to assign SOM or meta clusters. Needs to be `"cell_som_cluster"` or
424
            `"cell_meta_cluster"`. Defaults to `settings.CELL_TYPE` (`"cell_meta_cluster"`).
425
        seg_suffix (str, optional):
426
            The suffix that the segmentation images use. Defaults to `"_whole_cell.tiff"`.
427
        sub_dir (str, optional):
428
            The subdirectory to save the images in. If specified images are saved to
429
            `"data_dir/sub_dir"`. If `sub_dir = None` the images are saved to `"data_dir"`.
430
            Defaults to `None`.
431
        name_suffix (str, optional):
432
            Specify what to append at the end of every cell mask. Defaults to `""`.
433
    """
434

435
    cmd = ClusterMaskData(
3✔
436
        data=cell_data,
437
        fov_col=fov_col,
438
        label_col=label_col,
439
        cluster_col=cell_cluster_col,
440
    )
441

442
    # read in gui cluster mapping file and new cluster mapping created by ClusterMaskData
443
    gui_map = pd.read_csv(cluster_id_to_name_path)
3✔
444
    cluster_map = cmd.mapping.filter([cmd.cluster_column, cmd.cluster_id_column])
3✔
445
    cluster_map = cluster_map.drop_duplicates()
3✔
446

447
    # drop the cluster_id column from updated_cluster_map if it already exists, otherwise do nothing
448
    gui_map = gui_map.drop(columns="cluster_id", errors="ignore")
3✔
449

450
    # add a cluster_id column corresponding to the new mask integers
451
    updated_cluster_map = gui_map.merge(cluster_map, on=[cmd.cluster_column], how="left")
3✔
452
    updated_cluster_map.to_csv(cluster_id_to_name_path, index=False)
3✔
453

454
    # create the pixel cluster masks across each fov
455
    with tqdm(total=len(fovs), desc="Cell Cluster Mask Generation", unit="FOVs") as pbar:
3✔
456
        for fov in fovs:
3✔
457
            pbar.set_postfix(FOV=fov)
3✔
458

459
            # generate the cell mask for the FOV
460
            cell_mask: np.ndarray = generate_cluster_mask(
3✔
461
                fov=fov, seg_dir=seg_dir, cmd=cmd, seg_suffix=seg_suffix
462
            )
463

464
            # save the cell mask generated
465
            save_fov_mask(
3✔
466
                fov,
467
                data_dir=save_dir,
468
                mask_data=cell_mask,
469
                sub_dir=sub_dir,
470
                name_suffix=name_suffix,
471
            )
472

473
            pbar.update(1)
3✔
474

475

476
def generate_pixel_cluster_mask(fov, base_dir, tiff_dir, chan_file_path,
3✔
477
                                pixel_data_dir, cluster_mapping,
478
                                pixel_cluster_col='pixel_meta_cluster'):
479
    """For a fov, create a mask labeling each pixel with their SOM or meta cluster label
480

481
    Args:
482
        fov (list):
483
            The fov to relabel
484
        base_dir (str):
485
            The path to the data directory
486
        tiff_dir (str):
487
            The path to the tiff data
488
        chan_file_path (str):
489
            The path to the sample channel file to load (`tiff_dir` as root).
490
            Used to determine dimensions of the pixel mask.
491
        pixel_data_dir (str):
492
            The path to the data with full pixel data.
493
            This data should also have the SOM and meta cluster labels appended.
494
        cluster_mapping (pd.DataFrame)
495
            Dataframe detailing which meta_cluster IDs map to which cluster_id
496
        pixel_cluster_col (str):
497
            Whether to assign SOM or meta clusters
498
            needs to be `'pixel_som_cluster'` or `'pixel_meta_cluster'`
499

500
    Returns:
501
        numpy.ndarray:
502
            The image overlaid with pixel cluster labels
503
    """
504

505
    # path checking
506
    io_utils.validate_paths([tiff_dir, os.path.join(tiff_dir, chan_file_path),
3✔
507
                             os.path.join(base_dir, pixel_data_dir)])
508

509
    # verify the pixel_cluster_col provided is valid
510
    misc_utils.verify_in_list(
3✔
511
        provided_cluster_col=[pixel_cluster_col],
512
        valid_cluster_cols=['pixel_som_cluster', 'pixel_meta_cluster']
513
    )
514

515
    # verify the fov is valid
516
    misc_utils.verify_in_list(
3✔
517
        provided_fov_file=[fov + '.feather'],
518
        consensus_fov_files=os.listdir(os.path.join(base_dir, pixel_data_dir))
519
    )
520

521
    # read the sample channel file to determine size of pixel cluster mask
522
    channel_data = np.squeeze(io.imread(os.path.join(tiff_dir, chan_file_path)))
3✔
523

524
    # define an array to hold the overlays for the fov
525
    # use int16 to allow for Photoshop loading
526
    img_data = np.zeros((channel_data.shape[0], channel_data.shape[1]), dtype='int16')
3✔
527

528
    fov_data = feather.read_dataframe(
3✔
529
        os.path.join(base_dir, pixel_data_dir, fov + '.feather')
530
    )
531

532
    # ensure integer display and not float
533
    fov_data[pixel_cluster_col] = fov_data[pixel_cluster_col].astype(int)
3✔
534

535
    # get the pixel coordinates
536
    x_coords = fov_data['row_index'].values
3✔
537
    y_coords = fov_data['column_index'].values
3✔
538

539
    # convert to 1D indexing
540
    coordinates = x_coords * img_data.shape[1] + y_coords
3✔
541

542
    # get the corresponding cluster labels for each pixel
543
    cluster_labels = list(fov_data[pixel_cluster_col])
3✔
544

545
    # relabel meta_cluster numbers with cluster_id
546
    cluster_mapping = cluster_mapping.drop_duplicates()[[pixel_cluster_col, 'cluster_id']]
3✔
547
    id_mapping = dict(zip(cluster_mapping[pixel_cluster_col], cluster_mapping['cluster_id']))
3✔
548
    cluster_labels = [id_mapping[label] for label in cluster_labels]
3✔
549

550
    # assign each coordinate in pixel_cluster_mask to its respective cluster label
551
    img_subset = img_data.ravel()
3✔
552
    img_subset[coordinates] = cluster_labels
3✔
553
    img_data = img_subset.reshape(img_data.shape)
3✔
554

555
    return img_data
3✔
556

557

558
def generate_and_save_pixel_cluster_masks(fovs: List[str],
3✔
559
                                          base_dir: Union[pathlib.Path, str],
560
                                          save_dir: Union[pathlib.Path, str],
561
                                          tiff_dir: Union[pathlib.Path, str],
562
                                          chan_file: Union[pathlib.Path, str],
563
                                          pixel_data_dir: Union[pathlib.Path, str],
564
                                          cluster_id_to_name_path: Union[pathlib.Path, str],
565
                                          pixel_cluster_col: str = 'pixel_meta_cluster',
566
                                          sub_dir: str = None,
567
                                          name_suffix: str = ''):
568
    """Generates pixel cluster masks and saves them for downstream analysis.
569

570
    Args:
571
        fovs (List[str]):
572
            A list of fovs to generate and save pixel masks for.
573
        base_dir (Union[pathlib.Path, str]):
574
            The path to the data directory.
575
        save_dir (Union[pathlib.Path, str]):
576
            The directory to save the generated pixel cluster masks.
577
        tiff_dir (Union[pathlib.Path, str]):
578
            The path to the directory with the tiff data.
579
        chan_file (Union[pathlib.Path, str]):
580
            The path to the channel file inside each FOV folder (FOV folder as root).
581
            Used to determine dimensions of the pixel mask.
582
        pixel_data_dir (Union[pathlib.Path, str]):
583
            The path to the data with full pixel data.
584
            This data should also have the SOM and meta cluster labels appended.
585
        cluster_id_to_name_path (Union[str, pathlib.Path]): A path to a CSV identifying the
586
            pixel cluster to manually-defined name mapping this is output by the remapping
587
            visualization found in `metacluster_remap_gui`.
588
        pixel_cluster_col (str, optional):
589
            The path to the data with full pixel data.
590
            This data should also have the SOM and meta cluster labels appended.
591
            Defaults to 'pixel_meta_cluster'.
592
        sub_dir (str, optional):
593
            The subdirectory to save the images in. If specified images are saved to
594
            `"data_dir/sub_dir"`. If `sub_dir = None` the images are saved to `"data_dir"`.
595
            Defaults to `None`.
596
        name_suffix (str, optional):
597
            Specify what to append at the end of every pixel mask. Defaults to `''`.
598
    """
599
    # read in gui cluster mapping file and save cluster_id created in generate_pixel_cluster_mask
600
    gui_map = pd.read_csv(cluster_id_to_name_path)
3✔
601
    cluster_map = gui_map.copy()[[pixel_cluster_col]]
3✔
602

603
    cluster_map = cluster_map.drop_duplicates().sort_values(by=[pixel_cluster_col])
3✔
604
    cluster_map["cluster_id"] = list(range(1, len(cluster_map) + 1))
3✔
605

606
    # drop the cluster_id column from gui_map if it already exists, otherwise do nothing
607
    gui_map = gui_map.drop(columns="cluster_id", errors="ignore")
3✔
608

609
    # add a cluster_id column corresponding to the new mask integers
610
    updated_cluster_map = gui_map.merge(cluster_map, on=[pixel_cluster_col], how="left")
3✔
611
    updated_cluster_map.to_csv(cluster_id_to_name_path, index=False)
3✔
612

613
    # create the pixel cluster masks across each fov
614
    with tqdm(total=len(fovs), desc="Pixel Cluster Mask Generation", unit="FOVs") \
3✔
615
            as pixel_mask_progress:
616
        for fov in fovs:
3✔
617
            pixel_mask_progress.set_postfix(FOV=fov)
3✔
618

619
            # define the path to provided channel file in the fov dir, used to calculate dimensions
620
            chan_file_path = os.path.join(fov, chan_file)
3✔
621

622
            # generate the pixel mask for the FOV
623
            pixel_mask: np.ndarray =\
3✔
624
                generate_pixel_cluster_mask(fov=fov, base_dir=base_dir, tiff_dir=tiff_dir,
625
                                            chan_file_path=chan_file_path,
626
                                            pixel_data_dir=pixel_data_dir,
627
                                            pixel_cluster_col=pixel_cluster_col,
628
                                            cluster_mapping=updated_cluster_map)
629

630
            # save the pixel mask generated
631
            save_fov_mask(fov, data_dir=save_dir, mask_data=pixel_mask, sub_dir=sub_dir,
3✔
632
                          name_suffix=name_suffix)
633

634
            pixel_mask_progress.update(1)
3✔
635

636

637
def generate_and_save_neighborhood_cluster_masks(
3✔
638
    fovs: List[str],
639
    save_dir: Union[pathlib.Path, str],
640
    seg_dir: Union[pathlib.Path, str],
641
    neighborhood_data: pd.DataFrame,
642
    fov_col: str = settings.FOV_ID,
643
    label_col: str = settings.CELL_LABEL,
644
    cluster_col: str = settings.KMEANS_CLUSTER,
645
    seg_suffix: str = "_whole_cell.tiff",
646
    xr_channel_name="label",
647
    sub_dir: Union[pathlib.Path, str] = None,
648
    name_suffix: str = "",
649
):
650
    """Generates neighborhood cluster masks and saves them for downstream analysis.
651

652
    Args:
653
        fovs (List[str]):
654
            A list of fovs to generate and save neighborhood masks for.
655
        save_dir (Union[pathlib.Path, str]):
656
            The directory to save the generated pixel cluster masks.
657
        seg_dir (Union[pathlib.Path, str]):
658
            The path to the segmentation data.
659
        neighborhood_data (pd.DataFrame):
660
            Contains the neighborhood cluster assignments for each cell.
661
        fov_col (str, optional):
662
            The column name containing the FOV IDs . Defaults to `settings.FOV_ID` (`"fov"`).
663
        label_col (str, optional):
664
            The column name containing the cell label. Defaults to `settings.CELL_LABEL`
665
            (`"label"`).
666
        cluster_col (str, optional):
667
            The column name containing the cluster label. Defaults to `settings.KMEANS_CLUSTER`
668
            (`"kmeans_neighborhood"`).
669
        seg_suffix (str, optional):
670
            The suffix that the segmentation images use. Defaults to `'_whole_cell.tiff'`
671
        xr_channel_name (str):
672
            Channel name for segmented data array.
673
        sub_dir (str, optional):
674
            The subdirectory to save the images in. If specified images are saved to
675
            `"data_dir/sub_dir"`. If `sub_dir = None` the images are saved to `"data_dir"`.
676
            Defaults to `None`.
677
        name_suffix (str, optional):
678
            Specify what to append at the end of every pixel mask. Defaults to `''`.
679
    """
680

681
    cmd = ClusterMaskData(
3✔
682
        data=neighborhood_data,
683
        fov_col=fov_col,
684
        label_col=label_col,
685
        cluster_col=cluster_col,
686
    )
687

688
    # create the neighborhood cluster masks across each fov
689
    with tqdm(total=len(fovs), desc="Neighborhood Cluster Mask Generation", unit="FOVs") \
3✔
690
            as neigh_mask_progress:
691
        # generate the mask for each FOV
692
        for fov in fovs:
3✔
693
            neigh_mask_progress.set_postfix(FOV=fov)
3✔
694

695
            # load in the label map for the FOV
696
            label_map = load_utils.load_imgs_from_dir(
3✔
697
                seg_dir,
698
                files=[fov + seg_suffix],
699
                xr_channel_names=[xr_channel_name],
700
                trim_suffix=seg_suffix.split(".")[0],
701
            ).loc[fov, ..., :]
702

703
            # generate the neighborhood mask for the FOV
704
            neighborhood_mask: np.ndarray = label_cells_by_cluster(fov, cmd, label_map)
3✔
705

706
            # save the neighborhood mask generated
707
            save_fov_mask(
3✔
708
                fov,
709
                data_dir=save_dir,
710
                mask_data=neighborhood_mask,
711
                sub_dir=sub_dir,
712
                name_suffix=name_suffix,
713
            )
714

715
            neigh_mask_progress.update(1)
3✔
716

717

718
def split_img_stack(stack_dir, output_dir, stack_list, indices, names, channels_first=True):
3✔
719
    """Splits the channels in a given directory of images into separate files
720

721
    Images are saved in the output_dir
722

723
    Args:
724
        stack_dir (str):
725
            where we read the input files
726
        output_dir (str):
727
            where we write the split channel data
728
        stack_list (list):
729
            the names of the files we want to read from stack_dir
730
        indices (list):
731
            the indices we want to pull data from
732
        names (list):
733
            the corresponding names of the channels
734
        channels_first (bool):
735
            whether we index at the beginning or end of the array
736
    """
737

738
    for stack_name in stack_list:
3✔
739
        img_stack = io.imread(os.path.join(stack_dir, stack_name))
3✔
740
        img_dir = os.path.join(output_dir, os.path.splitext(stack_name)[0])
3✔
741
        os.makedirs(img_dir)
3✔
742

743
        for i in range(len(indices)):
3✔
744
            if channels_first:
3✔
745
                channel = img_stack[indices[i], ...]
3✔
746
            else:
747
                channel = img_stack[..., indices[i]]
3✔
748

749
            save_path = os.path.join(img_dir, names[i])
3✔
750
            image_utils.save_image(save_path, channel)
3✔
751

752

753
def stitch_images_by_shape(data_dir, stitched_dir, img_sub_folder=None, channels=None,
3✔
754
                           segmentation=False, clustering=False):
755
    """ Creates stitched images for the specified channels based on the FOV folder names
756

757
    Args:
758
        data_dir (str):
759
            path to directory containing images
760
        stitched_dir (str):
761
            path to directory to save stitched images to
762
        img_sub_folder (str):
763
            optional name of image sub-folder within each fov
764
        channels (list):
765
            optional list of imgs to load, otherwise loads all imgs
766
        segmentation (bool):
767
            if stitching images from the single segmentation dir
768
        clustering (bool or str):
769
            if stitching images from the single pixel or cell mask dir, specify 'pixel' / 'cell'
770
    """
771

772
    io_utils.validate_paths(data_dir)
3✔
773

774
    # no img_sub_folder, change to empty string to read directly from base folder
775
    if img_sub_folder in [None, '', ""]:
3✔
776
        img_sub_folder = ""
3✔
777

778
    if clustering and clustering not in ['pixel', 'cell']:
3✔
779
        raise ValueError('If stitching images from the pixie pipeline, the clustering arg must be '
3✔
780
                         'set to either \"pixel\" or \"cell\".')
781

782
    # retrieve valid fov names
783
    if segmentation:
3✔
784
        fovs = ns.natsorted(io_utils.list_files(data_dir, substrs='_whole_cell.tiff'))
3✔
785
        fovs = io_utils.extract_delimited_names(fovs, delimiter='_whole_cell.tiff')
3✔
786
    elif clustering:
3✔
787
        fovs = ns.natsorted(io_utils.list_files(data_dir, substrs=f'_{clustering}_mask.tiff'))
3✔
788
        fovs = io_utils.extract_delimited_names(fovs, delimiter=f'_{clustering}_mask.tiff')
3✔
789
    else:
790
        fovs = ns.natsorted(io_utils.list_folders(data_dir))
3✔
791
        # ignore previous toffy stitching in fov directory
792
        if 'stitched_images' in fovs:
3✔
793
            fovs.remove('stitched_images')
3✔
794

795
    if len(fovs) == 0:
3✔
796
        raise ValueError(f"No FOVs found in directory, {data_dir}.")
3✔
797

798
    # check previous stitching
799
    if os.path.exists(stitched_dir):
3✔
800
        raise ValueError(f"The {stitched_dir} directory already exists.")
3✔
801

802
    search_term: str = re.compile(r"(R\+?\d+)(C\+?\d+)")
3✔
803

804
    bad_fov_names = []
3✔
805
    for fov in fovs:
3✔
806
        res = re.search(search_term, fov)
3✔
807
        if res is None:
3✔
808
            bad_fov_names.append(fov)
3✔
809

810
    if len(bad_fov_names) > 0:
3✔
811
        raise ValueError(f"Invalid FOVs found in directory, {data_dir}. FOV names "
3✔
812
                         f"{bad_fov_names} should have the form RnCm.")
813

814
    # retrieve all extracted channel names and verify list if provided
815
    if not segmentation and not clustering:
3✔
816
        channel_imgs = io_utils.list_files(
3✔
817
            dir_name=os.path.join(data_dir, fovs[0], img_sub_folder),
818
            substrs=EXTENSION_TYPES["IMAGE"])
819
    else:
820
        channel_imgs = io_utils.list_files(data_dir, substrs=fovs[0]+'_')
3✔
821
        channel_imgs = [chan.split(fovs[0] + '_')[1] for chan in channel_imgs]
3✔
822

823
    if channels is None:
3✔
824
        channels = io_utils.remove_file_extensions(channel_imgs)
3✔
825
    else:
826
        misc_utils.verify_in_list(channel_inputs=channels,
3✔
827
                                  valid_channels=io_utils.remove_file_extensions(channel_imgs))
828

829
    file_ext = os.path.splitext(channel_imgs[0])[1]
3✔
830
    expected_tiles = load_utils.get_tiled_fov_names(fovs, return_dims=True)
3✔
831

832
    # save new images to the stitched_images, one channel at a time
833
    for chan, tile in itertools.product(channels, expected_tiles):
3✔
834
        prefix, expected_fovs, num_rows, num_cols = tile
3✔
835
        if prefix == "":
3✔
836
            prefix = "unnamed_tile"
3✔
837
        stitched_subdir = os.path.join(stitched_dir, prefix)
3✔
838
        if not os.path.exists(stitched_subdir):
3✔
839
            os.makedirs(stitched_subdir)
3✔
840
        image_data = load_utils.load_tiled_img_data(data_dir, fovs, expected_fovs, chan,
3✔
841
                                                    single_dir=any([segmentation, clustering]),
842
                                                    file_ext=file_ext[1:],
843
                                                    img_sub_folder=img_sub_folder)
844
        stitched_data = data_utils.stitch_images(image_data, num_cols)
3✔
845
        current_img = stitched_data.loc['stitched_image', :, :, chan].values
3✔
846
        image_utils.save_image(os.path.join(stitched_subdir, chan + '_stitched' + file_ext),
3✔
847
                               current_img)
848

849

850
def _convert_ct_fov_to_adata(fov_group: DataFrameGroupBy, var_names: list[str], obs_names: list[str], save_dir: os.PathLike) -> str:
3✔
851
    """Converts the cell table for a single FOV to an `AnnData` object and saves it to disk as a
852
    `Zarr` store.
853

854
    Parameters
855
    ----------
856
    fov_group : DataFrameGroupBy
857
        The cell table subset on a single FOV.
858
    var_names: list[str]
859
        The marker names to extract from the cell table.
860
    obs_names: list[str]
861
        The cell-level measurements and properties to extract from the cell table.
862
    save_dir: os.PathLike
863
        The directory to save the `AnnData` object to.
864

865
    Returns
866
    -------
867
    str
868
        The path of the saved `AnnData` object.
869
    """
870
    
871
    fov_pd: pd.DataFrame = fov_group.sort_values(by=settings.CELL_LABEL, key=ns.natsort_key).reset_index()
3✔
872
    fov_id: str = fov_pd[settings.FOV_ID].iloc[0]
3✔
873

874
    # Set the index to be the FOV and the segmentation label to create a unique index
875
    fov_pd.index = list(map(lambda label: f"{fov_id}_{int(label)}", fov_pd[settings.CELL_LABEL]))
3✔
876

877
    # Extract the X matrix
878
    X_dd: pd.DataFrame = fov_pd[var_names]
3✔
879
    
880
    # Extract the obs dataframe and convert the cell label to integer
881
    obs_pd: pd.DataFrame = fov_pd[obs_names].astype({settings.CELL_LABEL: int, settings.FOV_ID: str})
3✔
882
    obs_pd["cell_meta_cluster"] = pd.Categorical(obs_pd["cell_meta_cluster"].astype(str))
3✔
883

884
    # Move centroids from obs to obsm["spatial"]
885
    obsm_pd = obs_pd[[settings.CENTROID_0, settings.CENTROID_1]].rename(columns={settings.CENTROID_0: "centroid_y", settings.CENTROID_1: "centroid_x"})
3✔
886
    obs_pd = obs_pd.drop(columns=[settings.CENTROID_0, settings.CENTROID_1])
3✔
887

888
    # Create the AnnData object
889
    adata: AnnData = AnnData(X=X_dd, obs=obs_pd, obsm={"spatial": obsm_pd})
3✔
890

891
    # Convert any extra string labels to categorical if it's beneficial.
892
    adata.strings_to_categoricals()
3✔
893

894
    adata.write_zarr(pathlib.Path(save_dir, f"{fov_id}.zarr"), chunks=(1000, 1000))
3✔
895
    return pathlib.Path(save_dir, f"{fov_id}.zarr").as_posix()
3✔
896

897

898
class ConvertToAnnData:
3✔
899
    """ A class which converts the Cell Table `.csv` file to a series of `AnnData` objects,
3✔
900
    one object per FOV.
901
    
902
    The default parameters stored in the `.obs` slot include:
903
    - `area`
904
    - `cell_meta_cluster`
905
    - `centroid_dif`
906
    - `convex_area`
907
    - `convex_hull_resid`
908
    - `cell_meta_cluster`
909
    - `eccentricity`
910
    - `fov`
911
    - `major_axis_equiv_diam_ratio`
912

913
    Visit the Data Types document to see the full list of parameters.
914
    The default parameters stored in the `.obs` slot include:       
915
    - `centroid_x`
916
    - `centroid_y`
917

918
    Args:
919
        cell_table_path (os.PathLike): The path to the cell table.
920
        markers (list[str], "auto"): The markers to extract and store in `.X`. Defaults to "auto",
921
        which will extract all markers.
922
        extra_obs_parameters (list[str], optional): Extra parameters to load in `.obs`. Defaults
923
        to None.
924
    """
925

926
    def __init__(self, cell_table_path: os.PathLike,
3✔
927
                 markers: Union[list[str], Literal["auto"]] = "auto",
928
                 extra_obs_parameters: list[str] = None) -> None:
929
        
930
        io_utils.validate_paths(paths=cell_table_path)
3✔
931
        
932
        # Read in the cell table
933
        cell_table: pd.DataFrame = pd.read_csv(cell_table_path)
3✔
934
        ct_columns = cell_table.columns
3✔
935

936
        # Get the marker column indices
937
        marker_index_start: int = ct_columns.get_loc(settings.PRE_CHANNEL_COL) + 1
3✔
938
        marker_index_stop: int = ct_columns.get_loc(settings.POST_CHANNEL_COL)
3✔
939
        obs_index_start: int = ct_columns.get_loc(settings.POST_CHANNEL_COL) + 1
3✔
940
        
941
        if markers == "auto":
3✔
942
            # Default to all markers based on settings Pre and Post channel column values
943
            markers: list[str] = ct_columns[marker_index_start:marker_index_stop].to_list()
3✔
944
        else:
945
            # Verify that the correct markers exist
UNCOV
946
            misc_utils.verify_in_list(requested_markers=markers, 
×
947
                                    all_markers=ct_columns[marker_index_start:marker_index_stop].to_list())
948
        self.var_names = markers
3✔
949
        
950
        # Verify extra obs parameters
951
        if extra_obs_parameters:
3✔
UNCOV
952
            misc_utils.verify_in_list(requested_parameters=extra_obs_parameters, 
×
953
                                    all_parameters=ct_columns[obs_index_start:].to_list())
954
        else:
955
            extra_obs_parameters = []
3✔
956

957
        obs_names = [
3✔
958
            settings.CELL_LABEL,
959
            settings.CELL_SIZE,
960
            *ct_columns[obs_index_start:].to_list(),
961
            *extra_obs_parameters
962
        ]
963

964
        # Use "area" as the default area id instead of settings.CELL_SIZE to account for
965
        # non-cellular observations (ez_seg, fiber, etc...)
966
        if settings.CELL_SIZE in obs_names:
3✔
967
            obs_names.remove(settings.CELL_SIZE)
3✔
968
            if "area" not in obs_names:
3✔
969
                cell_table = cell_table.rename(columns={settings.CELL_SIZE: "area"})
3✔
970
                obs_names.append("area")
3✔
971
        
972
        self.obs_names: list[str] = obs_names
3!
973
        self.cell_table = cell_table
3✔
974

975
    def convert_to_adata(
3✔
976
        self,
977
        save_dir: os.PathLike,
978
    ) -> dict[str, str]:
979
        """Converts the cell table to a FOV-level `AnnData` object, and saves the results as
980
        a `Zarr` store to disk in the `save_dir`.
981

982
        Args:
983
            save_dir (os.PathLike): The directory to save the `AnnData` objects to.
984

985
        Returns:
986
            dict[str, str]: A dictionary containing the names of the FOVs and the paths where
987
            they were saved.
988
        """
989

990
        if not isinstance(save_dir, pathlib.Path):
3✔
UNCOV
991
            save_dir = pathlib.Path(save_dir)
×
992
        if not save_dir.exists():
3✔
UNCOV
993
            save_dir.mkdir(parents=True, exist_ok=True)
×
994

995
        n_unique_fovs = self.cell_table[settings.FOV_ID].nunique()
3✔
996

997
        tqdm.pandas(desc="Converting Cell Table to AnnData Tables", total=n_unique_fovs, unit="FOVs")
3✔
998

999
        result: pd.Series = self.cell_table.groupby(by=settings.FOV_ID, sort=True).progress_apply(
3✔
1000
            lambda x: _convert_ct_fov_to_adata(
1001
                x, var_names=self.var_names, obs_names=self.obs_names, save_dir=save_dir
1002
            ),
1003
        )
1004
        return result.to_dict()
3✔
1005

1006

1007
class AnnCollectionKwargs(TypedDict):
3✔
1008
    join_obs: Optional[Literal["inner", "outer"]]
3✔
1009
    join_obsm: Optional[Literal["inner"]]
3✔
1010
    join_vars: Optional[Literal["inner"]]
3✔
1011
    label: Optional[str]
3✔
1012
    keys: Optional[Sequence[str]]
3✔
1013
    index_unique: Optional[str]
3✔
1014
    convert: Optional[ConvertType]
3✔
1015
    harmonize_dtypes: bool
3✔
1016
    indices_strict: bool
3✔
1017

1018

1019
def load_anndatas(anndata_dir: os.PathLike, **anncollection_kwargs: Unpack[AnnCollectionKwargs]) -> AnnCollection:
3✔
1020
    """Lazily loads a directory of `AnnData` objects into an `AnnCollection`. The concatination happens across the `.obs` axis.
1021
    
1022
    For `AnnCollection` kwargs, see https://anndata.readthedocs.io/en/latest/generated/anndata.experimental.AnnCollection.html
1023
        
1024
    Args:
1025
        anndata_dir (os.PathLike): The directory containing the `AnnData` objects.
1026

1027
    Returns:
1028
        AnnCollection: The `AnnCollection` containing the `AnnData` objects.
1029
    """
UNCOV
1030
    if not isinstance(anndata_dir, pathlib.Path):
×
UNCOV
1031
        anndata_dir = pathlib.Path(anndata_dir)
×
1032
    
UNCOV
1033
    adata_zarr_stores = {f.stem: read_zarr(f) for f in ns.natsorted(anndata_dir.glob("*.zarr"))}
×
UNCOV
1034
    return AnnCollection(adatas=adata_zarr_stores, **anncollection_kwargs)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc