• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

DeepRank / deeprank-core / 4075652401

pending completion
4075652401

Pull #330

github

GitHub
Merge 45ea1393e into d73e8c34f
Pull Request #330: fix: data generation threading locked

1046 of 1331 branches covered (78.59%)

Branch coverage included in aggregate %.

36 of 36 new or added lines in 2 files covered. (100.0%)

2949 of 3482 relevant lines covered (84.69%)

0.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.06
/deeprankcore/utils/graph.py
1
from typing import Callable, Union, List
1✔
2
import logging
1✔
3
import numpy as np
1✔
4
import h5py
1✔
5
from deeprankcore.molstruct.atom import Atom
1✔
6
from deeprankcore.molstruct.residue import Residue, get_residue_center
1✔
7
from deeprankcore.molstruct.pair import Contact, AtomicContact, ResidueContact
1✔
8
from deeprankcore.utils.grid import MapMethod, Grid, GridSettings
1✔
9
from deeprankcore.domain import (edgestorage as Efeat, nodestorage as Nfeat, 
1✔
10
                                targetstorage as targets)
11
from scipy.spatial import distance_matrix
1✔
12

13

14
_log = logging.getLogger(__name__)
1✔
15

16

17
class Edge:
1✔
18
    def __init__(self, id_: Contact):
1✔
19
        self.id = id_
1✔
20
        self.features = {}
1✔
21

22
    def add_feature(
1✔
23
        self, feature_name: str, feature_function: Callable[[Contact], float]
24
    ):
25
        feature_value = feature_function(self.id)
×
26

27
        self.features[feature_name] = feature_value
×
28

29
    @property
1✔
30
    def position1(self) -> np.array:
1✔
31
        return self.id.item1.position
1✔
32

33
    @property
1✔
34
    def position2(self) -> np.array:
1✔
35
        return self.id.item2.position
1✔
36

37
    def has_nan(self) -> bool:
1✔
38
        "whether there are any NaN values in the edge's features"
39

40
        for feature_data in self.features.values():
1✔
41
            if np.any(np.isnan(feature_data)):
1!
42
                return True
×
43

44
        return False
1✔
45

46

47
class Node:
1✔
48
    def __init__(self, id_: Union[Atom, Residue]):
1✔
49
        if isinstance(id_, Atom):
1✔
50
            self._type = "atom"
1✔
51
        elif isinstance(id_, Residue):
1!
52
            self._type = "residue"
1✔
53
        else:
54
            raise TypeError(type(id_))
×
55

56
        self.id = id_
1✔
57

58
        self.features = {}
1✔
59

60
    @property
1✔
61
    def type(self):
1✔
62
        return self._type
×
63

64
    def has_nan(self) -> bool:
1✔
65
        "whether there are any NaN values in the node's features"
66

67
        for feature_data in self.features.values():
1✔
68
            if np.any(np.isnan(feature_data)):
1!
69
                return True
×
70
        return False
1✔
71

72
    def add_feature(
1✔
73
        self,
74
        feature_name: str,
75
        feature_function: Callable[[Union[Atom, Residue]], np.ndarray],
76
    ):
77
        feature_value = feature_function(self.id)
×
78

79
        if len(feature_value.shape) != 1:
×
80
            shape_s = "x".join(feature_value.shape)
×
81
            raise ValueError(
×
82
                f"Expected a 1-dimensional array for feature {feature_name}, but got {shape_s}"
83
            )
84

85
        self.features[feature_name] = feature_value
×
86

87
    @property
1✔
88
    def position(self) -> np.array:
1✔
89
        return self.id.position
1✔
90

91

92
class Graph:
1✔
93
    def __init__(self, id_: str):
1✔
94
        self.id = id_
1✔
95

96
        self._nodes = {}
1✔
97
        self._edges = {}
1✔
98

99
        # targets are optional and may be set later
100
        self.targets = {}
1✔
101

102
        # the center only needs to be set when this graph should be mapped to a grid.
103
        self.center = np.array((0.0, 0.0, 0.0))
1✔
104

105
    def add_node(self, node: Node):
1✔
106
        self._nodes[node.id] = node
1✔
107

108
    def get_node(self, id_: Union[Atom, Residue]) -> Node:
1✔
109
        return self._nodes[id_]
×
110

111
    def add_edge(self, edge: Edge):
1✔
112
        self._edges[edge.id] = edge
1✔
113

114
    def get_edge(self, id_: Contact) -> Edge:
1✔
115
        return self._edges[id_]
×
116

117
    @property
1✔
118
    def nodes(self) -> List[Node]:
1✔
119
        return list(self._nodes.values())
1✔
120

121
    @property
1✔
122
    def edges(self) -> List[Node]:
1✔
123
        return list(self._edges.values())
1✔
124

125
    def has_nan(self) -> bool:
1✔
126
        "whether there are any NaN values in the graph's features"
127

128
        for node in self._nodes.values():
1✔
129
            if node.has_nan():
1!
130
                return True
×
131

132
        for edge in self._edges.values():
1✔
133
            if edge.has_nan():
1!
134
                return True
×
135

136
        return False
1✔
137

138
    def map_to_grid(self, grid: Grid, method: MapMethod):
1✔
139

140
        for edge in self._edges.values():
1✔
141
            for feature_name, feature_value in edge.features.items():
1✔
142
                grid.map_feature(edge.position1, feature_name, feature_value, method)
1✔
143
                grid.map_feature(edge.position2, feature_name, feature_value, method)
1✔
144

145
        for node in self._nodes.values():
1✔
146
            for feature_name, feature_value in node.features.items():
1✔
147
                grid.map_feature(node.position, feature_name, feature_value, method)
1✔
148

149
    def write_to_hdf5(self, hdf5_path: str): # pylint: disable=too-many-locals
1✔
150
        "Write a featured graph to an hdf5 file, according to deeprank standards."
151

152
        with h5py.File(hdf5_path, "a") as hdf5_file:
1✔
153

154
            # create groups to hold data
155
            graph_group = hdf5_file.require_group(self.id)
1✔
156
            node_features_group = graph_group.create_group(Nfeat.NODE)
1✔
157
            edge_feature_group = graph_group.create_group(Efeat.EDGE)
1✔
158

159
            # store node names and chain_ids
160
            node_names = np.array([str(key) for key in self._nodes]).astype("S")
1✔
161
            node_features_group.create_dataset(Nfeat.NAME, data=node_names)
1✔
162
            chain_ids = np.array([str(key).split()[1] for key in self._nodes]).astype("S")
1✔
163
            node_features_group.create_dataset(Nfeat.CHAINID, data=chain_ids)
1✔
164

165
            # store node features
166
            node_key_list = list(self._nodes.keys())
1✔
167
            first_node_data = list(self._nodes.values())[0].features
1✔
168
            node_feature_names = list(first_node_data.keys())
1✔
169
            for node_feature_name in node_feature_names:
1✔
170

171
                node_feature_data = [
1✔
172
                    node.features[node_feature_name] for node in self._nodes.values()
173
                ]
174

175
                node_features_group.create_dataset(
1✔
176
                    node_feature_name, data=node_feature_data
177
                )
178

179
            # identify edges
180
            edge_indices = []
1✔
181
            edge_names = []
1✔
182

183
            first_edge_data = list(self._edges.values())[0].features
1✔
184
            edge_feature_names = list(first_edge_data.keys())
1✔
185

186
            edge_feature_data = {name: [] for name in edge_feature_names}
1✔
187

188
            for edge_id, edge in self._edges.items():
1✔
189

190
                id1, id2 = edge_id
1✔
191
                node_index1 = node_key_list.index(id1)
1✔
192
                node_index2 = node_key_list.index(id2)
1✔
193

194
                edge_indices.append((node_index1, node_index2))
1✔
195
                edge_names.append(f"{id1}-{id2}")
1✔
196

197
                for edge_feature_name in edge_feature_names:
1✔
198
                    edge_feature_data[edge_feature_name].append(
1✔
199
                        edge.features[edge_feature_name]
200
                    )
201

202
            # store edge names and indices
203
            edge_feature_group.create_dataset(
1✔
204
                Efeat.NAME, data=np.array(edge_names).astype("S")
205
            )
206
            edge_feature_group.create_dataset(Efeat.INDEX, data=edge_indices)
1✔
207

208
            # store edge features
209
            for edge_feature_name in edge_feature_names:
1✔
210
                edge_feature_group.create_dataset(
1✔
211
                    edge_feature_name, data=edge_feature_data[edge_feature_name]
212
                )
213

214
            # store target values
215
            score_group = graph_group.create_group(targets.VALUES)
1✔
216
            for target_name, target_data in self.targets.items():
1✔
217
                score_group.create_dataset(target_name, data=target_data)
1✔
218

219
    def write_as_grid_to_hdf5(
1✔
220
        self, hdf5_path: str, settings: GridSettings, method: MapMethod
221
    ) -> str:
222

223
        grid = Grid(self.id, self.center.tolist(), settings)
1✔
224

225
        self.map_to_grid(grid, method)
1✔
226
        grid.to_hdf5(hdf5_path)
1✔
227

228
        return hdf5_path
1✔
229

230

231
def build_atomic_graph( # pylint: disable=too-many-locals
1✔
232
    atoms: List[Atom], graph_id: str, edge_distance_cutoff: float
233
) -> Graph:
234
    """Builds a graph, using the atoms as nodes.
235
    The edge distance cutoff is in Ångströms.
236
    """
237

238
    positions = np.empty((len(atoms), 3))
1✔
239
    for atom_index, atom in enumerate(atoms):
1✔
240
        positions[atom_index] = atom.position
1✔
241

242
    distances = distance_matrix(positions, positions, p=2)
1✔
243
    neighbours = distances < edge_distance_cutoff
1✔
244

245
    graph = Graph(graph_id)
1✔
246
    for atom1_index, atom2_index in np.transpose(np.nonzero(neighbours)):
1✔
247
        if atom1_index != atom2_index:
1✔
248

249
            atom1 = atoms[atom1_index]
1✔
250
            atom2 = atoms[atom2_index]
1✔
251
            contact = AtomicContact(atom1, atom2)
1✔
252

253
            node1 = Node(atom1)
1✔
254
            node2 = Node(atom2)
1✔
255
            node1.features[Nfeat.POSITION] = atom1.position
1✔
256
            node2.features[Nfeat.POSITION] = atom2.position
1✔
257

258
            graph.add_node(node1)
1✔
259
            graph.add_node(node2)
1✔
260
            graph.add_edge(Edge(contact))
1✔
261

262
    return graph
1✔
263

264

265
def build_residue_graph( # pylint: disable=too-many-locals
1✔
266
    residues: List[Residue], graph_id: str, edge_distance_cutoff: float
267
) -> Graph:
268
    """Builds a graph, using the residues as nodes.
269
    The edge distance cutoff is in Ångströms.
270
    It's the shortest interatomic distance between two residues.
271
    """
272

273
    # collect the set of atoms and remember which are on the same residue (by index)
274
    atoms = []
1✔
275
    atoms_residues = []
1✔
276
    for residue_index, residue in enumerate(residues):
1✔
277
        for atom in residue.atoms:
1✔
278
            atoms.append(atom)
1✔
279
            atoms_residues.append(residue_index)
1✔
280

281
    atoms_residues = np.array(atoms_residues)
1✔
282

283
    # calculate the distance matrix
284
    positions = np.empty((len(atoms), 3))
1✔
285
    for atom_index, atom in enumerate(atoms):
1✔
286
        positions[atom_index] = atom.position
1✔
287

288
    distances = distance_matrix(positions, positions, p=2)
1✔
289

290
    # determine which atoms are close enough
291
    neighbours = distances < edge_distance_cutoff
1✔
292

293
    atom_index_pairs = np.transpose(np.nonzero(neighbours))
1✔
294

295
    # point out the unique residues for the atom pairs
296
    residue_index_pairs = np.unique(atoms_residues[atom_index_pairs], axis=0)
1✔
297

298
    # build the graph
299
    graph = Graph(graph_id)
1✔
300
    for residue1_index, residue2_index in residue_index_pairs:
1✔
301

302
        residue1 = residues[residue1_index]
1✔
303
        residue2 = residues[residue2_index]
1✔
304

305
        if residue1 != residue2:
1✔
306

307
            contact = ResidueContact(residue1, residue2)
1✔
308

309
            node1 = Node(residue1)
1✔
310
            node2 = Node(residue2)
1✔
311
            edge = Edge(contact)
1✔
312

313
            node1.features[Nfeat.POSITION] = get_residue_center(residue1)
1✔
314
            node2.features[Nfeat.POSITION] = get_residue_center(residue2)
1✔
315

316
            # The same residue will be added  multiple times as a node,
317
            # but the Graph class fixes this.
318
            graph.add_node(node1)
1✔
319
            graph.add_node(node2)
1✔
320
            graph.add_edge(edge)
1✔
321

322
    return graph
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc