• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

angelolab / cell_classification / 6002459992

28 Aug 2023 04:33PM UTC coverage: 80.919%. Remained the same
6002459992

push

github

web-flow
Update README.md

574 of 738 branches covered (0.0%)

Branch coverage included in aggregate %.

1398 of 1699 relevant lines covered (82.28%)

0.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.51
/src/cell_classification/model_builder.py
1
import argparse
1✔
2
import json
1✔
3
import os
1✔
4
from time import time
1✔
5

6
import h5py
1✔
7
import numpy as np
1✔
8
import pandas as pd
1✔
9
import tensorflow as tf
1✔
10
import toml
1✔
11
from deepcell.model_zoo.panopticnet import PanopticNet
1✔
12
from deepcell.utils.train_utils import count_gpus
1✔
13
from tensorflow.keras.optimizers import Adam
1✔
14
from tensorflow.keras.optimizers.schedules import CosineDecay
1✔
15
from tqdm import tqdm
1✔
16

17
from cell_classification.augmentation_pipeline import (
1✔
18
    get_augmentation_pipeline, prepare_tf_aug, py_aug)
19
from cell_classification.loss import Loss
1✔
20
from cell_classification.post_processing import (merge_activity_df,
1✔
21
                                                 process_to_cells)
22
from cell_classification.segmentation_data_prep import (feature_description,
1✔
23
                                                        parse_dict)
24
from cell_classification.semantic_head import create_semantic_head
1✔
25

26

27
class ModelBuilder:
1✔
28
    """Builds, trains and writes validation metrics for models"""
1✔
29

30
    def __init__(self, params):
1✔
31
        """Initialize the trainer with the parameters from the config file
32
        Args:
33
            params (dict): Dictionary of parameters from the config file
34
        """
35
        self.params = params
1✔
36
        self.params["model"] = "ModelBuilder"
1✔
37
        self.num_gpus = count_gpus()
1✔
38
        if "batch_constituents" in list(self.params.keys()):
1✔
39
            self.prep_batches = self.gen_prep_batches_fn(self.params["batch_constituents"])
1✔
40
        else:
41
            self.prep_batches = self.gen_prep_batches_fn()
×
42
        # make prep_batches a callable static method
43
        self.prep_batches = staticmethod(self.prep_batches).__func__
1✔
44

45
    def prep_data(self):
1✔
46
        """Prepares training and validation data"""
47
        # make datasets and splits
48
        datasets = [
1✔
49
            tf.data.TFRecordDataset(record_path) for record_path in self.params["record_path"]
50
        ]
51
        datasets = [
1✔
52
            dataset.map(
53
                lambda x: tf.io.parse_single_example(x, feature_description),
54
                num_parallel_calls=tf.data.AUTOTUNE,
55
            ) for dataset in datasets
56
        ]
57
        datasets = [
1✔
58
            dataset.map(parse_dict, num_parallel_calls=tf.data.AUTOTUNE) for dataset in datasets
59
        ]
60

61
        # filter out sparse samples
62
        if "filter_quantile" in self.params.keys():
1✔
63
            datasets = [
1✔
64
                self.quantile_filter(dataset, record_path) for dataset, record_path in
65
                zip(datasets, self.params["record_path"])
66
            ]
67

68
        # split into train, validation and test
69
        if "data_splits" in self.params.keys():
1✔
70
            data_splits = []
1✔
71
            for fpath in self.params["data_splits"]:
1✔
72
                with open(fpath, "r") as f:
1✔
73
                    data_splits.append(json.load(f))
1✔
74
            self.validation_datasets = [
1✔
75
                self.fov_filter(dataset, data_split["validation"]) for dataset, data_split in zip(
76
                    datasets, data_splits
77
                )
78
            ]
79
            self.test_datasets = [
1✔
80
                self.fov_filter(dataset, data_split["test"]) for dataset, data_split in zip(
81
                    datasets, data_splits
82
                )
83
            ]
84
            self.train_datasets = [
1✔
85
                self.fov_filter(dataset, data_split["train"]) for dataset, data_split in zip(
86
                    datasets, data_splits
87
                )
88
            ]
89
        else:
90
            self.validation_datasets = [
1✔
91
                dataset.take(num_validation) for dataset, num_validation in zip(
92
                    datasets, self.params["num_validation"])
93
                ]
94
            datasets = [dataset.skip(num_validation) for dataset, num_validation in zip(
1✔
95
                datasets, self.params["num_validation"])
96
            ]
97
            self.test_datasets = [
1✔
98
                dataset.take(num_test) for dataset, num_test in zip(
99
                    datasets, self.params["num_test"])
100
                ]
101
            self.train_datasets = [
1✔
102
                dataset.skip(num_test) for dataset, num_test in zip(
103
                    datasets, self.params["num_test"])
104
            ]
105
        # add external validation datasets
106
        if "external_validation_path" in self.params.keys():
1✔
107
            external_validation_datasets = [
×
108
                tf.data.TFRecordDataset(record_path) for record_path in
109
                self.params["external_validation_path"]
110
            ]
111
            external_validation_datasets = [
×
112
                dataset.map(
113
                    lambda x: tf.io.parse_single_example(x, feature_description),
114
                    num_parallel_calls=tf.data.AUTOTUNE,
115
                ) for dataset in external_validation_datasets
116
            ]
117
            external_validation_datasets = [
×
118
                dataset.map(parse_dict, num_parallel_calls=tf.data.AUTOTUNE) for dataset in
119
                external_validation_datasets
120
            ]
121
            self.external_validation_datasets = external_validation_datasets
×
122
            self.external_validation_names = self.params["external_validation_names"]
×
123

124
        if "num_training" in self.params.keys() and self.params["num_training"] is not None:
1✔
125
            self.train_datasets = [
×
126
                train_dataset.take(num_training) for train_dataset, num_training
127
                in zip(self.train_datasets, self.params["num_training"])
128
            ]
129

130
        # merge datasets with tf.data.Dataset.sample_from_datasets
131
        self.train_dataset = tf.data.Dataset.sample_from_datasets(
1!
132
            datasets=self.train_datasets, weights=self.params["dataset_sample_probs"],
133
            stop_on_empty_dataset=True
134
        )
135

136
        # shuffle, batch and augment the datasets
137
        self.train_dataset = self.train_dataset.shuffle(self.params["shuffle_buffer_size"]).batch(
1✔
138
            self.params["batch_size"] * np.max([self.num_gpus, 1])
139
        )
140
        self.validation_datasets = [validation_dataset.batch(
1✔
141
            self.params["batch_size"] * np.max([self.num_gpus, 1])
142
        ) for validation_dataset in self.validation_datasets]
143
        self.test_datasets = [test_dataset.batch(
1✔
144
            self.params["batch_size"] * np.max([self.num_gpus, 1])
145
        ) for test_dataset in self.test_datasets]
146

147
        self.dataset_names = self.params["dataset_names"]
1✔
148

149
    def prep_model(self):
1✔
150
        """Prepares the model for training"""
151
        # prepare folders
152
        self.params["model_dir"] = os.path.join(
1✔
153
            os.path.normpath(self.params["path"]), self.params["experiment"]
154
        )
155
        self.params["log_dir"] = os.path.join(self.params["model_dir"], "logs", str(int(time())))
1✔
156
        os.makedirs(self.params["model_dir"], exist_ok=True)
1✔
157
        os.makedirs(self.params["log_dir"], exist_ok=True)
1✔
158
        if "model_path" not in self.params.keys() or self.params["model_path"] is None:
1✔
159
            self.params["model_path"] = os.path.join(
1✔
160
                self.params["model_dir"], "{}.h5".format(self.params["experiment"])
161
            )
162
        self.params["loss_path"] = os.path.join(
1✔
163
            self.params["model_dir"], "{}.npz".format(self.params["experiment"])
164
        )
165

166
        # initialize optimizer and lr scheduler
167
        # replace with AdamW when available
168
        self.lr_sched = CosineDecay(
1✔
169
            initial_learning_rate=self.params["lr"],
170
            decay_steps=self.params["num_steps"],
171
            alpha=1e-6,
172
        )
173
        self.optimizer = Adam(learning_rate=self.lr_sched, clipnorm=0.001)
1✔
174

175
        # initialize model
176
        if "test" in self.params.keys() and self.params["test"]:
1✔
177
            self.model = tf.keras.Sequential(
1✔
178
                [tf.keras.layers.Conv2D(
179
                        1, (3, 3), input_shape=self.params["input_shape"], padding="same",
180
                        name="semantic_head", activation="sigmoid", data_format="channels_last",
181
                    )]
182
            )
183
        else:
184
            self.model = PanopticNet(
1✔
185
                backbone=self.params["backbone"], input_shape=self.params["input_shape"],
186
                norm_method="std", num_semantic_classes=self.params["classes"],
187
                create_semantic_head=create_semantic_head, location=self.params["location"],
188
            )
189

190
        loss = {}
1✔
191
        # Give losses for all of the semantic heads
192
        for layer in self.model.layers:
1✔
193
            if layer.name.startswith("semantic_"):
1✔
194
                loss[layer.name] = self.prep_loss()
1✔
195

196
        if "weight_decay" in self.params.keys():
1✔
197
            self.add_weight_decay()
1✔
198
        self.model.compile(loss=loss, optimizer=self.optimizer)
1✔
199

200
    @staticmethod
1✔
201
    @tf.function
1✔
202
    def train_step(model, x, y):
1✔
203
        """Trains the model for one step"""
204
        with tf.GradientTape() as tape:
1✔
205
            y_pred = model(x, training=True)
1✔
206
            loss = model.compute_loss(x, y, y_pred)
1✔
207
        gradients = tape.gradient(loss, model.trainable_variables)
1✔
208
        model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
1✔
209
        return loss
1✔
210

211
    def distributed_train_step(self, model, x, y):
1✔
212
        """Trains the model for one step on multiple GPUs"""
213
        loss = self.strategy.run(self.train_step, args=(model, x, y))
×
214
        return self.strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None)
×
215

216
    def train(self):
1✔
217
        """Calls prep functions and starts training loops"""
218
        print("Training on", self.num_gpus, "GPUs.")
1✔
219
        # initialize data and model
220
        self.prep_data()
1✔
221

222
        # make transformations on the training dataset
223
        augmentation_pipeline = get_augmentation_pipeline(self.params)
1✔
224
        tf_aug = prepare_tf_aug(augmentation_pipeline)
1✔
225
        self.train_dataset = self.train_dataset.map(
1✔
226
            lambda x: py_aug(x, tf_aug), num_parallel_calls=tf.data.AUTOTUNE
227
        )
228
        self.train_dataset = self.train_dataset.map(
1✔
229
            self.prep_batches, num_parallel_calls=tf.data.AUTOTUNE
230
        )
231
        self.train_dataset = self.train_dataset.prefetch(tf.data.AUTOTUNE)
1✔
232

233
        if self.num_gpus > 1:
1✔
234
            # set up distributed training
235
            self.strategy = tf.distribute.MirroredStrategy()
×
236
            self.train_dataset = self.strategy.experimental_distribute_dataset(self.train_dataset)
×
237
            with self.strategy.scope():
×
238
                self.prep_model()
×
239
                checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
×
240
            print("Distributed training on {} devices".format(self.strategy.num_replicas_in_sync))
×
241
            train_step = self.distributed_train_step
×
242
        else:
243
            self.prep_model()
1✔
244
            train_step = self.train_step
1✔
245
        #
246
        with open(os.path.join(self.params["model_dir"], "params.toml"), "w") as f:
1✔
247
            toml.dump(self.params, f)
1✔
248

249
        self.summary_writer = tf.summary.create_file_writer(self.params["log_dir"])
1✔
250
        self.step = 0
1✔
251
        self.global_val_loss = []
1✔
252
        self.val_loss_history = {}
1✔
253
        self.train_loss_tmp = []
1✔
254
        while self.step < self.params["num_steps"]:
1✔
255
            for x, y in tqdm(self.train_dataset):
1✔
256
                train_loss = train_step(self.model, x, y)
1✔
257
                self.train_loss_tmp.append(train_loss)
1✔
258
                self.step += 1
1✔
259
                self.tensorboard_callbacks(x, y)
1✔
260
                if self.step > self.params["num_steps"]:
1✔
261
                    break
×
262

263
    def tensorboard_callbacks(self, x, y):
1✔
264
        """Logs training metrics to Tensorboard
265
        Args:
266
            x (tf.Tensor): input image
267
            y (tf.Tensor): ground truth labels
268
        """
269
        if self.step % 10 == 0:
1✔
270
            with self.summary_writer.as_default():
1✔
271
                tf.summary.scalar(
1✔
272
                    "train_loss", tf.reduce_mean(self.train_loss_tmp), step=self.step
273
                )
274
                tf.summary.scalar(
1✔
275
                    "lr", self.model.optimizer._decayed_lr(tf.float32), step=self.step
276
                )
277
            print(
1✔
278
                "Step: {step}, loss {loss}".format(
279
                    step=self.step, loss=tf.reduce_mean(self.train_loss_tmp))
280
            )
281
            self.train_loss_tmp = []
1✔
282
        if self.step % self.params["snap_steps"] == 0:
1✔
283
            print("Saving training snapshots")
1✔
284
            if self.num_gpus > 1:
1✔
285
                x = self.strategy.experimental_local_results(x)[0]
×
286
                y_pred = self.model(x, training=False)
×
287
                y_pred = self.strategy.experimental_local_results(y_pred)[0]
×
288
                y = self.strategy.experimental_local_results(y)[0]
×
289
            else:
290
                y_pred = self.model(x, training=False)
1✔
291
            with self.summary_writer.as_default():
1✔
292
                tf.summary.image(
1✔
293
                    "x_0 | y | y_pred",
294
                    tf.concat([
295
                        x[:1, ..., :1],
296
                        x[:1, ..., 1:2] * 0.25 + tf.cast(y[:1, ..., :1], tf.float32),
297
                        y_pred[:1, ..., :1]],  axis=0,
298
                    ),
299
                    step=self.step,
300
                )
301
        # run validation and write to tensorboard
302
        if self.step % self.params["val_steps"] == 0:
1✔
303
            print("Running validation...")
1✔
304
            for validation_dataset, dataset_name in zip(
1✔
305
                self.validation_datasets, self.dataset_names
306
            ):
307
                validation_dataset = validation_dataset.map(
1✔
308
                    self.prep_batches, num_parallel_calls=tf.data.AUTOTUNE
309
                )
310
                val_loss = self.model.evaluate(validation_dataset, verbose=1)
1✔
311
                print("Validation loss:", val_loss)
1✔
312
                if dataset_name not in self.val_loss_history.keys():
1✔
313
                    self.val_loss_history[dataset_name] = []
1✔
314
                self.val_loss_history[dataset_name].append(val_loss)
1✔
315
                with self.summary_writer.as_default():
1✔
316
                    tf.summary.scalar(dataset_name + "_val", val_loss, step=self.step)
1✔
317
            val_loss = np.mean([val_loss[-1] for val_loss in self.val_loss_history.values()])
1✔
318
            self.global_val_loss.append(val_loss)
1✔
319
            with self.summary_writer.as_default():
1✔
320
                tf.summary.scalar("global_val", val_loss, step=self.step)
1✔
321
            if val_loss <= tf.reduce_min(self.global_val_loss):
1✔
322
                print("Saving model to", self.params["model_path"])
1✔
323
                self.model.save_weights(self.params["model_path"])
1✔
324
            # run external validation
325
            if hasattr(self, "external_validation_datasets"):
1!
326
                for validation_dataset, dataset_name in zip(
×
327
                    self.external_validation_datasets, self.external_dataset_names
328
                ):
329
                    validation_dataset = validation_dataset.map(
×
330
                        self.prep_batches, num_parallel_calls=tf.data.AUTOTUNE
331
                    )
332
                    val_loss = self.model.evaluate(validation_dataset, verbose=1)
×
333
                    print("Validation loss:", val_loss)
×
334
                    if dataset_name not in self.val_loss_history.keys():
×
335
                        self.val_loss_history[dataset_name] = []
×
336
                    self.val_loss_history[dataset_name].append(val_loss)
×
337
                    with self.summary_writer.as_default():
×
338
                        tf.summary.scalar(dataset_name + "_val", val_loss, step=self.step)
×
339
            if "save_model_on_dataset_name" in self.params.keys():
1!
340
                current = self.val_loss_history[self.params["save_model_on_dataset_name"]][-1]
×
341
                if current <= self.best_val_loss[self.params["save_model_on_dataset_name"]]:
×
342
                    print("Saving model to", self.params["model_path"])
×
343
                    self.model.save_weights(self.params["model_path"]+"_best.pkl")
×
344

345
    def prep_loss(self):
1✔
346
        """Prepares the loss function for the model
347
        Args:
348
            n_classes (int): Number of semantic classes in the dataset
349
        Returns:
350
            loss_fn (function): Loss function for the model
351
        """
352
        loss_fn = Loss(
1✔
353
            self.params["loss_fn"],
354
            self.params["loss_selective_masking"],
355
            **self.params["loss_kwargs"]
356
        )
357
        return loss_fn
1✔
358

359
    def gen_prep_batches_fn(self, keys=["mplex_img", "binary_mask"]):
1✔
360
        """Generates a function that preprocesses batches for training
361
        Args:
362
            keys (list): List of keys to concatenate into a single batch
363
        Returns:
364
            prep_batches (function): Function that preprocesses batches for training
365
        """
366

367
        def prep_batches(batch):
1✔
368
            """Preprocess batches for training
369
            Args:
370
                batch (dict):
371
                    Dictionary of tensors and strings containing data from a single batch
372
            Returns:
373
                inputs (tf.Tensor):
374
                    Batch of images
375
                targets (tf.Tensor):
376
                    Batch of labels
377
            """
378
            inputs = tf.concat(
1✔
379
                [tf.cast(batch[key], tf.float32) for key in keys], axis=-1
380
            )
381
            targets = batch["marker_activity_mask"]
1✔
382
            return inputs, targets
1✔
383

384
        return prep_batches
1✔
385

386
    def predict(self, image):
1✔
387
        """Runs inference on a single image or a batch of images
388
        Args:
389
            image np.ndarray or tf.Tensor:
390
                Image to run inference on shape (H, W, C) or (N, H, W, C)
391
        Returns:
392
            prediction (np.ndarray):
393
                Prediction from the model (N, H, W, 1)
394
        """
395
        if image.ndim != 4:
1✔
396
            image = tf.expand_dims(image, axis=0)
1✔
397
        prediction = self.model.predict(image)
1✔
398
        return prediction
1✔
399

400
    def load_model(self, path):
1✔
401
        """Loads a model from a path
402
        Args:
403
            path (str):
404
                Path to the model checkpoint file
405
        """
406
        if not hasattr(self, "model") or self.model is None:
1✔
407
            self.prep_model()
1✔
408
        self.model.load_weights(path)
1!
409

410
    def validate(self, val_dset):
1✔
411
        """Runs inference on a validation dataset
412
        Args:
413
            val_dset (tf.data.Dataset):
414
                Dataset to run inference on
415
        Returns:
416
            loss (float):
417
                Loss on the validation dataset
418
        """
419
        val_dset = val_dset.map(self.prep_batches, num_parallel_calls=tf.data.AUTOTUNE)
1✔
420
        loss = self.model.evaluate(val_dset)
1✔
421
        return loss
1✔
422

423
    def predict_dataset(self, test_dset, save_predictions=False):
1✔
424
        """Runs inference on a test dataset
425
        Args:
426
            test_dset (tf.data.Dataset):
427
                Dataset to run inference on
428
            save_predictions (bool):
429
                Whether to save the predictions to a file
430
        Returns:
431
            predictions (np.ndarray):
432
                Predictions from the model
433
        """
434
        # prepare output folder
435
        if "eval_dir" not in self.params.keys() and save_predictions:
1✔
436
            self.params["eval_dir"] = os.path.join(self.params["model_dir"], "eval")
1✔
437
            os.makedirs(self.params["eval_dir"], exist_ok=True)
1✔
438

439
        single_example_list = []
1✔
440
        j = 0
1✔
441
        for sample in tqdm(test_dset):
1✔
442
            sample["prediction"] = self.predict(self.prep_batches(sample)[0])
1✔
443

444
            # split batches to single samples
445
            # split numpy arrays to list of arrays
446
            for key in sample.keys():
1✔
447
                sample[key] = np.split(sample[key], sample[key].shape[0])
1✔
448
            # iterate over samples in batch
449
            for i in range(len(sample["prediction"])):
1✔
450
                single_example = {}
1✔
451
                for key in sample.keys():
1✔
452
                    single_example[key] = np.squeeze(sample[key][i], axis=0)
1✔
453
                    if single_example[key].dtype == object:
1✔
454
                        single_example[key] = sample[key][i].item().decode("utf-8")
1✔
455
                # decode activity df
456
                if not isinstance(single_example["activity_df"], pd.DataFrame):
1✔
457
                    single_example["activity_df"] = pd.read_json(single_example["activity_df"])
1✔
458
                # calculate cell level predictions
459
                single_example["prediction_mean"], pred_df = process_to_cells(
1!
460
                    single_example["instance_mask"], single_example["prediction"]
461
                )
462
                single_example["activity_df"] = merge_activity_df(
1✔
463
                    single_example["activity_df"], pred_df
464
                )
465
                # save single example to file
466
                if save_predictions:
1✔
467
                    fname = os.path.join(self.params["eval_dir"], str(j).zfill(4) + "_pred.hdf")
1✔
468
                    j += 1
1✔
469
                    with h5py.File(fname, "w") as f:
1✔
470
                        for key in [key for key in single_example.keys() if key != "activity_df"]:
1✔
471
                            f.create_dataset(key, data=single_example[key])
1✔
472
                        f.create_dataset(
1✔
473
                            "activity_df", data=single_example["activity_df"].to_json()
474
                        )
475
                single_example_list.append(single_example)
1✔
476
        # save params to toml file
477
        with open(os.path.join(self.params["model_dir"], "params.toml"), "w") as f:
1✔
478
            toml.dump(self.params, f)
1✔
479
        return single_example_list
1✔
480

481
    def add_weight_decay(self):
1✔
482
        if self.params["weight_decay"] in [False, None]:
1✔
483
            return None
1✔
484
        alpha = self.params["weight_decay"]
1✔
485
        for layer in self.model.layers:
1!
486
            if isinstance(layer, tf.keras.layers.Conv2D) or isinstance(
1✔
487
                layer, tf.keras.layers.Dense
488
            ):
489
                layer.add_loss(lambda layer=layer: tf.keras.regularizers.l2(alpha)(layer.kernel))
1✔
490
            if hasattr(layer, "bias_regularizer") and layer.use_bias:
1!
491
                layer.add_loss(lambda layer=layer: tf.keras.regularizers.l2(alpha)(layer.bias))
1✔
492

493
    def quantile_filter(self, dataset, record_path):
1✔
494
        """Filter out training examples that contain less than a certain quantile per marker of
495
        positive cells
496
        Args:
497
            dataset (tf.data.Dataset):
498
                Dataset to filter
499
            record_path (str):
500
                Path to the tfrecord file
501
        Returns:
502
            dataset (tf.data.Dataset):
503
                Filtered dataset
504
        """
505
        print("Filtering out sparse training examples...")
1✔
506
        self.num_pos_dict_path = record_path.split(".tfrecord")[0] + \
1✔
507
            "num_pos_dict.json"
508
        if os.path.exists(self.num_pos_dict_path):
1✔
509
            with open(self.num_pos_dict_path, "r") as f:
×
510
                num_pos_dict = json.load(f)
×
511
        else:
512
            num_pos_dict = {}
1✔
513
            for example in tqdm(dataset):
1✔
514
                marker = tf.get_static_value(example["marker"]).decode("utf-8")
1✔
515
                activity_df = pd.read_json(
1✔
516
                    tf.get_static_value(example["activity_df"]).decode("utf-8")
517
                )
518
                if marker not in num_pos_dict.keys():
1✔
519
                    num_pos_dict[marker] = []
1✔
520
                num_pos_dict[marker].append(int(np.sum(activity_df.activity == 1)))
1✔
521

522
            # save num_pos_dict to file
523
            with open(self.num_pos_dict_path, "w") as f:
1✔
524
                json.dump(num_pos_dict, f)
1✔
525

526
        quantile_dict = {}
1✔
527
        for marker, pos_list in num_pos_dict.items():
1✔
528
            quantile_dict[marker] = np.quantile(pos_list, self.params["filter_quantile"])
1✔
529

530
        def predicate(marker, activity_df):
1✔
531
            """Helper function that returns true if the number of positive cells is above the
532
            quantile threshold
533
            Args:
534
                marker (tf.Tensor):
535
                    Marker name of the example
536
                activity_df (tf.Tensor):
537
                    Activity dataframe of the example
538
            Returns:
539
                tf.Tensor:
540
                    True if the number of positive cells is above the quantile threshold
541
            """
542
            marker = tf.get_static_value(marker).decode("utf-8")
×
543
            activity_df = pd.read_json(tf.get_static_value(activity_df).decode("utf-8"))
×
544
            num_pos = tf.reduce_sum(tf.constant(activity_df.activity == 1, dtype=tf.float32))
×
545
            return tf.greater_equal(num_pos, quantile_dict[marker])
×
546

547
        dataset = dataset.filter(
1✔
548
            lambda example: tf.py_function(
549
                predicate, [example["marker"], example["activity_df"]], tf.bool
550
            )
551
        )
552
        return dataset
1✔
553

554
    def fov_filter(self, dataset, fov_list, fov_key="folder_name"):
1✔
555
        """Filter out training examples that are not in the fov_list and return a copy of the
556
        dataset
557
        Args:
558
            dataset (tf.data.Dataset):
559
                Dataset to filter
560
            fov_list (list):
561
                List of fovs to keep
562
            fov_key (str):
563
                Key of the fov in the dataset
564
        Returns:
565
            dataset (tf.data.Dataset):
566
                Filtered dataset
567
        """
568

569
        def predicate(example):
1✔
570
            """Helper function that returns true if the fov is in fov_list
571
            Args:
572
                example (dict):
573
                    Example dictionary
574
            Returns:
575
                tf.Tensor:
576
                    True if the fov is in fov_list
577
            """
578
            return tf.reduce_any(tf.equal(example[fov_key], fov_list))
×
579
        dataset = dataset.filter(predicate)
1✔
580
        return dataset
1✔
581

582

583
if __name__ == "__main__":
1✔
584
    print("CUDA_VISIBLE_DEVICES: " + str(os.getenv("CUDA_VISIBLE_DEVICES")))
×
585
    parser = argparse.ArgumentParser()
×
586
    parser.add_argument(
×
587
        "--params",
588
        type=str,
589
        default="configs/params.toml",
590
    )
591
    args = parser.parse_args()
×
592
    params = toml.load(args.params)
×
593
    trainer = ModelBuilder(params)
×
594
    trainer.train()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc