• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

angelolab / cell_classification / 6002459992

28 Aug 2023 04:33PM UTC coverage: 80.919%. Remained the same
6002459992

push

github

web-flow
Update README.md

574 of 738 branches covered (0.0%)

Branch coverage included in aggregate %.

1398 of 1699 relevant lines covered (82.28%)

0.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/cell_classification/evaluation_script.py
1
import argparse
×
2
import os
×
3
import pickle
×
4

5
import numpy as np
×
6
import pandas as pd
×
7
import tensorflow as tf
×
8
import toml
×
9

10
from cell_classification.metrics import (average_roc, calc_metrics, calc_roc,
×
11
                                         load_model)
12
from cell_classification.plot_utils import (heatmap_plot, plot_average_roc,
×
13
                                            plot_metrics_against_threshold,
14
                                            plot_together, subset_plots)
15
from cell_classification.segmentation_data_prep import (feature_description,
×
16
                                                        parse_dict)
17

18
if __name__ == "__main__":
×
19
    parser = argparse.ArgumentParser()
×
20
    parser.add_argument(
×
21
        "--model_path",
22
        type=str,
23
        help="Path to model weights",
24
        default=None,
25
    )
26
    parser.add_argument(
×
27
        "--params_path",
28
        type=str,
29
        help="Path to model params",
30
        default="E:\\angelo_lab\\test\\params.toml",
31
    )
32
    parser.add_argument(
×
33
        "--worst_n",
34
        type=int,
35
        help="Number of worst predictions to plot",
36
        default=20,
37
    )
38
    parser.add_argument(
×
39
        "--best_n",
40
        type=int,
41
        help="Number of best predictions to plot",
42
        default=20,
43
    )
44
    parser.add_argument(
×
45
        "--split_by_marker",
46
        type=bool,
47
        help="Split best/worst predictions by marker",
48
        default=True,
49
    )
50
    parser.add_argument(
×
51
        "--external_datasets",
52
        type=str,
53
        help="List of paths to tfrecord datasets",
54
        nargs='+',
55
        default=[],
56
    )
57
    args = parser.parse_args()
×
58
    with open(args.params_path, "r") as f:
×
59
        params = toml.load(f)
×
60
    if args.model_path is not None:
×
61
        params["model_path"] = args.model_path
×
62

63
    model = load_model(params)
×
64
    datasets = {name: dataset for name, dataset in zip(model.dataset_names, model.test_datasets)}
×
65
    if hasattr(args, "external_datasets"):
×
66
        external_datasets = {
×
67
            os.path.split(external_dataset)[-1].split(".")[0]: external_dataset.replace(",", "")
68
            for external_dataset in args.external_datasets
69
        }
70
        external_datasets = {
×
71
            key: tf.data.TFRecordDataset(external_datasets[key])
72
            for key in external_datasets.keys()
73
        }
74
        external_datasets = {
×
75
            name: dataset.map(
76
                lambda x: tf.io.parse_single_example(x, feature_description),
77
                num_parallel_calls=tf.data.AUTOTUNE,
78
            ) for name, dataset in external_datasets.items()
79
        }
80
        external_datasets = {
×
81
            name: dataset.map(
82
                parse_dict, num_parallel_calls=tf.data.AUTOTUNE
83
            ) for name, dataset in external_datasets.items()
84
        }
85
        external_datasets = {
×
86
            name: dataset.batch(
87
                params["batch_size"], drop_remainder=False
88
            ) for name, dataset in external_datasets.items()
89
        }
90

91
        datasets.update(external_datasets)
×
92

93
    for name, val_dset in datasets.items():
×
94
        params["eval_dir"] = os.path.join(*os.path.split(params["model_path"])[:-1], "eval", name)
×
95
        os.makedirs(params["eval_dir"], exist_ok=True)
×
96
        # iterate over datasets
97
        pred_list = model.predict_dataset(val_dset, False)
×
98

99
        # prepare cell_table
100
        activity_list = []
×
101
        for pred in pred_list:
×
102
            activity_df = pred["activity_df"].copy()
×
103
            for key in ["dataset", "marker", "folder_name"]:
×
104
                activity_df[key] = [pred[key]]*len(activity_df)
×
105
            activity_list.append(activity_df)
×
106
        activity_df = pd.concat(activity_list)
×
107
        activity_df.to_csv(os.path.join(params["eval_dir"], "pred_cell_table.csv"), index=False)
×
108

109
        # cell level evaluation
110
        roc = calc_roc(pred_list, gt_key="activity", pred_key="pred_activity", cell_level=True)
×
111
        with open(os.path.join(params["eval_dir"], "roc_cell_lvl.pkl"), "wb") as f:
×
112
            pickle.dump(roc, f)
×
113

114
        # find index of n worst predictions and save plots of them
115
        roc_df = pd.DataFrame(roc)
×
116
        if args.split_by_marker:
×
117
            worst_idx = []
×
118
            best_idx = []
×
119
            markers = np.unique(roc_df.marker)
×
120
            for marker in markers:
×
121
                marker_df = roc_df[roc_df.marker == marker]
×
122
                sort_idx = np.argsort(marker_df.auc).index
×
123
                worst_idx.extend(sort_idx[-args.worst_n:])
×
124
                best_idx.extend(sort_idx[:args.best_n])
×
125
        else:
126
            sort_idx = np.argsort(roc["auc"])
×
127
            worst_idx = sort_idx[-args.worst_n:]
×
128
            best_idx = sort_idx[:args.best_n]
×
129
        for idx_list, best_worst in [(best_idx, "best"), (worst_idx, "worst")]:
×
130
            for i, idx in enumerate(idx_list):
×
131
                pred = pred_list[idx]
×
132
                plot_together(
×
133
                    pred, keys=["mplex_img", "marker_activity_mask", "prediction"],
134
                    save_dir=os.path.join(params["eval_dir"], best_worst + "_predictions"),
135
                    save_file="worst_{}_{}_{}.png".format(
136
                        i, pred["marker"], pred["dataset"], pred["folder_name"]
137
                    )
138
                )
139

140
        pd.DataFrame(roc).auc
×
141
        tprs, mean_tprs, fpr, std, mean_thresh = average_roc(roc)
×
142
        plot_average_roc(
×
143
            mean_tprs, std, save_dir=params["eval_dir"], save_file="avg_roc_cell_lvl.png"
144
        )
145
        print("AUC: {}".format(np.mean(roc["auc"])))
×
146

147
        print("Calculate precision, recall, f1_score and accuracy on the cell level")
×
148
        avg_metrics = calc_metrics(
×
149
            pred_list, gt_key="activity", pred_key="pred_activity", cell_level=True
150
        )
151
        pd.DataFrame(avg_metrics).to_csv(
×
152
            os.path.join(params["eval_dir"], "cell_metrics.csv"), index=False
153
        )
154

155
        plot_metrics_against_threshold(
×
156
            avg_metrics,
157
            metric_keys=["precision", "recall", "f1_score", "specificity"],
158
            threshold_key="threshold",
159
            save_dir=params["eval_dir"],
160
            save_file="precision_recall_f1_cell_lvl.png",
161
        )
162

163
        print("Plot activity predictions split by markers and cell types")
×
164
        subset_plots(
×
165
            activity_df, subset_list=["marker"],
166
            save_dir=params["eval_dir"],
167
            save_file="split_by_marker.png",
168
            gt_key="activity",
169
            pred_key="pred_activity",
170
        )
171
        if "cell_type" in activity_df.columns:
×
172
            subset_plots(
×
173
                activity_df, subset_list=["cell_type"],
174
                save_dir=params["eval_dir"],
175
                save_file="split_by_cell_type.png",
176
                gt_key="activity",
177
                pred_key="pred_activity",
178
            )
179
        heatmap_plot(
×
180
            activity_df, subset_list=["marker"],
181
            save_dir=params["eval_dir"],
182
            save_file="heatmap_split_by_marker.png",
183
            gt_key="activity",
184
            pred_key="pred_activity",
185
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc