• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kipoi / kipoiseq / ec61831f-c7d7-4830-bb04-f4867a29ea5d

03 Sep 2024 10:47AM UTC coverage: 0.0%. Remained the same
ec61831f-c7d7-4830-bb04-f4867a29ea5d

push

circleci

web-flow
Merge pull request #116 from gtsitsiridis/master

VariantSeqExtractor fix: allow padding if sequence gets out of bounds

0 of 19 new or added lines in 1 file covered. (0.0%)

1 existing line in 1 file now uncovered.

0 of 1759 relevant lines covered (0.0%)

0.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/kipoiseq/extractors/vcf_seq.py
1
import abc
×
2
from typing import Union
×
3

4
from pyfaidx import Sequence, complement
×
5
from kipoiseq.dataclasses import Interval
×
6
from kipoiseq.extractors import (
×
7
    BaseExtractor,
8
    FastaStringExtractor,
9
    MultiSampleVCF,
10
)
11

12
from kipoiseq import __version__
×
13
from deprecation import deprecated
×
NEW
14
import math
×
15

16
__all__ = [
×
17
    'VariantSeqExtractor',
18
    'SingleVariantVCFSeqExtractor',
19
    'SingleSeqVCFSeqExtractor'
20
]
21

22

23
class IntervalSeqBuilder(list):
×
24
    """
25
    String builder for `pyfaidx.Sequence` and `Interval` objects.
26
    """
27

28
    def restore(self, sequence: Sequence):
×
29
        """
30
        Args:
31
          sequence: `pyfaidx.Sequence` which convert all interval inside
32
            to `Seqeunce` objects.
33
        """
34
        for i, interval in enumerate(self):
×
35
            # interval.end can be bigger than interval.start
36
            interval_len = max(0, interval.end - interval.start)
×
37

38
            if type(self[i]) == Interval:
×
39
                start = interval.start - sequence.start
×
40
                end = start + interval_len
×
41
                self[i] = sequence[start: end]
×
42

43
    def _concat(self):
×
44
        for sequence in self:
×
45
            if type(sequence) != Sequence:
×
46
                raise TypeError('Intervals should be restored with `restore`'
×
47
                                ' method before calling concat method!')
48
            yield sequence.seq
×
49

50
    def concat(self):
×
51
        """
52
        Build the string from sequence objects.
53

54
        Returns:
55
          str: the final sequence.
56
        """
57
        return ''.join(self._concat())
×
58

59

60
class VariantSeqExtractor(BaseExtractor):
×
61

62
    def __init__(self, fasta_file: str = None, reference_sequence: BaseExtractor = None, use_strand=True):
×
63
        """
64
        Sequence extractor which allows to obtain the alternative sequence,
65
        given some interval and variants inside this interval.
66

67
        Args:
68
            fasta_file: path to the fasta file (can be gzipped)
69
            reference_sequence: extractor returning the reference sequence given some interval
70
            use_strand (bool): if True, the extracted sequence
71
                is reverse complemented in case interval.strand == "-"
72
        """
73
        self._use_strand = use_strand
×
74

75
        if fasta_file is not None:
×
76
            if reference_sequence is not None:
×
77
                raise ValueError(
×
78
                    "either fasta_file or ref_seq_extractor have to be specified")
79
            self._ref_seq_extractor = FastaStringExtractor(
×
80
                fasta_file, use_strand=False)
81
        else:
82
            if reference_sequence is None:
×
83
                raise ValueError(
×
84
                    "either fasta_file or ref_seq_extractor have to be specified")
85
            self._ref_seq_extractor = reference_sequence
×
86

87
    @property
×
88
    @deprecated(deprecated_in="1.0",
×
89
                # removed_in="2.0",
90
                current_version=__version__,
91
                details="Use `ref_seq_extractor` instead")
92
    def fasta(self):
93
        return self._ref_seq_extractor
×
94

95
    @property
×
96
    def ref_seq_extractor(self) -> BaseExtractor:
×
97
        """
98

99
        Returns:
100
            The reference sequence extractor of this object
101
        """
102
        return self._ref_seq_extractor
×
103

NEW
104
    def extract(self, interval, variants, anchor, fixed_len=True, use_strand=None, chrom_len=math.inf,
×
105
                is_padding=False, **kwargs):
106
        """
107
        Args:
108
            interval: pybedtools.Interval Region of interest from
109
                which to query the sequence. 0-based
110
            variants: List[cyvcf2.Variant]: variants overlapping the `interval`.
111
                can also be indels. 1-based
112
            anchor: absolution position w.r.t. the interval start. (0-based).
113
                E.g. for an interval of `chr1:10-20` the anchor of 10 denotes
114
                the point chr1:10 in the 0-based coordinate system.
115
            fixed_len: if True, the return sequence will have the same length
116
                as the `interval` (e.g. `interval.end - interval.start`)
117
            use_strand (bool, optional): if True, the extracted sequence
118
                is reverse complemented in case interval.strand == "-".
119
                Overrides `self.use_strand`
120
            chrom_len: length of the chromosome. If chrom_len == math.inf, the length of the chromosome is not checked.
121
            is_padding: if True, the sequence is padded with 'N's if sequence can't extend to the fixed length,
122

123
        Returns:
124
            A single sequence (`str`) with all the variants applied.
125
        """
126
        # Preprocessing
127
        anchor = max(min(anchor, interval.end), interval.start)
×
128
        variant_pairs = self._variant_to_sequence(variants)
×
129

130
        # 1. Split variants overlapping with anchor
131
        # and interval start end if not fixed_len
132
        variant_pairs = self._split_overlapping(variant_pairs, anchor)
×
133

134
        if not fixed_len:
×
135
            variant_pairs = self._split_overlapping(
×
136
                variant_pairs, interval.start, which='right')
137
            variant_pairs = self._split_overlapping(
×
138
                variant_pairs, interval.end, which='left')
139

140
        variant_pairs = list(variant_pairs)
×
141

142
        # 2. split the variants into upstream and downstream
143
        # and sort the variants in each interval
144
        upstream_variants = sorted(
×
145
            filter(lambda x: x[0].start >= anchor, variant_pairs),
146
            key=lambda x: x[0].start
147
        )
148

149
        downstream_variants = sorted(
×
150
            filter(lambda x: x[0].start < anchor, variant_pairs),
151
            key=lambda x: x[0].start,
152
            reverse=True
153
        )
154

155
        # 3. Extend start and end position for deletions
156
        if fixed_len:
×
157
            istart, iend = self._updated_interval(
×
158
                interval, upstream_variants, downstream_variants)
159
        else:
160
            istart, iend = interval.start, interval.end
×
161

NEW
162
        istart = max(istart, 0)
×
NEW
163
        iend = min(iend, chrom_len - 1)
×
164

165
        # 4. Iterate from the anchor point outwards. At each
166
        # register the interval from which to take the reference sequence
167
        # as well as the interval for the variant
168
        down_sb = self._downstream_builder(
×
169
            downstream_variants, interval, anchor, istart)
170

171
        up_sb = self._upstream_builder(
×
172
            upstream_variants, interval, anchor, iend)
173

174
        # 5. fetch the sequence and restore intervals in builder
175
        seq = self._fetch(interval, istart, iend)
×
176
        up_sb.restore(seq)
×
177
        down_sb.restore(seq)
×
178

179
        # 6. Concate sequences from the upstream and downstream splits. Concat
180
        # upstream and downstream sequence. Cut to fix the length.
181
        down_str = down_sb.concat()
×
182
        up_str = up_sb.concat()
×
183

184
        if fixed_len:
×
185
            down_str, up_str = self._cut_to_fix_len(
×
186
                down_str, up_str, interval, anchor, is_padding=is_padding)
187

188
        seq = down_str + up_str
×
189

190
        if use_strand is None:
×
191
            use_strand = self.use_strand
×
192
        if use_strand and interval.strand == '-':
×
193
            # reverse-complement
194
            seq = complement(seq)[::-1]
×
195

196
        return seq
×
197

198
    @staticmethod
×
199
    def _variant_to_sequence(variants):
200
        """
201
        Convert `cyvcf2.Variant` objects to `pyfaidx.Seqeunce` objects
202
        for reference and variants.
203
        """
204
        for v in variants:
×
205
            ref = Sequence(name=v.chrom, seq=v.ref,
×
206
                           start=v.start, end=v.start + len(v.ref))
207
            alt = Sequence(name=v.chrom, seq=v.alt,
×
208
                           start=v.start, end=v.start + len(v.alt))
209
            yield ref, alt
×
210

211
    @staticmethod
×
212
    def _split_overlapping(variant_pairs, anchor, which='both'):
×
213
        """
214
        Split the variants hitting the anchor into two
215
        """
216
        for ref, alt in variant_pairs:
×
217
            if ref.start < anchor < ref.end:
×
218
                mid = anchor - ref.start
×
219
                if which == 'left' or which == 'both':
×
220
                    yield ref[:mid], alt[:mid]
×
221
                if which == 'right' or which == 'both':
×
222
                    yield ref[mid:], alt[mid:]
×
223
            else:
224
                yield ref, alt
×
225

226
    @staticmethod
×
227
    def _updated_interval(interval, up_variants, down_variants):
228
        istart = interval.start
×
229
        iend = interval.end
×
230

231
        for ref, alt in up_variants:
×
232
            diff_len = len(alt) - len(ref)
×
233
            if diff_len < 0:
×
234
                iend -= diff_len
×
235

236
        for ref, alt in down_variants:
×
237
            diff_len = len(alt) - len(ref)
×
238
            if diff_len < 0:
×
239
                istart += diff_len
×
240

241
        return istart, iend
×
242

243
    @staticmethod
×
244
    def _downstream_builder(down_variants, interval, anchor, istart):
245
        down_sb = IntervalSeqBuilder()
×
246

247
        prev = anchor
×
248
        for ref, alt in down_variants:
×
249
            if ref.end <= istart:
×
250
                break
×
251
            down_sb.append(Interval(interval.chrom, ref.end, prev))
×
252
            down_sb.append(alt)
×
253
            prev = ref.start
×
254
        down_sb.append(Interval(interval.chrom, istart, prev))
×
255
        down_sb.reverse()
×
256

257
        return down_sb
×
258

259
    @staticmethod
×
260
    def _upstream_builder(up_variants, interval, anchor, iend):
261
        up_sb = IntervalSeqBuilder()
×
262

263
        prev = anchor
×
264
        for ref, alt in up_variants:
×
265
            if ref.start >= iend:
×
266
                break
×
267
            up_sb.append(Interval(interval.chrom, prev, ref.start))
×
268
            up_sb.append(alt)
×
269
            prev = ref.end
×
270
        up_sb.append(Interval(interval.chrom, prev, iend))
×
271

272
        return up_sb
×
273

274
    def _fetch(self, interval, istart, iend):
×
275
        # fetch interval, ignore strand
276
        seq = self.ref_seq_extractor.extract(
×
277
            Interval(interval.chrom, istart, iend))
278
        seq = Sequence(name=interval.chrom, seq=seq, start=istart, end=iend)
×
279
        return seq
×
280

281
    @staticmethod
×
NEW
282
    def _cut_to_fix_len(down_str, up_str, interval, anchor, is_padding=False):
×
283
        down_len = anchor - interval.start
×
NEW
284
        down_diff = len(down_str) - down_len
×
NEW
285
        if down_diff > 0:
×
NEW
286
            down_str = down_str[-down_len:]
×
NEW
287
        elif down_diff < 0:
×
NEW
288
            if is_padding:
×
NEW
289
                down_str = 'N' * abs(down_diff) + down_str
×
290
            else:
NEW
291
                raise ValueError(f"padding should be set to True, if the sequence can't extend to the fixed length")
×
292

293
        up_len = interval.end - anchor
×
NEW
294
        up_diff = len(up_str) - up_len
×
NEW
295
        if up_diff > 0:
×
NEW
296
            up_str = up_str[: up_len]
×
NEW
297
        elif up_diff < 0:
×
NEW
298
            if is_padding:
×
NEW
299
                up_str = up_str + 'N' * abs(up_diff)
×
300
            else:
NEW
301
                raise ValueError(f"padding should be set to True, if the sequence can't extend to the fixed length")
×
302

UNCOV
303
        return down_str, up_str
×
304

305

306
class _BaseVCFSeqExtractor(BaseExtractor, metaclass=abc.ABCMeta):
×
307
    """
308
    Base class to fetch sequence in which variants applied based
309
    on given vcf file.
310
    """
311

312
    def __init__(self, fasta_file, vcf_file):
×
313
        """
314
        Args:
315
          fasta_file: path to the fasta file (can be gzipped)
316
          vcf_file: path to the fasta file (need be bgzipped and indexed)
317
        """
318
        self.fasta_file = fasta_file
×
319
        self.vcf_file = vcf_file
×
320
        self.variant_extractor = VariantSeqExtractor(fasta_file)
×
321
        self.vcf = MultiSampleVCF(vcf_file)
×
322

323
    @abc.abstractmethod
×
324
    def extract(self, interval: Interval, *args, **kwargs) -> str:
×
325
        raise NotImplementedError()
×
326

327

328
class SingleVariantVCFSeqExtractor(_BaseVCFSeqExtractor):
×
329
    """
330
    Fetch list of sequence in which each variant applied based
331
    on given vcf file.
332
    """
333

334
    def extract(self, interval, anchor=None, sample_id=None, fixed_len=True):
×
335
        for variant in self.vcf.fetch_variants(interval, sample_id):
×
336
            yield self.variant_extractor.extract(
×
337
                interval,
338
                variants=[variant],
339
                anchor=anchor,
340
                fixed_len=fixed_len
341
            )
342

343

344
class SingleSeqVCFSeqExtractor(_BaseVCFSeqExtractor):
×
345
    """
346
    Fetch sequence in which all variant applied based on given vcf file.
347
    """
348

349
    def extract(self, interval, anchor=None, sample_id=None, fixed_len=True):
×
350
        return self.variant_extractor.extract(
×
351
            interval,
352
            variants=self.vcf.fetch_variants(interval, sample_id),
353
            anchor=anchor,
354
            fixed_len=fixed_len
355
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc