• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

DeepRank / deeprank-core / 4075781048

pending completion
4075781048

Pull #336

github

GitHub
Merge 011ebfa73 into d73e8c34f
Pull Request #336: adds data augmentation

1059 of 1351 branches covered (78.39%)

Branch coverage included in aggregate %.

57 of 57 new or added lines in 2 files covered. (100.0%)

2977 of 3517 relevant lines covered (84.65%)

0.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.7
/deeprankcore/utils/graph.py
1
from typing import Callable, Union, List, Optional
1✔
2
import os
1✔
3
import logging
1✔
4
import numpy as np
1✔
5
import h5py
1✔
6
import pdb2sql.transform
1✔
7
from deeprankcore.molstruct.atom import Atom
1✔
8
from deeprankcore.molstruct.residue import Residue, get_residue_center
1✔
9
from deeprankcore.molstruct.pair import Contact, AtomicContact, ResidueContact
1✔
10
from deeprankcore.utils.grid import MapMethod, Grid, GridSettings, Augmentation
1✔
11
from deeprankcore.domain import (edgestorage as Efeat, nodestorage as Nfeat, 
1✔
12
                                targetstorage as targets)
13
from scipy.spatial import distance_matrix
1✔
14

15

16
_log = logging.getLogger(__name__)
1✔
17

18

19
class Edge:
1✔
20
    def __init__(self, id_: Contact):
1✔
21
        self.id = id_
1✔
22
        self.features = {}
1✔
23

24
    def add_feature(
1✔
25
        self, feature_name: str, feature_function: Callable[[Contact], float]
26
    ):
27
        feature_value = feature_function(self.id)
×
28

29
        self.features[feature_name] = feature_value
×
30

31
    @property
1✔
32
    def position1(self) -> np.array:
1✔
33
        return self.id.item1.position
1✔
34

35
    @property
1✔
36
    def position2(self) -> np.array:
1✔
37
        return self.id.item2.position
1✔
38

39
    def has_nan(self) -> bool:
1✔
40
        "whether there are any NaN values in the edge's features"
41

42
        for feature_data in self.features.values():
1✔
43
            if np.any(np.isnan(feature_data)):
1!
44
                return True
×
45

46
        return False
1✔
47

48

49
class Node:
1✔
50
    def __init__(self, id_: Union[Atom, Residue]):
1✔
51
        if isinstance(id_, Atom):
1✔
52
            self._type = "atom"
1✔
53
        elif isinstance(id_, Residue):
1!
54
            self._type = "residue"
1✔
55
        else:
56
            raise TypeError(type(id_))
×
57

58
        self.id = id_
1✔
59

60
        self.features = {}
1✔
61

62
    @property
1✔
63
    def type(self):
1✔
64
        return self._type
×
65

66
    def has_nan(self) -> bool:
1✔
67
        "whether there are any NaN values in the node's features"
68

69
        for feature_data in self.features.values():
1✔
70
            if np.any(np.isnan(feature_data)):
1!
71
                return True
×
72
        return False
1✔
73

74
    def add_feature(
1✔
75
        self,
76
        feature_name: str,
77
        feature_function: Callable[[Union[Atom, Residue]], np.ndarray],
78
    ):
79
        feature_value = feature_function(self.id)
×
80

81
        if len(feature_value.shape) != 1:
×
82
            shape_s = "x".join(feature_value.shape)
×
83
            raise ValueError(
×
84
                f"Expected a 1-dimensional array for feature {feature_name}, but got {shape_s}"
85
            )
86

87
        self.features[feature_name] = feature_value
×
88

89
    @property
1✔
90
    def position(self) -> np.array:
1✔
91
        return self.id.position
1✔
92

93

94
class Graph:
1✔
95
    def __init__(self, id_: str):
1✔
96
        self.id = id_
1✔
97

98
        self._nodes = {}
1✔
99
        self._edges = {}
1✔
100

101
        # targets are optional and may be set later
102
        self.targets = {}
1✔
103

104
        # the center only needs to be set when this graph should be mapped to a grid.
105
        self.center = np.array((0.0, 0.0, 0.0))
1✔
106

107
    def add_node(self, node: Node):
1✔
108
        self._nodes[node.id] = node
1✔
109

110
    def get_node(self, id_: Union[Atom, Residue]) -> Node:
1✔
111
        return self._nodes[id_]
×
112

113
    def add_edge(self, edge: Edge):
1✔
114
        self._edges[edge.id] = edge
1✔
115

116
    def get_edge(self, id_: Contact) -> Edge:
1✔
117
        return self._edges[id_]
×
118

119
    @property
1✔
120
    def nodes(self) -> List[Node]:
1✔
121
        return list(self._nodes.values())
1✔
122

123
    @property
1✔
124
    def edges(self) -> List[Node]:
1✔
125
        return list(self._edges.values())
1✔
126

127
    def has_nan(self) -> bool:
1✔
128
        "whether there are any NaN values in the graph's features"
129

130
        for node in self._nodes.values():
1✔
131
            if node.has_nan():
1!
132
                return True
×
133

134
        for edge in self._edges.values():
1✔
135
            if edge.has_nan():
1!
136
                return True
×
137

138
        return False
1✔
139

140
    def _map_point_features(self, grid: Grid, method: MapMethod,  # pylint: disable=too-many-arguments
1✔
141
                            feature_name: str, points: List[np.ndarray],
142
                            values: List[Union[float, np.ndarray]],
143
                            augmentation: Optional[Augmentation] = None):
144

145
        points = np.stack(points, axis=0)
1✔
146

147
        if augmentation is not None:
1✔
148
            points = pdb2sql.transform.rot_xyz_around_axis(points,
1✔
149
                                                           augmentation.axis,
150
                                                           augmentation.angle,
151
                                                           self.center)
152

153
        for point_index in range(points.shape[0]):
1✔
154
            position = points[point_index]
1✔
155
            value = values[point_index]
1✔
156

157
            grid.map_feature(position, feature_name, value, method)
1✔
158

159
    def map_to_grid(self, grid: Grid, method: MapMethod, augmentation: Optional[Augmentation] = None):
1✔
160

161
        # order edge features by xyz point
162
        points = []
1✔
163
        feature_values = {}
1✔
164
        for edge in self._edges.values():
1✔
165

166
            points += [edge.position1, edge.position2]
1✔
167

168
            for feature_name, feature_value in edge.features.items():
1✔
169
                feature_values[feature_name] = feature_values.get(feature_name, []) + [feature_value, feature_value]
1✔
170

171
        # map edge features to grid
172
        for feature_name, values in feature_values.items():
1✔
173
            self._map_point_features(grid, method, feature_name, points, values, augmentation)
1✔
174

175
        # order node features by xyz point
176
        points = []
1✔
177
        feature_values = {}
1✔
178
        for node in self._nodes.values():
1✔
179

180
            points.append(node.position)
1✔
181

182
            for feature_name, feature_value in node.features.items():
1✔
183
                feature_values[feature_name] = feature_values.get(feature_name, []) + [feature_value]
1✔
184

185
        # map node features to grid
186
        for feature_name, values in feature_values.items():
1✔
187
            self._map_point_features(grid, method, feature_name, points, values, augmentation)
1✔
188

189
    def write_to_hdf5(self, hdf5_path: str): # pylint: disable=too-many-locals
1✔
190
        "Write a featured graph to an hdf5 file, according to deeprank standards."
191

192
        with h5py.File(hdf5_path, "a") as hdf5_file:
1✔
193

194
            # create groups to hold data
195
            graph_group = hdf5_file.require_group(self.id)
1✔
196
            node_features_group = graph_group.create_group(Nfeat.NODE)
1✔
197
            edge_feature_group = graph_group.create_group(Efeat.EDGE)
1✔
198

199
            # store node names and chain_ids
200
            node_names = np.array([str(key) for key in self._nodes]).astype("S")
1✔
201
            node_features_group.create_dataset(Nfeat.NAME, data=node_names)
1✔
202
            chain_ids = np.array([str(key).split()[1] for key in self._nodes]).astype("S")
1✔
203
            node_features_group.create_dataset(Nfeat.CHAINID, data=chain_ids)
1✔
204

205
            # store node features
206
            node_key_list = list(self._nodes.keys())
1✔
207
            first_node_data = list(self._nodes.values())[0].features
1✔
208
            node_feature_names = list(first_node_data.keys())
1✔
209
            for node_feature_name in node_feature_names:
1✔
210

211
                node_feature_data = [
1✔
212
                    node.features[node_feature_name] for node in self._nodes.values()
213
                ]
214

215
                node_features_group.create_dataset(
1✔
216
                    node_feature_name, data=node_feature_data
217
                )
218

219
            # identify edges
220
            edge_indices = []
1✔
221
            edge_names = []
1✔
222

223
            first_edge_data = list(self._edges.values())[0].features
1✔
224
            edge_feature_names = list(first_edge_data.keys())
1✔
225

226
            edge_feature_data = {name: [] for name in edge_feature_names}
1✔
227

228
            for edge_id, edge in self._edges.items():
1✔
229

230
                id1, id2 = edge_id
1✔
231
                node_index1 = node_key_list.index(id1)
1✔
232
                node_index2 = node_key_list.index(id2)
1✔
233

234
                edge_indices.append((node_index1, node_index2))
1✔
235
                edge_names.append(f"{id1}-{id2}")
1✔
236

237
                for edge_feature_name in edge_feature_names:
1✔
238
                    edge_feature_data[edge_feature_name].append(
1✔
239
                        edge.features[edge_feature_name]
240
                    )
241

242
            # store edge names and indices
243
            edge_feature_group.create_dataset(
1✔
244
                Efeat.NAME, data=np.array(edge_names).astype("S")
245
            )
246
            edge_feature_group.create_dataset(Efeat.INDEX, data=edge_indices)
1✔
247

248
            # store edge features
249
            for edge_feature_name in edge_feature_names:
1✔
250
                edge_feature_group.create_dataset(
1✔
251
                    edge_feature_name, data=edge_feature_data[edge_feature_name]
252
                )
253

254
            # store target values
255
            score_group = graph_group.create_group(targets.VALUES)
1✔
256
            for target_name, target_data in self.targets.items():
1✔
257
                score_group.create_dataset(target_name, data=target_data)
1✔
258

259
    @staticmethod
1✔
260
    def _find_unused_augmentation_name(unaugmented_id: str, hdf5_path: str) -> str:
1✔
261

262
        prefix = f"{unaugmented_id}_"
1✔
263

264
        entry_names_taken = []
1✔
265
        if os.path.isfile(hdf5_path):
1!
266
            with h5py.File(hdf5_path, 'r') as hdf5_file:
1✔
267
                for entry_name in hdf5_file:
1✔
268
                    if entry_name.startswith(prefix):
1!
269
                        entry_names_taken.append(entry_name)
×
270

271
        augmentation_count = 0
1✔
272
        chosen_name = f"{prefix}{augmentation_count:03}"
1✔
273
        while chosen_name in entry_names_taken:
1!
274
            augmentation_count += 1
×
275
            chosen_name = f"{prefix}{augmentation_count:03}"
×
276

277
        return chosen_name
1✔
278

279
    def write_as_grid_to_hdf5(
1✔
280
        self, hdf5_path: str,
281
        settings: GridSettings,
282
        method: MapMethod,
283
        augmentation: Optional[Augmentation] = None
284
    ) -> str:
285

286
        id_ = self.id
1✔
287
        if augmentation is not None:
1✔
288
            id_ = self._find_unused_augmentation_name(id_, hdf5_path)
1✔
289

290
        grid = Grid(id_, self.center.tolist(), settings)
1✔
291

292
        self.map_to_grid(grid, method, augmentation)
1✔
293
        grid.to_hdf5(hdf5_path)
1✔
294

295
        return hdf5_path
1✔
296

297

298
def build_atomic_graph( # pylint: disable=too-many-locals
1✔
299
    atoms: List[Atom], graph_id: str, edge_distance_cutoff: float
300
) -> Graph:
301
    """Builds a graph, using the atoms as nodes.
302
    The edge distance cutoff is in Ångströms.
303
    """
304

305
    positions = np.empty((len(atoms), 3))
1✔
306
    for atom_index, atom in enumerate(atoms):
1✔
307
        positions[atom_index] = atom.position
1✔
308

309
    distances = distance_matrix(positions, positions, p=2)
1✔
310
    neighbours = distances < edge_distance_cutoff
1✔
311

312
    graph = Graph(graph_id)
1✔
313
    for atom1_index, atom2_index in np.transpose(np.nonzero(neighbours)):
1✔
314
        if atom1_index != atom2_index:
1✔
315

316
            atom1 = atoms[atom1_index]
1✔
317
            atom2 = atoms[atom2_index]
1✔
318
            contact = AtomicContact(atom1, atom2)
1✔
319

320
            node1 = Node(atom1)
1✔
321
            node2 = Node(atom2)
1✔
322
            node1.features[Nfeat.POSITION] = atom1.position
1✔
323
            node2.features[Nfeat.POSITION] = atom2.position
1✔
324

325
            graph.add_node(node1)
1✔
326
            graph.add_node(node2)
1✔
327
            graph.add_edge(Edge(contact))
1✔
328

329
    return graph
1✔
330

331

332
def build_residue_graph( # pylint: disable=too-many-locals
1✔
333
    residues: List[Residue], graph_id: str, edge_distance_cutoff: float
334
) -> Graph:
335
    """Builds a graph, using the residues as nodes.
336
    The edge distance cutoff is in Ångströms.
337
    It's the shortest interatomic distance between two residues.
338
    """
339

340
    # collect the set of atoms and remember which are on the same residue (by index)
341
    atoms = []
1✔
342
    atoms_residues = []
1✔
343
    for residue_index, residue in enumerate(residues):
1✔
344
        for atom in residue.atoms:
1✔
345
            atoms.append(atom)
1✔
346
            atoms_residues.append(residue_index)
1✔
347

348
    atoms_residues = np.array(atoms_residues)
1✔
349

350
    # calculate the distance matrix
351
    positions = np.empty((len(atoms), 3))
1✔
352
    for atom_index, atom in enumerate(atoms):
1✔
353
        positions[atom_index] = atom.position
1✔
354

355
    distances = distance_matrix(positions, positions, p=2)
1✔
356

357
    # determine which atoms are close enough
358
    neighbours = distances < edge_distance_cutoff
1✔
359

360
    atom_index_pairs = np.transpose(np.nonzero(neighbours))
1✔
361

362
    # point out the unique residues for the atom pairs
363
    residue_index_pairs = np.unique(atoms_residues[atom_index_pairs], axis=0)
1✔
364

365
    # build the graph
366
    graph = Graph(graph_id)
1✔
367
    for residue1_index, residue2_index in residue_index_pairs:
1✔
368

369
        residue1 = residues[residue1_index]
1✔
370
        residue2 = residues[residue2_index]
1✔
371

372
        if residue1 != residue2:
1✔
373

374
            contact = ResidueContact(residue1, residue2)
1✔
375

376
            node1 = Node(residue1)
1✔
377
            node2 = Node(residue2)
1✔
378
            edge = Edge(contact)
1✔
379

380
            node1.features[Nfeat.POSITION] = get_residue_center(residue1)
1✔
381
            node2.features[Nfeat.POSITION] = get_residue_center(residue2)
1✔
382

383
            # The same residue will be added  multiple times as a node,
384
            # but the Graph class fixes this.
385
            graph.add_node(node1)
1✔
386
            graph.add_node(node2)
1✔
387
            graph.add_edge(edge)
1✔
388

389
    return graph
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc