• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

DeepRank / deeprank-core / 4075652401

pending completion
4075652401

Pull #330

github

GitHub
Merge 45ea1393e into d73e8c34f
Pull Request #330: fix: data generation threading locked

1046 of 1331 branches covered (78.59%)

Branch coverage included in aggregate %.

36 of 36 new or added lines in 2 files covered. (100.0%)

2949 of 3482 relevant lines covered (84.69%)

0.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.29
/deeprankcore/query.py
1
import logging
1✔
2
import os
1✔
3
from typing import Dict, List, Optional, Iterator, Union
1✔
4
import tempfile
1✔
5
import pdb2sql
1✔
6
import pickle
1✔
7
from glob import glob
1✔
8
from types import ModuleType
1✔
9
from functools import partial
1✔
10
from multiprocessing import Pool
1✔
11
import importlib
1✔
12
from os.path import basename
1✔
13
import h5py
1✔
14
import pkgutil
1✔
15
import numpy as np
1✔
16
from deeprankcore.utils.graph import Graph
1✔
17
from deeprankcore.utils.grid import GridSettings, MapMethod
1✔
18
from deeprankcore.molstruct.aminoacid import AminoAcid
1✔
19
from deeprankcore.molstruct.residue import get_residue_center
1✔
20
from deeprankcore.molstruct.atom import Atom
1✔
21
from deeprankcore.molstruct.structure import PDBStructure
1✔
22
from deeprankcore.utils.buildgraph import (
1✔
23
    get_contact_atoms,
24
    get_surrounding_residues,
25
    get_structure,
26
    add_hydrogens,
27
)
28
from deeprankcore.utils.parsing.pssm import parse_pssm
1✔
29
from deeprankcore.utils.graph import build_residue_graph, build_atomic_graph
1✔
30
from deeprankcore.molstruct.variant import SingleResidueVariant
1✔
31
import deeprankcore.features
1✔
32

33

34
_log = logging.getLogger(__name__)
1✔
35

36

37
class Query:
1✔
38

39
    def __init__(self, model_id: str, targets: Optional[Dict[str, Union[float, int]]] = None):
1✔
40
        """
41
        Represents one entity of interest, like a single residue variant or a protein-protein interface.
42
            :class:`Query` objects are used to generate graphs from structures, and they should be created before any model is loaded.
43
            They can have target values associated with them, which will be stored with the resulting graph.
44

45
        Args:
46
            model_id(str): The ID of the model to load, usually a .PDB accession code.
47
            targets(Dict[str, Union[float, int]], optional): Target values associated with the query, defaults to None.
48
        """
49

50
        self._model_id = model_id
1✔
51

52
        if targets is None:
1✔
53
            self._targets = {}
1✔
54
        else:
55
            self._targets = targets
1✔
56

57
    def _set_graph_targets(self, graph: Graph):
1✔
58
        "Simply copies target data from query to graph."
59

60
        for target_name, target_data in self._targets.items():
1✔
61
            graph.targets[target_name] = target_data
1✔
62

63
    def _load_structure(
1✔
64
        self, pdb_path: str, pssm_paths: Optional[Dict[str, str]],
65
        include_hydrogens: bool
66
    ):
67
        "A helper function, to build the structure from .PDB and .PSSM files."
68

69
        # make a copy of the pdb, with hydrogens
70
        pdb_name = os.path.basename(pdb_path)
1✔
71
        hydrogen_pdb_file, hydrogen_pdb_path = tempfile.mkstemp(
1✔
72
            prefix="hydrogenated-", suffix=pdb_name
73
        )
74
        os.close(hydrogen_pdb_file)
1✔
75

76
        if include_hydrogens:
1!
77
            add_hydrogens(pdb_path, hydrogen_pdb_path)
×
78

79
            # read the .PDB copy
80
            try:
×
81
                pdb = pdb2sql.pdb2sql(hydrogen_pdb_path)
×
82
            finally:
83
                os.remove(hydrogen_pdb_path)
×
84
        else:
85
            pdb = pdb2sql.pdb2sql(pdb_path)
1✔
86

87
        try:
1✔
88
            structure = get_structure(pdb, self.model_id)
1✔
89
        finally:
90
            pdb._close() # pylint: disable=protected-access
1✔
91

92
        # read the pssm
93
        if pssm_paths is not None:
1!
94
            for chain in structure.chains:
1✔
95
                if chain.id in pssm_paths:
1!
96
                    pssm_path = pssm_paths[chain.id]
1✔
97

98
                    with open(pssm_path, "rt", encoding="utf-8") as f:
1✔
99
                        chain.pssm = parse_pssm(f, chain)
1✔
100

101
        return structure
1✔
102

103
    @property
1✔
104
    def model_id(self) -> str:
1✔
105
        "The ID of the model, usually a .PDB accession code."
106
        return self._model_id
1✔
107

108
    @model_id.setter
1✔
109
    def model_id(self, value):
1✔
110
        self._model_id = value
1✔
111

112
    @property
1✔
113
    def targets(self) -> Dict[str, float]:
1✔
114
        "The target values associated with the query."
115
        return self._targets
×
116

117
    def __repr__(self) -> str:
1✔
118
        return f"{type(self)}({self.get_query_id()})"
×
119

120

121
class QueryCollection:
1✔
122
    """
123
    Represents the collection of data queries.
124
        Queries can be saved as a dictionary to easily navigate through their data.
125
    
126
    """
127

128
    def __init__(self):
1✔
129

130
        self._queries = []
1✔
131
        self.cpu_count = None
1✔
132
        self.ids_count = {}
1✔
133

134
    def add(self, query: Query, verbose: bool = False):
1✔
135
        """
136
        Adds a new query to the collection.
137

138
        Args:
139
            query(:class:`Query`): Must be a :class:`Query` object, either :class:`ProteinProteinInterfaceResidueQuery` or
140
                :class:`SingleResidueVariantAtomicQuery`.
141
                
142
            verbose(bool, optional): For logging query IDs added, defaults to False.
143
        """
144
        query_id = query.get_query_id()
1✔
145

146
        if verbose:
1!
147
            _log.info(f'Adding query with ID {query_id}.')
×
148

149
        if query_id not in self.ids_count:
1✔
150
            self.ids_count[query_id] = 1
1✔
151
        else:
152
            self.ids_count[query_id] += 1
1✔
153
            new_id = query.model_id + "_" + str(self.ids_count[query_id])
1✔
154
            query.model_id = new_id
1✔
155
            _log.warning(f'Query with ID {query_id} has already been added to the collection. Renaming it as {query.get_query_id()}')
1✔
156

157
        self._queries.append(query)
1✔
158

159
    def export_dict(self, dataset_path: str):
1✔
160
        """Exports the colection of all queries to a dictionary file.
161
        
162
            Args:
163
                dataset_path(str): The path where to save the list of queries.
164
        """
165
        with open(dataset_path, "wb") as pkl_file:
×
166
            pickle.dump(self, pkl_file)    
×
167
            
168
    @property
1✔
169
    def queries(self) -> List[Query]:
1✔
170
        "The list of queries added to the collection."
171
        return self._queries
1✔
172

173
    def __contains__(self, query: Query) -> bool:
1✔
174
        return query in self._queries
×
175

176
    def __iter__(self) -> Iterator[Query]:
1✔
177
        return iter(self._queries)
1✔
178

179
    def _process_one_query(  # pylint: disable=too-many-arguments
1✔
180
        self,
181
        prefix: str,
182
        feature_names: List[str],
183
        grid_settings: Union[GridSettings, None],
184
        grid_map_method: Union[MapMethod, None],
185
        query: Query):
186

187
        try:
×
188
            # because only one process may access an hdf5 file at the time:
189
            output_path = f"{prefix}-{os.getpid()}.hdf5"
×
190

191
            feature_modules = [
×
192
                importlib.import_module('deeprankcore.features.' + name) for name in feature_names]
193

194
            graph = query.build(feature_modules)
×
195
            graph.write_to_hdf5(output_path)
×
196

197
            if grid_settings is not None and grid_map_method is not None:
×
198
                graph.write_as_grid_to_hdf5(output_path, grid_settings, grid_map_method)
×
199
            
200
            return None
×
201

202
        except (ValueError, AttributeError, KeyError, TimeoutError) as e:
×
203
            _log.warning(f'\nGraph/Query with ID {query.get_query_id()} run into an Exception ({e.__class__.__name__}: {e}),'
×
204
            ' and it has not been written to the hdf5 file. More details below:')
205
            _log.exception(e)
×
206
            return None
×
207

208
    def process( # pylint: disable=too-many-arguments, too-many-locals
1✔
209
        self, 
210
        prefix: Optional[str] = None,
211
        feature_modules: List[ModuleType] = None,
212
        cpu_count: Optional[int] = None,
213
        combine_output: Optional[bool] = True,
214
        grid_settings: Optional[GridSettings] = None,
215
        grid_map_method: Optional[MapMethod] = None
216
        ) -> List[str]:
217
        """
218
        Args:
219
            prefix(str, optional): Prefix for the output files. Defaults to None, which sets ./processed-queries- prefix.
220
            
221
            feature_modules(List[ModuleType], optional): List of features' modules used to generate features. Each feature's module must
222
                implement the :py:func:`add_features` function, and features' modules can be found (or should be placed in case of a custom made feature)
223
                in `deeprankcore.features` folder. Defaults to None, which means that all available modules in `deeprankcore.features` are used to generate
224
                the features. 
225
            
226
            cpu_count(int, optional): How many processes to be run simultaneously. Defaults to None, which takes all available cpu cores.
227
            
228
            combine_output(bool, optional): For combining the .HDF5 files generated by the processes, defaults to True.
229

230
            grid_settings(:class:`GridSettings`, optional): if valid together with `grid_map_method`, the grid data will be stored as well.
231
                Defaults to None.
232
            
233
            grid_map_method(:class:`MapMethod`, optional): if valid together with `grid_settings`, the grid data will be stored as well.
234
                Defaults to None.
235
        
236
        Returns:
237
            List(str): The list of paths of the generated .HDF5 files.
238
        """
239

240
        if cpu_count is None:
1✔
241
            # returns the number of CPUs in the system
242
            cpu_count = os.cpu_count()
1✔
243
        else:
244
            cpu_count_system = os.cpu_count()
1✔
245
            if cpu_count > cpu_count_system:
1!
246
                _log.warning(f'\nTried to set {cpu_count} CPUs, but only {cpu_count_system} are present in the system.')
×
247
                cpu_count = cpu_count_system
×
248
        
249
        self.cpu_count = cpu_count
1✔
250

251
        _log.info(f'\nNumber of CPUs for processing the queries set to: {self.cpu_count}.')
1✔
252

253
        if prefix is None:
1!
254
            prefix = "processed-queries"
×
255
        
256
        if feature_modules is None:
1✔
257
            feature_names = [modname for _, modname, _ in pkgutil.iter_modules(deeprankcore.features.__path__)]
1✔
258
        else:
259
            feature_names = [basename(m.__file__)[:-3] for m in feature_modules]
1✔
260

261
        _log.info(f'Creating pool function to process {len(self.queries)} queries...')
1✔
262
        pool_function = partial(self._process_one_query, prefix,
1✔
263
                                feature_names,
264
                                grid_settings, grid_map_method)
265

266
        with Pool(self.cpu_count) as pool:
1✔
267
            _log.info('Starting pooling...\n')
1✔
268
            pool.map(pool_function, self.queries)
1✔
269

270
        output_paths = glob(f"{prefix}-*.hdf5")
1✔
271

272
        if combine_output:
1✔
273
            for output_path in output_paths:
1✔
274
                with h5py.File(f"{prefix}.hdf5",'a') as f_dest, h5py.File(output_path,'r') as f_src:
1✔
275
                    for _, value in f_src.items():
1✔
276
                        f_src.copy(value, f_dest)
1✔
277
                os.remove(output_path)
1✔
278
            return glob(f"{prefix}.hdf5")
1✔
279

280
        return output_paths
1✔
281

282

283
class SingleResidueVariantResidueQuery(Query):
1✔
284

285
    def __init__(  # pylint: disable=too-many-arguments
1✔
286
        self,
287
        pdb_path: str,
288
        chain_id: str,
289
        residue_number: int,
290
        insertion_code: str,
291
        wildtype_amino_acid: AminoAcid,
292
        variant_amino_acid: AminoAcid,
293
        pssm_paths: Optional[Dict[str, str]] = None,
294
        radius: Optional[float] = 10.0,
295
        distance_cutoff: Optional[float] = 4.5,
296
        targets: Optional[Dict[str, float]] = None,
297
    ):
298
        """
299
        Creates a residue graph from a single residue variant in a .PDB file.
300

301
        Args:
302
            pdb_path(str): The path to the .PDB file.
303

304
            chain_id(str): The .PDB chain identifier of the variant residue.
305

306
            residue_number(int): The number of the variant residue.
307

308
            insertion_code(str): The insertion code of the variant residue, set to None if not applicable.
309

310
            wildtype_amino_acid(:class:`AminoAcid`): The wildtype amino acid.
311

312
            variant_amino_acid(:class:`AminoAcid`): The variant amino acid.
313

314
            pssm_paths(Dict(str,str), optional): The paths to the .PSSM files, per chain identifier. Defaults to None.
315

316
            radius(float, optional): In Ångström, determines how many residues will be included in the graph. Defaults to 10.0.
317

318
            distance_cutoff(float, optional): Max distance in Ångström between a pair of atoms to consider them as an external edge in the graph.
319
                Defaults to 4.5.
320

321
            targets(Dict(str,float), optional): Named target values associated with this query. Defaults to None.
322
        """
323

324
        self._pdb_path = pdb_path
1✔
325
        self._pssm_paths = pssm_paths
1✔
326

327
        model_id = os.path.splitext(os.path.basename(pdb_path))[0]
1✔
328

329
        Query.__init__(self, model_id, targets)
1✔
330

331
        self._chain_id = chain_id
1✔
332
        self._residue_number = residue_number
1✔
333
        self._insertion_code = insertion_code
1✔
334
        self._wildtype_amino_acid = wildtype_amino_acid
1✔
335
        self._variant_amino_acid = variant_amino_acid
1✔
336

337
        self._radius = radius
1✔
338
        self._distance_cutoff = distance_cutoff
1✔
339

340
    @property
1✔
341
    def residue_id(self) -> str:
1✔
342
        "String representation of the residue number and insertion code."
343

344
        if self._insertion_code is not None:
1!
345

346
            return f"{self._residue_number}{self._insertion_code}"
×
347

348
        return str(self._residue_number)
1✔
349

350
    def get_query_id(self) -> str:
1✔
351
        "Returns the string representing the complete query ID."
352
        return f"residue-graph-{self.model_id}:{self._chain_id}:{self.residue_id}:{self._wildtype_amino_acid.name}->{self._variant_amino_acid.name}"
1✔
353

354
    def build(self, feature_modules: List, include_hydrogens: bool = False) -> Graph:
1✔
355
        """
356
        Builds the graph from the .PDB structure.
357

358
        Args:
359
            feature_modules(List[ModuleType]): Each must implement the :py:func:`add_features` function.
360

361
            include_hydrogens(bool, optional): Whether to include hydrogens in the :class:`Graph`, defaults to False.
362

363
        Returns:
364
            :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. 
365
        """
366

367
        # load .PDB structure
368
        structure = self._load_structure(self._pdb_path, self._pssm_paths, include_hydrogens)
1✔
369

370
        # find the variant residue
371
        variant_residue = None
1✔
372
        for residue in structure.get_chain(self._chain_id).residues:
1!
373
            if (
1✔
374
                residue.number == self._residue_number
375
                and residue.insertion_code == self._insertion_code
376
            ):
377
                variant_residue = residue
1✔
378
                break
1✔
379

380
        if variant_residue is None:
1!
381
            raise ValueError(
×
382
                f"Residue not found in {self._pdb_path}: {self._chain_id} {self.residue_id}"
383
            )
384

385
        # define the variant
386
        variant = SingleResidueVariant(variant_residue, self._variant_amino_acid)
1✔
387

388
        # select which residues will be the graph
389
        residues = list(get_surrounding_residues(structure, residue, self._radius)) # pylint: disable=undefined-loop-variable
1✔
390

391
        # build the graph
392
        graph = build_residue_graph(
1✔
393
            residues, self.get_query_id(), self._distance_cutoff
394
        )
395

396
        # add data to the graph
397
        self._set_graph_targets(graph)
1✔
398

399
        for feature_module in feature_modules:
1✔
400
            feature_module.add_features(self._pdb_path, graph, variant)
1✔
401

402
        graph.center = get_residue_center(variant_residue)
1✔
403
        return graph
1✔
404

405

406
class SingleResidueVariantAtomicQuery(Query):
1✔
407

408
    def __init__(  # pylint: disable=too-many-arguments
1✔
409
        self,
410
        pdb_path: str,
411
        chain_id: str,
412
        residue_number: int,
413
        insertion_code: str,
414
        wildtype_amino_acid: AminoAcid,
415
        variant_amino_acid: AminoAcid,
416
        pssm_paths: Optional[Dict[str, str]] = None,
417
        radius: Optional[float] = 10.0,
418
        distance_cutoff: Optional[float] = 4.5,
419
        targets: Optional[Dict[str, float]] = None,
420
    ):
421
        """
422
        Creates an atomic graph for a single residue variant in a .PDB file.
423

424
        Args:
425
            pdb_path(str): The path to the .PDB file.
426

427
            chain_id(str): The .PDB chain identifier of the variant residue.
428

429
            residue_number(int): The number of the variant residue.
430

431
            insertion_code(str): The insertion code of the variant residue, set to None if not applicable.
432

433
            wildtype_amino_acid(deeprank amino acid object): The wildtype amino acid.
434

435
            variant_amino_acid(deeprank amino acid object): The variant amino acid.
436

437
            pssm_paths(dict(str,str), optional): The paths to the .PSSM files, per chain identifier. Defaults to None.
438

439
            radius(float, optional): In Ångström, determines how many residues will be included in the graph. Defaults to 10.0. 
440

441
            distance_cutoff(float, optional): Max distance in Ångström between a pair of atoms to consider them as an external edge in the graph.
442
                Defaults to 4.5.
443

444
            targets(dict(str,float), optional): Named target values associated with this query. Defaults to None.
445
        """
446

447
        self._pdb_path = pdb_path
1✔
448
        self._pssm_paths = pssm_paths
1✔
449

450
        model_id = os.path.splitext(os.path.basename(pdb_path))[0]
1✔
451

452
        Query.__init__(self, model_id, targets)
1✔
453

454
        self._chain_id = chain_id
1✔
455
        self._residue_number = residue_number
1✔
456
        self._insertion_code = insertion_code
1✔
457
        self._wildtype_amino_acid = wildtype_amino_acid
1✔
458
        self._variant_amino_acid = variant_amino_acid
1✔
459

460
        self._radius = radius
1✔
461

462
        self._distance_cutoff = distance_cutoff
1✔
463

464
    @property
1✔
465
    def residue_id(self) -> str:
1✔
466
        "String representation of the residue number and insertion code."
467

468
        if self._insertion_code is not None:
1!
469
            return f"{self._residue_number}{self._insertion_code}"
×
470

471
        return str(self._residue_number)
1✔
472

473
    def get_query_id(self) -> str:
1✔
474
        "Returns the string representing the complete query ID."
475
        return f"{self.model_id,}:{self._chain_id}:{self.residue_id}:{self._wildtype_amino_acid.name}->{self._variant_amino_acid.name}"
1✔
476

477
    def __eq__(self, other) -> bool:
1✔
478
        return (
×
479
            isinstance(self, type(other))
480
            and self.model_id == other.model_id
481
            and self._chain_id == other._chain_id
482
            and self.residue_id == other.residue_id
483
            and self._wildtype_amino_acid == other._wildtype_amino_acid
484
            and self._variant_amino_acid == other._variant_amino_acid
485
        )
486

487
    def __hash__(self) -> hash:
1✔
488
        return hash(
×
489
            (
490
                self.model_id,
491
                self._chain_id,
492
                self.residue_id,
493
                self._wildtype_amino_acid,
494
                self._variant_amino_acid,
495
            )
496
        )
497

498
    @staticmethod
1✔
499
    def _get_atom_node_key(atom) -> str:
1✔
500
        """Pickle has problems serializing the graph when the nodes are atoms,
501
        so use this function to generate an unique key for the atom"""
502

503
        # This should include the model, chain, residue and atom
504
        return str(atom)
×
505

506
    def build(self, feature_modules: List, include_hydrogens: bool = False) -> Graph:
1✔
507
        """
508
        Builds the graph from the .PDB structure.
509

510
        Args:
511
            feature_modules(List[ModuleType]): Each must implement the :py:func:`add_features` function.
512

513
            include_hydrogens(bool, optional): Whether to include hydrogens in the :class:`Graph`, defaults to False.
514

515
        Returns:
516
            :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. 
517
        """
518

519
        # load .PDB structure
520
        structure = self._load_structure(self._pdb_path, self._pssm_paths, include_hydrogens)
1✔
521

522
        # find the variant residue
523
        variant_residue = None
1✔
524
        for residue in structure.get_chain(self._chain_id).residues:
1!
525
            if (
1✔
526
                residue.number == self._residue_number
527
                and residue.insertion_code == self._insertion_code
528
            ):
529
                variant_residue = residue
1✔
530
                break
1✔
531

532
        if variant_residue is None:
1!
533
            raise ValueError(
×
534
                f"Residue not found in {self._pdb_path}: {self._chain_id} {self.residue_id}"
535
            )
536

537
        # define the variant
538
        variant = SingleResidueVariant(variant_residue, self._variant_amino_acid)
1✔
539

540
        # get the residues and atoms involved
541
        residues = get_surrounding_residues(structure, variant_residue, self._radius)
1✔
542
        residues.add(variant_residue)
1✔
543
        atoms = set([])
1✔
544
        for residue in residues:
1✔
545
            if residue.amino_acid is not None:
1!
546
                for atom in residue.atoms:
1✔
547
                    atoms.add(atom)
1✔
548
        atoms = list(atoms)
1✔
549

550
        # build the graph
551
        graph = build_atomic_graph(
1✔
552
            atoms, self.get_query_id(), self._distance_cutoff
553
        )
554

555
        # add data to the graph
556
        self._set_graph_targets(graph)
1✔
557

558
        for feature_module in feature_modules:
1✔
559
            feature_module.add_features(self._pdb_path, graph, variant)
1✔
560

561
        graph.center = get_residue_center(variant_residue)
1✔
562
        return graph
1✔
563

564

565
def _load_ppi_atoms(pdb_path: str,
1✔
566
                    chain_id1: str, chain_id2: str,
567
                    distance_cutoff: float,
568
                    include_hydrogens: bool) -> List[Atom]:
569

570
    # get the contact atoms
571
    if include_hydrogens:
1!
572

573
        pdb_name = os.path.basename(pdb_path)
×
574
        hydrogen_pdb_file, hydrogen_pdb_path = tempfile.mkstemp(
×
575
            prefix="hydrogenated-", suffix=pdb_name
576
        )
577
        os.close(hydrogen_pdb_file)
×
578

579
        add_hydrogens(pdb_path, hydrogen_pdb_path)
×
580

581
        try:
×
582
            contact_atoms = get_contact_atoms(hydrogen_pdb_path,
×
583
                                              chain_id1, chain_id2,
584
                                              distance_cutoff)
585
        finally:
586
            os.remove(hydrogen_pdb_path)
×
587
    else:
588
        contact_atoms = get_contact_atoms(pdb_path,
1✔
589
                                          chain_id1, chain_id2,
590
                                          distance_cutoff)
591

592
    if len(contact_atoms) == 0:
1!
593
        raise ValueError("no contact atoms found")
×
594

595
    return contact_atoms
1✔
596

597

598
def _load_ppi_pssms(pssm_paths: Union[Dict[str, str], None],
1✔
599
                    chain_id1: str, chain_id2: str,
600
                    structure: PDBStructure):
601

602
    if pssm_paths is not None:
1✔
603
        for chain_id in [chain_id1, chain_id2]:
1✔
604
            if chain_id in pssm_paths:
1!
605

606
                chain = structure.get_chain(chain_id)
1✔
607

608
                pssm_path = pssm_paths[chain_id]
1✔
609

610
                with open(pssm_path, "rt", encoding="utf-8") as f:
1✔
611
                    chain.pssm = parse_pssm(f, chain)
1✔
612

613

614

615
class ProteinProteinInterfaceAtomicQuery(Query):
1✔
616

617
    def __init__(  # pylint: disable=too-many-arguments
1✔
618
        self,
619
        pdb_path: str,
620
        chain_id1: str,
621
        chain_id2: str,
622
        pssm_paths: Optional[Dict[str, str]] = None,
623
        distance_cutoff: Optional[float] = 5.5,
624
        targets: Optional[Dict[str, float]] = None,
625
    ):
626
        """
627
        A query that builds atom-based graphs, using the residues at a protein-protein interface.
628

629
        Args:
630
            pdb_path(str): The path to the .PDB file.
631

632
            chain_id1(str): The .PDB chain identifier of the first protein of interest.
633

634
            chain_id2(str): The .PDB chain identifier of the second protein of interest.
635

636
            pssm_paths(dict(str,str), optional): The paths to the .PSSM files, per chain identifier. Defaults to None.
637

638
            distance_cutoff(float, optional): Max distance in Ångström between two interacting atoms of the two proteins,
639
                defaults to 5.5.
640

641
            targets(dict, optional): Named target values associated with this query, defaults to None.
642
        """
643

644
        model_id = os.path.splitext(os.path.basename(pdb_path))[0]
1✔
645

646
        Query.__init__(self, model_id, targets)
1✔
647

648
        self._pdb_path = pdb_path
1✔
649

650
        self._chain_id1 = chain_id1
1✔
651
        self._chain_id2 = chain_id2
1✔
652

653
        self._pssm_paths = pssm_paths
1✔
654

655
        self._distance_cutoff = distance_cutoff
1✔
656

657
    def get_query_id(self) -> str:
1✔
658
        "Returns the string representing the complete query ID."
659
        return f"atom-ppi-{self.model_id}:{self._chain_id1}-{self._chain_id2}"
1✔
660

661
    def __eq__(self, other) -> bool:
1✔
662
        return (
×
663
            isinstance(self, type(other))
664
            and self.model_id == other.model_id
665
            and {self._chain_id1, self._chain_id2}
666
            == {other._chain_id1, other._chain_id2}
667
        )
668

669
    def __hash__(self) -> hash:
1✔
670
        return hash((self.model_id, tuple(sorted([self._chain_id1, self._chain_id2]))))
×
671

672
    def build(self, feature_modules: List, include_hydrogens: bool = False) -> Graph:
1✔
673
        """
674
        Builds the graph from the .PDB structure.
675

676
        Args:
677
            feature_modules(List[ModuleType]): Each must implement the :py:func:`add_features` function.
678

679
            include_hydrogens(bool, optional): Whether to include hydrogens in the :class:`Graph`, defaults to False.
680

681
        Returns:
682
            :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. 
683
        """
684

685
        contact_atoms = _load_ppi_atoms(self._pdb_path,
1✔
686
                                        self._chain_id1, self._chain_id2,
687
                                        self._distance_cutoff,
688
                                        include_hydrogens)
689

690
        # build the graph
691
        graph = build_atomic_graph(
1✔
692
            contact_atoms, self.get_query_id(), self._distance_cutoff
693
        )
694

695
        # add data to the graph
696
        self._set_graph_targets(graph)
1✔
697

698
        # read the pssm
699
        structure = contact_atoms[0].residue.chain.model
1✔
700

701
        _load_ppi_pssms(self._pssm_paths,
1✔
702
                        self._chain_id1, self._chain_id2,
703
                        structure)
704

705
        # add the features
706
        for feature_module in feature_modules:
1✔
707
            feature_module.add_features(self._pdb_path, graph)
1✔
708

709
        graph.center = np.mean([atom.position for atom in contact_atoms], axis=0)
1✔
710
        return graph
1✔
711

712

713
class ProteinProteinInterfaceResidueQuery(Query):
1✔
714

715
    def __init__(  # pylint: disable=too-many-arguments
1✔
716
        self,
717
        pdb_path: str,
718
        chain_id1: str,
719
        chain_id2: str,
720
        pssm_paths: Optional[Dict[str, str]] = None,
721
        distance_cutoff: float = 10,
722
        targets: Optional[Dict[str, float]] = None,
723
    ):
724
        """
725
        A query that builds residue-based graphs, using the residues at a protein-protein interface.
726

727
        Args:
728
            pdb_path(str): The path to the .PDB file.
729

730
            chain_id1(str): The .PDB chain identifier of the first protein of interest.
731

732
            chain_id2(str): The .PDB chain identifier of the second protein of interest.
733

734
            pssm_paths(dict(str,str), optional): The paths to the .PSSM files, per chain identifier. Defaults to None.
735

736
            distance_cutoff(float, optional): Max distance in Ångström between two interacting residues of the two proteins,
737
                defaults to 10.
738

739
            targets(dict, optional): Named target values associated with this query, defaults to None.
740
        """
741

742
        model_id = os.path.splitext(os.path.basename(pdb_path))[0]
1✔
743

744
        Query.__init__(self, model_id, targets)
1✔
745

746
        self._pdb_path = pdb_path
1✔
747

748
        self._chain_id1 = chain_id1
1✔
749
        self._chain_id2 = chain_id2
1✔
750

751
        self._pssm_paths = pssm_paths
1✔
752

753
        self._distance_cutoff = distance_cutoff
1✔
754

755
    def get_query_id(self) -> str:
1✔
756
        return f"residue-ppi-{self.model_id}:{self._chain_id1}-{self._chain_id2}"
1✔
757

758
    def __eq__(self, other) -> bool:
1✔
759
        return (
×
760
            isinstance(self, type(other))
761
            and self.model_id == other.model_id
762
            and {self._chain_id1, self._chain_id2}
763
            == {other._chain_id1, other._chain_id2}
764
        )
765

766
    def __hash__(self) -> hash:
1✔
767
        return hash((self.model_id, tuple(sorted([self._chain_id1, self._chain_id2]))))
×
768

769
    def build(self, feature_modules: List, include_hydrogens: bool = False) -> Graph:
1✔
770
        """
771
        Builds the graph from the .PDB structure.
772

773
        Args:
774
            feature_modules(List[ModuleType]): Each must implement the :py:func:`add_features` function.
775

776
            include_hydrogens(bool, optional): Whether to include hydrogens in the :class:`Graph`, defaults to False.
777

778
        Returns:
779
            :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. 
780
        """
781

782
        contact_atoms = _load_ppi_atoms(self._pdb_path,
1✔
783
                                        self._chain_id1, self._chain_id2,
784
                                        self._distance_cutoff,
785
                                        include_hydrogens)
786

787
        atom_positions = []
1✔
788
        residues_selected = set([])
1✔
789
        for atom in contact_atoms:
1✔
790
            atom_positions.append(atom.position)
1✔
791
            residues_selected.add(atom.residue)
1✔
792
        residues_selected = list(residues_selected)
1✔
793

794
        # build the graph
795
        graph = build_residue_graph(
1✔
796
            residues_selected, self.get_query_id(), self._distance_cutoff
797
        )
798

799
        # add data to the graph
800
        self._set_graph_targets(graph)
1✔
801

802
        # read the pssm
803
        structure = contact_atoms[0].residue.chain.model
1✔
804

805
        _load_ppi_pssms(self._pssm_paths,
1✔
806
                        self._chain_id1, self._chain_id2,
807
                        structure)
808

809
        # add the features
810
        for feature_module in feature_modules:
1✔
811
            feature_module.add_features(self._pdb_path, graph)
1✔
812

813
        graph.center = np.mean(atom_positions, axis=0)
1✔
814
        return graph
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc