• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sgkit-dev / bio2zarr / 15018922208

14 May 2025 10:57AM UTC coverage: 98.166% (+0.01%) from 98.153%
15018922208

Pull #385

github

web-flow
Merge fae8afa58 into 587a29e79
Pull Request #385: Optional deps

25 of 25 new or added lines in 5 files covered. (100.0%)

1 existing line in 1 file now uncovered.

2784 of 2836 relevant lines covered (98.17%)

3.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.73
/bio2zarr/vcf.py
1
import collections
4✔
2
import contextlib
4✔
3
import dataclasses
4✔
4
import json
4✔
5
import logging
4✔
6
import math
4✔
7
import pathlib
4✔
8
import pickle
4✔
9
import shutil
4✔
10
import sys
4✔
11
import tempfile
4✔
12
from functools import partial
4✔
13
from typing import Any
4✔
14

15
import numcodecs
4✔
16
import numpy as np
4✔
17

18
from . import constants, core, provenance, vcf_utils, vcz
4✔
19

20
logger = logging.getLogger(__name__)
4✔
21

22

23
@dataclasses.dataclass
4✔
24
class VcfFieldSummary(core.JsonDataclass):
4✔
25
    num_chunks: int = 0
4✔
26
    compressed_size: int = 0
4✔
27
    uncompressed_size: int = 0
4✔
28
    max_number: int = 0  # Corresponds to VCF Number field, depends on context
4✔
29
    # Only defined for numeric fields
30
    max_value: Any = -math.inf
4✔
31
    min_value: Any = math.inf
4✔
32

33
    def update(self, other):
4✔
34
        self.num_chunks += other.num_chunks
4✔
35
        self.compressed_size += other.compressed_size
4✔
36
        self.uncompressed_size += other.uncompressed_size
4✔
37
        self.max_number = max(self.max_number, other.max_number)
4✔
38
        self.min_value = min(self.min_value, other.min_value)
4✔
39
        self.max_value = max(self.max_value, other.max_value)
4✔
40

41
    @staticmethod
4✔
42
    def fromdict(d):
4✔
43
        return VcfFieldSummary(**d)
4✔
44

45

46
@dataclasses.dataclass(order=True)
4✔
47
class VcfField:
4✔
48
    category: str
4✔
49
    name: str
4✔
50
    vcf_number: str
4✔
51
    vcf_type: str
4✔
52
    description: str
4✔
53
    summary: VcfFieldSummary
4✔
54

55
    @staticmethod
4✔
56
    def from_header(definition):
4✔
57
        category = definition["HeaderType"]
4✔
58
        name = definition["ID"]
4✔
59
        vcf_number = definition["Number"]
4✔
60
        vcf_type = definition["Type"]
4✔
61
        return VcfField(
4✔
62
            category=category,
63
            name=name,
64
            vcf_number=vcf_number,
65
            vcf_type=vcf_type,
66
            description=definition["Description"].strip('"'),
67
            summary=VcfFieldSummary(),
68
        )
69

70
    @staticmethod
4✔
71
    def fromdict(d):
4✔
72
        f = VcfField(**d)
4✔
73
        f.summary = VcfFieldSummary(**d["summary"])
4✔
74
        return f
4✔
75

76
    @property
4✔
77
    def full_name(self):
4✔
78
        if self.category == "fixed":
4✔
79
            return self.name
4✔
80
        return f"{self.category}/{self.name}"
4✔
81

82
    @property
4✔
83
    def max_number(self):
4✔
84
        if self.vcf_number in ("R", "A", "G", "."):
4✔
85
            return self.summary.max_number
4✔
86
        else:
87
            # use declared number if larger than max found
88
            return max(self.summary.max_number, int(self.vcf_number))
4✔
89

90
    def smallest_dtype(self):
4✔
91
        """
92
        Returns the smallest dtype suitable for this field based
93
        on type, and values.
94
        """
95
        s = self.summary
4✔
96
        if self.vcf_type == "Float":
4✔
97
            ret = "f4"
4✔
98
        elif self.vcf_type == "Integer":
4✔
99
            if not math.isfinite(s.max_value):
4✔
100
                # All missing values; use i1. Note we should have some API to
101
                # check more explicitly for missingness:
102
                # https://github.com/sgkit-dev/bio2zarr/issues/131
103
                ret = "i1"
4✔
104
            else:
105
                ret = core.min_int_dtype(s.min_value, s.max_value)
4✔
106
        elif self.vcf_type == "Flag":
4✔
107
            ret = "bool"
4✔
108
        elif self.vcf_type == "Character":
4✔
109
            ret = "U1"
4✔
110
        else:
111
            assert self.vcf_type == "String"
4✔
112
            ret = "O"
4✔
113
        return ret
4✔
114

115

116
@dataclasses.dataclass
4✔
117
class VcfPartition:
4✔
118
    vcf_path: str
4✔
119
    region: str
4✔
120
    num_records: int = -1
4✔
121

122

123
ICF_METADATA_FORMAT_VERSION = "0.4"
4✔
124
ICF_DEFAULT_COMPRESSOR = numcodecs.Blosc(
4✔
125
    cname="zstd", clevel=7, shuffle=numcodecs.Blosc.NOSHUFFLE
126
)
127

128

129
@dataclasses.dataclass
4✔
130
class IcfMetadata(core.JsonDataclass):
4✔
131
    samples: list
4✔
132
    contigs: list
4✔
133
    filters: list
4✔
134
    fields: list
4✔
135
    partitions: list = None
4✔
136
    format_version: str = None
4✔
137
    compressor: dict = None
4✔
138
    column_chunk_size: int = None
4✔
139
    provenance: dict = None
4✔
140
    num_records: int = -1
4✔
141

142
    @property
4✔
143
    def info_fields(self):
4✔
144
        fields = []
4✔
145
        for field in self.fields:
4✔
146
            if field.category == "INFO":
4✔
147
                fields.append(field)
4✔
148
        return fields
4✔
149

150
    @property
4✔
151
    def format_fields(self):
4✔
152
        fields = []
4✔
153
        for field in self.fields:
4✔
154
            if field.category == "FORMAT":
4✔
155
                fields.append(field)
4✔
156
        return fields
4✔
157

158
    @property
4✔
159
    def num_contigs(self):
4✔
160
        return len(self.contigs)
4✔
161

162
    @property
4✔
163
    def num_filters(self):
4✔
164
        return len(self.filters)
4✔
165

166
    @property
4✔
167
    def num_samples(self):
4✔
168
        return len(self.samples)
4✔
169

170
    @staticmethod
4✔
171
    def fromdict(d):
4✔
172
        if d["format_version"] != ICF_METADATA_FORMAT_VERSION:
4✔
173
            raise ValueError(
4✔
174
                "Intermediate columnar metadata format version mismatch: "
175
                f"{d['format_version']} != {ICF_METADATA_FORMAT_VERSION}"
176
            )
177
        partitions = [VcfPartition(**pd) for pd in d["partitions"]]
4✔
178
        for p in partitions:
4✔
179
            p.region = vcf_utils.Region(**p.region)
4✔
180
        d = d.copy()
4✔
181
        d["partitions"] = partitions
4✔
182
        d["fields"] = [VcfField.fromdict(fd) for fd in d["fields"]]
4✔
183
        d["samples"] = [vcz.Sample(**sd) for sd in d["samples"]]
4✔
184
        d["filters"] = [vcz.Filter(**fd) for fd in d["filters"]]
4✔
185
        d["contigs"] = [vcz.Contig(**cd) for cd in d["contigs"]]
4✔
186
        return IcfMetadata(**d)
4✔
187

188
    def __eq__(self, other):
4✔
189
        if not isinstance(other, IcfMetadata):
4✔
190
            return NotImplemented
×
191
        return (
4✔
192
            self.samples == other.samples
193
            and self.contigs == other.contigs
194
            and self.filters == other.filters
195
            and sorted(self.fields) == sorted(other.fields)
196
        )
197

198

199
def fixed_vcf_field_definitions():
4✔
200
    def make_field_def(name, vcf_type, vcf_number):
4✔
201
        return VcfField(
4✔
202
            category="fixed",
203
            name=name,
204
            vcf_type=vcf_type,
205
            vcf_number=vcf_number,
206
            description="",
207
            summary=VcfFieldSummary(),
208
        )
209

210
    fields = [
4✔
211
        make_field_def("CHROM", "String", "1"),
212
        make_field_def("POS", "Integer", "1"),
213
        make_field_def("QUAL", "Float", "1"),
214
        make_field_def("ID", "String", "."),
215
        make_field_def("FILTERS", "String", "."),
216
        make_field_def("REF", "String", "1"),
217
        make_field_def("ALT", "String", "."),
218
        make_field_def("rlen", "Integer", "1"),  # computed field
219
    ]
220
    return fields
4✔
221

222

223
def scan_vcf(path, target_num_partitions):
4✔
224
    with vcf_utils.VcfFile(path) as vcf_file:
4✔
225
        vcf = vcf_file.vcf
4✔
226
        filters = []
4✔
227
        pass_index = -1
4✔
228
        for h in vcf.header_iter():
4✔
229
            if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str):
4✔
230
                try:
4✔
231
                    description = h["Description"].strip('"')
4✔
232
                except KeyError:
×
233
                    description = ""
×
234
                if h["ID"] == "PASS":
4✔
235
                    pass_index = len(filters)
4✔
236
                filters.append(vcz.Filter(h["ID"], description))
4✔
237

238
        # Ensure PASS is the first filter if present
239
        if pass_index > 0:
4✔
240
            pass_filter = filters.pop(pass_index)
×
241
            filters.insert(0, pass_filter)
×
242

243
        fields = fixed_vcf_field_definitions()
4✔
244
        for h in vcf.header_iter():
4✔
245
            if h["HeaderType"] in ["INFO", "FORMAT"]:
4✔
246
                field = VcfField.from_header(h)
4✔
247
                if h["HeaderType"] == "FORMAT" and field.name == "GT":
4✔
248
                    field.vcf_type = "Integer"
4✔
249
                    field.vcf_number = "."
4✔
250
                fields.append(field)
4✔
251

252
        try:
4✔
253
            contig_lengths = vcf.seqlens
4✔
254
        except AttributeError:
4✔
255
            contig_lengths = [None for _ in vcf.seqnames]
4✔
256

257
        metadata = IcfMetadata(
4✔
258
            samples=[vcz.Sample(sample_id) for sample_id in vcf.samples],
259
            contigs=[
260
                vcz.Contig(contig_id, length)
261
                for contig_id, length in zip(vcf.seqnames, contig_lengths)
262
            ],
263
            filters=filters,
264
            fields=fields,
265
            partitions=[],
266
            num_records=sum(vcf_file.contig_record_counts().values()),
267
        )
268

269
        regions = vcf_file.partition_into_regions(num_parts=target_num_partitions)
4✔
270
        for region in regions:
4✔
271
            metadata.partitions.append(
4✔
272
                VcfPartition(
273
                    # TODO should this be fully resolving the path? Otherwise it's all
274
                    # relative to the original WD
275
                    vcf_path=str(path),
276
                    region=region,
277
                )
278
            )
279
        logger.info(
4✔
280
            f"Split {path} into {len(metadata.partitions)} "
281
            f"partitions target={target_num_partitions})"
282
        )
283
        core.update_progress(1)
4✔
284
        return metadata, vcf.raw_header
4✔
285

286

287
def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
4✔
288
    logger.info(
4✔
289
        f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
290
        f" partitions."
291
    )
292
    # An easy mistake to make is to pass the same file twice. Check this early on.
293
    for path, count in collections.Counter(paths).items():
4✔
294
        if not path.exists():  # NEEDS TEST
4✔
295
            raise FileNotFoundError(path)
×
296
        if count > 1:
4✔
297
            raise ValueError(f"Duplicate path provided: {path}")
4✔
298

299
    progress_config = core.ProgressConfig(
4✔
300
        total=len(paths),
301
        units="files",
302
        title="Scan",
303
        show=show_progress,
304
    )
305
    with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
4✔
306
        for path in paths:
4✔
307
            pwm.submit(
4✔
308
                scan_vcf,
309
                path,
310
                max(1, target_num_partitions // len(paths)),
311
            )
312
        results = list(pwm.results_as_completed())
4✔
313

314
    # Sort to make the ordering deterministic
315
    results.sort(key=lambda t: t[0].partitions[0].vcf_path)
4✔
316
    # We just take the first header, assuming the others
317
    # are compatible.
318
    all_partitions = []
4✔
319
    total_records = 0
4✔
320
    contigs = {}
4✔
321
    for metadata, _ in results:
4✔
322
        for partition in metadata.partitions:
4✔
323
            logger.debug(f"Scanned partition {partition}")
4✔
324
            all_partitions.append(partition)
4✔
325
        for contig in metadata.contigs:
4✔
326
            if contig.id in contigs:
4✔
327
                if contig != contigs[contig.id]:
4✔
UNCOV
328
                    raise ValueError(
3✔
329
                        "Incompatible contig definitions: "
330
                        f"{contig} != {contigs[contig.id]}"
331
                    )
332
            else:
333
                contigs[contig.id] = contig
4✔
334
        total_records += metadata.num_records
4✔
335
        metadata.num_records = 0
4✔
336
        metadata.partitions = []
4✔
337

338
    contig_union = list(contigs.values())
4✔
339
    for metadata, _ in results:
4✔
340
        metadata.contigs = contig_union
4✔
341

342
    icf_metadata, header = results[0]
4✔
343
    for metadata, _ in results[1:]:
4✔
344
        if metadata != icf_metadata:
4✔
345
            raise ValueError("Incompatible VCF chunks")
4✔
346

347
    # Note: this will be infinity here if any of the chunks has an index
348
    # that doesn't keep track of the number of records per-contig
349
    icf_metadata.num_records = total_records
4✔
350

351
    # Sort by contig (in the order they appear in the header) first,
352
    # then by start coordinate
353
    contig_index_map = {contig.id: j for j, contig in enumerate(metadata.contigs)}
4✔
354
    all_partitions.sort(
4✔
355
        key=lambda x: (contig_index_map[x.region.contig], x.region.start)
356
    )
357
    icf_metadata.partitions = all_partitions
4✔
358
    logger.info(f"Scan complete, resulting in {len(all_partitions)} partitions.")
4✔
359
    return icf_metadata, header
4✔
360

361

362
def sanitise_value_bool(shape, value):
4✔
363
    x = True
4✔
364
    if value is None:
4✔
365
        x = False
4✔
366
    return x
4✔
367

368

369
def sanitise_value_float_scalar(shape, value):
4✔
370
    x = value
4✔
371
    if value is None:
4✔
372
        x = [constants.FLOAT32_MISSING]
4✔
373
    return x[0]
4✔
374

375

376
def sanitise_value_int_scalar(shape, value):
4✔
377
    x = value
4✔
378
    if value is None:
4✔
379
        x = [constants.INT_MISSING]
4✔
380
    else:
381
        x = sanitise_int_array(value, ndmin=1, dtype=np.int32)
4✔
382
    return x[0]
4✔
383

384

385
def sanitise_value_string_scalar(shape, value):
4✔
386
    if value is None:
4✔
387
        return "."
4✔
388
    else:
389
        return value[0]
4✔
390

391

392
def sanitise_value_string_1d(shape, value):
4✔
393
    if value is None:
4✔
394
        return np.full(shape, ".", dtype="O")
4✔
395
    else:
396
        value = drop_empty_second_dim(value)
4✔
397
        result = np.full(shape, "", dtype=value.dtype)
4✔
398
        result[: value.shape[0]] = value
4✔
399
        return result
4✔
400

401

402
def sanitise_value_string_2d(shape, value):
4✔
403
    if value is None:
4✔
404
        return np.full(shape, ".", dtype="O")
4✔
405
    else:
406
        result = np.full(shape, "", dtype="O")
4✔
407
        if value.ndim == 2:
4✔
408
            result[: value.shape[0], : value.shape[1]] = value
4✔
409
        else:
410
            # Convert 1D array into 2D with appropriate shape
411
            for k, val in enumerate(value):
4✔
412
                result[k, : len(val)] = val
4✔
413
        return result
4✔
414

415

416
def drop_empty_second_dim(value):
4✔
417
    assert len(value.shape) == 1 or value.shape[1] == 1
4✔
418
    if len(value.shape) == 2 and value.shape[1] == 1:
4✔
419
        value = value[..., 0]
4✔
420
    return value
4✔
421

422

423
def sanitise_value_float_1d(shape, value):
4✔
424
    if value is None:
4✔
425
        return np.full(shape, constants.FLOAT32_MISSING)
4✔
426
    else:
427
        value = np.array(value, ndmin=1, dtype=np.float32, copy=True)
4✔
428
        # numpy will map None values to Nan, but we need a
429
        # specific NaN
430
        value[np.isnan(value)] = constants.FLOAT32_MISSING
4✔
431
        value = drop_empty_second_dim(value)
4✔
432
        result = np.full(shape, constants.FLOAT32_FILL, dtype=np.float32)
4✔
433
        result[: value.shape[0]] = value
4✔
434
        return result
4✔
435

436

437
def sanitise_value_float_2d(shape, value):
4✔
438
    if value is None:
4✔
439
        return np.full(shape, constants.FLOAT32_MISSING)
4✔
440
    else:
441
        value = np.array(value, ndmin=2, dtype=np.float32, copy=True)
4✔
442
        result = np.full(shape, constants.FLOAT32_FILL, dtype=np.float32)
4✔
443
        result[:, : value.shape[1]] = value
4✔
444
        return result
4✔
445

446

447
def sanitise_int_array(value, ndmin, dtype):
4✔
448
    if isinstance(value, tuple):
4✔
449
        value = [
×
450
            constants.VCF_INT_MISSING if x is None else x for x in value
451
        ]  # NEEDS TEST
452
    value = np.array(value, ndmin=ndmin, copy=True)
4✔
453
    value[value == constants.VCF_INT_MISSING] = -1
4✔
454
    value[value == constants.VCF_INT_FILL] = -2
4✔
455
    # TODO watch out for clipping here!
456
    return value.astype(dtype)
4✔
457

458

459
def sanitise_value_int_1d(shape, value):
4✔
460
    if value is None:
4✔
461
        return np.full(shape, -1)
4✔
462
    else:
463
        value = sanitise_int_array(value, 1, np.int32)
4✔
464
        value = drop_empty_second_dim(value)
4✔
465
        result = np.full(shape, -2, dtype=np.int32)
4✔
466
        result[: value.shape[0]] = value
4✔
467
        return result
4✔
468

469

470
def sanitise_value_int_2d(shape, value):
4✔
471
    if value is None:
4✔
472
        return np.full(shape, -1)
4✔
473
    else:
474
        value = sanitise_int_array(value, 2, np.int32)
4✔
475
        result = np.full(shape, -2, dtype=np.int32)
4✔
476
        result[:, : value.shape[1]] = value
4✔
477
        return result
4✔
478

479

480
missing_value_map = {
4✔
481
    "Integer": constants.INT_MISSING,
482
    "Float": constants.FLOAT32_MISSING,
483
    "String": constants.STR_MISSING,
484
    "Character": constants.STR_MISSING,
485
    "Flag": False,
486
}
487

488

489
class VcfValueTransformer:
4✔
490
    """
491
    Transform VCF values into the stored intermediate format used
492
    in the IntermediateColumnarFormat, and update field summaries.
493
    """
494

495
    def __init__(self, field, num_samples):
4✔
496
        self.field = field
4✔
497
        self.num_samples = num_samples
4✔
498
        self.dimension = 1
4✔
499
        if field.category == "FORMAT":
4✔
500
            self.dimension = 2
4✔
501
        self.missing = missing_value_map[field.vcf_type]
4✔
502

503
    @staticmethod
4✔
504
    def factory(field, num_samples):
4✔
505
        if field.vcf_type in ("Integer", "Flag"):
4✔
506
            return IntegerValueTransformer(field, num_samples)
4✔
507
        if field.vcf_type == "Float":
4✔
508
            return FloatValueTransformer(field, num_samples)
4✔
509
        if field.name in ["REF", "FILTERS", "ALT", "ID", "CHROM"]:
4✔
510
            return SplitStringValueTransformer(field, num_samples)
4✔
511
        return StringValueTransformer(field, num_samples)
4✔
512

513
    def transform(self, vcf_value):
4✔
514
        if isinstance(vcf_value, tuple):
4✔
515
            vcf_value = [self.missing if v is None else v for v in vcf_value]
4✔
516
        value = np.array(vcf_value, ndmin=self.dimension, copy=True)
4✔
517
        return value
4✔
518

519
    def transform_and_update_bounds(self, vcf_value):
4✔
520
        if vcf_value is None:
4✔
521
            return None
4✔
522
        # print(self, self.field.full_name, "T", vcf_value)
523
        value = self.transform(vcf_value)
4✔
524
        self.update_bounds(value)
4✔
525
        return value
4✔
526

527

528
class IntegerValueTransformer(VcfValueTransformer):
4✔
529
    def update_bounds(self, value):
4✔
530
        summary = self.field.summary
4✔
531
        # Mask out missing and fill values
532
        # print(value)
533
        a = value[value >= constants.MIN_INT_VALUE]
4✔
534
        if a.size > 0:
4✔
535
            summary.max_value = int(max(summary.max_value, np.max(a)))
4✔
536
            summary.min_value = int(min(summary.min_value, np.min(a)))
4✔
537
        number = value.shape[-1]
4✔
538
        summary.max_number = max(summary.max_number, number)
4✔
539

540

541
class FloatValueTransformer(VcfValueTransformer):
4✔
542
    def update_bounds(self, value):
4✔
543
        summary = self.field.summary
4✔
544
        summary.max_value = float(max(summary.max_value, np.max(value)))
4✔
545
        summary.min_value = float(min(summary.min_value, np.min(value)))
4✔
546
        number = value.shape[-1]
4✔
547
        summary.max_number = max(summary.max_number, number)
4✔
548

549

550
class StringValueTransformer(VcfValueTransformer):
4✔
551
    def update_bounds(self, value):
4✔
552
        summary = self.field.summary
4✔
553
        if self.field.category == "FORMAT":
4✔
554
            number = max(len(v) for v in value)
4✔
555
        else:
556
            number = value.shape[-1]
4✔
557
        # TODO would be nice to report string lengths, but not
558
        # really necessary.
559
        summary.max_number = max(summary.max_number, number)
4✔
560

561
    def transform(self, vcf_value):
4✔
562
        if self.dimension == 1:
4✔
563
            value = np.array(list(vcf_value.split(",")))
4✔
564
        else:
565
            # TODO can we make this faster??
566
            value = np.array([v.split(",") for v in vcf_value], dtype="O")
4✔
567
            # print("HERE", vcf_value, value)
568
            # for v in vcf_value:
569
            #     print("\t", type(v), len(v), v.split(","))
570
        # print("S: ", self.dimension, ":", value.shape, value)
571
        return value
4✔
572

573

574
class SplitStringValueTransformer(StringValueTransformer):
4✔
575
    def transform(self, vcf_value):
4✔
576
        if vcf_value is None:
4✔
577
            return self.missing_value  # NEEDS TEST
×
578
        assert self.dimension == 1
4✔
579
        return np.array(vcf_value, ndmin=1, dtype="str")
4✔
580

581

582
def get_vcf_field_path(base_path, vcf_field):
4✔
583
    if vcf_field.category == "fixed":
4✔
584
        return base_path / vcf_field.name
4✔
585
    return base_path / vcf_field.category / vcf_field.name
4✔
586

587

588
class IntermediateColumnarFormatField:
4✔
589
    def __init__(self, icf, vcf_field):
4✔
590
        self.vcf_field = vcf_field
4✔
591
        self.path = get_vcf_field_path(icf.path, vcf_field)
4✔
592
        self.compressor = icf.compressor
4✔
593
        self.num_partitions = icf.num_partitions
4✔
594
        self.num_records = icf.num_records
4✔
595
        self.partition_record_index = icf.partition_record_index
4✔
596
        # A map of partition id to the cumulative number of records
597
        # in chunks within that partition
598
        self._chunk_record_index = {}
4✔
599

600
    @property
4✔
601
    def name(self):
4✔
602
        return self.vcf_field.full_name
4✔
603

604
    def partition_path(self, partition_id):
4✔
605
        return self.path / f"p{partition_id}"
4✔
606

607
    def __repr__(self):
4✔
608
        partition_chunks = [self.num_chunks(j) for j in range(self.num_partitions)]
4✔
609
        return (
4✔
610
            f"IntermediateColumnarFormatField(name={self.name}, "
611
            f"partition_chunks={partition_chunks}, "
612
            f"path={self.path})"
613
        )
614

615
    def num_chunks(self, partition_id):
4✔
616
        return len(self.chunk_record_index(partition_id)) - 1
4✔
617

618
    def chunk_record_index(self, partition_id):
4✔
619
        if partition_id not in self._chunk_record_index:
4✔
620
            index_path = self.partition_path(partition_id) / "chunk_index"
4✔
621
            with open(index_path, "rb") as f:
4✔
622
                a = pickle.load(f)
4✔
623
            assert len(a) > 1
4✔
624
            assert a[0] == 0
4✔
625
            self._chunk_record_index[partition_id] = a
4✔
626
        return self._chunk_record_index[partition_id]
4✔
627

628
    def read_chunk(self, path):
4✔
629
        with open(path, "rb") as f:
4✔
630
            pkl = self.compressor.decode(f.read())
4✔
631
        return pickle.loads(pkl)
4✔
632

633
    def chunk_num_records(self, partition_id):
4✔
634
        return np.diff(self.chunk_record_index(partition_id))
4✔
635

636
    def chunks(self, partition_id, start_chunk=0):
4✔
637
        partition_path = self.partition_path(partition_id)
4✔
638
        chunk_cumulative_records = self.chunk_record_index(partition_id)
4✔
639
        chunk_num_records = np.diff(chunk_cumulative_records)
4✔
640
        for count, cumulative in zip(
4✔
641
            chunk_num_records[start_chunk:],
642
            chunk_cumulative_records[start_chunk + 1 :],
643
        ):
644
            path = partition_path / f"{cumulative}"
4✔
645
            chunk = self.read_chunk(path)
4✔
646
            if len(chunk) != count:
4✔
647
                raise ValueError(f"Corruption detected in chunk: {path}")
4✔
648
            yield chunk
4✔
649

650
    def iter_values(self, start=None, stop=None):
4✔
651
        start = 0 if start is None else start
4✔
652
        stop = self.num_records if stop is None else stop
4✔
653
        start_partition = (
4✔
654
            np.searchsorted(self.partition_record_index, start, side="right") - 1
655
        )
656
        offset = self.partition_record_index[start_partition]
4✔
657
        assert offset <= start
4✔
658
        chunk_offset = start - offset
4✔
659

660
        chunk_record_index = self.chunk_record_index(start_partition)
4✔
661
        start_chunk = (
4✔
662
            np.searchsorted(chunk_record_index, chunk_offset, side="right") - 1
663
        )
664
        record_id = offset + chunk_record_index[start_chunk]
4✔
665
        assert record_id <= start
4✔
666
        logger.debug(
4✔
667
            f"Read {self.vcf_field.full_name} slice [{start}:{stop}]:"
668
            f"p_start={start_partition}, c_start={start_chunk}, r_start={record_id}"
669
        )
670
        for chunk in self.chunks(start_partition, start_chunk):
4✔
671
            for record in chunk:
4✔
672
                if record_id == stop:
4✔
673
                    return
4✔
674
                if record_id >= start:
4✔
675
                    yield record
4✔
676
                record_id += 1
4✔
677
        assert record_id > start
4✔
678
        for partition_id in range(start_partition + 1, self.num_partitions):
4✔
679
            for chunk in self.chunks(partition_id):
4✔
680
                for record in chunk:
4✔
681
                    if record_id == stop:
4✔
682
                        return
4✔
683
                    yield record
4✔
684
                    record_id += 1
4✔
685

686
    # Note: this involves some computation so should arguably be a method,
687
    # but making a property for consistency with xarray etc
688
    @property
4✔
689
    def values(self):
4✔
690
        ret = [None] * self.num_records
4✔
691
        j = 0
4✔
692
        for partition_id in range(self.num_partitions):
4✔
693
            for chunk in self.chunks(partition_id):
4✔
694
                for record in chunk:
4✔
695
                    ret[j] = record
4✔
696
                    j += 1
4✔
697
        assert j == self.num_records
4✔
698
        return ret
4✔
699

700
    def sanitiser_factory(self, shape):
4✔
701
        assert len(shape) <= 2
4✔
702
        if self.vcf_field.vcf_type == "Flag":
4✔
703
            assert len(shape) == 0
4✔
704
            return partial(sanitise_value_bool, shape)
4✔
705
        elif self.vcf_field.vcf_type == "Float":
4✔
706
            if len(shape) == 0:
4✔
707
                return partial(sanitise_value_float_scalar, shape)
4✔
708
            elif len(shape) == 1:
4✔
709
                return partial(sanitise_value_float_1d, shape)
4✔
710
            else:
711
                return partial(sanitise_value_float_2d, shape)
4✔
712
        elif self.vcf_field.vcf_type == "Integer":
4✔
713
            if len(shape) == 0:
4✔
714
                return partial(sanitise_value_int_scalar, shape)
4✔
715
            elif len(shape) == 1:
4✔
716
                return partial(sanitise_value_int_1d, shape)
4✔
717
            else:
718
                return partial(sanitise_value_int_2d, shape)
4✔
719
        else:
720
            assert self.vcf_field.vcf_type in ("String", "Character")
4✔
721
            if len(shape) == 0:
4✔
722
                return partial(sanitise_value_string_scalar, shape)
4✔
723
            elif len(shape) == 1:
4✔
724
                return partial(sanitise_value_string_1d, shape)
4✔
725
            else:
726
                return partial(sanitise_value_string_2d, shape)
4✔
727

728

729
@dataclasses.dataclass
4✔
730
class IcfFieldWriter:
4✔
731
    vcf_field: VcfField
4✔
732
    path: pathlib.Path
4✔
733
    transformer: VcfValueTransformer
4✔
734
    compressor: Any
4✔
735
    max_buffered_bytes: int
4✔
736
    buff: list[Any] = dataclasses.field(default_factory=list)
4✔
737
    buffered_bytes: int = 0
4✔
738
    chunk_index: list[int] = dataclasses.field(default_factory=lambda: [0])
4✔
739
    num_records: int = 0
4✔
740

741
    def append(self, val):
4✔
742
        val = self.transformer.transform_and_update_bounds(val)
4✔
743
        assert val is None or isinstance(val, np.ndarray)
4✔
744
        self.buff.append(val)
4✔
745
        val_bytes = sys.getsizeof(val)
4✔
746
        self.buffered_bytes += val_bytes
4✔
747
        self.num_records += 1
4✔
748
        if self.buffered_bytes >= self.max_buffered_bytes:
4✔
749
            logger.debug(
4✔
750
                f"Flush {self.path} buffered={self.buffered_bytes} "
751
                f"max={self.max_buffered_bytes}"
752
            )
753
            self.write_chunk()
4✔
754
            self.buff.clear()
4✔
755
            self.buffered_bytes = 0
4✔
756

757
    def write_chunk(self):
4✔
758
        # Update index
759
        self.chunk_index.append(self.num_records)
4✔
760
        path = self.path / f"{self.num_records}"
4✔
761
        logger.debug(f"Start write: {path}")
4✔
762
        pkl = pickle.dumps(self.buff)
4✔
763
        compressed = self.compressor.encode(pkl)
4✔
764
        with open(path, "wb") as f:
4✔
765
            f.write(compressed)
4✔
766

767
        # Update the summary
768
        self.vcf_field.summary.num_chunks += 1
4✔
769
        self.vcf_field.summary.compressed_size += len(compressed)
4✔
770
        self.vcf_field.summary.uncompressed_size += self.buffered_bytes
4✔
771
        logger.debug(f"Finish write: {path}")
4✔
772

773
    def flush(self):
4✔
774
        logger.debug(
4✔
775
            f"Flush {self.path} records={len(self.buff)} buffered={self.buffered_bytes}"
776
        )
777
        if len(self.buff) > 0:
4✔
778
            self.write_chunk()
4✔
779
        with open(self.path / "chunk_index", "wb") as f:
4✔
780
            a = np.array(self.chunk_index, dtype=int)
4✔
781
            pickle.dump(a, f)
4✔
782

783

784
class IcfPartitionWriter(contextlib.AbstractContextManager):
4✔
785
    """
786
    Writes the data for a IntermediateColumnarFormat partition.
787
    """
788

789
    def __init__(
4✔
790
        self,
791
        icf_metadata,
792
        out_path,
793
        partition_index,
794
    ):
795
        self.partition_index = partition_index
4✔
796
        # chunk_size is in megabytes
797
        max_buffered_bytes = icf_metadata.column_chunk_size * 2**20
4✔
798
        assert max_buffered_bytes > 0
4✔
799
        compressor = numcodecs.get_codec(icf_metadata.compressor)
4✔
800

801
        self.field_writers = {}
4✔
802
        num_samples = len(icf_metadata.samples)
4✔
803
        for vcf_field in icf_metadata.fields:
4✔
804
            field_path = get_vcf_field_path(out_path, vcf_field)
4✔
805
            field_partition_path = field_path / f"p{partition_index}"
4✔
806
            # Should be robust to running explode_partition twice.
807
            field_partition_path.mkdir(exist_ok=True)
4✔
808
            transformer = VcfValueTransformer.factory(vcf_field, num_samples)
4✔
809
            self.field_writers[vcf_field.full_name] = IcfFieldWriter(
4✔
810
                vcf_field,
811
                field_partition_path,
812
                transformer,
813
                compressor,
814
                max_buffered_bytes,
815
            )
816

817
    @property
4✔
818
    def field_summaries(self):
4✔
819
        return {
4✔
820
            name: field.vcf_field.summary for name, field in self.field_writers.items()
821
        }
822

823
    def append(self, name, value):
4✔
824
        self.field_writers[name].append(value)
4✔
825

826
    def __exit__(self, exc_type, exc_val, exc_tb):
4✔
827
        if exc_type is None:
4✔
828
            for field in self.field_writers.values():
4✔
829
                field.flush()
4✔
830
        return False
4✔
831

832

833
def convert_local_allele_field_types(fields, schema_instance):
4✔
834
    """
835
    Update the specified list of fields to include the LAA field, and to convert
836
    any supported localisable fields to the L* counterpart.
837

838
    Note that we currently support only two ALT alleles per sample, and so the
839
    dimensions of these fields are fixed by that requirement. Later versions may
840
    use summary data storted in the ICF to make different choices, if information
841
    about subsequent alleles (not in the actual genotype calls) should also be
842
    stored.
843
    """
844
    fields_by_name = {field.name: field for field in fields}
4✔
845
    gt = fields_by_name["call_genotype"]
4✔
846

847
    if schema_instance.get_shape(["ploidy"])[0] != 2:
4✔
848
        raise ValueError("Local alleles only supported on diploid data")
4✔
849

850
    dimensions = gt.dimensions[:-1]
4✔
851

852
    la = vcz.ZarrArraySpec(
4✔
853
        name="call_LA",
854
        dtype="i1",
855
        dimensions=(*dimensions, "local_alleles"),
856
        description=(
857
            "0-based indices into REF+ALT, indicating which alleles"
858
            " are relevant (local) for the current sample"
859
        ),
860
    )
861
    schema_instance.dimensions["local_alleles"] = vcz.VcfZarrDimension(
4✔
862
        size=schema_instance.dimensions["ploidy"].size
863
    )
864

865
    ad = fields_by_name.get("call_AD", None)
4✔
866
    if ad is not None:
4✔
867
        # TODO check if call_LAD is in the list already
868
        ad.name = "call_LAD"
4✔
869
        ad.source = None
4✔
870
        ad.dimensions = (*dimensions, "local_alleles_AD")
4✔
871
        ad.description += " (local-alleles)"
4✔
872
        schema_instance.dimensions["local_alleles_AD"] = vcz.VcfZarrDimension(size=2)
4✔
873

874
    pl = fields_by_name.get("call_PL", None)
4✔
875
    if pl is not None:
4✔
876
        # TODO check if call_LPL is in the list already
877
        pl.name = "call_LPL"
4✔
878
        pl.source = None
4✔
879
        pl.description += " (local-alleles)"
4✔
880
        pl.dimensions = (*dimensions, "local_" + pl.dimensions[-1].split("_")[-1])
4✔
881
        schema_instance.dimensions["local_" + pl.dimensions[-1].split("_")[-1]] = (
4✔
882
            vcz.VcfZarrDimension(size=3)
883
        )
884

885
    return [*fields, la]
4✔
886

887

888
class IntermediateColumnarFormat(vcz.Source):
4✔
889
    def __init__(self, path):
4✔
890
        self._path = pathlib.Path(path)
4✔
891
        # TODO raise a more informative error here telling people this
892
        # directory is either a WIP or the wrong format.
893
        with open(self.path / "metadata.json") as f:
4✔
894
            self.metadata = IcfMetadata.fromdict(json.load(f))
4✔
895
        with open(self.path / "header.txt") as f:
4✔
896
            self.vcf_header = f.read()
4✔
897
        self.compressor = numcodecs.get_codec(self.metadata.compressor)
4✔
898
        self.fields = {}
4✔
899
        partition_num_records = [
4✔
900
            partition.num_records for partition in self.metadata.partitions
901
        ]
902
        # Allow us to find which partition a given record is in
903
        self.partition_record_index = np.cumsum([0, *partition_num_records])
4✔
904
        self.gt_field = None
4✔
905
        for field in self.metadata.fields:
4✔
906
            self.fields[field.full_name] = IntermediateColumnarFormatField(self, field)
4✔
907
            if field.name == "GT":
4✔
908
                self.gt_field = field
4✔
909

910
        logger.info(
4✔
911
            f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, "
912
            f"records={self.num_records}, fields={self.num_fields})"
913
        )
914

915
    def __repr__(self):
4✔
916
        return (
4✔
917
            f"IntermediateColumnarFormat(fields={len(self.fields)}, "
918
            f"partitions={self.num_partitions}, "
919
            f"records={self.num_records}, path={self.path})"
920
        )
921

922
    def summary_table(self):
4✔
923
        data = []
4✔
924
        for name, icf_field in self.fields.items():
4✔
925
            summary = icf_field.vcf_field.summary
4✔
926
            d = {
4✔
927
                "name": name,
928
                "type": icf_field.vcf_field.vcf_type,
929
                "chunks": summary.num_chunks,
930
                "size": core.display_size(summary.uncompressed_size),
931
                "compressed": core.display_size(summary.compressed_size),
932
                "max_n": summary.max_number,
933
                "min_val": core.display_number(summary.min_value),
934
                "max_val": core.display_number(summary.max_value),
935
            }
936

937
            data.append(d)
4✔
938
        return data
4✔
939

940
    @property
4✔
941
    def path(self):
4✔
942
        return self._path
4✔
943

944
    @property
4✔
945
    def num_records(self):
4✔
946
        return self.metadata.num_records
4✔
947

948
    @property
4✔
949
    def num_partitions(self):
4✔
950
        return len(self.metadata.partitions)
4✔
951

952
    @property
4✔
953
    def samples(self):
4✔
954
        return self.metadata.samples
4✔
955

956
    @property
4✔
957
    def contigs(self):
4✔
958
        return self.metadata.contigs
4✔
959

960
    @property
4✔
961
    def filters(self):
4✔
962
        return self.metadata.filters
4✔
963

964
    @property
4✔
965
    def num_samples(self):
4✔
966
        return len(self.metadata.samples)
4✔
967

968
    @property
4✔
969
    def num_fields(self):
4✔
970
        return len(self.fields)
4✔
971

972
    @property
4✔
973
    def root_attrs(self):
4✔
974
        return {
4✔
975
            "vcf_header": self.vcf_header,
976
        }
977

978
    def iter_id(self, start, stop):
4✔
979
        for value in self.fields["ID"].iter_values(start, stop):
4✔
980
            if value is not None:
4✔
981
                yield value[0]
4✔
982
            else:
983
                yield None
4✔
984

985
    def iter_filters(self, start, stop):
4✔
986
        source_field = self.fields["FILTERS"]
4✔
987
        lookup = {filt.id: index for index, filt in enumerate(self.metadata.filters)}
4✔
988

989
        for filter_values in source_field.iter_values(start, stop):
4✔
990
            filters = np.zeros(len(self.metadata.filters), dtype=bool)
4✔
991
            if filter_values is not None:
4✔
992
                for filter_id in filter_values:
4✔
993
                    try:
4✔
994
                        filters[lookup[filter_id]] = True
4✔
995
                    except KeyError:
×
996
                        raise ValueError(
×
997
                            f"Filter '{filter_id}' was not defined in the header."
998
                        ) from None
999
            yield filters
4✔
1000

1001
    def iter_contig(self, start, stop):
4✔
1002
        source_field = self.fields["CHROM"]
4✔
1003
        lookup = {
4✔
1004
            contig.id: index for index, contig in enumerate(self.metadata.contigs)
1005
        }
1006

1007
        for value in source_field.iter_values(start, stop):
4✔
1008
            # Note: because we are using the indexes to define the lookups
1009
            # and we always have an index, it seems that we the contig lookup
1010
            # will always succeed. However, if anyone ever does hit a KeyError
1011
            # here, please do open an issue with a reproducible example!
1012
            yield lookup[value[0]]
4✔
1013

1014
    def iter_field(self, field_name, shape, start, stop):
4✔
1015
        source_field = self.fields[field_name]
4✔
1016
        sanitiser = source_field.sanitiser_factory(shape)
4✔
1017
        for value in source_field.iter_values(start, stop):
4✔
1018
            yield sanitiser(value)
4✔
1019

1020
    def iter_alleles(self, start, stop, num_alleles):
4✔
1021
        ref_field = self.fields["REF"]
4✔
1022
        alt_field = self.fields["ALT"]
4✔
1023

1024
        for ref, alt in zip(
4✔
1025
            ref_field.iter_values(start, stop),
1026
            alt_field.iter_values(start, stop),
1027
        ):
1028
            alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
4✔
1029
            alleles[0] = ref[0]
4✔
1030
            alleles[1 : 1 + len(alt)] = alt
4✔
1031
            yield alleles
4✔
1032

1033
    def iter_genotypes(self, shape, start, stop):
4✔
1034
        source_field = self.fields["FORMAT/GT"]
4✔
1035
        for value in source_field.iter_values(start, stop):
4✔
1036
            genotypes = value[:, :-1] if value is not None else None
4✔
1037
            phased = value[:, -1] if value is not None else None
4✔
1038
            sanitised_genotypes = sanitise_value_int_2d(shape, genotypes)
4✔
1039
            sanitised_phased = sanitise_value_int_1d(shape[:-1], phased)
4✔
1040
            yield sanitised_genotypes, sanitised_phased
4✔
1041

1042
    def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
4✔
1043
        if self.gt_field is None or shape is None:
4✔
1044
            for alleles in self.iter_alleles(start, stop, num_alleles):
4✔
1045
                yield alleles, (None, None)
4✔
1046
        else:
1047
            yield from zip(
4✔
1048
                self.iter_alleles(start, stop, num_alleles),
1049
                self.iter_genotypes(shape, start, stop),
1050
            )
1051

1052
    def generate_schema(
4✔
1053
        self, variants_chunk_size=None, samples_chunk_size=None, local_alleles=None
1054
    ):
1055
        m = self.num_records
4✔
1056
        n = self.num_samples
4✔
1057
        if local_alleles is None:
4✔
1058
            local_alleles = False
4✔
1059

1060
        max_alleles = max(self.fields["ALT"].vcf_field.summary.max_number + 1, 2)
4✔
1061
        dimensions = {
4✔
1062
            "variants": vcz.VcfZarrDimension(
1063
                size=m, chunk_size=variants_chunk_size or vcz.DEFAULT_VARIANT_CHUNK_SIZE
1064
            ),
1065
            "samples": vcz.VcfZarrDimension(
1066
                size=n, chunk_size=samples_chunk_size or vcz.DEFAULT_SAMPLE_CHUNK_SIZE
1067
            ),
1068
            # ploidy and genotypes added conditionally below
1069
            "alleles": vcz.VcfZarrDimension(size=max_alleles),
1070
            "alt_alleles": vcz.VcfZarrDimension(size=max_alleles - 1),
1071
            "filters": vcz.VcfZarrDimension(size=self.metadata.num_filters),
1072
        }
1073

1074
        # Add ploidy and genotypes dimensions only when needed
1075
        max_genotypes = 0
4✔
1076
        for field in self.metadata.format_fields:
4✔
1077
            if field.vcf_number == "G":
4✔
1078
                max_genotypes = max(max_genotypes, field.summary.max_number)
4✔
1079
        if self.gt_field is not None:
4✔
1080
            ploidy = max(self.gt_field.summary.max_number - 1, 1)
4✔
1081
            dimensions["ploidy"] = vcz.VcfZarrDimension(size=ploidy)
4✔
1082
            max_genotypes = math.comb(max_alleles + ploidy - 1, ploidy)
4✔
1083
            dimensions["genotypes"] = vcz.VcfZarrDimension(size=max_genotypes)
4✔
1084
        else:
1085
            if max_genotypes > 0:
4✔
1086
                # there is no GT field, but there is at least one Number=G field,
1087
                # so need to define genotypes dimension
1088
                dimensions["genotypes"] = vcz.VcfZarrDimension(size=max_genotypes)
4✔
1089

1090
        schema_instance = vcz.VcfZarrSchema(
4✔
1091
            format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION,
1092
            dimensions=dimensions,
1093
            fields=[],
1094
        )
1095

1096
        logger.info(
4✔
1097
            "Generating schema with chunks="
1098
            f"variants={dimensions['variants'].chunk_size}, "
1099
            f"samples={dimensions['samples'].chunk_size}"
1100
        )
1101

1102
        def spec_from_field(field, array_name=None):
4✔
1103
            return vcz.ZarrArraySpec.from_field(
4✔
1104
                field,
1105
                schema_instance,
1106
                array_name=array_name,
1107
            )
1108

1109
        def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
4✔
1110
            compressor = (
4✔
1111
                vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config()
1112
                if dtype == "bool"
1113
                else None
1114
            )
1115
            return vcz.ZarrArraySpec(
4✔
1116
                source=source,
1117
                name=name,
1118
                dtype=dtype,
1119
                description="",
1120
                dimensions=dimensions,
1121
                compressor=compressor,
1122
            )
1123

1124
        array_specs = [
4✔
1125
            fixed_field_spec(
1126
                name="variant_contig",
1127
                dtype=core.min_int_dtype(0, self.metadata.num_contigs),
1128
            ),
1129
            fixed_field_spec(
1130
                name="variant_filter",
1131
                dtype="bool",
1132
                dimensions=["variants", "filters"],
1133
            ),
1134
            fixed_field_spec(
1135
                name="variant_allele",
1136
                dtype="O",
1137
                dimensions=["variants", "alleles"],
1138
            ),
1139
            fixed_field_spec(
1140
                name="variant_id",
1141
                dtype="O",
1142
            ),
1143
            fixed_field_spec(
1144
                name="variant_id_mask",
1145
                dtype="bool",
1146
            ),
1147
        ]
1148
        name_map = {field.full_name: field for field in self.metadata.fields}
4✔
1149

1150
        # Only three of the fixed fields have a direct one-to-one mapping.
1151
        array_specs.extend(
4✔
1152
            [
1153
                spec_from_field(name_map["QUAL"], array_name="variant_quality"),
1154
                spec_from_field(name_map["POS"], array_name="variant_position"),
1155
                spec_from_field(name_map["rlen"], array_name="variant_length"),
1156
            ]
1157
        )
1158
        array_specs.extend(
4✔
1159
            [spec_from_field(field) for field in self.metadata.info_fields]
1160
        )
1161

1162
        for field in self.metadata.format_fields:
4✔
1163
            if field.name == "GT":
4✔
1164
                continue
4✔
1165
            array_specs.append(spec_from_field(field))
4✔
1166

1167
        if self.gt_field is not None and n > 0:
4✔
1168
            array_specs.append(
4✔
1169
                vcz.ZarrArraySpec(
1170
                    name="call_genotype_phased",
1171
                    dtype="bool",
1172
                    dimensions=["variants", "samples"],
1173
                    description="",
1174
                    compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
1175
                )
1176
            )
1177
            array_specs.append(
4✔
1178
                vcz.ZarrArraySpec(
1179
                    name="call_genotype",
1180
                    dtype=self.gt_field.smallest_dtype(),
1181
                    dimensions=["variants", "samples", "ploidy"],
1182
                    description="",
1183
                    compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(),
1184
                )
1185
            )
1186
            array_specs.append(
4✔
1187
                vcz.ZarrArraySpec(
1188
                    name="call_genotype_mask",
1189
                    dtype="bool",
1190
                    dimensions=["variants", "samples", "ploidy"],
1191
                    description="",
1192
                    compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
1193
                )
1194
            )
1195

1196
        if local_alleles:
4✔
1197
            array_specs = convert_local_allele_field_types(array_specs, schema_instance)
4✔
1198

1199
        schema_instance.fields = array_specs
4✔
1200
        return schema_instance
4✔
1201

1202

1203
@dataclasses.dataclass
4✔
1204
class IcfPartitionMetadata(core.JsonDataclass):
4✔
1205
    num_records: int
4✔
1206
    last_position: int
4✔
1207
    field_summaries: dict
4✔
1208

1209
    @staticmethod
4✔
1210
    def fromdict(d):
4✔
1211
        md = IcfPartitionMetadata(**d)
4✔
1212
        for k, v in md.field_summaries.items():
4✔
1213
            md.field_summaries[k] = VcfFieldSummary.fromdict(v)
4✔
1214
        return md
4✔
1215

1216

1217
def check_overlapping_partitions(partitions):
4✔
1218
    for i in range(1, len(partitions)):
4✔
1219
        prev_region = partitions[i - 1].region
4✔
1220
        current_region = partitions[i].region
4✔
1221
        if prev_region.contig == current_region.contig:
4✔
1222
            assert prev_region.end is not None
4✔
1223
            # Regions are *inclusive*
1224
            if prev_region.end >= current_region.start:
4✔
1225
                raise ValueError(
4✔
1226
                    f"Overlapping VCF regions in partitions {i - 1} and {i}: "
1227
                    f"{prev_region} and {current_region}"
1228
                )
1229

1230

1231
def check_field_clobbering(icf_metadata):
4✔
1232
    info_field_names = set(field.name for field in icf_metadata.info_fields)
4✔
1233
    fixed_variant_fields = set(
4✔
1234
        ["contig", "id", "id_mask", "position", "allele", "filter", "quality"]
1235
    )
1236
    intersection = info_field_names & fixed_variant_fields
4✔
1237
    if len(intersection) > 0:
4✔
1238
        raise ValueError(
4✔
1239
            f"INFO field name(s) clashing with VCF Zarr spec: {intersection}"
1240
        )
1241

1242
    format_field_names = set(field.name for field in icf_metadata.format_fields)
4✔
1243
    fixed_variant_fields = set(["genotype", "genotype_phased", "genotype_mask"])
4✔
1244
    intersection = format_field_names & fixed_variant_fields
4✔
1245
    if len(intersection) > 0:
4✔
1246
        raise ValueError(
4✔
1247
            f"FORMAT field name(s) clashing with VCF Zarr spec: {intersection}"
1248
        )
1249

1250

1251
@dataclasses.dataclass
4✔
1252
class IcfWriteSummary(core.JsonDataclass):
4✔
1253
    num_partitions: int
4✔
1254
    num_samples: int
4✔
1255
    num_variants: int
4✔
1256

1257

1258
class IntermediateColumnarFormatWriter:
4✔
1259
    def __init__(self, path):
4✔
1260
        self.path = pathlib.Path(path)
4✔
1261
        self.wip_path = self.path / "wip"
4✔
1262
        self.metadata = None
4✔
1263

1264
    @property
4✔
1265
    def num_partitions(self):
4✔
1266
        return len(self.metadata.partitions)
4✔
1267

1268
    def init(
4✔
1269
        self,
1270
        vcfs,
1271
        *,
1272
        column_chunk_size=16,
1273
        worker_processes=1,
1274
        target_num_partitions=None,
1275
        show_progress=False,
1276
        compressor=None,
1277
    ):
1278
        if self.path.exists():
4✔
1279
            raise ValueError(f"ICF path already exists: {self.path}")
×
1280
        if compressor is None:
4✔
1281
            compressor = ICF_DEFAULT_COMPRESSOR
4✔
1282
        vcfs = [pathlib.Path(vcf) for vcf in vcfs]
4✔
1283
        target_num_partitions = max(target_num_partitions, len(vcfs))
4✔
1284

1285
        # TODO move scan_vcfs into this class
1286
        icf_metadata, header = scan_vcfs(
4✔
1287
            vcfs,
1288
            worker_processes=worker_processes,
1289
            show_progress=show_progress,
1290
            target_num_partitions=target_num_partitions,
1291
        )
1292
        check_field_clobbering(icf_metadata)
4✔
1293
        self.metadata = icf_metadata
4✔
1294
        self.metadata.format_version = ICF_METADATA_FORMAT_VERSION
4✔
1295
        self.metadata.compressor = compressor.get_config()
4✔
1296
        self.metadata.column_chunk_size = column_chunk_size
4✔
1297
        # Bare minimum here for provenance - would be nice to include versions of key
1298
        # dependencies as well.
1299
        self.metadata.provenance = {"source": f"bio2zarr-{provenance.__version__}"}
4✔
1300

1301
        self.mkdirs()
4✔
1302

1303
        # Note: this is needed for the current version of the vcfzarr spec, but it's
1304
        # probably going to be dropped.
1305
        # https://github.com/pystatgen/vcf-zarr-spec/issues/15
1306
        # May be useful to keep lying around still though?
1307
        logger.info("Writing VCF header")
4✔
1308
        with open(self.path / "header.txt", "w") as f:
4✔
1309
            f.write(header)
4✔
1310

1311
        logger.info("Writing WIP metadata")
4✔
1312
        with open(self.wip_path / "metadata.json", "w") as f:
4✔
1313
            json.dump(self.metadata.asdict(), f, indent=4)
4✔
1314
        return IcfWriteSummary(
4✔
1315
            num_partitions=self.num_partitions,
1316
            num_variants=icf_metadata.num_records,
1317
            num_samples=icf_metadata.num_samples,
1318
        )
1319

1320
    def mkdirs(self):
4✔
1321
        num_dirs = len(self.metadata.fields)
4✔
1322
        logger.info(f"Creating {num_dirs} field directories")
4✔
1323
        self.path.mkdir()
4✔
1324
        self.wip_path.mkdir()
4✔
1325
        for field in self.metadata.fields:
4✔
1326
            field_path = get_vcf_field_path(self.path, field)
4✔
1327
            field_path.mkdir(parents=True)
4✔
1328

1329
    def load_partition_summaries(self):
4✔
1330
        summaries = []
4✔
1331
        not_found = []
4✔
1332
        for j in range(self.num_partitions):
4✔
1333
            try:
4✔
1334
                with open(self.wip_path / f"p{j}.json") as f:
4✔
1335
                    summaries.append(IcfPartitionMetadata.fromdict(json.load(f)))
4✔
1336
            except FileNotFoundError:
4✔
1337
                not_found.append(j)
4✔
1338
        if len(not_found) > 0:
4✔
1339
            raise FileNotFoundError(
4✔
1340
                f"Partition metadata not found for {len(not_found)}"
1341
                f" partitions: {not_found}"
1342
            )
1343
        return summaries
4✔
1344

1345
    def load_metadata(self):
4✔
1346
        if self.metadata is None:
4✔
1347
            with open(self.wip_path / "metadata.json") as f:
4✔
1348
                self.metadata = IcfMetadata.fromdict(json.load(f))
4✔
1349

1350
    def process_partition(self, partition_index):
4✔
1351
        self.load_metadata()
4✔
1352
        summary_path = self.wip_path / f"p{partition_index}.json"
4✔
1353
        # If someone is rewriting a summary path (for whatever reason), make sure it
1354
        # doesn't look like it's already been completed.
1355
        # NOTE to do this properly we probably need to take a lock on this file - but
1356
        # this simple approach will catch the vast majority of problems.
1357
        if summary_path.exists():
4✔
1358
            summary_path.unlink()
4✔
1359

1360
        partition = self.metadata.partitions[partition_index]
4✔
1361
        logger.info(
4✔
1362
            f"Start p{partition_index} {partition.vcf_path}__{partition.region}"
1363
        )
1364
        info_fields = self.metadata.info_fields
4✔
1365
        format_fields = []
4✔
1366
        has_gt = False
4✔
1367
        for field in self.metadata.format_fields:
4✔
1368
            if field.name == "GT":
4✔
1369
                has_gt = True
4✔
1370
            else:
1371
                format_fields.append(field)
4✔
1372

1373
        last_position = None
4✔
1374
        with IcfPartitionWriter(
4✔
1375
            self.metadata,
1376
            self.path,
1377
            partition_index,
1378
        ) as tcw:
1379
            with vcf_utils.VcfFile(partition.vcf_path) as vcf:
4✔
1380
                num_records = 0
4✔
1381
                for variant in vcf.variants(partition.region):
4✔
1382
                    num_records += 1
4✔
1383
                    last_position = variant.POS
4✔
1384
                    tcw.append("CHROM", variant.CHROM)
4✔
1385
                    tcw.append("POS", variant.POS)
4✔
1386
                    tcw.append("QUAL", variant.QUAL)
4✔
1387
                    tcw.append("ID", variant.ID)
4✔
1388
                    tcw.append("FILTERS", variant.FILTERS)
4✔
1389
                    tcw.append("REF", variant.REF)
4✔
1390
                    tcw.append("ALT", variant.ALT)
4✔
1391
                    tcw.append("rlen", variant.end - variant.start)
4✔
1392
                    for field in info_fields:
4✔
1393
                        tcw.append(field.full_name, variant.INFO.get(field.name, None))
4✔
1394
                    if has_gt:
4✔
1395
                        val = None
4✔
1396
                        if "GT" in variant.FORMAT and variant.genotype is not None:
4✔
1397
                            val = variant.genotype.array()
4✔
1398
                        tcw.append("FORMAT/GT", val)
4✔
1399
                    for field in format_fields:
4✔
1400
                        val = variant.format(field.name)
4✔
1401
                        tcw.append(field.full_name, val)
4✔
1402

1403
                    # Note: an issue with updating the progress per variant here like
1404
                    # this is that we get a significant pause at the end of the counter
1405
                    # while all the "small" fields get flushed. Possibly not much to be
1406
                    # done about it.
1407
                    core.update_progress(1)
4✔
1408
            logger.info(
4✔
1409
                f"Finished reading VCF for partition {partition_index}, "
1410
                f"flushing buffers"
1411
            )
1412

1413
        partition_metadata = IcfPartitionMetadata(
4✔
1414
            num_records=num_records,
1415
            last_position=last_position,
1416
            field_summaries=tcw.field_summaries,
1417
        )
1418
        with open(summary_path, "w") as f:
4✔
1419
            f.write(partition_metadata.asjson())
4✔
1420
        logger.info(
4✔
1421
            f"Finish p{partition_index} {partition.vcf_path}__{partition.region} "
1422
            f"{num_records} records last_pos={last_position}"
1423
        )
1424

1425
    def explode(self, *, worker_processes=1, show_progress=False):
4✔
1426
        self.load_metadata()
4✔
1427
        num_records = self.metadata.num_records
4✔
1428
        if np.isinf(num_records):
4✔
1429
            logger.warning(
4✔
1430
                "Total records unknown, cannot show progress; "
1431
                "reindex VCFs with bcftools index to fix"
1432
            )
1433
            num_records = None
4✔
1434
        num_fields = len(self.metadata.fields)
4✔
1435
        num_samples = len(self.metadata.samples)
4✔
1436
        logger.info(
4✔
1437
            f"Exploding fields={num_fields} samples={num_samples}; "
1438
            f"partitions={self.num_partitions} "
1439
            f"variants={'unknown' if num_records is None else num_records}"
1440
        )
1441
        progress_config = core.ProgressConfig(
4✔
1442
            total=num_records,
1443
            units="vars",
1444
            title="Explode",
1445
            show=show_progress,
1446
        )
1447
        with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
4✔
1448
            for j in range(self.num_partitions):
4✔
1449
                pwm.submit(self.process_partition, j)
4✔
1450

1451
    def explode_partition(self, partition):
4✔
1452
        self.load_metadata()
4✔
1453
        if partition < 0 or partition >= self.num_partitions:
4✔
1454
            raise ValueError("Partition index not in the valid range")
4✔
1455
        self.process_partition(partition)
4✔
1456

1457
    def finalise(self):
4✔
1458
        self.load_metadata()
4✔
1459
        partition_summaries = self.load_partition_summaries()
4✔
1460
        total_records = 0
4✔
1461
        for index, summary in enumerate(partition_summaries):
4✔
1462
            partition_records = summary.num_records
4✔
1463
            self.metadata.partitions[index].num_records = partition_records
4✔
1464
            self.metadata.partitions[index].region.end = summary.last_position
4✔
1465
            total_records += partition_records
4✔
1466
        if not np.isinf(self.metadata.num_records):
4✔
1467
            # Note: this is just telling us that there's a bug in the
1468
            # index based record counting code, but it doesn't actually
1469
            # matter much. We may want to just make this a warning if
1470
            # we hit regular problems.
1471
            assert total_records == self.metadata.num_records
4✔
1472
        self.metadata.num_records = total_records
4✔
1473

1474
        check_overlapping_partitions(self.metadata.partitions)
4✔
1475

1476
        for field in self.metadata.fields:
4✔
1477
            for summary in partition_summaries:
4✔
1478
                field.summary.update(summary.field_summaries[field.full_name])
4✔
1479

1480
        logger.info("Finalising metadata")
4✔
1481
        with open(self.path / "metadata.json", "w") as f:
4✔
1482
            f.write(self.metadata.asjson())
4✔
1483

1484
        logger.debug("Removing WIP directory")
4✔
1485
        shutil.rmtree(self.wip_path)
4✔
1486

1487

1488
def explode(
4✔
1489
    icf_path,
1490
    vcfs,
1491
    *,
1492
    column_chunk_size=16,
1493
    worker_processes=1,
1494
    show_progress=False,
1495
    compressor=None,
1496
):
1497
    writer = IntermediateColumnarFormatWriter(icf_path)
4✔
1498
    writer.init(
4✔
1499
        vcfs,
1500
        # Heuristic to get reasonable worker utilisation with lumpy partition sizing
1501
        target_num_partitions=max(1, worker_processes * 4),
1502
        worker_processes=worker_processes,
1503
        show_progress=show_progress,
1504
        column_chunk_size=column_chunk_size,
1505
        compressor=compressor,
1506
    )
1507
    writer.explode(worker_processes=worker_processes, show_progress=show_progress)
4✔
1508
    writer.finalise()
4✔
1509
    return IntermediateColumnarFormat(icf_path)
4✔
1510

1511

1512
def explode_init(
4✔
1513
    icf_path,
1514
    vcfs,
1515
    *,
1516
    column_chunk_size=16,
1517
    target_num_partitions=1,
1518
    worker_processes=1,
1519
    show_progress=False,
1520
    compressor=None,
1521
):
1522
    writer = IntermediateColumnarFormatWriter(icf_path)
4✔
1523
    return writer.init(
4✔
1524
        vcfs,
1525
        target_num_partitions=target_num_partitions,
1526
        worker_processes=worker_processes,
1527
        show_progress=show_progress,
1528
        column_chunk_size=column_chunk_size,
1529
        compressor=compressor,
1530
    )
1531

1532

1533
def explode_partition(icf_path, partition):
4✔
1534
    writer = IntermediateColumnarFormatWriter(icf_path)
4✔
1535
    writer.explode_partition(partition)
4✔
1536

1537

1538
def explode_finalise(icf_path):
4✔
1539
    writer = IntermediateColumnarFormatWriter(icf_path)
4✔
1540
    writer.finalise()
4✔
1541

1542

1543
def inspect(path):
4✔
1544
    path = pathlib.Path(path)
4✔
1545
    if not path.exists():
4✔
1546
        raise ValueError(f"Path not found: {path}")
4✔
1547
    if (path / "metadata.json").exists():
4✔
1548
        obj = IntermediateColumnarFormat(path)
4✔
1549
    # NOTE: this is too strict, we should support more general Zarrs, see #276
1550
    elif (path / ".zmetadata").exists():
4✔
1551
        obj = vcz.VcfZarr(path)
4✔
1552
    else:
1553
        raise ValueError(f"{path} not in ICF or VCF Zarr format")
4✔
1554
    return obj.summary_table()
4✔
1555

1556

1557
def mkschema(
4✔
1558
    if_path,
1559
    out,
1560
    *,
1561
    variants_chunk_size=None,
1562
    samples_chunk_size=None,
1563
    local_alleles=None,
1564
):
1565
    store = IntermediateColumnarFormat(if_path)
4✔
1566
    spec = store.generate_schema(
4✔
1567
        variants_chunk_size=variants_chunk_size,
1568
        samples_chunk_size=samples_chunk_size,
1569
        local_alleles=local_alleles,
1570
    )
1571
    out.write(spec.asjson())
4✔
1572

1573

1574
def convert(
4✔
1575
    vcfs,
1576
    out_path,
1577
    *,
1578
    variants_chunk_size=None,
1579
    samples_chunk_size=None,
1580
    worker_processes=1,
1581
    local_alleles=None,
1582
    show_progress=False,
1583
    icf_path=None,
1584
):
1585
    if icf_path is None:
4✔
1586
        cm = temp_icf_path(prefix="vcf2zarr")
4✔
1587
    else:
1588
        cm = contextlib.nullcontext(icf_path)
4✔
1589

1590
    with cm as icf_path:
4✔
1591
        explode(
4✔
1592
            icf_path,
1593
            vcfs,
1594
            worker_processes=worker_processes,
1595
            show_progress=show_progress,
1596
        )
1597
        encode(
4✔
1598
            icf_path,
1599
            out_path,
1600
            variants_chunk_size=variants_chunk_size,
1601
            samples_chunk_size=samples_chunk_size,
1602
            worker_processes=worker_processes,
1603
            show_progress=show_progress,
1604
            local_alleles=local_alleles,
1605
        )
1606

1607

1608
@contextlib.contextmanager
4✔
1609
def temp_icf_path(prefix=None):
4✔
1610
    with tempfile.TemporaryDirectory(prefix=prefix) as tmp:
4✔
1611
        yield pathlib.Path(tmp) / "icf"
4✔
1612

1613

1614
def encode(
4✔
1615
    icf_path,
1616
    zarr_path,
1617
    schema_path=None,
1618
    variants_chunk_size=None,
1619
    samples_chunk_size=None,
1620
    max_variant_chunks=None,
1621
    dimension_separator=None,
1622
    max_memory=None,
1623
    local_alleles=None,
1624
    worker_processes=1,
1625
    show_progress=False,
1626
):
1627
    # Rough heuristic to split work up enough to keep utilisation high
1628
    target_num_partitions = max(1, worker_processes * 4)
4✔
1629
    encode_init(
4✔
1630
        icf_path,
1631
        zarr_path,
1632
        target_num_partitions,
1633
        schema_path=schema_path,
1634
        variants_chunk_size=variants_chunk_size,
1635
        samples_chunk_size=samples_chunk_size,
1636
        local_alleles=local_alleles,
1637
        max_variant_chunks=max_variant_chunks,
1638
        dimension_separator=dimension_separator,
1639
    )
1640
    vzw = vcz.VcfZarrWriter(IntermediateColumnarFormat, zarr_path)
4✔
1641
    vzw.encode_all_partitions(
4✔
1642
        worker_processes=worker_processes,
1643
        show_progress=show_progress,
1644
        max_memory=max_memory,
1645
    )
1646
    vzw.finalise(show_progress)
4✔
1647
    vzw.create_index()
4✔
1648

1649

1650
def encode_init(
4✔
1651
    icf_path,
1652
    zarr_path,
1653
    target_num_partitions,
1654
    *,
1655
    schema_path=None,
1656
    variants_chunk_size=None,
1657
    samples_chunk_size=None,
1658
    local_alleles=None,
1659
    max_variant_chunks=None,
1660
    dimension_separator=None,
1661
    max_memory=None,
1662
    worker_processes=1,
1663
    show_progress=False,
1664
):
1665
    icf_store = IntermediateColumnarFormat(icf_path)
4✔
1666
    if schema_path is None:
4✔
1667
        schema_instance = icf_store.generate_schema(
4✔
1668
            variants_chunk_size=variants_chunk_size,
1669
            samples_chunk_size=samples_chunk_size,
1670
            local_alleles=local_alleles,
1671
        )
1672
    else:
1673
        logger.info(f"Reading schema from {schema_path}")
4✔
1674
        if variants_chunk_size is not None or samples_chunk_size is not None:
4✔
1675
            raise ValueError(
×
1676
                "Cannot specify schema along with chunk sizes"
1677
            )  # NEEDS TEST
1678
        with open(schema_path) as f:
4✔
1679
            schema_instance = vcz.VcfZarrSchema.fromjson(f.read())
4✔
1680
    zarr_path = pathlib.Path(zarr_path)
4✔
1681
    vzw = vcz.VcfZarrWriter("icf", zarr_path)
4✔
1682
    return vzw.init(
4✔
1683
        icf_store,
1684
        target_num_partitions=target_num_partitions,
1685
        schema=schema_instance,
1686
        dimension_separator=dimension_separator,
1687
        max_variant_chunks=max_variant_chunks,
1688
    )
1689

1690

1691
def encode_partition(zarr_path, partition):
4✔
1692
    writer_instance = vcz.VcfZarrWriter(IntermediateColumnarFormat, zarr_path)
4✔
1693
    writer_instance.encode_partition(partition)
4✔
1694

1695

1696
def encode_finalise(zarr_path, show_progress=False):
4✔
1697
    writer_instance = vcz.VcfZarrWriter(IntermediateColumnarFormat, zarr_path)
4✔
1698
    writer_instance.finalise(show_progress=show_progress)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc