• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

angelolab / cell_classification / 6002459992

28 Aug 2023 04:33PM UTC coverage: 80.919%. Remained the same
6002459992

push

github

web-flow
Update README.md

574 of 738 branches covered (0.0%)

Branch coverage included in aggregate %.

1398 of 1699 relevant lines covered (82.28%)

0.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.8
/src/cell_classification/metrics.py
1
import os
1✔
2
from copy import deepcopy
1✔
3

4
import h5py
1✔
5
import numpy as np
1✔
6
import pandas as pd
1✔
7
from sklearn.metrics import auc, confusion_matrix, roc_curve
1✔
8

9
from cell_classification.model_builder import ModelBuilder
1✔
10
from cell_classification.promix_naive import PromixNaive
1✔
11

12

13
def load_model(params):
1✔
14
    """Load model and validation data from params dict
15
    Args:
16
        params (dict):
17
            dictionary containing model and validation data paths
18
    Returns:
19
        model (ModelBuilder):
20
            trained model
21
        val_data (tf.data.Dataset):
22
            validation dataset
23
    """
24
    params["eval"] = True
×
25
    if params["model"] == "ModelBuilder":
×
26
        model = ModelBuilder(params)
×
27
    elif params["model"] == "PromixNaive":
×
28
        model = PromixNaive(params)
×
29
    model.prep_data()
×
30
    model.load_model(params["model_path"])
×
31
    return model
×
32

33

34
def calc_roc(pred_list, gt_key="marker_activity_mask", pred_key="prediction", cell_level=False):
1✔
35
    """Calculate ROC curve
36
    Args:
37
        pred_list (list):
38
            list of samples with predictions
39
        gt_key (str):
40
            key for ground truth labels
41
        pred_key (str):
42
            key for predictions
43
    Returns:
44
        roc (dict):
45
            dictionary containing ROC curve data
46
    """
47
    roc = {"fpr": [], "tpr": [], "thresholds": [], "auc": [], "marker": []}
1✔
48
    for sample in pred_list:
1✔
49
        if cell_level:
1✔
50
            # filter out cells with gt activity == 2
51
            df = sample["activity_df"].copy()
×
52
            df = df[df[gt_key] != 2]
×
53
            gt = df[gt_key].to_numpy()
×
54
            pred = df[pred_key].to_numpy()
×
55
        else:
56
            foreground = sample["binary_mask"] > 0
1✔
57
            gt = sample[gt_key][foreground].flatten()
1✔
58
            pred = sample[pred_key][foreground].flatten()
1✔
59
        if gt.size > 0 and gt.min() == 0 and gt.max() > 0:  # roc is only defined for this interval
1✔
60
            fpr, tpr, thresholds = roc_curve(gt, pred)
1✔
61
            roc["fpr"].append(fpr)
1✔
62
            roc["tpr"].append(tpr)
1✔
63
            roc["thresholds"].append(thresholds)
1✔
64
            roc["auc"].append(auc(fpr, tpr))
1✔
65
            roc["marker"].append(sample["marker"])
1✔
66
    return roc
1✔
67

68

69
def calc_scores(gt, pred, threshold):
1✔
70
    """Calculate scores for a given threshold
71
    Args:
72
        gt (np.array):
73
            ground truth labels
74
        pred (np.array):
75
            predictions
76
        threshold (float):
77
            threshold for predictions
78
    Returns:
79
        scores (dict):
80
            dictionary containing scores
81
    """
82
    # exclude masked out regions from metric calculation
83
    pred = pred[gt < 2]
1✔
84
    gt = gt[gt < 2]
1✔
85
    tn, fp, fn, tp = confusion_matrix(
1✔
86
        y_true=gt, y_pred=(pred >= threshold).astype(int), labels=[0, 1]
87
    ).ravel()
88
    metrics = {
1✔
89
        "tp": tp, "tn": tn, "fp": fp, "fn": fn,
90
        "accuracy": (tp + tn) / (tp + tn + fp + fn + 1e-8),
91
        "precision": tp / (tp + fp + 1e-8),
92
        "recall": tp / (tp + fn + 1e-8),
93
        "specificity": tn / (tn + fp + 1e-8),
94
        "f1_score": 2 * tp / (2 * tp + fp + fn + 1e-8),
95
    }
96
    return metrics
1✔
97

98

99
def calc_metrics(
1✔
100
    pred_list, gt_key="marker_activity_mask", pred_key="prediction", cell_level=False
101
):
102
    """Calculate metrics
103
    Args:
104
        pred_list (list):
105
            list of samples with predictions
106
        gt_key (str):
107
            key of ground truth in pred_list
108
        pred_key (str):
109
            key of prediction in pred_list
110
    Returns:
111
        avg_metrics (dict):
112
            dictionary containing metrics averaged over all samples
113
    """
114
    metrics_dict = {
1✔
115
        "accuracy": [], "precision": [], "recall": [], "specificity": [], "f1_score": [], "tp": [],
116
        "tn": [], "fp": [], "fn": [],
117
    }
118

119
    def _calc_metrics(threshold):
1✔
120
        """Helper function to calculate metrics for a given threshold in parallel"""
121
        metrics = deepcopy(metrics_dict)
1✔
122

123
        for sample in pred_list:
1✔
124
            if cell_level:
1✔
125
                df = sample["activity_df"]
×
126
                # filter out cells with gt activity == 2
127
                df = df[df[gt_key] != 2]
×
128
                gt = np.array(df[gt_key])
×
129
                pred = np.array(df[pred_key])
×
130
            else:
131
                foreground = sample["binary_mask"] > 0
1✔
132
                gt = sample[gt_key][foreground].flatten()
1✔
133
                pred = sample[pred_key][foreground].flatten()
1✔
134
            if gt.size == 0:
1✔
135
                continue
×
136
            scores = calc_scores(gt, pred, threshold)
1✔
137

138
            # only add specificity for samples that have no positives
139
            if np.sum(gt) == 0:
1✔
140
                keys = ["specificity"]
1✔
141
            else:
142
                keys = scores.keys()
1✔
143
            for key in keys:
1✔
144
                metrics[key].append(scores[key])
1✔
145
            metrics["threshold"] = threshold
1✔
146
        for key in ["dataset", "imaging_platform", "marker"]:
1✔
147
            metrics[key] = sample[key]
1✔
148
        return metrics
1✔
149

150
    # calculate metrics for all thresholds in parallel
151
    thresholds = np.linspace(0.01, 1, 50)
1✔
152
    # metric_list = Parallel(n_jobs=8)(delayed(_calc_metrics)(i) for i in thresholds)
153
    metric_list = [_calc_metrics(i) for i in thresholds]
1✔
154
    # reduce metrics over all samples for each threshold
155
    avg_metrics = deepcopy(metrics_dict)
1✔
156
    for key in ["dataset", "imaging_platform", "marker", "threshold"]:
1✔
157
        avg_metrics[key] = []
1✔
158
    for metrics in metric_list:
1✔
159
        for key in ["accuracy", "precision", "recall", "specificity", "f1_score"]:
1✔
160
            avg_metrics[key].append(np.mean(metrics[key]))
1✔
161
        for key in ["tp", "tn", "fp", "fn"]:  # sum fn, fp, tn, tp
1✔
162
            avg_metrics[key].append(np.sum(metrics[key]))
1✔
163
        for key in ["dataset", "imaging_platform", "marker", "threshold"]:  # copy strings
1✔
164
            avg_metrics[key].append(metrics[key])
1✔
165
    return avg_metrics
1✔
166

167

168
def average_roc(roc_list):
1✔
169
    """Average ROC curves
170
    Args:
171
        roc_list (list):
172
            list of ROC curves
173
    Returns:
174
        tprs (np.array):
175
            standardized true positive rates for each sample
176
        mean_tprs (np.array):
177
            mean true positive rates over all samples
178
        std np.array:
179
            standard deviation of true positive rates over all samples
180
        base (np.array):
181
            fpr values for interpolation
182
        mean_thresh (np.array):
183
            mean of the threshold values over all samples
184
    """
185
    base = np.linspace(0, 1, 101)
1✔
186
    tpr_list = []
1✔
187
    thresh_list = []
1✔
188
    for i in range(len(roc_list["tpr"])):
1✔
189
        tpr_list.append(np.interp(base, roc_list["fpr"][i], roc_list["tpr"][i]))
1✔
190
        thresh_list.append(np.interp(base, roc_list["tpr"][i], roc_list["thresholds"][i]))
1✔
191

192
    tprs = np.array(tpr_list)
1✔
193
    thresh_list = np.array(thresh_list)
1✔
194
    mean_thresh = np.mean(thresh_list, axis=0)
1✔
195
    mean_tprs = tprs.mean(axis=0)
1✔
196
    std = tprs.std(axis=0)
1✔
197
    return tprs, mean_tprs, base, std, mean_thresh
1✔
198

199

200
class HDF5Loader(object):
1✔
201
    """HDF5 iterator for loading data from HDF5 files"""
1✔
202

203
    def __init__(self, folder):
1✔
204
        """Initialize HDF5 generator
205
        Args:
206
            folder (str):
207
                path to folder containing HDF5 files
208
        """
209
        self.folder = folder
1✔
210
        self.files = os.listdir(folder)
1✔
211
        # filter out hdf files
212
        self.files = [os.path.join(folder, f) for f in self.files if f.endswith(".hdf")]
1✔
213
        self.file_idx = 0
1✔
214

215
    def __len__(self):
1✔
216
        return len(self.files)
1✔
217

218
    def load_hdf(self, file):
1✔
219
        """Load HDF5 file
220
        Args:
221
            file (str):
222
                path to HDF5 file
223
        Returns:
224
            data (dict):
225
                dictionary containing data from HDF5 file
226
        """
227
        out_dict = {}
1✔
228
        with h5py.File(file, "r") as f:
1✔
229
            keys = [key for key in f.keys() if key != "activity_df"]
1✔
230
            for key in keys:
1✔
231
                if isinstance(f[key][()], bytes):
1✔
232
                    out_dict[key] = f[key][()].decode("utf-8")
1✔
233
                else:
234
                    out_dict[key] = f[key][()]
1✔
235
            out_dict["activity_df"] = pd.read_json(f["activity_df"][()].decode())
1✔
236
        return out_dict
1✔
237

238
    def __iter__(self):
1✔
239
        self.file_idx = 0
1✔
240
        return self
1✔
241

242
    def __next__(self):
1✔
243
        if self.file_idx >= len(self.files):
1✔
244
            raise StopIteration
1✔
245
        else:
246
            self.file_idx += 1
1✔
247
            return self.load_hdf(self.files[self.file_idx - 1])
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc