• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sgkit-dev / bio2zarr / 18940786519

30 Oct 2025 12:34PM UTC coverage: 98.361% (+0.07%) from 98.292%
18940786519

Pull #422

github

web-flow
Merge b33cd4646 into 3f955cef0
Pull Request #422: Update intro.md with VCF Zarr conversion info

2880 of 2928 relevant lines covered (98.36%)

3.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.95
/bio2zarr/vcf.py
1
import collections
4✔
2
import contextlib
4✔
3
import dataclasses
4✔
4
import json
4✔
5
import logging
4✔
6
import math
4✔
7
import pathlib
4✔
8
import pickle
4✔
9
import re
4✔
10
import shutil
4✔
11
import sys
4✔
12
import tempfile
4✔
13
from functools import partial
4✔
14
from typing import Any
4✔
15

16
import numcodecs
4✔
17
import numpy as np
4✔
18

19
from . import constants, core, provenance, vcf_utils, vcz
4✔
20

21
logger = logging.getLogger(__name__)
4✔
22

23

24
@dataclasses.dataclass
4✔
25
class VcfFieldSummary(core.JsonDataclass):
4✔
26
    num_chunks: int = 0
4✔
27
    compressed_size: int = 0
4✔
28
    uncompressed_size: int = 0
4✔
29
    max_number: int = 0  # Corresponds to VCF Number field, depends on context
4✔
30
    # Only defined for numeric fields
31
    max_value: Any = -math.inf
4✔
32
    min_value: Any = math.inf
4✔
33

34
    def update(self, other):
4✔
35
        self.num_chunks += other.num_chunks
4✔
36
        self.compressed_size += other.compressed_size
4✔
37
        self.uncompressed_size += other.uncompressed_size
4✔
38
        self.max_number = max(self.max_number, other.max_number)
4✔
39
        self.min_value = min(self.min_value, other.min_value)
4✔
40
        self.max_value = max(self.max_value, other.max_value)
4✔
41

42
    @staticmethod
4✔
43
    def fromdict(d):
4✔
44
        return VcfFieldSummary(**d)
4✔
45

46

47
@dataclasses.dataclass(order=True)
4✔
48
class VcfField:
4✔
49
    category: str
4✔
50
    name: str
4✔
51
    vcf_number: str
4✔
52
    vcf_type: str
4✔
53
    description: str
4✔
54
    summary: VcfFieldSummary
4✔
55

56
    @staticmethod
4✔
57
    def from_header(definition):
4✔
58
        category = definition["HeaderType"]
4✔
59
        name = definition["ID"]
4✔
60
        vcf_number = definition["Number"]
4✔
61
        vcf_type = definition["Type"]
4✔
62
        return VcfField(
4✔
63
            category=category,
64
            name=name,
65
            vcf_number=vcf_number,
66
            vcf_type=vcf_type,
67
            description=definition["Description"].strip('"'),
68
            summary=VcfFieldSummary(),
69
        )
70

71
    @staticmethod
4✔
72
    def fromdict(d):
4✔
73
        f = VcfField(**d)
4✔
74
        f.summary = VcfFieldSummary(**d["summary"])
4✔
75
        return f
4✔
76

77
    @property
4✔
78
    def full_name(self):
4✔
79
        if self.category == "fixed":
4✔
80
            return self.name
4✔
81
        return f"{self.category}/{self.name}"
4✔
82

83
    @property
4✔
84
    def max_number(self):
4✔
85
        if self.vcf_number in ("R", "A", "G", "."):
4✔
86
            return self.summary.max_number
4✔
87
        else:
88
            # use declared number if larger than max found
89
            return max(self.summary.max_number, int(self.vcf_number))
4✔
90

91
    def smallest_dtype(self):
4✔
92
        """
93
        Returns the smallest dtype suitable for this field based
94
        on type, and values.
95
        """
96
        s = self.summary
4✔
97
        if self.vcf_type == "Float":
4✔
98
            ret = "f4"
4✔
99
        elif self.vcf_type == "Integer":
4✔
100
            if not math.isfinite(s.max_value):
4✔
101
                # All missing values; use i1. Note we should have some API to
102
                # check more explicitly for missingness:
103
                # https://github.com/sgkit-dev/bio2zarr/issues/131
104
                ret = "i1"
4✔
105
            else:
106
                ret = core.min_int_dtype(s.min_value, s.max_value)
4✔
107
        elif self.vcf_type == "Flag":
4✔
108
            ret = "bool"
4✔
109
        elif self.vcf_type == "Character":
4✔
110
            ret = "U1"
4✔
111
        else:
112
            assert self.vcf_type == "String"
4✔
113
            ret = "O"
4✔
114
        return ret
4✔
115

116

117
@dataclasses.dataclass
4✔
118
class VcfPartition:
4✔
119
    vcf_path: str
4✔
120
    region: str
4✔
121
    num_records: int = -1
4✔
122

123

124
ICF_METADATA_FORMAT_VERSION = "0.4"
4✔
125
ICF_DEFAULT_COMPRESSOR = numcodecs.Blosc(
4✔
126
    cname="zstd", clevel=7, shuffle=numcodecs.Blosc.NOSHUFFLE
127
)
128

129

130
@dataclasses.dataclass
4✔
131
class IcfMetadata(core.JsonDataclass):
4✔
132
    samples: list
4✔
133
    contigs: list
4✔
134
    filters: list
4✔
135
    fields: list
4✔
136
    partitions: list = None
4✔
137
    format_version: str = None
4✔
138
    compressor: dict = None
4✔
139
    column_chunk_size: int = None
4✔
140
    provenance: dict = None
4✔
141
    num_records: int = -1
4✔
142

143
    @property
4✔
144
    def info_fields(self):
4✔
145
        fields = []
4✔
146
        for field in self.fields:
4✔
147
            if field.category == "INFO":
4✔
148
                fields.append(field)
4✔
149
        return fields
4✔
150

151
    @property
4✔
152
    def format_fields(self):
4✔
153
        fields = []
4✔
154
        for field in self.fields:
4✔
155
            if field.category == "FORMAT":
4✔
156
                fields.append(field)
4✔
157
        return fields
4✔
158

159
    @property
4✔
160
    def num_contigs(self):
4✔
161
        return len(self.contigs)
4✔
162

163
    @property
4✔
164
    def num_filters(self):
4✔
165
        return len(self.filters)
4✔
166

167
    @property
4✔
168
    def num_samples(self):
4✔
169
        return len(self.samples)
4✔
170

171
    @staticmethod
4✔
172
    def fromdict(d):
4✔
173
        if d["format_version"] != ICF_METADATA_FORMAT_VERSION:
4✔
174
            raise ValueError(
4✔
175
                "Intermediate columnar metadata format version mismatch: "
176
                f"{d['format_version']} != {ICF_METADATA_FORMAT_VERSION}"
177
            )
178
        partitions = [VcfPartition(**pd) for pd in d["partitions"]]
4✔
179
        for p in partitions:
4✔
180
            p.region = vcf_utils.Region(**p.region)
4✔
181
        d = d.copy()
4✔
182
        d["partitions"] = partitions
4✔
183
        d["fields"] = [VcfField.fromdict(fd) for fd in d["fields"]]
4✔
184
        d["samples"] = [vcz.Sample(**sd) for sd in d["samples"]]
4✔
185
        d["filters"] = [vcz.Filter(**fd) for fd in d["filters"]]
4✔
186
        d["contigs"] = [vcz.Contig(**cd) for cd in d["contigs"]]
4✔
187
        return IcfMetadata(**d)
4✔
188

189
    def __eq__(self, other):
4✔
190
        if not isinstance(other, IcfMetadata):
4✔
191
            return NotImplemented
×
192
        return (
4✔
193
            self.samples == other.samples
194
            and self.contigs == other.contigs
195
            and self.filters == other.filters
196
            and sorted(self.fields) == sorted(other.fields)
197
        )
198

199

200
def fixed_vcf_field_definitions():
4✔
201
    def make_field_def(name, vcf_type, vcf_number):
4✔
202
        return VcfField(
4✔
203
            category="fixed",
204
            name=name,
205
            vcf_type=vcf_type,
206
            vcf_number=vcf_number,
207
            description="",
208
            summary=VcfFieldSummary(),
209
        )
210

211
    fields = [
4✔
212
        make_field_def("CHROM", "String", "1"),
213
        make_field_def("POS", "Integer", "1"),
214
        make_field_def("QUAL", "Float", "1"),
215
        make_field_def("ID", "String", "."),
216
        make_field_def("FILTERS", "String", "."),
217
        make_field_def("REF", "String", "1"),
218
        make_field_def("ALT", "String", "."),
219
        make_field_def("rlen", "Integer", "1"),  # computed field
220
    ]
221
    return fields
4✔
222

223

224
def scan_vcf(path, target_num_partitions):
4✔
225
    with vcf_utils.VcfFile(path) as vcf_file:
4✔
226
        vcf = vcf_file.vcf
4✔
227
        filters = []
4✔
228
        pass_index = -1
4✔
229
        for h in vcf.header_iter():
4✔
230
            if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str):
4✔
231
                try:
4✔
232
                    description = h["Description"].strip('"')
4✔
233
                except KeyError:
×
234
                    description = ""
×
235
                if h["ID"] == "PASS":
4✔
236
                    pass_index = len(filters)
4✔
237
                filters.append(vcz.Filter(h["ID"], description))
4✔
238

239
        # Ensure PASS is the first filter if present
240
        if pass_index > 0:
4✔
241
            pass_filter = filters.pop(pass_index)
×
242
            filters.insert(0, pass_filter)
×
243

244
        fields = fixed_vcf_field_definitions()
4✔
245
        for h in vcf.header_iter():
4✔
246
            if h["HeaderType"] in ["INFO", "FORMAT"]:
4✔
247
                field = VcfField.from_header(h)
4✔
248
                if h["HeaderType"] == "FORMAT" and field.name == "GT":
4✔
249
                    field.vcf_type = "Integer"
4✔
250
                    field.vcf_number = "."
4✔
251
                fields.append(field)
4✔
252

253
        try:
4✔
254
            contig_lengths = vcf.seqlens
4✔
255
        except AttributeError:
4✔
256
            contig_lengths = [None for _ in vcf.seqnames]
4✔
257

258
        metadata = IcfMetadata(
4✔
259
            samples=[vcz.Sample(sample_id) for sample_id in vcf.samples],
260
            contigs=[
261
                vcz.Contig(contig_id, length)
262
                for contig_id, length in zip(vcf.seqnames, contig_lengths)
263
            ],
264
            filters=filters,
265
            fields=fields,
266
            partitions=[],
267
            num_records=sum(vcf_file.contig_record_counts().values()),
268
        )
269

270
        regions = vcf_file.partition_into_regions(num_parts=target_num_partitions)
4✔
271
        for region in regions:
4✔
272
            metadata.partitions.append(
4✔
273
                VcfPartition(
274
                    # TODO should this be fully resolving the path? Otherwise it's all
275
                    # relative to the original WD
276
                    vcf_path=str(path),
277
                    region=region,
278
                )
279
            )
280
        logger.info(
4✔
281
            f"Split {path} into {len(metadata.partitions)} "
282
            f"partitions target={target_num_partitions})"
283
        )
284
        core.update_progress(1)
4✔
285
        return metadata, vcf.raw_header
4✔
286

287

288
def scan_vcfs(
4✔
289
    paths,
290
    show_progress,
291
    target_num_partitions,
292
    worker_processes=core.DEFAULT_WORKER_PROCESSES,
293
):
294
    logger.info(
4✔
295
        f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
296
        f" partitions."
297
    )
298
    # An easy mistake to make is to pass the same file twice. Check this early on.
299
    for path, count in collections.Counter(paths).items():
4✔
300
        if not path.exists():  # NEEDS TEST
4✔
301
            raise FileNotFoundError(path)
×
302
        if count > 1:
4✔
303
            raise ValueError(f"Duplicate path provided: {path}")
4✔
304

305
    progress_config = core.ProgressConfig(
4✔
306
        total=len(paths),
307
        units="files",
308
        title="Scan",
309
        show=show_progress,
310
    )
311
    with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
4✔
312
        for path in paths:
4✔
313
            pwm.submit(
4✔
314
                scan_vcf,
315
                path,
316
                max(1, target_num_partitions // len(paths)),
317
            )
318
        results = list(pwm.results_as_completed())
4✔
319

320
    # Sort to make the ordering deterministic
321
    results.sort(key=lambda t: t[0].partitions[0].vcf_path)
4✔
322
    # We just take the first header, assuming the others
323
    # are compatible.
324
    all_partitions = []
4✔
325
    total_records = 0
4✔
326
    contigs = {}
4✔
327
    for metadata, _ in results:
4✔
328
        for partition in metadata.partitions:
4✔
329
            logger.debug(f"Scanned partition {partition}")
4✔
330
            all_partitions.append(partition)
4✔
331
        for contig in metadata.contigs:
4✔
332
            if contig.id in contigs:
4✔
333
                if contig != contigs[contig.id]:
4✔
334
                    raise ValueError(
4✔
335
                        "Incompatible contig definitions: "
336
                        f"{contig} != {contigs[contig.id]}"
337
                    )
338
            else:
339
                contigs[contig.id] = contig
4✔
340
        total_records += metadata.num_records
4✔
341
        metadata.num_records = 0
4✔
342
        metadata.partitions = []
4✔
343

344
    contig_union = list(contigs.values())
4✔
345
    for metadata, _ in results:
4✔
346
        metadata.contigs = contig_union
4✔
347

348
    icf_metadata, header = results[0]
4✔
349
    for metadata, _ in results[1:]:
4✔
350
        if metadata != icf_metadata:
4✔
351
            raise ValueError("Incompatible VCF chunks")
4✔
352

353
    # Note: this will be infinity here if any of the chunks has an index
354
    # that doesn't keep track of the number of records per-contig
355
    icf_metadata.num_records = total_records
4✔
356

357
    # Sort by contig (in the order they appear in the header) first,
358
    # then by start coordinate
359
    contig_index_map = {contig.id: j for j, contig in enumerate(metadata.contigs)}
4✔
360
    all_partitions.sort(
4✔
361
        key=lambda x: (contig_index_map[x.region.contig], x.region.start)
362
    )
363
    icf_metadata.partitions = all_partitions
4✔
364
    logger.info(f"Scan complete, resulting in {len(all_partitions)} partitions.")
4✔
365
    return icf_metadata, header
4✔
366

367

368
def sanitise_value_bool(shape, value):
4✔
369
    x = True
4✔
370
    if value is None:
4✔
371
        x = False
4✔
372
    return x
4✔
373

374

375
def sanitise_value_float_scalar(shape, value):
4✔
376
    x = value
4✔
377
    if value is None:
4✔
378
        x = [constants.FLOAT32_MISSING]
4✔
379
    return x[0]
4✔
380

381

382
def sanitise_value_int_scalar(shape, value):
4✔
383
    x = value
4✔
384
    if value is None:
4✔
385
        x = [constants.INT_MISSING]
4✔
386
    else:
387
        x = sanitise_int_array(value, ndmin=1, dtype=np.int32)
4✔
388
    return x[0]
4✔
389

390

391
def sanitise_value_string_scalar(shape, value):
4✔
392
    if value is None:
4✔
393
        return "."
4✔
394
    else:
395
        return value[0]
4✔
396

397

398
def sanitise_value_string_1d(shape, value):
4✔
399
    if value is None:
4✔
400
        return np.full(shape, ".", dtype="O")
4✔
401
    else:
402
        value = drop_empty_second_dim(value)
4✔
403
        result = np.full(shape, "", dtype=value.dtype)
4✔
404
        result[: value.shape[0]] = value
4✔
405
        return result
4✔
406

407

408
def sanitise_value_string_2d(shape, value):
4✔
409
    if value is None:
4✔
410
        return np.full(shape, ".", dtype="O")
4✔
411
    else:
412
        result = np.full(shape, "", dtype="O")
4✔
413
        if value.ndim == 2:
4✔
414
            result[: value.shape[0], : value.shape[1]] = value
4✔
415
        else:
416
            # Convert 1D array into 2D with appropriate shape
417
            for k, val in enumerate(value):
4✔
418
                result[k, : len(val)] = val
4✔
419
        return result
4✔
420

421

422
def drop_empty_second_dim(value):
4✔
423
    assert len(value.shape) == 1 or value.shape[1] == 1
4✔
424
    if len(value.shape) == 2 and value.shape[1] == 1:
4✔
425
        value = value[..., 0]
4✔
426
    return value
4✔
427

428

429
def sanitise_value_float_1d(shape, value):
4✔
430
    if value is None:
4✔
431
        return np.full(shape, constants.FLOAT32_MISSING)
4✔
432
    else:
433
        value = np.array(value, ndmin=1, dtype=np.float32, copy=True)
4✔
434
        # numpy will map None values to Nan, but we need a
435
        # specific NaN
436
        value[np.isnan(value)] = constants.FLOAT32_MISSING
4✔
437
        value = drop_empty_second_dim(value)
4✔
438
        result = np.full(shape, constants.FLOAT32_FILL, dtype=np.float32)
4✔
439
        result[: value.shape[0]] = value
4✔
440
        return result
4✔
441

442

443
def sanitise_value_float_2d(shape, value):
4✔
444
    if value is None:
4✔
445
        return np.full(shape, constants.FLOAT32_MISSING)
4✔
446
    else:
447
        value = np.array(value, ndmin=2, dtype=np.float32, copy=True)
4✔
448
        result = np.full(shape, constants.FLOAT32_FILL, dtype=np.float32)
4✔
449
        result[:, : value.shape[1]] = value
4✔
450
        return result
4✔
451

452

453
def sanitise_int_array(value, ndmin, dtype):
4✔
454
    if isinstance(value, tuple):
4✔
455
        value = [
×
456
            constants.VCF_INT_MISSING if x is None else x for x in value
457
        ]  # NEEDS TEST
458
    value = np.array(value, ndmin=ndmin, copy=True)
4✔
459
    value[value == constants.VCF_INT_MISSING] = -1
4✔
460
    value[value == constants.VCF_INT_FILL] = -2
4✔
461
    # TODO watch out for clipping here!
462
    return value.astype(dtype)
4✔
463

464

465
def sanitise_value_int_1d(shape, value):
4✔
466
    if value is None:
4✔
467
        return np.full(shape, -1)
4✔
468
    else:
469
        value = sanitise_int_array(value, 1, np.int32)
4✔
470
        value = drop_empty_second_dim(value)
4✔
471
        result = np.full(shape, -2, dtype=np.int32)
4✔
472
        result[: value.shape[0]] = value
4✔
473
        return result
4✔
474

475

476
def sanitise_value_int_2d(shape, value):
4✔
477
    if value is None:
4✔
478
        return np.full(shape, -1)
4✔
479
    else:
480
        value = sanitise_int_array(value, 2, np.int32)
4✔
481
        result = np.full(shape, -2, dtype=np.int32)
4✔
482
        result[:, : value.shape[1]] = value
4✔
483
        return result
4✔
484

485

486
missing_value_map = {
4✔
487
    "Integer": constants.INT_MISSING,
488
    "Float": constants.FLOAT32_MISSING,
489
    "String": constants.STR_MISSING,
490
    "Character": constants.STR_MISSING,
491
    "Flag": False,
492
}
493

494

495
class VcfValueTransformer:
4✔
496
    """
497
    Transform VCF values into the stored intermediate format used
498
    in the IntermediateColumnarFormat, and update field summaries.
499
    """
500

501
    def __init__(self, field, num_samples):
4✔
502
        self.field = field
4✔
503
        self.num_samples = num_samples
4✔
504
        self.dimension = 1
4✔
505
        if field.category == "FORMAT":
4✔
506
            self.dimension = 2
4✔
507
        self.missing = missing_value_map[field.vcf_type]
4✔
508

509
    @staticmethod
4✔
510
    def factory(field, num_samples):
4✔
511
        if field.vcf_type in ("Integer", "Flag"):
4✔
512
            return IntegerValueTransformer(field, num_samples)
4✔
513
        if field.vcf_type == "Float":
4✔
514
            return FloatValueTransformer(field, num_samples)
4✔
515
        if field.name in ["REF", "FILTERS", "ALT", "ID", "CHROM"]:
4✔
516
            return SplitStringValueTransformer(field, num_samples)
4✔
517
        return StringValueTransformer(field, num_samples)
4✔
518

519
    def transform(self, vcf_value):
4✔
520
        if isinstance(vcf_value, tuple):
4✔
521
            vcf_value = [self.missing if v is None else v for v in vcf_value]
4✔
522
        value = np.array(vcf_value, ndmin=self.dimension, copy=True)
4✔
523
        return value
4✔
524

525
    def transform_and_update_bounds(self, vcf_value):
4✔
526
        if vcf_value is None:
4✔
527
            return None
4✔
528
        # print(self, self.field.full_name, "T", vcf_value)
529
        value = self.transform(vcf_value)
4✔
530
        self.update_bounds(value)
4✔
531
        return value
4✔
532

533

534
class IntegerValueTransformer(VcfValueTransformer):
4✔
535
    def update_bounds(self, value):
4✔
536
        summary = self.field.summary
4✔
537
        # Mask out missing and fill values
538
        # print(value)
539
        a = value[value >= constants.MIN_INT_VALUE]
4✔
540
        if a.size > 0:
4✔
541
            summary.max_value = int(max(summary.max_value, np.max(a)))
4✔
542
            summary.min_value = int(min(summary.min_value, np.min(a)))
4✔
543
        number = value.shape[-1]
4✔
544
        summary.max_number = max(summary.max_number, number)
4✔
545

546

547
class FloatValueTransformer(VcfValueTransformer):
4✔
548
    def update_bounds(self, value):
4✔
549
        summary = self.field.summary
4✔
550
        summary.max_value = float(max(summary.max_value, np.max(value)))
4✔
551
        summary.min_value = float(min(summary.min_value, np.min(value)))
4✔
552
        number = value.shape[-1]
4✔
553
        summary.max_number = max(summary.max_number, number)
4✔
554

555

556
class StringValueTransformer(VcfValueTransformer):
4✔
557
    def update_bounds(self, value):
4✔
558
        summary = self.field.summary
4✔
559
        if self.field.category == "FORMAT":
4✔
560
            number = max(len(v) for v in value)
4✔
561
        else:
562
            number = value.shape[-1]
4✔
563
        # TODO would be nice to report string lengths, but not
564
        # really necessary.
565
        summary.max_number = max(summary.max_number, number)
4✔
566

567
    def transform(self, vcf_value):
4✔
568
        if self.dimension == 1:
4✔
569
            value = np.array(list(vcf_value.split(",")))
4✔
570
        else:
571
            # TODO can we make this faster??
572
            value = np.array([v.split(",") for v in vcf_value], dtype="O")
4✔
573
            # print("HERE", vcf_value, value)
574
            # for v in vcf_value:
575
            #     print("\t", type(v), len(v), v.split(","))
576
        # print("S: ", self.dimension, ":", value.shape, value)
577
        return value
4✔
578

579

580
class SplitStringValueTransformer(StringValueTransformer):
4✔
581
    def transform(self, vcf_value):
4✔
582
        if vcf_value is None:
4✔
583
            return self.missing_value  # NEEDS TEST
×
584
        assert self.dimension == 1
4✔
585
        return np.array(vcf_value, ndmin=1, dtype="str")
4✔
586

587

588
def get_vcf_field_path(base_path, vcf_field):
4✔
589
    if vcf_field.category == "fixed":
4✔
590
        return base_path / vcf_field.name
4✔
591
    return base_path / vcf_field.category / vcf_field.name
4✔
592

593

594
class IntermediateColumnarFormatField:
4✔
595
    def __init__(self, icf, vcf_field):
4✔
596
        self.vcf_field = vcf_field
4✔
597
        self.path = get_vcf_field_path(icf.path, vcf_field)
4✔
598
        self.compressor = icf.compressor
4✔
599
        self.num_partitions = icf.num_partitions
4✔
600
        self.num_records = icf.num_records
4✔
601
        self.partition_record_index = icf.partition_record_index
4✔
602
        # A map of partition id to the cumulative number of records
603
        # in chunks within that partition
604
        self._chunk_record_index = {}
4✔
605

606
    @property
4✔
607
    def name(self):
4✔
608
        return self.vcf_field.full_name
4✔
609

610
    def partition_path(self, partition_id):
4✔
611
        return self.path / f"p{partition_id}"
4✔
612

613
    def __repr__(self):
4✔
614
        partition_chunks = [self.num_chunks(j) for j in range(self.num_partitions)]
4✔
615
        return (
4✔
616
            f"IntermediateColumnarFormatField(name={self.name}, "
617
            f"partition_chunks={partition_chunks}, "
618
            f"path={self.path})"
619
        )
620

621
    def num_chunks(self, partition_id):
4✔
622
        return len(self.chunk_record_index(partition_id)) - 1
4✔
623

624
    def chunk_record_index(self, partition_id):
4✔
625
        if partition_id not in self._chunk_record_index:
4✔
626
            index_path = self.partition_path(partition_id) / "chunk_index"
4✔
627
            with open(index_path, "rb") as f:
4✔
628
                a = pickle.load(f)
4✔
629
            assert len(a) > 1
4✔
630
            assert a[0] == 0
4✔
631
            self._chunk_record_index[partition_id] = a
4✔
632
        return self._chunk_record_index[partition_id]
4✔
633

634
    def read_chunk(self, path):
4✔
635
        with open(path, "rb") as f:
4✔
636
            pkl = self.compressor.decode(f.read())
4✔
637
        return pickle.loads(pkl)
4✔
638

639
    def chunk_num_records(self, partition_id):
4✔
640
        return np.diff(self.chunk_record_index(partition_id))
4✔
641

642
    def chunks(self, partition_id, start_chunk=0):
4✔
643
        partition_path = self.partition_path(partition_id)
4✔
644
        chunk_cumulative_records = self.chunk_record_index(partition_id)
4✔
645
        chunk_num_records = np.diff(chunk_cumulative_records)
4✔
646
        for count, cumulative in zip(
4✔
647
            chunk_num_records[start_chunk:],
648
            chunk_cumulative_records[start_chunk + 1 :],
649
        ):
650
            path = partition_path / f"{cumulative}"
4✔
651
            chunk = self.read_chunk(path)
4✔
652
            if len(chunk) != count:
4✔
653
                raise ValueError(f"Corruption detected in chunk: {path}")
4✔
654
            yield chunk
4✔
655

656
    def iter_values(self, start=None, stop=None):
4✔
657
        start = 0 if start is None else start
4✔
658
        stop = self.num_records if stop is None else stop
4✔
659
        start_partition = (
4✔
660
            np.searchsorted(self.partition_record_index, start, side="right") - 1
661
        )
662
        offset = self.partition_record_index[start_partition]
4✔
663
        assert offset <= start
4✔
664
        chunk_offset = start - offset
4✔
665

666
        chunk_record_index = self.chunk_record_index(start_partition)
4✔
667
        start_chunk = (
4✔
668
            np.searchsorted(chunk_record_index, chunk_offset, side="right") - 1
669
        )
670
        record_id = offset + chunk_record_index[start_chunk]
4✔
671
        assert record_id <= start
4✔
672
        logger.debug(
4✔
673
            f"Read {self.vcf_field.full_name} slice [{start}:{stop}]:"
674
            f"p_start={start_partition}, c_start={start_chunk}, r_start={record_id}"
675
        )
676
        for chunk in self.chunks(start_partition, start_chunk):
4✔
677
            for record in chunk:
4✔
678
                if record_id == stop:
4✔
679
                    return
4✔
680
                if record_id >= start:
4✔
681
                    yield record
4✔
682
                record_id += 1
4✔
683
        assert record_id > start
4✔
684
        for partition_id in range(start_partition + 1, self.num_partitions):
4✔
685
            for chunk in self.chunks(partition_id):
4✔
686
                for record in chunk:
4✔
687
                    if record_id == stop:
4✔
688
                        return
4✔
689
                    yield record
4✔
690
                    record_id += 1
4✔
691

692
    # Note: this involves some computation so should arguably be a method,
693
    # but making a property for consistency with xarray etc
694
    @property
4✔
695
    def values(self):
4✔
696
        ret = [None] * self.num_records
4✔
697
        j = 0
4✔
698
        for partition_id in range(self.num_partitions):
4✔
699
            for chunk in self.chunks(partition_id):
4✔
700
                for record in chunk:
4✔
701
                    ret[j] = record
4✔
702
                    j += 1
4✔
703
        assert j == self.num_records
4✔
704
        return ret
4✔
705

706
    def sanitiser_factory(self, shape):
4✔
707
        assert len(shape) <= 2
4✔
708
        if self.vcf_field.vcf_type == "Flag":
4✔
709
            assert len(shape) == 0
4✔
710
            return partial(sanitise_value_bool, shape)
4✔
711
        elif self.vcf_field.vcf_type == "Float":
4✔
712
            if len(shape) == 0:
4✔
713
                return partial(sanitise_value_float_scalar, shape)
4✔
714
            elif len(shape) == 1:
4✔
715
                return partial(sanitise_value_float_1d, shape)
4✔
716
            else:
717
                return partial(sanitise_value_float_2d, shape)
4✔
718
        elif self.vcf_field.vcf_type == "Integer":
4✔
719
            if len(shape) == 0:
4✔
720
                return partial(sanitise_value_int_scalar, shape)
4✔
721
            elif len(shape) == 1:
4✔
722
                return partial(sanitise_value_int_1d, shape)
4✔
723
            else:
724
                return partial(sanitise_value_int_2d, shape)
4✔
725
        else:
726
            assert self.vcf_field.vcf_type in ("String", "Character")
4✔
727
            if len(shape) == 0:
4✔
728
                return partial(sanitise_value_string_scalar, shape)
4✔
729
            elif len(shape) == 1:
4✔
730
                return partial(sanitise_value_string_1d, shape)
4✔
731
            else:
732
                return partial(sanitise_value_string_2d, shape)
4✔
733

734

735
@dataclasses.dataclass
4✔
736
class IcfFieldWriter:
4✔
737
    vcf_field: VcfField
4✔
738
    path: pathlib.Path
4✔
739
    transformer: VcfValueTransformer
4✔
740
    compressor: Any
4✔
741
    max_buffered_bytes: int
4✔
742
    buff: list[Any] = dataclasses.field(default_factory=list)
4✔
743
    buffered_bytes: int = 0
4✔
744
    chunk_index: list[int] = dataclasses.field(default_factory=lambda: [0])
4✔
745
    num_records: int = 0
4✔
746

747
    def append(self, val):
4✔
748
        val = self.transformer.transform_and_update_bounds(val)
4✔
749
        assert val is None or isinstance(val, np.ndarray)
4✔
750
        self.buff.append(val)
4✔
751
        val_bytes = sys.getsizeof(val)
4✔
752
        self.buffered_bytes += val_bytes
4✔
753
        self.num_records += 1
4✔
754
        if self.buffered_bytes >= self.max_buffered_bytes:
4✔
755
            logger.debug(
4✔
756
                f"Flush {self.path} buffered={self.buffered_bytes} "
757
                f"max={self.max_buffered_bytes}"
758
            )
759
            self.write_chunk()
4✔
760
            self.buff.clear()
4✔
761
            self.buffered_bytes = 0
4✔
762

763
    def write_chunk(self):
4✔
764
        # Update index
765
        self.chunk_index.append(self.num_records)
4✔
766
        path = self.path / f"{self.num_records}"
4✔
767
        logger.debug(f"Start write: {path}")
4✔
768
        pkl = pickle.dumps(self.buff)
4✔
769
        compressed = self.compressor.encode(pkl)
4✔
770
        with open(path, "wb") as f:
4✔
771
            f.write(compressed)
4✔
772

773
        # Update the summary
774
        self.vcf_field.summary.num_chunks += 1
4✔
775
        self.vcf_field.summary.compressed_size += len(compressed)
4✔
776
        self.vcf_field.summary.uncompressed_size += self.buffered_bytes
4✔
777
        logger.debug(f"Finish write: {path}")
4✔
778

779
    def flush(self):
4✔
780
        logger.debug(
4✔
781
            f"Flush {self.path} records={len(self.buff)} buffered={self.buffered_bytes}"
782
        )
783
        if len(self.buff) > 0:
4✔
784
            self.write_chunk()
4✔
785
        with open(self.path / "chunk_index", "wb") as f:
4✔
786
            a = np.array(self.chunk_index, dtype=int)
4✔
787
            pickle.dump(a, f)
4✔
788

789

790
class IcfPartitionWriter(contextlib.AbstractContextManager):
4✔
791
    """
792
    Writes the data for a IntermediateColumnarFormat partition.
793
    """
794

795
    def __init__(
4✔
796
        self,
797
        icf_metadata,
798
        out_path,
799
        partition_index,
800
    ):
801
        self.partition_index = partition_index
4✔
802
        # chunk_size is in megabytes
803
        max_buffered_bytes = icf_metadata.column_chunk_size * 2**20
4✔
804
        assert max_buffered_bytes > 0
4✔
805
        compressor = numcodecs.get_codec(icf_metadata.compressor)
4✔
806

807
        self.field_writers = {}
4✔
808
        num_samples = len(icf_metadata.samples)
4✔
809
        for vcf_field in icf_metadata.fields:
4✔
810
            field_path = get_vcf_field_path(out_path, vcf_field)
4✔
811
            field_partition_path = field_path / f"p{partition_index}"
4✔
812
            # Should be robust to running explode_partition twice.
813
            field_partition_path.mkdir(exist_ok=True)
4✔
814
            transformer = VcfValueTransformer.factory(vcf_field, num_samples)
4✔
815
            self.field_writers[vcf_field.full_name] = IcfFieldWriter(
4✔
816
                vcf_field,
817
                field_partition_path,
818
                transformer,
819
                compressor,
820
                max_buffered_bytes,
821
            )
822

823
    @property
4✔
824
    def field_summaries(self):
4✔
825
        return {
4✔
826
            name: field.vcf_field.summary for name, field in self.field_writers.items()
827
        }
828

829
    def append(self, name, value):
4✔
830
        self.field_writers[name].append(value)
4✔
831

832
    def __exit__(self, exc_type, exc_val, exc_tb):
4✔
833
        if exc_type is None:
4✔
834
            for field in self.field_writers.values():
4✔
835
                field.flush()
4✔
836
        return False
4✔
837

838

839
def convert_local_allele_field_types(fields, schema_instance):
4✔
840
    """
841
    Update the specified list of fields to include the LAA field, and to convert
842
    any supported localisable fields to the L* counterpart.
843

844
    Note that we currently support only two ALT alleles per sample, and so the
845
    dimensions of these fields are fixed by that requirement. Later versions may
846
    use summary data storted in the ICF to make different choices, if information
847
    about subsequent alleles (not in the actual genotype calls) should also be
848
    stored.
849
    """
850
    fields_by_name = {field.name: field for field in fields}
4✔
851
    gt = fields_by_name["call_genotype"]
4✔
852

853
    if schema_instance.get_shape(["ploidy"])[0] != 2:
4✔
854
        raise ValueError("Local alleles only supported on diploid data")
4✔
855

856
    dimensions = gt.dimensions[:-1]
4✔
857

858
    la = vcz.ZarrArraySpec(
4✔
859
        name="call_LA",
860
        dtype="i1",
861
        dimensions=(*dimensions, "local_alleles"),
862
        description=(
863
            "0-based indices into REF+ALT, indicating which alleles"
864
            " are relevant (local) for the current sample"
865
        ),
866
    )
867
    schema_instance.dimensions["local_alleles"] = vcz.VcfZarrDimension.unchunked(
4✔
868
        schema_instance.dimensions["ploidy"].size
869
    )
870

871
    ad = fields_by_name.get("call_AD", None)
4✔
872
    if ad is not None:
4✔
873
        # TODO check if call_LAD is in the list already
874
        ad.name = "call_LAD"
4✔
875
        ad.source = None
4✔
876
        ad.dimensions = (*dimensions, "local_alleles_AD")
4✔
877
        ad.description += " (local-alleles)"
4✔
878
        schema_instance.dimensions["local_alleles_AD"] = vcz.VcfZarrDimension.unchunked(
4✔
879
            2
880
        )
881

882
    pl = fields_by_name.get("call_PL", None)
4✔
883
    if pl is not None:
4✔
884
        # TODO check if call_LPL is in the list already
885
        pl.name = "call_LPL"
4✔
886
        pl.source = None
4✔
887
        pl.description += " (local-alleles)"
4✔
888
        pl.dimensions = (*dimensions, "local_" + pl.dimensions[-1].split("_")[-1])
4✔
889
        schema_instance.dimensions["local_" + pl.dimensions[-1].split("_")[-1]] = (
4✔
890
            vcz.VcfZarrDimension.unchunked(3)
891
        )
892

893
    return [*fields, la]
4✔
894

895

896
class IntermediateColumnarFormat(vcz.Source):
4✔
897
    def __init__(self, path):
4✔
898
        self._path = pathlib.Path(path)
4✔
899
        # TODO raise a more informative error here telling people this
900
        # directory is either a WIP or the wrong format.
901
        with open(self.path / "metadata.json") as f:
4✔
902
            self.metadata = IcfMetadata.fromdict(json.load(f))
4✔
903
        with open(self.path / "header.txt") as f:
4✔
904
            self.vcf_header = f.read()
4✔
905
        self.compressor = numcodecs.get_codec(self.metadata.compressor)
4✔
906
        self.fields = {}
4✔
907
        partition_num_records = [
4✔
908
            partition.num_records for partition in self.metadata.partitions
909
        ]
910
        # Allow us to find which partition a given record is in
911
        self.partition_record_index = np.cumsum([0, *partition_num_records])
4✔
912
        self.gt_field = None
4✔
913
        for field in self.metadata.fields:
4✔
914
            self.fields[field.full_name] = IntermediateColumnarFormatField(self, field)
4✔
915
            if field.name == "GT":
4✔
916
                self.gt_field = field
4✔
917

918
        logger.info(
4✔
919
            f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, "
920
            f"records={self.num_records}, fields={self.num_fields})"
921
        )
922

923
    def __repr__(self):
4✔
924
        return (
4✔
925
            f"IntermediateColumnarFormat(fields={len(self.fields)}, "
926
            f"partitions={self.num_partitions}, "
927
            f"records={self.num_records}, path={self.path})"
928
        )
929

930
    def summary_table(self):
4✔
931
        data = []
4✔
932
        for name, icf_field in self.fields.items():
4✔
933
            summary = icf_field.vcf_field.summary
4✔
934
            d = {
4✔
935
                "name": name,
936
                "type": icf_field.vcf_field.vcf_type,
937
                "chunks": summary.num_chunks,
938
                "size": core.display_size(summary.uncompressed_size),
939
                "compressed": core.display_size(summary.compressed_size),
940
                "max_n": summary.max_number,
941
                "min_val": core.display_number(summary.min_value),
942
                "max_val": core.display_number(summary.max_value),
943
            }
944

945
            data.append(d)
4✔
946
        return data
4✔
947

948
    @property
4✔
949
    def path(self):
4✔
950
        return self._path
4✔
951

952
    @property
4✔
953
    def num_records(self):
4✔
954
        return self.metadata.num_records
4✔
955

956
    @property
4✔
957
    def num_partitions(self):
4✔
958
        return len(self.metadata.partitions)
4✔
959

960
    @property
4✔
961
    def samples(self):
4✔
962
        return self.metadata.samples
4✔
963

964
    @property
4✔
965
    def contigs(self):
4✔
966
        return self.metadata.contigs
4✔
967

968
    @property
4✔
969
    def filters(self):
4✔
970
        return self.metadata.filters
4✔
971

972
    @property
4✔
973
    def num_samples(self):
4✔
974
        return len(self.metadata.samples)
4✔
975

976
    @property
4✔
977
    def num_fields(self):
4✔
978
        return len(self.fields)
4✔
979

980
    @property
4✔
981
    def root_attrs(self):
4✔
982
        meta_information_pattern = re.compile("##([^=]+)=(.*)")
4✔
983
        vcf_meta_information = []
4✔
984
        for line in self.vcf_header.split("\n"):
4✔
985
            match = re.fullmatch(meta_information_pattern, line)
4✔
986
            if match:
4✔
987
                key = match.group(1)
4✔
988
                if key in ("contig", "FILTER", "INFO", "FORMAT"):
4✔
989
                    # these fields are stored in Zarr arrays
990
                    continue
4✔
991
                value = match.group(2)
4✔
992
                vcf_meta_information.append((key, value))
4✔
993
        return {
4✔
994
            "vcf_meta_information": vcf_meta_information,
995
        }
996

997
    def iter_id(self, start, stop):
4✔
998
        for value in self.fields["ID"].iter_values(start, stop):
4✔
999
            if value is not None:
4✔
1000
                yield value[0]
4✔
1001
            else:
1002
                yield None
4✔
1003

1004
    def iter_filters(self, start, stop):
4✔
1005
        source_field = self.fields["FILTERS"]
4✔
1006
        lookup = {filt.id: index for index, filt in enumerate(self.metadata.filters)}
4✔
1007

1008
        for filter_values in source_field.iter_values(start, stop):
4✔
1009
            filters = np.zeros(len(self.metadata.filters), dtype=bool)
4✔
1010
            if filter_values is not None:
4✔
1011
                for filter_id in filter_values:
4✔
1012
                    try:
4✔
1013
                        filters[lookup[filter_id]] = True
4✔
1014
                    except KeyError:
4✔
1015
                        raise ValueError(
4✔
1016
                            f"Filter '{filter_id}' was not defined in the header."
1017
                        ) from None
1018
            yield filters
4✔
1019

1020
    def iter_contig(self, start, stop):
4✔
1021
        source_field = self.fields["CHROM"]
4✔
1022
        lookup = {
4✔
1023
            contig.id: index for index, contig in enumerate(self.metadata.contigs)
1024
        }
1025

1026
        for value in source_field.iter_values(start, stop):
4✔
1027
            # Note: because we are using the indexes to define the lookups
1028
            # and we always have an index, it seems that we the contig lookup
1029
            # will always succeed. However, if anyone ever does hit a KeyError
1030
            # here, please do open an issue with a reproducible example!
1031
            yield lookup[value[0]]
4✔
1032

1033
    def iter_field(self, field_name, shape, start, stop):
4✔
1034
        source_field = self.fields[field_name]
4✔
1035
        sanitiser = source_field.sanitiser_factory(shape)
4✔
1036
        for value in source_field.iter_values(start, stop):
4✔
1037
            yield sanitiser(value)
4✔
1038

1039
    def iter_alleles(self, start, stop, num_alleles):
4✔
1040
        ref_field = self.fields["REF"]
4✔
1041
        alt_field = self.fields["ALT"]
4✔
1042

1043
        for ref, alt in zip(
4✔
1044
            ref_field.iter_values(start, stop),
1045
            alt_field.iter_values(start, stop),
1046
        ):
1047
            alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
4✔
1048
            alleles[0] = ref[0]
4✔
1049
            alleles[1 : 1 + len(alt)] = alt
4✔
1050
            yield alleles
4✔
1051

1052
    def iter_genotypes(self, shape, start, stop):
4✔
1053
        source_field = self.fields["FORMAT/GT"]
4✔
1054
        for value in source_field.iter_values(start, stop):
4✔
1055
            genotypes = value[:, :-1] if value is not None else None
4✔
1056
            phased = value[:, -1] if value is not None else None
4✔
1057
            sanitised_genotypes = sanitise_value_int_2d(shape, genotypes)
4✔
1058
            sanitised_phased = sanitise_value_int_1d(shape[:-1], phased)
4✔
1059
            # Force haploids to always be phased
1060
            # https://github.com/sgkit-dev/bio2zarr/issues/399
1061
            if sanitised_genotypes.shape[1] == 1:
4✔
1062
                sanitised_phased[:] = True
4✔
1063
            yield sanitised_genotypes, sanitised_phased
4✔
1064

1065
    def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
4✔
1066
        variant_lengths = self.fields["rlen"].iter_values(start, stop)
4✔
1067
        if self.gt_field is None or shape is None:
4✔
1068
            for variant_length, alleles in zip(
4✔
1069
                variant_lengths, self.iter_alleles(start, stop, num_alleles)
1070
            ):
1071
                yield vcz.VariantData(variant_length, alleles, None, None)
4✔
1072
        else:
1073
            for variant_length, alleles, (gt, phased) in zip(
4✔
1074
                variant_lengths,
1075
                self.iter_alleles(start, stop, num_alleles),
1076
                self.iter_genotypes(shape, start, stop),
1077
            ):
1078
                yield vcz.VariantData(variant_length, alleles, gt, phased)
4✔
1079

1080
    def generate_schema(
4✔
1081
        self, variants_chunk_size=None, samples_chunk_size=None, local_alleles=None
1082
    ):
1083
        if local_alleles is None:
4✔
1084
            local_alleles = False
4✔
1085

1086
        max_alleles = max(self.fields["ALT"].vcf_field.summary.max_number + 1, 2)
4✔
1087

1088
        # Add ploidy and genotypes dimensions only when needed
1089
        max_genotypes = 0
4✔
1090
        for field in self.metadata.format_fields:
4✔
1091
            if field.vcf_number == "G":
4✔
1092
                max_genotypes = max(max_genotypes, field.summary.max_number)
4✔
1093

1094
        ploidy = None
4✔
1095
        genotypes_size = None
4✔
1096
        if self.gt_field is not None:
4✔
1097
            ploidy = max(self.gt_field.summary.max_number - 1, 1)
4✔
1098
            # NOTE: it's not clear why we're computing this, when we must have had
1099
            # at least one number=G field to require it anyway?
1100
            genotypes_size = math.comb(max_alleles + ploidy - 1, ploidy)
4✔
1101
            # assert max_genotypes == genotypes_size
1102
        else:
1103
            if max_genotypes > 0:
4✔
1104
                # there is no GT field, but there is at least one Number=G field,
1105
                # so need to define genotypes dimension
1106
                genotypes_size = max_genotypes
4✔
1107

1108
        dimensions = vcz.standard_dimensions(
4✔
1109
            variants_size=self.num_records,
1110
            variants_chunk_size=variants_chunk_size,
1111
            samples_size=self.num_samples,
1112
            samples_chunk_size=samples_chunk_size,
1113
            alleles_size=max_alleles,
1114
            filters_size=self.metadata.num_filters,
1115
            ploidy_size=ploidy,
1116
            genotypes_size=genotypes_size,
1117
        )
1118

1119
        schema_instance = vcz.VcfZarrSchema(
4✔
1120
            format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION,
1121
            dimensions=dimensions,
1122
            fields=[],
1123
        )
1124

1125
        logger.info(
4✔
1126
            "Generating schema with chunks="
1127
            f"variants={dimensions['variants'].chunk_size}, "
1128
            f"samples={dimensions['samples'].chunk_size}"
1129
        )
1130

1131
        def spec_from_field(field, array_name=None):
4✔
1132
            return vcz.ZarrArraySpec.from_field(
4✔
1133
                field,
1134
                schema_instance,
1135
                array_name=array_name,
1136
            )
1137

1138
        def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
4✔
1139
            compressor = (
4✔
1140
                vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config()
1141
                if dtype == "bool"
1142
                else None
1143
            )
1144
            return vcz.ZarrArraySpec(
4✔
1145
                source=source,
1146
                name=name,
1147
                dtype=dtype,
1148
                description="",
1149
                dimensions=dimensions,
1150
                compressor=compressor,
1151
            )
1152

1153
        name_map = {field.full_name: field for field in self.metadata.fields}
4✔
1154
        array_specs = [
4✔
1155
            fixed_field_spec(
1156
                name="variant_contig",
1157
                dtype=core.min_int_dtype(0, self.metadata.num_contigs),
1158
            ),
1159
            fixed_field_spec(
1160
                name="variant_filter",
1161
                dtype="bool",
1162
                dimensions=["variants", "filters"],
1163
            ),
1164
            fixed_field_spec(
1165
                name="variant_allele",
1166
                dtype="O",
1167
                dimensions=["variants", "alleles"],
1168
            ),
1169
            fixed_field_spec(
1170
                name="variant_length",
1171
                dtype=name_map["rlen"].smallest_dtype(),
1172
                dimensions=["variants"],
1173
            ),
1174
            fixed_field_spec(
1175
                name="variant_id",
1176
                dtype="O",
1177
            ),
1178
            fixed_field_spec(
1179
                name="variant_id_mask",
1180
                dtype="bool",
1181
            ),
1182
        ]
1183

1184
        # Only two of the fixed fields have a direct one-to-one mapping.
1185
        array_specs.extend(
4✔
1186
            [
1187
                spec_from_field(name_map["QUAL"], array_name="variant_quality"),
1188
                spec_from_field(name_map["POS"], array_name="variant_position"),
1189
            ]
1190
        )
1191
        array_specs.extend(
4✔
1192
            [spec_from_field(field) for field in self.metadata.info_fields]
1193
        )
1194

1195
        for field in self.metadata.format_fields:
4✔
1196
            if field.name == "GT":
4✔
1197
                continue
4✔
1198
            array_specs.append(spec_from_field(field))
4✔
1199

1200
        if self.gt_field is not None and self.num_samples > 0:
4✔
1201
            array_specs.append(
4✔
1202
                vcz.ZarrArraySpec(
1203
                    name="call_genotype_phased",
1204
                    dtype="bool",
1205
                    dimensions=["variants", "samples"],
1206
                    description="",
1207
                    compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
1208
                )
1209
            )
1210
            array_specs.append(
4✔
1211
                vcz.ZarrArraySpec(
1212
                    name="call_genotype",
1213
                    dtype=self.gt_field.smallest_dtype(),
1214
                    dimensions=["variants", "samples", "ploidy"],
1215
                    description="",
1216
                    compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(),
1217
                )
1218
            )
1219
            array_specs.append(
4✔
1220
                vcz.ZarrArraySpec(
1221
                    name="call_genotype_mask",
1222
                    dtype="bool",
1223
                    dimensions=["variants", "samples", "ploidy"],
1224
                    description="",
1225
                    compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
1226
                )
1227
            )
1228

1229
        if local_alleles:
4✔
1230
            array_specs = convert_local_allele_field_types(array_specs, schema_instance)
4✔
1231

1232
        schema_instance.fields = array_specs
4✔
1233
        return schema_instance
4✔
1234

1235

1236
@dataclasses.dataclass
4✔
1237
class IcfPartitionMetadata(core.JsonDataclass):
4✔
1238
    num_records: int
4✔
1239
    last_position: int
4✔
1240
    field_summaries: dict
4✔
1241

1242
    @staticmethod
4✔
1243
    def fromdict(d):
4✔
1244
        md = IcfPartitionMetadata(**d)
4✔
1245
        for k, v in md.field_summaries.items():
4✔
1246
            md.field_summaries[k] = VcfFieldSummary.fromdict(v)
4✔
1247
        return md
4✔
1248

1249

1250
def check_overlapping_partitions(partitions):
4✔
1251
    for i in range(1, len(partitions)):
4✔
1252
        prev_region = partitions[i - 1].region
4✔
1253
        current_region = partitions[i].region
4✔
1254
        if prev_region.contig == current_region.contig:
4✔
1255
            assert prev_region.end is not None
4✔
1256
            # Regions are *inclusive*
1257
            if prev_region.end >= current_region.start:
4✔
1258
                raise ValueError(
4✔
1259
                    f"Overlapping VCF regions in partitions {i - 1} and {i}: "
1260
                    f"{prev_region} and {current_region}"
1261
                )
1262

1263

1264
def check_field_clobbering(icf_metadata):
4✔
1265
    info_field_names = set(field.name for field in icf_metadata.info_fields)
4✔
1266
    fixed_variant_fields = set(
4✔
1267
        ["contig", "id", "id_mask", "position", "allele", "filter", "quality"]
1268
    )
1269
    intersection = info_field_names & fixed_variant_fields
4✔
1270
    if len(intersection) > 0:
4✔
1271
        raise ValueError(
4✔
1272
            f"INFO field name(s) clashing with VCF Zarr spec: {intersection}"
1273
        )
1274

1275
    format_field_names = set(field.name for field in icf_metadata.format_fields)
4✔
1276
    fixed_variant_fields = set(["genotype", "genotype_phased", "genotype_mask"])
4✔
1277
    intersection = format_field_names & fixed_variant_fields
4✔
1278
    if len(intersection) > 0:
4✔
1279
        raise ValueError(
4✔
1280
            f"FORMAT field name(s) clashing with VCF Zarr spec: {intersection}"
1281
        )
1282

1283

1284
@dataclasses.dataclass
4✔
1285
class IcfWriteSummary(core.JsonDataclass):
4✔
1286
    num_partitions: int
4✔
1287
    num_samples: int
4✔
1288
    num_variants: int
4✔
1289

1290

1291
class IntermediateColumnarFormatWriter:
4✔
1292
    def __init__(self, path):
4✔
1293
        self.path = pathlib.Path(path)
4✔
1294
        self.wip_path = self.path / "wip"
4✔
1295
        self.metadata = None
4✔
1296

1297
    @property
4✔
1298
    def num_partitions(self):
4✔
1299
        return len(self.metadata.partitions)
4✔
1300

1301
    def init(
4✔
1302
        self,
1303
        vcfs,
1304
        *,
1305
        column_chunk_size=16,
1306
        worker_processes=core.DEFAULT_WORKER_PROCESSES,
1307
        target_num_partitions=None,
1308
        show_progress=False,
1309
        compressor=None,
1310
    ):
1311
        if self.path.exists():
4✔
1312
            raise ValueError(f"ICF path already exists: {self.path}")
×
1313
        if compressor is None:
4✔
1314
            compressor = ICF_DEFAULT_COMPRESSOR
4✔
1315
        vcfs = [pathlib.Path(vcf) for vcf in vcfs]
4✔
1316
        target_num_partitions = max(target_num_partitions, len(vcfs))
4✔
1317

1318
        # TODO move scan_vcfs into this class
1319
        icf_metadata, header = scan_vcfs(
4✔
1320
            vcfs,
1321
            worker_processes=worker_processes,
1322
            show_progress=show_progress,
1323
            target_num_partitions=target_num_partitions,
1324
        )
1325
        check_field_clobbering(icf_metadata)
4✔
1326
        self.metadata = icf_metadata
4✔
1327
        self.metadata.format_version = ICF_METADATA_FORMAT_VERSION
4✔
1328
        self.metadata.compressor = compressor.get_config()
4✔
1329
        self.metadata.column_chunk_size = column_chunk_size
4✔
1330
        # Bare minimum here for provenance - would be nice to include versions of key
1331
        # dependencies as well.
1332
        self.metadata.provenance = {"source": f"bio2zarr-{provenance.__version__}"}
4✔
1333

1334
        self.mkdirs()
4✔
1335

1336
        # Note: this is needed for the current version of the vcfzarr spec, but it's
1337
        # probably going to be dropped.
1338
        # https://github.com/pystatgen/vcf-zarr-spec/issues/15
1339
        # May be useful to keep lying around still though?
1340
        logger.info("Writing VCF header")
4✔
1341
        with open(self.path / "header.txt", "w") as f:
4✔
1342
            f.write(header)
4✔
1343

1344
        logger.info("Writing WIP metadata")
4✔
1345
        with open(self.wip_path / "metadata.json", "w") as f:
4✔
1346
            json.dump(self.metadata.asdict(), f, indent=4)
4✔
1347
        return IcfWriteSummary(
4✔
1348
            num_partitions=self.num_partitions,
1349
            num_variants=icf_metadata.num_records,
1350
            num_samples=icf_metadata.num_samples,
1351
        )
1352

1353
    def mkdirs(self):
4✔
1354
        num_dirs = len(self.metadata.fields)
4✔
1355
        logger.info(f"Creating {num_dirs} field directories")
4✔
1356
        self.path.mkdir()
4✔
1357
        self.wip_path.mkdir()
4✔
1358
        for field in self.metadata.fields:
4✔
1359
            field_path = get_vcf_field_path(self.path, field)
4✔
1360
            field_path.mkdir(parents=True)
4✔
1361

1362
    def load_partition_summaries(self):
4✔
1363
        summaries = []
4✔
1364
        not_found = []
4✔
1365
        for j in range(self.num_partitions):
4✔
1366
            try:
4✔
1367
                with open(self.wip_path / f"p{j}.json") as f:
4✔
1368
                    summaries.append(IcfPartitionMetadata.fromdict(json.load(f)))
4✔
1369
            except FileNotFoundError:
4✔
1370
                not_found.append(j)
4✔
1371
        if len(not_found) > 0:
4✔
1372
            raise FileNotFoundError(
4✔
1373
                f"Partition metadata not found for {len(not_found)}"
1374
                f" partitions: {not_found}"
1375
            )
1376
        return summaries
4✔
1377

1378
    def load_metadata(self):
4✔
1379
        if self.metadata is None:
4✔
1380
            with open(self.wip_path / "metadata.json") as f:
4✔
1381
                self.metadata = IcfMetadata.fromdict(json.load(f))
4✔
1382

1383
    def process_partition(self, partition_index):
4✔
1384
        self.load_metadata()
4✔
1385
        summary_path = self.wip_path / f"p{partition_index}.json"
4✔
1386
        # If someone is rewriting a summary path (for whatever reason), make sure it
1387
        # doesn't look like it's already been completed.
1388
        # NOTE to do this properly we probably need to take a lock on this file - but
1389
        # this simple approach will catch the vast majority of problems.
1390
        if summary_path.exists():
4✔
1391
            summary_path.unlink()
4✔
1392

1393
        partition = self.metadata.partitions[partition_index]
4✔
1394
        logger.info(
4✔
1395
            f"Start p{partition_index} {partition.vcf_path}__{partition.region}"
1396
        )
1397
        info_fields = self.metadata.info_fields
4✔
1398
        format_fields = []
4✔
1399
        has_gt = False
4✔
1400
        for field in self.metadata.format_fields:
4✔
1401
            if field.name == "GT":
4✔
1402
                has_gt = True
4✔
1403
            else:
1404
                format_fields.append(field)
4✔
1405

1406
        last_position = None
4✔
1407
        with IcfPartitionWriter(
4✔
1408
            self.metadata,
1409
            self.path,
1410
            partition_index,
1411
        ) as tcw:
1412
            with vcf_utils.VcfFile(partition.vcf_path) as vcf:
4✔
1413
                num_records = 0
4✔
1414
                for variant in vcf.variants(partition.region):
4✔
1415
                    num_records += 1
4✔
1416
                    last_position = variant.POS
4✔
1417
                    tcw.append("CHROM", variant.CHROM)
4✔
1418
                    tcw.append("POS", variant.POS)
4✔
1419
                    tcw.append("QUAL", variant.QUAL)
4✔
1420
                    tcw.append("ID", variant.ID)
4✔
1421
                    tcw.append("FILTERS", variant.FILTERS)
4✔
1422
                    tcw.append("REF", variant.REF)
4✔
1423
                    tcw.append("ALT", variant.ALT)
4✔
1424
                    tcw.append("rlen", variant.end - variant.start)
4✔
1425
                    for field in info_fields:
4✔
1426
                        tcw.append(field.full_name, variant.INFO.get(field.name, None))
4✔
1427
                    if has_gt:
4✔
1428
                        val = None
4✔
1429
                        if "GT" in variant.FORMAT and variant.genotype is not None:
4✔
1430
                            val = variant.genotype.array()
4✔
1431
                        tcw.append("FORMAT/GT", val)
4✔
1432
                    for field in format_fields:
4✔
1433
                        val = variant.format(field.name)
4✔
1434
                        tcw.append(field.full_name, val)
4✔
1435

1436
                    # Note: an issue with updating the progress per variant here like
1437
                    # this is that we get a significant pause at the end of the counter
1438
                    # while all the "small" fields get flushed. Possibly not much to be
1439
                    # done about it.
1440
                    core.update_progress(1)
4✔
1441
            logger.info(
4✔
1442
                f"Finished reading VCF for partition {partition_index}, "
1443
                f"flushing buffers"
1444
            )
1445

1446
        partition_metadata = IcfPartitionMetadata(
4✔
1447
            num_records=num_records,
1448
            last_position=last_position,
1449
            field_summaries=tcw.field_summaries,
1450
        )
1451
        with open(summary_path, "w") as f:
4✔
1452
            f.write(partition_metadata.asjson())
4✔
1453
        logger.info(
4✔
1454
            f"Finish p{partition_index} {partition.vcf_path}__{partition.region} "
1455
            f"{num_records} records last_pos={last_position}"
1456
        )
1457

1458
    def explode(
4✔
1459
        self, *, worker_processes=core.DEFAULT_WORKER_PROCESSES, show_progress=False
1460
    ):
1461
        self.load_metadata()
4✔
1462
        num_records = self.metadata.num_records
4✔
1463
        if np.isinf(num_records):
4✔
1464
            logger.warning(
4✔
1465
                "Total records unknown, cannot show progress; "
1466
                "reindex VCFs with bcftools index to fix"
1467
            )
1468
            num_records = None
4✔
1469
        num_fields = len(self.metadata.fields)
4✔
1470
        num_samples = len(self.metadata.samples)
4✔
1471
        logger.info(
4✔
1472
            f"Exploding fields={num_fields} samples={num_samples}; "
1473
            f"partitions={self.num_partitions} "
1474
            f"variants={'unknown' if num_records is None else num_records}"
1475
        )
1476
        progress_config = core.ProgressConfig(
4✔
1477
            total=num_records,
1478
            units="vars",
1479
            title="Explode",
1480
            show=show_progress,
1481
        )
1482
        with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
4✔
1483
            for j in range(self.num_partitions):
4✔
1484
                pwm.submit(self.process_partition, j)
4✔
1485

1486
    def explode_partition(self, partition):
4✔
1487
        self.load_metadata()
4✔
1488
        if partition < 0 or partition >= self.num_partitions:
4✔
1489
            raise ValueError("Partition index not in the valid range")
4✔
1490
        self.process_partition(partition)
4✔
1491

1492
    def finalise(self):
4✔
1493
        self.load_metadata()
4✔
1494
        partition_summaries = self.load_partition_summaries()
4✔
1495
        total_records = 0
4✔
1496
        for index, summary in enumerate(partition_summaries):
4✔
1497
            partition_records = summary.num_records
4✔
1498
            self.metadata.partitions[index].num_records = partition_records
4✔
1499
            self.metadata.partitions[index].region.end = summary.last_position
4✔
1500
            total_records += partition_records
4✔
1501
        if not np.isinf(self.metadata.num_records):
4✔
1502
            # Note: this is just telling us that there's a bug in the
1503
            # index based record counting code, but it doesn't actually
1504
            # matter much. We may want to just make this a warning if
1505
            # we hit regular problems.
1506
            assert total_records == self.metadata.num_records
4✔
1507
        self.metadata.num_records = total_records
4✔
1508

1509
        check_overlapping_partitions(self.metadata.partitions)
4✔
1510

1511
        for field in self.metadata.fields:
4✔
1512
            for summary in partition_summaries:
4✔
1513
                field.summary.update(summary.field_summaries[field.full_name])
4✔
1514

1515
        logger.info("Finalising metadata")
4✔
1516
        with open(self.path / "metadata.json", "w") as f:
4✔
1517
            f.write(self.metadata.asjson())
4✔
1518

1519
        logger.debug("Removing WIP directory")
4✔
1520
        shutil.rmtree(self.wip_path)
4✔
1521

1522

1523
def explode(
4✔
1524
    icf_path,
1525
    vcfs,
1526
    *,
1527
    column_chunk_size=16,
1528
    worker_processes=core.DEFAULT_WORKER_PROCESSES,
1529
    show_progress=False,
1530
    compressor=None,
1531
):
1532
    writer = IntermediateColumnarFormatWriter(icf_path)
4✔
1533
    writer.init(
4✔
1534
        vcfs,
1535
        # Heuristic to get reasonable worker utilisation with lumpy partition sizing
1536
        target_num_partitions=max(1, worker_processes * 4),
1537
        worker_processes=worker_processes,
1538
        show_progress=show_progress,
1539
        column_chunk_size=column_chunk_size,
1540
        compressor=compressor,
1541
    )
1542
    writer.explode(worker_processes=worker_processes, show_progress=show_progress)
4✔
1543
    writer.finalise()
4✔
1544
    return IntermediateColumnarFormat(icf_path)
4✔
1545

1546

1547
def explode_init(
4✔
1548
    icf_path,
1549
    vcfs,
1550
    *,
1551
    column_chunk_size=16,
1552
    target_num_partitions=1,
1553
    worker_processes=core.DEFAULT_WORKER_PROCESSES,
1554
    show_progress=False,
1555
    compressor=None,
1556
):
1557
    writer = IntermediateColumnarFormatWriter(icf_path)
4✔
1558
    return writer.init(
4✔
1559
        vcfs,
1560
        target_num_partitions=target_num_partitions,
1561
        worker_processes=worker_processes,
1562
        show_progress=show_progress,
1563
        column_chunk_size=column_chunk_size,
1564
        compressor=compressor,
1565
    )
1566

1567

1568
def explode_partition(icf_path, partition):
4✔
1569
    writer = IntermediateColumnarFormatWriter(icf_path)
4✔
1570
    writer.explode_partition(partition)
4✔
1571

1572

1573
def explode_finalise(icf_path):
4✔
1574
    writer = IntermediateColumnarFormatWriter(icf_path)
4✔
1575
    writer.finalise()
4✔
1576

1577

1578
def inspect(path):
4✔
1579
    path = pathlib.Path(path)
4✔
1580
    if not path.exists():
4✔
1581
        raise ValueError(f"Path not found: {path}")
4✔
1582
    if (path / "metadata.json").exists():
4✔
1583
        obj = IntermediateColumnarFormat(path)
4✔
1584
    # NOTE: this is too strict, we should support more general Zarrs, see #276
1585
    elif (path / ".zmetadata").exists():
4✔
1586
        obj = vcz.VcfZarr(path)
4✔
1587
    else:
1588
        raise ValueError(f"{path} not in ICF or VCF Zarr format")
4✔
1589
    return obj.summary_table()
4✔
1590

1591

1592
def mkschema(
4✔
1593
    if_path,
1594
    out,
1595
    *,
1596
    variants_chunk_size=None,
1597
    samples_chunk_size=None,
1598
    local_alleles=None,
1599
):
1600
    store = IntermediateColumnarFormat(if_path)
4✔
1601
    spec = store.generate_schema(
4✔
1602
        variants_chunk_size=variants_chunk_size,
1603
        samples_chunk_size=samples_chunk_size,
1604
        local_alleles=local_alleles,
1605
    )
1606
    out.write(spec.asjson())
4✔
1607

1608

1609
def convert(
4✔
1610
    vcfs,
1611
    vcz_path,
1612
    *,
1613
    variants_chunk_size=None,
1614
    samples_chunk_size=None,
1615
    worker_processes=core.DEFAULT_WORKER_PROCESSES,
1616
    local_alleles=None,
1617
    show_progress=False,
1618
    icf_path=None,
1619
):
1620
    """
1621
    Convert the VCF data at the specified list of paths
1622
    to VCF Zarr format stored at the specified path.
1623

1624
    .. todo:: Document parameters
1625
    """
1626
    if icf_path is None:
4✔
1627
        cm = temp_icf_path(prefix="vcf2zarr")
4✔
1628
    else:
1629
        cm = contextlib.nullcontext(icf_path)
4✔
1630

1631
    with cm as icf_path:
4✔
1632
        explode(
4✔
1633
            icf_path,
1634
            vcfs,
1635
            worker_processes=worker_processes,
1636
            show_progress=show_progress,
1637
        )
1638
        encode(
4✔
1639
            icf_path,
1640
            vcz_path,
1641
            variants_chunk_size=variants_chunk_size,
1642
            samples_chunk_size=samples_chunk_size,
1643
            worker_processes=worker_processes,
1644
            show_progress=show_progress,
1645
            local_alleles=local_alleles,
1646
        )
1647

1648

1649
@contextlib.contextmanager
4✔
1650
def temp_icf_path(prefix=None):
4✔
1651
    with tempfile.TemporaryDirectory(prefix=prefix) as tmp:
4✔
1652
        yield pathlib.Path(tmp) / "icf"
4✔
1653

1654

1655
def encode(
4✔
1656
    icf_path,
1657
    zarr_path,
1658
    schema_path=None,
1659
    variants_chunk_size=None,
1660
    samples_chunk_size=None,
1661
    max_variant_chunks=None,
1662
    dimension_separator=None,
1663
    max_memory=None,
1664
    local_alleles=None,
1665
    worker_processes=core.DEFAULT_WORKER_PROCESSES,
1666
    show_progress=False,
1667
):
1668
    # Rough heuristic to split work up enough to keep utilisation high
1669
    target_num_partitions = max(1, worker_processes * 4)
4✔
1670
    encode_init(
4✔
1671
        icf_path,
1672
        zarr_path,
1673
        target_num_partitions,
1674
        schema_path=schema_path,
1675
        variants_chunk_size=variants_chunk_size,
1676
        samples_chunk_size=samples_chunk_size,
1677
        local_alleles=local_alleles,
1678
        max_variant_chunks=max_variant_chunks,
1679
        dimension_separator=dimension_separator,
1680
    )
1681
    vzw = vcz.VcfZarrWriter(IntermediateColumnarFormat, zarr_path)
4✔
1682
    vzw.encode_all_partitions(
4✔
1683
        worker_processes=worker_processes,
1684
        show_progress=show_progress,
1685
        max_memory=max_memory,
1686
    )
1687
    vzw.finalise(show_progress)
4✔
1688
    vzw.create_index()
4✔
1689

1690

1691
def encode_init(
4✔
1692
    icf_path,
1693
    zarr_path,
1694
    target_num_partitions,
1695
    *,
1696
    schema_path=None,
1697
    variants_chunk_size=None,
1698
    samples_chunk_size=None,
1699
    local_alleles=None,
1700
    max_variant_chunks=None,
1701
    dimension_separator=None,
1702
    max_memory=None,
1703
    worker_processes=core.DEFAULT_WORKER_PROCESSES,
1704
    show_progress=False,
1705
):
1706
    icf_store = IntermediateColumnarFormat(icf_path)
4✔
1707
    if schema_path is None:
4✔
1708
        schema_instance = icf_store.generate_schema(
4✔
1709
            variants_chunk_size=variants_chunk_size,
1710
            samples_chunk_size=samples_chunk_size,
1711
            local_alleles=local_alleles,
1712
        )
1713
    else:
1714
        logger.info(f"Reading schema from {schema_path}")
4✔
1715
        if variants_chunk_size is not None or samples_chunk_size is not None:
4✔
1716
            raise ValueError(
×
1717
                "Cannot specify schema along with chunk sizes"
1718
            )  # NEEDS TEST
1719
        with open(schema_path) as f:
4✔
1720
            schema_instance = vcz.VcfZarrSchema.fromjson(f.read())
4✔
1721
    zarr_path = pathlib.Path(zarr_path)
4✔
1722
    vzw = vcz.VcfZarrWriter("icf", zarr_path)
4✔
1723
    return vzw.init(
4✔
1724
        icf_store,
1725
        target_num_partitions=target_num_partitions,
1726
        schema=schema_instance,
1727
        dimension_separator=dimension_separator,
1728
        max_variant_chunks=max_variant_chunks,
1729
    )
1730

1731

1732
def encode_partition(zarr_path, partition):
4✔
1733
    writer_instance = vcz.VcfZarrWriter(IntermediateColumnarFormat, zarr_path)
4✔
1734
    writer_instance.encode_partition(partition)
4✔
1735

1736

1737
def encode_finalise(zarr_path, show_progress=False):
4✔
1738
    writer_instance = vcz.VcfZarrWriter(IntermediateColumnarFormat, zarr_path)
4✔
1739
    writer_instance.finalise(show_progress=show_progress)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc