• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sgkit-dev / bio2zarr / 18940786519

30 Oct 2025 12:34PM UTC coverage: 98.361% (+0.07%) from 98.292%
18940786519

Pull #422

github

web-flow
Merge b33cd4646 into 3f955cef0
Pull Request #422: Update intro.md with VCF Zarr conversion info

2880 of 2928 relevant lines covered (98.36%)

3.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

100.0
/bio2zarr/vcf_utils.py
1
import contextlib
4✔
2
import gzip
4✔
3
import logging
4✔
4
import os
4✔
5
import pathlib
4✔
6
import struct
4✔
7
from collections.abc import Sequence
4✔
8
from dataclasses import dataclass
4✔
9
from enum import Enum
4✔
10
from typing import IO, Any
4✔
11

12
import humanfriendly
4✔
13
import numpy as np
4✔
14

15
from bio2zarr import core
4✔
16
from bio2zarr.typing import PathType
4✔
17

18
logger = logging.getLogger(__name__)
4✔
19

20
CSI_EXTENSION = ".csi"
4✔
21
TABIX_EXTENSION = ".tbi"
4✔
22
TABIX_LINEAR_INDEX_INTERVAL_SIZE = 1 << 14  # 16kb interval size
4✔
23

24

25
def ceildiv(a: int, b: int) -> int:
4✔
26
    """Safe integer ceil function"""
27
    return -(-a // b)
4✔
28

29

30
def get_file_offset(vfp: int) -> int:
4✔
31
    """Convert a block compressed virtual file pointer to a file offset."""
32
    address_mask = 0xFFFFFFFFFFFF
4✔
33
    return vfp >> 16 & address_mask
4✔
34

35

36
def read_bytes_as_value(f: IO[Any], fmt: str, nodata: Any | None = None) -> Any:
4✔
37
    """Read bytes using a `struct` format string and return the unpacked data value.
38

39
    Parameters
40
    ----------
41
    f : IO[Any]
42
        The IO stream to read bytes from.
43
    fmt : str
44
        A Python `struct` format string.
45
    nodata : Optional[Any], optional
46
        The value to return in case there is no further data in the stream,
47
        by default None
48

49
    Returns
50
    -------
51
    Any
52
        The unpacked data value read from the stream.
53
    """
54
    data = f.read(struct.calcsize(fmt))
4✔
55
    if not data:
4✔
56
        return nodata
4✔
57
    values = struct.Struct(fmt).unpack(data)
4✔
58
    assert len(values) == 1
4✔
59
    return values[0]
4✔
60

61

62
def read_bytes_as_tuple(f: IO[Any], fmt: str) -> Sequence[Any]:
4✔
63
    """Read bytes using a `struct` format string and return the unpacked data values.
64

65
    Parameters
66
    ----------
67
    f : IO[Any]
68
        The IO stream to read bytes from.
69
    fmt : str
70
        A Python `struct` format string.
71

72
    Returns
73
    -------
74
    Sequence[Any]
75
        The unpacked data values read from the stream.
76
    """
77
    data = f.read(struct.calcsize(fmt))
4✔
78
    return struct.Struct(fmt).unpack(data)
4✔
79

80

81
@dataclass
4✔
82
class Region:
4✔
83
    """
84
    A htslib style region, where coordinates are 1-based and inclusive.
85
    """
86

87
    contig: str
4✔
88
    start: int | None = None
4✔
89
    end: int | None = None
4✔
90

91
    def __post_init__(self):
4✔
92
        assert self.contig is not None
4✔
93
        if self.start is None:
4✔
94
            self.start = 1
4✔
95
        else:
96
            self.start = int(self.start)
4✔
97
            assert self.start > 0
4✔
98
        if self.end is not None:
4✔
99
            self.end = int(self.end)
4✔
100
            assert self.end >= self.start
4✔
101

102
    def __str__(self):
4✔
103
        s = f"{self.contig}"
4✔
104
        if self.start is not None:
4✔
105
            s += f":{self.start}-"
4✔
106
        if self.end is not None:
4✔
107
            s += str(self.end)
4✔
108
        return s
4✔
109

110
    # TODO add "parse" class methoda for when we accept regions
111
    # as input
112

113

114
@dataclass
4✔
115
class Chunk:
4✔
116
    cnk_beg: int
4✔
117
    cnk_end: int
4✔
118

119

120
@dataclass
4✔
121
class CSIBin:
4✔
122
    bin: int
4✔
123
    loffset: int
4✔
124
    chunks: Sequence[Chunk]
4✔
125

126

127
RECORD_COUNT_UNKNOWN = np.inf
4✔
128

129

130
@dataclass
4✔
131
class CSIIndex:
4✔
132
    min_shift: int
4✔
133
    depth: int
4✔
134
    aux: str
4✔
135
    bins: Sequence[Sequence[CSIBin]]
4✔
136
    record_counts: Sequence[int]
4✔
137
    n_no_coor: int
4✔
138

139
    def parse_vcf_aux(self):
4✔
140
        assert len(self.aux) > 0
4✔
141
        # The first 7 values form the Tabix header or something, but I don't
142
        # know how to interpret what's in there. The n_ref value doesn't seem
143
        # to correspond to the number of contigs at all anyway, so just
144
        # ignoring for now.
145
        # values = struct.Struct("<7i").unpack(self.aux[:28])
146
        # tabix_header = Header(*values, 0)
147
        names = self.aux[28:]
4✔
148
        # Convert \0-terminated names to strings
149
        sequence_names = [str(name, "utf-8") for name in names.split(b"\x00")[:-1]]
4✔
150
        return sequence_names
4✔
151

152
    def offsets(self) -> Any:
4✔
153
        pseudo_bin = bin_limit(self.min_shift, self.depth) + 1
4✔
154

155
        file_offsets = []
4✔
156
        contig_indexes = []
4✔
157
        positions = []
4✔
158
        for contig_index, bins in enumerate(self.bins):
4✔
159
            # bins may be in any order within a contig, so sort by loffset
160
            for bin in sorted(bins, key=lambda b: b.loffset):
4✔
161
                if bin.bin == pseudo_bin:
4✔
162
                    continue  # skip pseudo bins
4✔
163
                file_offset = get_file_offset(bin.loffset)
4✔
164
                position = get_first_locus_in_bin(self, bin.bin)
4✔
165
                file_offsets.append(file_offset)
4✔
166
                contig_indexes.append(contig_index)
4✔
167
                positions.append(position)
4✔
168

169
        return np.array(file_offsets), np.array(contig_indexes), np.array(positions)
4✔
170

171

172
def bin_limit(min_shift: int, depth: int) -> int:
4✔
173
    """Defined in CSI spec"""
174
    return ((1 << (depth + 1) * 3) - 1) // 7
4✔
175

176

177
def get_first_bin_in_level(level: int) -> int:
4✔
178
    return ((1 << level * 3) - 1) // 7
4✔
179

180

181
def get_level_size(level: int) -> int:
4✔
182
    return 1 << level * 3
4✔
183

184

185
def get_level_for_bin(csi: CSIIndex, bin: int) -> int:
4✔
186
    for i in range(csi.depth, -1, -1):
4✔
187
        if bin >= get_first_bin_in_level(i):
4✔
188
            return i
4✔
189
    raise ValueError(f"Cannot find level for bin {bin}.")  # pragma: no cover
190

191

192
def get_first_locus_in_bin(csi: CSIIndex, bin: int) -> int:
4✔
193
    level = get_level_for_bin(csi, bin)
4✔
194
    first_bin_on_level = get_first_bin_in_level(level)
4✔
195
    level_size = get_level_size(level)
4✔
196
    max_span = 1 << (csi.min_shift + 3 * csi.depth)
4✔
197
    return (bin - first_bin_on_level) * (max_span // level_size) + 1
4✔
198

199

200
def read_csi(file: PathType, storage_options: dict[str, str] | None = None) -> CSIIndex:
4✔
201
    """Parse a CSI file into a `CSIIndex` object.
202

203
    Parameters
204
    ----------
205
    file : PathType
206
        The path to the CSI file.
207

208
    Returns
209
    -------
210
    CSIIndex
211
        An object representing a CSI index.
212

213
    Raises
214
    ------
215
    ValueError
216
        If the file is not a CSI file.
217
    """
218
    with gzip.open(file) as f:
4✔
219
        magic = read_bytes_as_value(f, "4s")
4✔
220
        if magic != b"CSI\x01":
4✔
221
            raise ValueError("File not in CSI format.")
4✔
222

223
        min_shift, depth, l_aux = read_bytes_as_tuple(f, "<3i")
4✔
224
        aux = read_bytes_as_value(f, f"{l_aux}s", "")
4✔
225
        n_ref = read_bytes_as_value(f, "<i")
4✔
226

227
        pseudo_bin = bin_limit(min_shift, depth) + 1
4✔
228

229
        bins = []
4✔
230
        record_counts = []
4✔
231

232
        if n_ref > 0:
4✔
233
            for _ in range(n_ref):
4✔
234
                n_bin = read_bytes_as_value(f, "<i")
4✔
235
                seq_bins = []
4✔
236
                # Distinguish between counts that are zero because the sequence
237
                # isn't there, vs counts that aren't in the index.
238
                record_count = 0 if n_bin == 0 else RECORD_COUNT_UNKNOWN
4✔
239
                for _ in range(n_bin):
4✔
240
                    bin, loffset, n_chunk = read_bytes_as_tuple(f, "<IQi")
4✔
241
                    chunks = []
4✔
242
                    for _ in range(n_chunk):
4✔
243
                        chunk = Chunk(*read_bytes_as_tuple(f, "<QQ"))
4✔
244
                        chunks.append(chunk)
4✔
245
                    seq_bins.append(CSIBin(bin, loffset, chunks))
4✔
246

247
                    if bin == pseudo_bin:
4✔
248
                        assert len(chunks) == 2
4✔
249
                        n_mapped, n_unmapped = chunks[1].cnk_beg, chunks[1].cnk_end
4✔
250
                        record_count = n_mapped + n_unmapped
4✔
251
                bins.append(seq_bins)
4✔
252
                record_counts.append(record_count)
4✔
253

254
        n_no_coor = read_bytes_as_value(f, "<Q", 0)
4✔
255

256
        assert len(f.read(1)) == 0
4✔
257

258
        return CSIIndex(min_shift, depth, aux, bins, record_counts, n_no_coor)
4✔
259

260

261
@dataclass
4✔
262
class Header:
4✔
263
    n_ref: int
4✔
264
    format: int
4✔
265
    col_seq: int
4✔
266
    col_beg: int
4✔
267
    col_end: int
4✔
268
    meta: int
4✔
269
    skip: int
4✔
270
    l_nm: int
4✔
271

272

273
@dataclass
4✔
274
class TabixBin:
4✔
275
    bin: int
4✔
276
    chunks: Sequence[Chunk]
4✔
277

278

279
@dataclass
4✔
280
class TabixIndex:
4✔
281
    header: Header
4✔
282
    sequence_names: Sequence[str]
4✔
283
    bins: Sequence[Sequence[TabixBin]]
4✔
284
    linear_indexes: Sequence[Sequence[int]]
4✔
285
    record_counts: Sequence[int]
4✔
286
    n_no_coor: int
4✔
287

288
    def offsets(self) -> Any:
4✔
289
        # Combine the linear indexes into one stacked array
290
        linear_indexes = self.linear_indexes
4✔
291
        linear_index = np.hstack([np.array(li) for li in linear_indexes])
4✔
292

293
        # Create file offsets for each element in the linear index
294
        file_offsets = np.array([get_file_offset(vfp) for vfp in linear_index])
4✔
295

296
        # Calculate corresponding contigs and positions or each element in
297
        # the linear index
298
        contig_indexes = np.hstack(
4✔
299
            [np.full(len(li), i) for (i, li) in enumerate(linear_indexes)]
300
        )
301
        # positions are 1-based and inclusive
302
        positions = np.hstack(
4✔
303
            [
304
                np.arange(len(li)) * TABIX_LINEAR_INDEX_INTERVAL_SIZE + 1
305
                for li in linear_indexes
306
            ]
307
        )
308
        assert len(file_offsets) == len(contig_indexes)
4✔
309
        assert len(file_offsets) == len(positions)
4✔
310

311
        return file_offsets, contig_indexes, positions
4✔
312

313

314
def read_tabix(
4✔
315
    file: PathType, storage_options: dict[str, str] | None = None
316
) -> TabixIndex:
317
    """Parse a tabix file into a `TabixIndex` object.
318

319
    Parameters
320
    ----------
321
    file : PathType
322
        The path to the tabix file.
323

324
    Returns
325
    -------
326
    TabixIndex
327
        An object representing a tabix index.
328

329
    Raises
330
    ------
331
    ValueError
332
        If the file is not a tabix file.
333
    """
334
    with gzip.open(file) as f:
4✔
335
        magic = read_bytes_as_value(f, "4s")
4✔
336
        if magic != b"TBI\x01":
4✔
337
            raise ValueError("File not in Tabix format.")
4✔
338

339
        header = Header(*read_bytes_as_tuple(f, "<8i"))
4✔
340

341
        sequence_names = []
4✔
342
        bins = []
4✔
343
        linear_indexes = []
4✔
344
        record_counts = []
4✔
345

346
        if header.l_nm > 0:
4✔
347
            names = read_bytes_as_value(f, f"<{header.l_nm}s")
4✔
348
            # Convert \0-terminated names to strings
349
            sequence_names = [str(name, "utf-8") for name in names.split(b"\x00")[:-1]]
4✔
350

351
            for _ in range(header.n_ref):
4✔
352
                n_bin = read_bytes_as_value(f, "<i")
4✔
353
                seq_bins = []
4✔
354
                # Distinguish between counts that are zero because the sequence
355
                # isn't there, vs counts that aren't in the index.
356
                record_count = 0 if n_bin == 0 else RECORD_COUNT_UNKNOWN
4✔
357
                for _ in range(n_bin):
4✔
358
                    bin, n_chunk = read_bytes_as_tuple(f, "<Ii")
4✔
359
                    chunks = []
4✔
360
                    for _ in range(n_chunk):
4✔
361
                        chunk = Chunk(*read_bytes_as_tuple(f, "<QQ"))
4✔
362
                        chunks.append(chunk)
4✔
363
                    seq_bins.append(TabixBin(bin, chunks))
4✔
364

365
                    if bin == 37450:  # pseudo-bin, see section 5.2 of BAM spec
4✔
366
                        assert len(chunks) == 2
4✔
367
                        n_mapped, n_unmapped = chunks[1].cnk_beg, chunks[1].cnk_end
4✔
368
                        record_count = n_mapped + n_unmapped
4✔
369
                n_intv = read_bytes_as_value(f, "<i")
4✔
370
                linear_index = []
4✔
371
                for _ in range(n_intv):
4✔
372
                    ioff = read_bytes_as_value(f, "<Q")
4✔
373
                    linear_index.append(ioff)
4✔
374
                bins.append(seq_bins)
4✔
375
                linear_indexes.append(linear_index)
4✔
376
                record_counts.append(record_count)
4✔
377

378
        n_no_coor = read_bytes_as_value(f, "<Q", 0)
4✔
379

380
        assert len(f.read(1)) == 0
4✔
381

382
        return TabixIndex(
4✔
383
            header, sequence_names, bins, linear_indexes, record_counts, n_no_coor
384
        )
385

386

387
class VcfFileType(Enum):
4✔
388
    VCF = ".vcf"
4✔
389
    BCF = ".bcf"
4✔
390

391

392
class VcfIndexType(Enum):
4✔
393
    CSI = ".csi"
4✔
394
    TABIX = ".tbi"
4✔
395

396

397
class VcfFile(contextlib.AbstractContextManager):
4✔
398
    @core.requires_optional_dependency("cyvcf2", "vcf")
4✔
399
    def __init__(self, vcf_path, index_path=None):
4✔
400
        import cyvcf2
4✔
401

402
        self.vcf = None
4✔
403
        self.file_type = None
4✔
404
        self.index_type = None
4✔
405

406
        vcf_path = pathlib.Path(vcf_path)
4✔
407
        if not vcf_path.exists():
4✔
408
            raise FileNotFoundError(vcf_path)
4✔
409
        if index_path is None:
4✔
410
            index_path = vcf_path.with_suffix(
4✔
411
                vcf_path.suffix + VcfIndexType.TABIX.value
412
            )
413
            if not index_path.exists():
4✔
414
                index_path = vcf_path.with_suffix(
4✔
415
                    vcf_path.suffix + VcfIndexType.CSI.value
416
                )
417
                if not index_path.exists():
4✔
418
                    # No supported index found
419
                    index_path = None
4✔
420
        else:
421
            index_path = pathlib.Path(index_path)
4✔
422
            if not index_path.exists():
4✔
423
                raise FileNotFoundError(
4✔
424
                    f"Specified index path {index_path} does not exist"
425
                )
426

427
        self.vcf_path = vcf_path
4✔
428
        self.index_path = index_path
4✔
429
        if index_path is not None:
4✔
430
            if index_path.suffix == VcfIndexType.CSI.value:
4✔
431
                self.index_type = VcfIndexType.CSI
4✔
432
            elif index_path.suffix == VcfIndexType.TABIX.value:
4✔
433
                self.index_type = VcfIndexType.TABIX
4✔
434
                self.file_type = VcfFileType.VCF
4✔
435
            else:
436
                raise ValueError("Only .tbi or .csi indexes are supported.")
4✔
437

438
        self.vcf = cyvcf2.VCF(vcf_path)
4✔
439
        if self.index_path is not None:
4✔
440
            self.vcf.set_index(str(self.index_path))
4✔
441

442
        logger.debug(f"Loaded {vcf_path} with index {self.index_path}")
4✔
443
        self.sequence_names = None
4✔
444

445
        self.index = None
4✔
446
        if self.index_type == VcfIndexType.CSI:
4✔
447
            # Determine the file-type based on the "aux" field.
448
            self.index = read_csi(self.index_path)
4✔
449
            self.file_type = VcfFileType.BCF
4✔
450
            if len(self.index.aux) > 0:
4✔
451
                self.file_type = VcfFileType.VCF
4✔
452
                self.sequence_names = self.index.parse_vcf_aux()
4✔
453
            else:
454
                self.sequence_names = self.vcf.seqnames
4✔
455
        elif self.index_type == VcfIndexType.TABIX:
4✔
456
            self.index = read_tabix(self.index_path)
4✔
457
            self.file_type = VcfFileType.VCF
4✔
458
            self.sequence_names = self.index.sequence_names
4✔
459
        else:
460
            assert self.index is None
4✔
461
            var = next(self.vcf)
4✔
462
            self.sequence_names = [var.CHROM]
4✔
463
            self.vcf.close()
4✔
464
            # There doesn't seem to be a way to reset the iterator
465
            self.vcf = cyvcf2.VCF(vcf_path)
4✔
466

467
    def __exit__(self, exc_type, exc_val, exc_tb):
4✔
468
        if self.vcf is not None:
4✔
469
            self.vcf.close()
4✔
470
            self.vcf = None
4✔
471
        return False
4✔
472

473
    def contig_record_counts(self):
4✔
474
        if self.index is None:
4✔
475
            return {self.sequence_names[0]: RECORD_COUNT_UNKNOWN}
4✔
476
        d = dict(zip(self.sequence_names, self.index.record_counts))
4✔
477
        if self.file_type == VcfFileType.BCF:
4✔
478
            d = {k: v for k, v in d.items() if v > 0}
4✔
479
        return d
4✔
480

481
    def count_variants(self, region):
4✔
482
        return sum(1 for _ in self.variants(region))
4✔
483

484
    def variants(self, region=None):
4✔
485
        if self.index is None:
4✔
486
            contig = self.sequence_names[0]
4✔
487
            if region is not None:
4✔
488
                assert region.contig == contig
4✔
489
            for var in self.vcf:
4✔
490
                if var.CHROM != contig:
4✔
491
                    raise ValueError("Multi-contig VCFs must be indexed")
4✔
492
                yield var
4✔
493
        else:
494
            start = 1 if region.start is None else region.start
4✔
495
            for var in self.vcf(str(region)):
4✔
496
                # Need to filter because of indels overlapping the region
497
                if var.POS >= start:
4✔
498
                    yield var
4✔
499

500
    def _filter_empty_and_refine(self, regions):
4✔
501
        """
502
        Return all regions in the specified list that have one or more records,
503
        and refine the start coordinate of the region to be the actual first coord.
504

505
        Because this is a relatively expensive operation requiring seeking around
506
        the file, we return the results as an iterator.
507
        """
508
        for region in regions:
4✔
509
            var = next(self.variants(region), None)
4✔
510
            if var is not None:
4✔
511
                region.start = var.POS
4✔
512
                yield region
4✔
513

514
    def partition_into_regions(
4✔
515
        self,
516
        num_parts: int | None = None,
517
        target_part_size: None | int | str = None,
518
    ):
519
        if num_parts is None and target_part_size is None:
4✔
520
            raise ValueError("One of num_parts or target_part_size must be specified")
4✔
521

522
        if num_parts is not None and target_part_size is not None:
4✔
523
            raise ValueError(
4✔
524
                "Only one of num_parts or target_part_size may be specified"
525
            )
526

527
        if num_parts is not None and num_parts < 1:
4✔
528
            raise ValueError("num_parts must be positive")
4✔
529

530
        if target_part_size is not None:
4✔
531
            if isinstance(target_part_size, int):
4✔
532
                target_part_size_bytes = target_part_size
4✔
533
            else:
534
                target_part_size_bytes = humanfriendly.parse_size(target_part_size)
4✔
535
            if target_part_size_bytes < 1:
4✔
536
                raise ValueError("target_part_size must be positive")
4✔
537

538
        if self.index is None:
4✔
539
            return [Region(self.sequence_names[0])]
4✔
540

541
        # Calculate the desired part file boundaries
542
        file_length = os.stat(self.vcf_path).st_size
4✔
543
        if num_parts is not None:
4✔
544
            target_part_size_bytes = file_length // num_parts
4✔
545
        elif target_part_size_bytes is not None:
4✔
546
            num_parts = ceildiv(file_length, target_part_size_bytes)
4✔
547
        part_lengths = target_part_size_bytes * np.arange(num_parts, dtype=int)
4✔
548
        file_offsets, region_contig_indexes, region_positions = self.index.offsets()
4✔
549

550
        # Search the file offsets to find which indexes the part lengths fall at
551
        ind = np.searchsorted(file_offsets, part_lengths)
4✔
552

553
        # Drop any parts that are greater than the file offsets
554
        # (these will be covered by a region with no end)
555
        ind = np.delete(ind, ind >= len(file_offsets))
4✔
556

557
        # Drop any duplicates
558
        ind = np.unique(ind)
4✔
559

560
        # Calculate region contig and start for each index
561
        region_contigs = region_contig_indexes[ind]
4✔
562
        region_starts = region_positions[ind]
4✔
563

564
        # Build region query strings
565
        regions = []
4✔
566
        for i in range(len(region_starts)):
4✔
567
            contig = self.sequence_names[region_contigs[i]]
4✔
568
            start = region_starts[i]
4✔
569

570
            if i == len(region_starts) - 1:  # final region
4✔
571
                regions.append(Region(contig, start))
4✔
572
            else:
573
                next_contig = self.sequence_names[region_contigs[i + 1]]
4✔
574
                next_start = region_starts[i + 1]
4✔
575
                end = next_start - 1  # subtract one since positions are inclusive
4✔
576
                # print("next_start", next_contig, next_start)
577
                if next_contig == contig:  # contig doesn't change
4✔
578
                    regions.append(Region(contig, start, end))
4✔
579
                else:
580
                    # contig changes, so need two regions (or possibly more if any
581
                    # sequences were skipped)
582
                    regions.append(Region(contig, start))
4✔
583
                    for ri in range(region_contigs[i] + 1, region_contigs[i + 1]):
4✔
584
                        regions.append(Region(self.sequence_names[ri]))
4✔
585
                    if end >= 1:
4✔
586
                        regions.append(Region(next_contig, 1, end))
4✔
587

588
        # Add any sequences at the end that were not skipped
589
        for ri in range(region_contigs[-1] + 1, len(self.sequence_names)):
4✔
590
            if self.index.record_counts[ri] > 0:
4✔
591
                regions.append(Region(self.sequence_names[ri]))
4✔
592

593
        return self._filter_empty_and_refine(regions)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc