• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

exalearn / ExaMol / 6251994855

20 Sep 2023 05:17PM UTC coverage: 98.25% (-0.04%) from 98.293%
6251994855

Pull #108

github

WardLT
Use molecule which takes longer to optimize in test
Pull Request #108: Use a more robust relaxation technique

11 of 11 new or added lines in 2 files covered. (100.0%)

1740 of 1771 relevant lines covered (98.25%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.15
/examol/score/nfp.py
1
"""Train neural network models using `NFP <https://github.com/NREL/nfp>`_"""
2

3
from sklearn.model_selection import train_test_split
1✔
4
try:
1✔
5
    from tensorflow.keras import callbacks as cb
1✔
6
except ImportError as e:  # pragma: no-coverage
7
    raise ImportError('You may need to install Tensorflow and NFP.') from e
8
import tensorflow as tf
1✔
9
import numpy as np
1✔
10
import nfp
1✔
11

12
from examol.store.models import MoleculeRecord
1✔
13
from examol.utils.conversions import convert_string_to_nx
1✔
14
from .base import Scorer
1✔
15
from .utils.tf import LRLogger, TimeLimitCallback, EpochTimeLogger
1✔
16

17

18
class ReduceAtoms(tf.keras.layers.Layer):
1✔
19
    """Reduce the atoms along a certain direction
20

21
    Args:
22
        reduction_op: Name of the operation used for reduction
23
    """
24

25
    def __init__(self, reduction_op: str = 'mean', **kwargs):
1✔
26
        super().__init__(**kwargs)
1✔
27
        self.reduction_op = reduction_op
1✔
28

29
    def get_config(self):
1✔
30
        config = super().get_config()
1✔
31
        config['reduction_op'] = self.reduction_op
1✔
32
        return config
1✔
33

34
    def call(self, inputs, mask=None):  # pragma: no-coverage
35
        """
36
        Args:
37
            inputs: Matrix to be reduced
38
            mask: Identifies which rows to sum are placeholders
39
        """
40
        masked_tensor = tf.ragged.boolean_mask(inputs, mask)
41
        reduce_fn = getattr(tf.math, f'reduce_{self.reduction_op}')
42
        return reduce_fn(masked_tensor, axis=1)
43

44

45
# Define the custom layers for our class
46
custom_objects = nfp.custom_objects.copy()
1✔
47
custom_objects['ReduceAtoms'] = ReduceAtoms
1✔
48

49

50
def make_simple_network(
1✔
51
        atom_features: int = 64,
52
        message_steps: int = 8,
53
        output_layers: list[int] = (512, 256, 128),
54
        reduce_op: str = 'mean',
55
        atomwise: bool = True,
56
) -> tf.keras.models.Model:
57
    """Construct a Keras model using the settings provided by a user
58

59
    Args:
60
        atom_features: Number of features used per atom and bond
61
        message_steps: Number of message passing steps
62
        output_layers: Number of neurons in the readout layers
63
        reduce_op: Operation used to reduce from atom-level to molecule-level vectors
64
        atomwise: Whether to reduce atomwise contributions to form an output,
65
                  or reduce to a single vector per molecule before the output layers
66
    Returns:
67
        A model instantiated with the user-defined options
68
    """
69
    atom = tf.keras.layers.Input(shape=[None], dtype=tf.int32, name='atom')
1✔
70
    bond = tf.keras.layers.Input(shape=[None], dtype=tf.int32, name='bond')
1✔
71
    connectivity = tf.keras.layers.Input(shape=[None, 2], dtype=tf.int32, name='connectivity')
1✔
72

73
    # Convert from a single integer defining the atom state to a vector
74
    # of weights associated with that class
75
    atom_state = tf.keras.layers.Embedding(64, atom_features, name='atom_embedding', mask_zero=True)(atom)
1✔
76

77
    # Ditto with the bond state
78
    bond_state = tf.keras.layers.Embedding(5, atom_features, name='bond_embedding', mask_zero=True)(bond)
1✔
79

80
    # Here we use our first nfp layer. This is an attention layer that looks at
81
    # the atom and bond states and reduces them to a single, graph-level vector.
82
    # mum_heads * units has to be the same dimension as the atom / bond dimension
83
    global_state = nfp.GlobalUpdate(units=4, num_heads=1, name='problem')([atom_state, bond_state, connectivity])
1✔
84

85
    for _ in range(message_steps):  # Do the message passing
1✔
86
        new_bond_state = nfp.EdgeUpdate()([atom_state, bond_state, connectivity, global_state])
1✔
87
        bond_state = tf.keras.layers.Add()([bond_state, new_bond_state])
1✔
88

89
        new_atom_state = nfp.NodeUpdate()([atom_state, bond_state, connectivity, global_state])
1✔
90
        atom_state = tf.keras.layers.Add()([atom_state, new_atom_state])
1✔
91

92
        new_global_state = nfp.GlobalUpdate(units=4, num_heads=1)(
1✔
93
            [atom_state, bond_state, connectivity, global_state]
94
        )
95
        global_state = tf.keras.layers.Add()([global_state, new_global_state])
1✔
96

97
    # Pass the global state through an output
98
    output = atom_state
1✔
99
    if not atomwise:
1✔
100
        output = ReduceAtoms(reduce_op)(output)
1✔
101
    for shape in output_layers:
1✔
102
        output = tf.keras.layers.Dense(shape, activation='relu')(output)
1✔
103
    output = tf.keras.layers.Dense(1)(output)
1✔
104
    if atomwise:
1✔
105
        output = ReduceAtoms(reduce_op)(output)
1✔
106
    output = tf.keras.layers.Dense(1, activation='linear', name='scale')(output)
1✔
107

108
    # Construct the tf.keras model
109
    return tf.keras.Model([atom, bond, connectivity], [output])
1✔
110

111

112
class NFPMessage:
1✔
113
    """Package for sending an MPNN model over connections that require pickling"""
114

115
    def __init__(self, model: tf.keras.Model):
1✔
116
        """
117
        Args:
118
            model: Model to be sent
119
        """
120

121
        self.config = model.to_json()
1✔
122
        # Makes a copy of the weights to ensure they are not memoryview objects
123
        self.weights = [np.array(v) for v in model.get_weights()]
1✔
124

125
        # Cached copy of the model
126
        self._model = model
1✔
127

128
    def __getstate__(self):
1✔
129
        """Get state except the model"""
130
        state = self.__dict__.copy()
1✔
131
        state['_model'] = None
1✔
132
        return state
1✔
133

134
    def get_model(self) -> tf.keras.Model:
1✔
135
        """Get a copy of the model
136

137
        Returns:
138
            The model specified by this message
139
        """
140
        if self._model is None:
1✔
141
            self._model = tf.keras.models.model_from_json(
1✔
142
                self.config,
143
                custom_objects=custom_objects
144
            )
145
            self._model.set_weights(self.weights)
1✔
146
        return self._model
1✔
147

148

149
def convert_string_to_dict(mol_string: str) -> dict:
1✔
150
    """Convert a molecule to an NFP-compatible dictionary form
151

152
    Args:
153
        mol_string: SMILES or InChI string
154
    Returns:
155
        Dictionary
156
    """
157

158
    # Convert first to a nx.Graph
159
    graph = convert_string_to_nx(mol_string)
1✔
160

161
    # Get the atom types
162
    atom_type_id = [n['atomic_num'] for _, n in graph.nodes(data=True)]
1✔
163

164
    # Get the bond types, making the data
165
    bond_types = ["", "AROMATIC", "DOUBLE", "SINGLE", "TRIPLE"]  # 0 is a dummy type
1✔
166
    connectivity = []
1✔
167
    edge_type = []
1✔
168
    for a, b, d in graph.edges(data=True):
1✔
169
        connectivity.append([a, b])
1✔
170
        connectivity.append([b, a])
1✔
171
        edge_type.append(str(d['bond_type']))
1✔
172
        edge_type.append(str(d['bond_type']))
1✔
173
    edge_type_id = list(map(bond_types.index, edge_type))
1✔
174

175
    # Sort connectivity array by the first column
176
    #  This is needed for the MPNN code to efficiently group messages for
177
    #  each node when performing the message passing step
178
    connectivity = np.array(connectivity)
1✔
179
    if connectivity.size > 0:
1✔
180
        # Skip a special case of a molecule w/o bonds
181
        inds = np.lexsort((connectivity[:, 1], connectivity[:, 0]))
1✔
182
        connectivity = connectivity[inds, :]
1✔
183

184
        # Tensorflow's "segment_sum" will cause problems if the last atom
185
        #  is not bonded because it returns an array
186
        if connectivity.max() != len(atom_type_id) - 1:
1✔
187
            raise ValueError(f"Problem with unconnected atoms for \"{mol_string}\"")
1✔
188
    else:
189
        connectivity = np.zeros((0, 2))
×
190

191
    return {
1✔
192
        'atom': atom_type_id,
193
        'bond': edge_type_id,
194
        'connectivity': connectivity
195
    }
196

197

198
def make_data_loader(mol_dicts: list[dict],
1✔
199
                     values: np.ndarray | list[object] | None = None,
200
                     batch_size: int = 32,
201
                     repeat: bool = False,
202
                     shuffle_buffer: int | None = None,
203
                     value_spec: tf.TensorSpec = tf.TensorSpec((), dtype=tf.float32),
204
                     drop_last_batch: bool = False) -> tf.data.Dataset:
205
    """Make an in-memory data loader for data compatible with NFP-style neural networks
206

207
    Args:
208
        mol_dicts: List of molecules parsed into the moldesign format
209
        values: List of output values, if included in the output
210
        value_spec: Tensorflow specification for the output
211
        batch_size: Number of molecules per batch
212
        repeat: Whether to create an infinitely-repeating iterator
213
        shuffle_buffer: Size of a shuffle buffer. Use ``None`` to leave data unshuffled
214
        drop_last_batch: Whether to keep the last batch in the dataset. Set to ``True`` if, for example, you need every batch to be the same size
215
    Returns:
216
        Data loader that generates molecules in the desired shapes
217
    """
218

219
    # Determine the maximum size of molecule, used when padding the arrays
220
    max_atoms = max(len(x['atom']) for x in mol_dicts)
1✔
221
    max_bonds = max(len(x['bond']) for x in mol_dicts)
1✔
222

223
    # Make the initial data loader
224
    record_sig = {
1✔
225
        "atom": tf.TensorSpec(shape=(None,), dtype=tf.int32),
226
        "bond": tf.TensorSpec(shape=(None,), dtype=tf.int32),
227
        "connectivity": tf.TensorSpec(shape=(None, 2), dtype=tf.int32),
228
    }
229
    if values is None:
1✔
230
        def generator():
1✔
231
            yield from mol_dicts
×
232
    else:
233
        def generator():
1✔
234
            yield from zip(mol_dicts, values)
×
235

236
        record_sig = (record_sig, value_spec)
1✔
237

238
    loader = tf.data.Dataset.from_generator(generator=generator, output_signature=record_sig).cache()  # TODO (wardlt): Make caching optional?
1✔
239

240
    # Repeat the molecule list before shuffling
241
    if repeat:
1✔
242
        loader = loader.repeat()
1✔
243

244
    # Shuffle, if desired
245
    if shuffle_buffer is not None:
1✔
246
        loader = loader.shuffle(shuffle_buffer)
1✔
247

248
    # Make the batches. Pads the data to make them all the same size, adding 0's to signify padded values
249
    padded_records = {
1✔
250
        "atom": tf.TensorShape((max_atoms,)),
251
        "bond": tf.TensorShape((max_bonds,)),
252
        "connectivity": tf.TensorShape((max_bonds, 2))
253
    }
254
    if values is not None:
1✔
255
        padded_records = (padded_records, value_spec.shape)
1✔
256
    loader = loader.padded_batch(batch_size=batch_size, padded_shapes=padded_records, drop_remainder=drop_last_batch)
1✔
257

258
    return loader
1✔
259

260

261
class NFPScorer(Scorer):
1✔
262
    """Train message-passing neural networks based on the `NFP <https://github.com/NREL/nfp>`_ library.
263

264
    NFP uses Keras to define message-passing networks, which is backed by Tensorflow for executing the networks on different hardware."""
265

266
    def __init__(self, retrain_from_scratch: bool = True):
1✔
267
        """
268
        Args:
269
            retrain_from_scratch: Whether to retrain models from scratch or not
270
        """
271
        self.retrain_from_scratch = retrain_from_scratch
1✔
272

273
    def prepare_message(self, model: tf.keras.models.Model, training: bool = False) -> dict | NFPMessage:
1✔
274
        if training and self.retrain_from_scratch:
1✔
275
            return model.get_config()
1✔
276
        else:
277
            return NFPMessage(model)
1✔
278

279
    def transform_inputs(self, record_batch: list[MoleculeRecord]) -> list:
1✔
280
        return [convert_string_to_dict(record.identifier.inchi) for record in record_batch]
1✔
281

282
    def score(self, model_msg: NFPMessage, inputs: list[dict], batch_size: int = 64, **kwargs) -> np.ndarray:
1✔
283
        """Assign a score to molecules
284

285
        Args:
286
            model_msg: Model in a transmittable format
287
            inputs: Batch of inputs ready for the model (in dictionary format)
288
            batch_size: Number of molecules to evaluate at each time
289
        Returns:
290
            The scores to a set of records
291
        """
292
        model = model_msg.get_model()  # Unpack the model
1✔
293
        loader = make_data_loader(inputs, batch_size=batch_size)
1✔
294
        return model.predict(loader, verbose=False)
1✔
295

296
    def retrain(self,
1✔
297
                model_msg: dict | NFPMessage,
298
                inputs: list,
299
                outputs: np.ndarray,
300
                num_epochs: int = 4,
301
                batch_size: int = 32,
302
                validation_split: float = 0.1,
303
                learning_rate: float = 1e-3,
304
                device_type: str = 'gpu',
305
                steps_per_exec: int = 1,
306
                patience: int = None,
307
                timeout: float = None,
308
                verbose: bool = False) -> tuple[list[np.ndarray], dict]:
309
        """Retrain the scorer based on new training records
310

311
        Args:
312
            model_msg: Model to be retrained
313
            inputs: Training set inputs, as generated by :meth:`transform_inputs`
314
            outputs: Training Set outputs, as generated by :meth:`transform_outputs`
315
            num_epochs: Maximum number of epochs to run
316
            batch_size: Number of molecules per training batch
317
            validation_split: Fraction of molecules used for the training/validation split
318
            learning_rate: Learning rate for the Adam optimizer
319
            device_type: Type of device used for training
320
            steps_per_exec: Number of training steps to run per execution on acceleration
321
            patience: Number of epochs without improvement before terminating training. Default is 10% of ``num_epochs``
322
            timeout: Maximum training time in seconds
323
            verbose: Whether to print training information to screen
324
        Returns:
325
            Message defining how to update the model
326
        """
327

328
        # Make the model
329
        if isinstance(model_msg, NFPMessage):
1✔
330
            model = model_msg.get_model()
1✔
331
        elif isinstance(model_msg, dict):
1✔
332
            model = tf.keras.Model.from_config(model_msg, custom_objects=custom_objects)
1✔
333
        else:
334
            raise NotImplementedError(f'Unrecognized message type: {type(model_msg)}')
335

336
        # Split off a validation set
337
        train_x, valid_x, train_y, valid_y = train_test_split(inputs, outputs, test_size=validation_split)
1✔
338

339
        # Make the loaders
340
        steps_per_epoch = len(train_x) // batch_size
1✔
341
        train_loader = make_data_loader(train_x, train_y, repeat=True, batch_size=batch_size, drop_last_batch=True, shuffle_buffer=32768)
1✔
342
        valid_steps = len(valid_x) // batch_size
1✔
343
        assert valid_steps > 0, 'We need some validation data'
1✔
344
        valid_loader = make_data_loader(valid_x, valid_y, batch_size=batch_size, drop_last_batch=True)
1✔
345

346
        # Define initial guesses for the "scaling" later
347
        try:
1✔
348
            scale_layer = model.get_layer('scale')
1✔
349
            outputs = np.array(outputs)
1✔
350
            scale_layer.set_weights([outputs.std()[None, None], outputs.mean()[None]])
1✔
351
        except ValueError:
×
352
            pass
×
353

354
        # Configure the LR schedule
355
        init_learn_rate = learning_rate
1✔
356
        final_learn_rate = init_learn_rate * 1e-3
1✔
357
        decay_rate = (final_learn_rate / init_learn_rate) ** (1. / (num_epochs - 1))
1✔
358

359
        def lr_schedule(epoch, lr):
1✔
360
            return lr * decay_rate
1✔
361

362
        # Compile the model then train
363
        model.compile(
1✔
364
            tf.optimizers.Adam(init_learn_rate),
365
            'mean_squared_error',
366
            metrics=['mean_absolute_error'],
367
            steps_per_execution=steps_per_exec,
368
        )
369

370
        # Make the callbacks
371
        if patience is None:
1✔
372
            patience = num_epochs // 10
1✔
373
        early_stopping = cb.EarlyStopping(patience=patience, restore_best_weights=True)
1✔
374
        callbacks = [
1✔
375
            LRLogger(),
376
            EpochTimeLogger(),
377
            cb.LearningRateScheduler(lr_schedule),
378
            early_stopping,
379
            cb.TerminateOnNaN(),
380
        ]
381
        if timeout is not None:
1✔
382
            callbacks.append(TimeLimitCallback(timeout))
1✔
383
        if timeout is not None:
1✔
384
            callbacks.append(TimeLimitCallback(timeout))
1✔
385

386
        history = model.fit(
1✔
387
            train_loader,
388
            epochs=num_epochs,
389
            shuffle=False,
390
            verbose=verbose,
391
            callbacks=callbacks,
392
            steps_per_epoch=steps_per_epoch,
393
            validation_data=valid_loader,
394
            validation_steps=valid_steps,
395
            validation_freq=1,
396
        )
397

398
        # If a timeout is used, make sure we are using the best weights
399
        #  The training may have exited without storing the best weights
400
        if timeout is not None:
1✔
401
            model.set_weights(early_stopping.best_weights)
1✔
402

403
        # Convert weights to numpy arrays (avoids mmap issues)
404
        weights = []
1✔
405
        for v in model.get_weights():
1✔
406
            v = np.array(v)
1✔
407
            if np.isnan(v).any():
1✔
408
                raise ValueError('Found some NaN weights.')
×
409
            weights.append(v)
1✔
410

411
        # Once we are finished training call "clear_session" to flush the model out of GPU memory
412
        tf.keras.backend.clear_session()
1✔
413
        return weights, history.history
1✔
414

415
    def update(self, model: tf.keras.models.Model, update_msg: tuple[list[np.ndarray], dict]) -> tf.keras.models.Model:
1✔
416
        model.set_weights(update_msg[0])
1✔
417
        return model
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc