• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

DeepRank / deeprank-core / 4075781048

pending completion
4075781048

Pull #336

github

GitHub
Merge 011ebfa73 into d73e8c34f
Pull Request #336: adds data augmentation

1059 of 1351 branches covered (78.39%)

Branch coverage included in aggregate %.

57 of 57 new or added lines in 2 files covered. (100.0%)

2977 of 3517 relevant lines covered (84.65%)

0.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.74
/deeprankcore/utils/grid.py
1
"""
2
This module holds the classes that are used when working with a 3D grid.
3
"""
4

5
import logging
1✔
6
from enum import Enum
1✔
7
from typing import Dict, Union, List
1✔
8
import numpy as np
1✔
9
import h5py
1✔
10
import itertools
1✔
11
from scipy.signal import bspline
1✔
12

13
from deeprankcore.domain import gridstorage
1✔
14

15

16
_log = logging.getLogger(__name__)
1✔
17

18

19
class MapMethod(Enum):
1✔
20
    """This holds the value of either one of 4 grid mapping methods.
21
    A mapping method determines how feature point values are divided over the grid points.
22
    """
23

24
    GAUSSIAN = 1
1✔
25
    FAST_GAUSSIAN = 2
1✔
26
    BSP_LINE = 3
1✔
27
    NEAREST_NEIGHBOURS = 4
1✔
28

29

30
class Augmentation:
1✔
31
    "a rotation around an axis, to be applied to a feature before mapping it to a grid"
32

33
    def __init__(self, axis: np.ndarray, angle: float):
1✔
34
        self._axis = axis
1✔
35
        self._angle = angle
1✔
36

37
    @property
1✔
38
    def axis(self) -> np.ndarray:
1✔
39
        return self._axis
1✔
40

41
    @property
1✔
42
    def angle(self) -> float:
1✔
43
        return self._angle
1✔
44

45

46
class GridSettings:
1✔
47
    """Objects of this class hold the settings to build a grid.
48
    The grid is basically a multi-divided 3D cube with
49
    the following properties:
50

51
     - sizes: x, y, z sizes of the box in Ã…
52
     - points_counts: the number of points on the x, y, z edges of the cube
53
     - resolutions: the size in Ã… of one x, y, z edge subdivision. Also the distance between two points on the edge.
54
    """
55

56
    def __init__(
1✔
57
        self,
58
        points_counts: List[int],
59
        sizes: List[float]
60
    ):
61
        assert len(points_counts) == 3
1✔
62
        assert len(sizes) == 3
1✔
63

64
        self._points_counts = points_counts
1✔
65
        self._sizes = sizes
1✔
66

67
    @property
1✔
68
    def resolutions(self) -> List[float]:
1✔
69
        return [self._sizes[i] / self._points_counts[i] for i in range(3)]
1✔
70

71
    @property
1✔
72
    def sizes(self) -> List[float]:
1✔
73
        return self._sizes
1✔
74

75
    @property
1✔
76
    def points_counts(self) -> List[int]:
1✔
77
        return self._points_counts
1✔
78

79

80
class Grid:
1✔
81
    """An instance of this class holds everything that the grid is made of:
82
    - coordinates of points
83
    - names of features
84
    - feature values on each point
85
    """
86

87
    def __init__(self, id_: str, center: List[float], settings: GridSettings):
1✔
88
        self.id = id_
1✔
89

90
        self._center = np.array(center)
1✔
91

92
        self._settings = settings
1✔
93

94
        self._set_mesh(self._center, settings)
1✔
95

96
        self._features = {}
1✔
97

98
    def _set_mesh(self, center: np.ndarray, settings: GridSettings):
1✔
99
        "builds the grid points"
100

101
        half_size_x = settings.sizes[0] / 2
1✔
102
        half_size_y = settings.sizes[1] / 2
1✔
103
        half_size_z = settings.sizes[2] / 2
1✔
104

105
        min_x = center[0] - half_size_x
1✔
106
        max_x = min_x + (settings.points_counts[0] - 1.0) * settings.resolutions[0]
1✔
107
        self._xs = np.linspace(min_x, max_x, num=settings.points_counts[0])
1✔
108

109
        min_y = center[1] - half_size_y
1✔
110
        max_y = min_y + (settings.points_counts[1] - 1.0) * settings.resolutions[1]
1✔
111
        self._ys = np.linspace(min_y, max_y, num=settings.points_counts[1])
1✔
112

113
        min_z = center[2] - half_size_z
1✔
114
        max_z = min_z + (settings.points_counts[2] - 1.0) * settings.resolutions[2]
1✔
115
        self._zs = np.linspace(min_z, max_z, num=settings.points_counts[2])
1✔
116

117
        self._ygrid, self._xgrid, self._zgrid = np.meshgrid(
1✔
118
            self._ys, self._xs, self._zs
119
        )
120

121
    @property
1✔
122
    def center(self) -> np.ndarray:
1✔
123
        return self._center
1✔
124

125
    @property
1✔
126
    def xs(self) -> np.array:
1✔
127
        return self._xs
1✔
128

129
    @property
1✔
130
    def xgrid(self) -> np.array:
1✔
131
        return self._xgrid
1✔
132

133
    @property
1✔
134
    def ys(self) -> np.array:
1✔
135
        return self._ys
1✔
136

137
    @property
1✔
138
    def ygrid(self) -> np.array:
1✔
139
        return self._ygrid
1✔
140

141
    @property
1✔
142
    def zs(self) -> np.array:
1✔
143
        return self._zs
1✔
144

145
    @property
1✔
146
    def zgrid(self) -> np.array:
1✔
147
        return self._zgrid
1✔
148

149
    @property
1✔
150
    def features(self) -> Dict[str, np.array]:
1✔
151
        return self._features
1✔
152

153
    def add_feature_values(self, feature_name: str, data: np.ndarray):
1✔
154
        """Makes sure feature values per grid point get stored.
155

156
        This method may be called repeatedly to add on to existing grid point values.
157
        """
158

159
        if feature_name not in self._features:
1✔
160
            self._features[feature_name] = data
1✔
161
        else:
162
            self._features[feature_name] += data
1✔
163

164
    def _get_mapped_feature_gaussian(
1✔
165
        self, position: np.ndarray, value: float
166
    ) -> np.ndarray:
167

168
        beta = 1.0
1✔
169

170
        fx, fy, fz = position
1✔
171
        distances = np.sqrt(
1✔
172
            (self.xgrid - fx) ** 2 + (self.ygrid - fy) ** 2 + (self.zgrid - fz) ** 2
173
        )
174

175
        return value * np.exp(-beta * distances)
1✔
176

177
    def _get_mapped_feature_fast_gaussian(
1✔
178
        self, position: np.ndarray, value: float
179
    ) -> np.ndarray:
180

181
        beta = 1.0
×
182
        cutoff = 5.0 * beta
×
183

184
        fx, fy, fz = position
×
185
        distances = np.sqrt(
×
186
            (self.xgrid - fx) ** 2 + (self.ygrid - fy) ** 2 + (self.zgrid - fz) ** 2
187
        )
188

189
        data = np.zeros(distances.shape)
×
190

191
        data[distances < cutoff] = value * np.exp(
×
192
            -beta * distances[distances < cutoff]
193
        )
194

195
        return data
×
196

197
    def _get_mapped_feature_bsp_line(
1✔
198
        self, position: np.ndarray, value: float
199
    ) -> np.ndarray:
200

201
        order = 4
×
202

203
        fx, fy, fz = position
×
204
        bsp_data = (
×
205
            bspline((self.xgrid - fx) / self.resolution, order)
206
            * bspline((self.ygrid - fy) / self.resolution, order)
207
            * bspline((self.zgrid - fz) / self.resolution, order)
208
        )
209

210
        return value * bsp_data
×
211

212
    def _get_mapped_feature_nearest_neighbour( # pylint: disable=too-many-locals
1✔
213
        self, position: np.ndarray, value: float
214
    ) -> np.ndarray:
215

216
        fx, _, _ = position
×
217
        distances_x = np.abs(self.xs - fx)
×
218
        distances_y = np.abs(self.ys - fx)
×
219
        distances_z = np.abs(self.zs - fx)
×
220

221
        indices_x = np.argsort(distances_x)[:2]
×
222
        indices_y = np.argsort(distances_y)[:2]
×
223
        indices_z = np.argsort(distances_z)[:2]
×
224

225
        sorted_x = distances_x[indices_x]
×
226
        weights_x = sorted_x / np.sum(sorted_x)
×
227

228
        sorted_y = distances_y[indices_y]
×
229
        weights_y = sorted_y / np.sum(sorted_y)
×
230

231
        sorted_z = distances_z[indices_z]
×
232
        weights_z = sorted_z / np.sum(sorted_z)
×
233

234
        indices = [indices_x, indices_y, indices_z]
×
235
        points = list(itertools.product(*indices))
×
236

237
        weight_products = list(itertools.product(weights_x, weights_y, weights_z))
×
238
        weights = [np.sum(p) for p in weight_products]
×
239

240
        neighbour_data = np.zeros(
×
241
            (self.xs.shape[0], self.ys.shape[0], self.zs.shape[0])
242
        )
243

244
        for point_index, point in enumerate(points):
×
245
            weight = weights[point_index]
×
246

247
            neighbour_data[point] = weight * value
×
248

249
        return neighbour_data
×
250

251
    def _get_atomic_density_koes(self, position: np.ndarray, vanderwaals_radius: float) -> np.ndarray:
1✔
252
        """
253
        Function to map individual atomic density on the grid.
254
        The formula is equation (1) of the Koes paper
255
        Protein-Ligand Scoring with Convolutional NN Arxiv:1612.02751v1
256

257
        Returns:
258
            the mapped density
259
        """
260

261
        distances = np.sqrt(np.square(self.xgrid - position[0]) +
1✔
262
                            np.square(self.ygrid - position[1]) +
263
                            np.square(self.zgrid - position[2]))
264

265
        density_data = np.zeros(distances.shape)
1✔
266

267
        indices_close = distances < vanderwaals_radius
1✔
268
        indices_far = (distances >= vanderwaals_radius) & (distances < 1.5 * vanderwaals_radius)
1✔
269

270
        density_data[indices_close] = np.exp(-2.0 * np.square(distances[indices_close]) /  np.square(vanderwaals_radius))
1✔
271
        density_data[indices_far] = 4.0 / np.square(np.e) / np.square(vanderwaals_radius) * np.square(distances[indices_far]) - \
1✔
272
                                    12.0 / np.square(np.e) / vanderwaals_radius * distances[indices_far] + \
273
                                    9.0 / np.square(np.e)
274

275
        return density_data
1✔
276

277
    def map_feature(
1✔
278
        self,
279
        position: np.ndarray,
280
        feature_name: str,
281
        feature_value: Union[np.ndarray, float],
282
        method: MapMethod,
283
    ):
284
        """
285
        Maps point feature data at a given position to the grid, using the given method.
286
        The feature_value should either be a single number or a one-dimensional array
287
        """
288

289
        # determine whether we're dealing with a single number of multiple numbers:
290
        index_names_values = []
1✔
291
        if isinstance(feature_value, float):
1✔
292
            index_names_values = [(feature_name, feature_value)]
1✔
293

294
        elif isinstance(feature_value, int):
1!
295
            index_names_values = [(feature_name, float(feature_value))]
×
296

297
        else:
298
            for index, value in enumerate(feature_value):
1✔
299
                index_name = f"{feature_name}_{index:03d}"
1✔
300
                index_names_values.append((index_name, value))
1✔
301

302
        # map the data to the grid
303
        for index_name, value in index_names_values:
1✔
304

305
            if method == MapMethod.GAUSSIAN:
1!
306
                grid_data = self._get_mapped_feature_gaussian(position, value)
1✔
307

308
            elif method == MapMethod.FAST_GAUSSIAN:
×
309
                grid_data = self._get_mapped_feature_fast_gaussian(position, value)
×
310

311
            # elif method == MapMethod.BSP_LINE:
312
            #     grid_data = self._get_mapped_feature_bsp_line(position, value)
313

314
            elif method == MapMethod.NEAREST_NEIGHBOUR:
×
315
                grid_data = self._get_mapped_feature_nearest_neighbour(position, value)
×
316

317
            # set to grid
318
            self.add_feature_values(index_name, grid_data)
1✔
319

320
    def to_hdf5(self, hdf5_path: str):
1✔
321
        "Write the grid data to hdf5, according to deeprank standards."
322

323
        with h5py.File(hdf5_path, "a") as hdf5_file:
1✔
324

325
            # create a group to hold everything
326
            grid_group = hdf5_file.require_group(self.id)
1✔
327

328
            # store grid points
329
            points_group = grid_group.require_group("grid_points")
1✔
330
            points_group.create_dataset("x", data=self.xs)
1✔
331
            points_group.create_dataset("y", data=self.ys)
1✔
332
            points_group.create_dataset("z", data=self.zs)
1✔
333
            points_group.create_dataset("center", data=self.center)
1✔
334

335
            # store grid features
336
            features_group = grid_group.require_group(gridstorage.MAPPED_FEATURES)
1✔
337
            for feature_name, feature_data in self.features.items():
1✔
338

339
                feature_group = features_group.require_group(feature_name)
1✔
340
                feature_group.create_dataset(
1✔
341
                    gridstorage.FEATURE_VALUE,
342
                    data=feature_data,
343
                    compression="lzf",
344
                    chunks=True,
345
                )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc