• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

angelolab / cell_classification / 5943061665

22 Aug 2023 07:11PM UTC coverage: 79.006% (-2.4%) from 81.434%
5943061665

Pull #65

github

web-flow
Merge 950cf388d into cbd1d19a3
Pull Request #65: Prediction notebook

566 of 735 branches covered (77.01%)

Branch coverage included in aggregate %.

287 of 287 new or added lines in 3 files covered. (100.0%)

1357 of 1699 relevant lines covered (79.87%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.89
/src/cell_classification/application.py
1
from deepcell.model_zoo.panopticnet import PanopticNet
1✔
2
from cell_classification.semantic_head import create_semantic_head
1✔
3
from deepcell.applications import Application
1✔
4
from alpineer import io_utils
1✔
5
from cell_classification.inference import prepare_normalization_dict, predict_fovs
1✔
6
import tensorflow as tf
1✔
7
import numpy as np
1✔
8
import json
1✔
9
import os
1✔
10

11

12
def nimbus_preprocess(image, **kwargs):
1✔
13
    """Preprocess input data for Nimbus model.
14
    Args:
15
        image: array to be processed
16
    Returns:
17
        np.array: processed image array
18
    """
19
    output = np.copy(image)
1✔
20
    if len(image.shape) != 4:
1✔
21
        raise ValueError("Image data must be 4D, got image of shape {}".format(image.shape))
×
22

23
    normalize = kwargs.get('normalize', True)
1✔
24
    if normalize:
1✔
25
        marker = kwargs.get('marker', None)
1✔
26
        normalization_dict = kwargs.get('normalization_dict', {})
1✔
27
        if marker in normalization_dict.keys():
1✔
28
            norm_factor = normalization_dict[marker]
1✔
29
        else:
30
            print("Norm_factor not found for marker {}, calculating directly from the image. \
1✔
31
            ".format(marker))
32
            norm_factor = np.quantile(output[..., 0], 0.999)
1✔
33
        # normalize only marker channel in chan 0 not binary mask in chan 1
34
        output[..., 0] /= norm_factor
1✔
35
        output = output.clip(0, 1)
1✔
36
    return output
1✔
37

38

39
def nimbus_postprocess(model_output):
1✔
40
    return model_output
1✔
41

42

43
def format_output(model_output):
1✔
44
    return model_output[0]
1✔
45

46

47
def prep_deepcell_naming_convention(deepcell_output_dir):
1✔
48
    """Prepares the naming convention for the segmentation data
49
    Args:
50
        deepcell_output_dir (str): path to directory where segmentation data is saved
51
    Returns:
52
        segmentation_naming_convention (function): function that returns the path to the
53
            segmentation data for a given fov
54
    """
55
    def segmentation_naming_convention(fov_path):
1✔
56
        """Prepares the path to the segmentation data for a given fov
57
        Args:
58
            fov_path (str): path to fov
59
        Returns:
60
            seg_path (str): paths to segmentation fovs
61
        """
62
        fov_name = os.path.basename(fov_path)
1✔
63
        return os.path.join(
1✔
64
            deepcell_output_dir, fov_name + "_whole_cell.tiff"
65
        )
66
    return segmentation_naming_convention
1✔
67

68

69
class Nimbus(Application):
1✔
70
    """Nimbus application class for predicting marker activity for cells in multiplexed images.
1✔
71
    """
72
    def __init__(
1✔
73
              self, fov_paths, segmentation_naming_convention, output_dir,
74
                save_predictions=True, exclude_channels=[], half_resolution=True,
75
                batch_size=4, test_time_aug=True, input_shape=[1024,1024]
76
        ):
77
        """Initializes a Nimbus Application.
78
        Args:
79
            fov_paths (list): List of paths to fovs to be analyzed.
80
            exclude_channels (list): List of channels to exclude from analysis.
81
            segmentation_naming_convention (function): Function that returns the path to the
82
                segmentation mask for a given fov path.
83
            output_dir (str): Path to directory to save output.
84
            save_predictions (bool): Whether to save predictions.
85
            half_resolution (bool): Whether to run model on half resolution images.
86
            batch_size (int): Batch size for model inference.
87
            test_time_aug (bool): Whether to use test time augmentation.
88
            input_shape (list): Shape of input images.
89
        """
90
        self.fov_paths = fov_paths
1✔
91
        self.exclude_channels = exclude_channels
1✔
92
        self.segmentation_naming_convention = segmentation_naming_convention
1✔
93
        self.output_dir = output_dir
1✔
94
        self.half_resolution = half_resolution
1✔
95
        self.save_predictions = save_predictions
1✔
96
        self._batch_size = batch_size
1✔
97
        self.checked_inputs = False
1✔
98
        self.test_time_aug = test_time_aug
1✔
99
        self.input_shape = input_shape
1✔
100
        # exclude segmentation channel from analysis
101
        seg_name = os.path.basename(self.segmentation_naming_convention(self.fov_paths[0]))
1✔
102
        self.exclude_channels.append(seg_name.split(".")[0])
1✔
103
        os.makedirs(self.output_dir, exist_ok=True)
1✔
104
        
105
        # initialize model and parent class
106
        self.initialize_model()
1✔
107
        
108
        super(Nimbus, self).__init__(
1✔
109
            model=self.model, 
110
            model_image_shape=self.model.input_shape[1:],
111
            preprocessing_fn=nimbus_preprocess,
112
            postprocessing_fn=nimbus_postprocess,
113
            format_model_output_fn=format_output,
114
        )
115

116
    def check_inputs(self):
1✔
117
        """ check inputs for Nimbus model
118
        """
119
        # check if all paths in fov_paths exists
120
        io_utils.validate_paths(self.fov_paths)
1✔
121

122
        # check if segmentation_naming_convention returns valid paths
123
        path_to_segmentation = self.segmentation_naming_convention(self.fov_paths[0])
1✔
124
        if not os.path.exists(path_to_segmentation):
1✔
125
            raise FileNotFoundError("Function segmentation_naming_convention does not return valid\
×
126
                                    path. Segmentation path {} does not exist."\
127
                                    .format(path_to_segmentation))
128
        # check if output_dir exists
129
        io_utils.validate_paths([self.output_dir])
1✔
130

131
        if isinstance(self.exclude_channels, str):
1✔
132
            self.exclude_channels = [self.exclude_channels]
×
133
        self.checked_inputs = True
1✔
134
        print("All inputs are valid.")
1✔
135

136
    def initialize_model(self):
1✔
137
        """Initializes the model and load weights.
138
        """
139
        backbone = "efficientnetv2bs"
1✔
140
        input_shape = self.input_shape + [2]
1✔
141
        model = PanopticNet(
1✔
142
            backbone=backbone, input_shape=input_shape,
143
            norm_method="std", num_semantic_classes=[1],
144
            create_semantic_head=create_semantic_head, location=False,
145
        )
146
        # make sure path can be resolved on any OS and when importing  from anywhere
147
        self.checkpoint_path = os.path.normpath(
1✔
148
            "../cell_classification/checkpoints/halfres_512_checkpoint_160000.h5"
149
        )
150
        if not os.path.exists(self.checkpoint_path):
1✔
151
            path = os.path.abspath(__file__)
×
152
            drive, path = os.path.splitdrive(path)
×
153
            self.checkpoint_path = os.path.join(
×
154
                drive, os.sep, *path.split(os.sep)[1:-3], 'checkpoints',
155
                'halfres_512_checkpoint_160000.h5'
156
            )
157
        model.load_weights(self.checkpoint_path)
1✔
158
        print("Loaded weights from {}".format(self.checkpoint_path))
1✔
159
        self.model = model
1✔
160

161
    def prepare_normalization_dict(
1✔
162
            self, quantile=0.999, n_subset=10, multiprocessing=False, overwrite=False,
163
        ):
164
        """Load or prepare and save normalization dictionary for Nimbus model.
165
        Args:
166
            quantile (float): Quantile to use for normalization.
167
            n_subset (int): Number of fovs to use for normalization.
168
            multiprocessing (bool): Whether to use multiprocessing.
169
            overwrite (bool): Whether to overwrite existing normalization dict.
170
        Returns:
171
            dict: Dictionary of normalization factors.
172
        """
173
        self.normalization_dict_path = os.path.join(self.output_dir, "normalization_dict.json")
1✔
174
        if os.path.exists(self.normalization_dict_path) and not overwrite:
1✔
175
            self.normalization_dict = json.load(open(self.normalization_dict_path))
1✔
176
        else:
177

178
            n_jobs = os.cpu_count() if multiprocessing else 1
1✔
179
            self.normalization_dict = prepare_normalization_dict(
1✔
180
                self.fov_paths, self.output_dir, quantile, self.exclude_channels, n_subset, n_jobs
181
            )
182

183
    def predict_fovs(self):
1✔
184
        """Predicts cell classification for input data.
185
        Returns:
186
            np.array: Predicted cell classification.
187
        """
188
        if self.checked_inputs == False:
1✔
189
            self.check_inputs()
1✔
190
        if not hasattr(self, "normalization_dict"):
1!
191
            self.prepare_normalization_dict()
1✔
192
        # check if GPU is available
193
        print("Available GPUs: ", tf.config.list_physical_devices('GPU'))
1!
194
        print("Predictions will be saved in {}".format(self.output_dir))
1✔
195
        print("Iterating through fovs will take a while...")
1✔
196
        self.cell_table = predict_fovs(
1✔
197
            self.fov_paths, self.output_dir, self, self.normalization_dict,
198
            self.segmentation_naming_convention, self.exclude_channels, self.save_predictions,
199
            self.half_resolution, batch_size=self._batch_size,
200
            test_time_augmentation=self.test_time_aug,
201
        )
202
        self.cell_table.to_csv(
1✔
203
            os.path.join(self.output_dir,"nimbus_cell_table.csv"), index=False
204
        )
205
        return self.cell_table
1✔
206
        
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc