• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sgkit-dev / bio2zarr / 14264606989

04 Apr 2025 11:44AM UTC coverage: 98.771% (-0.1%) from 98.867%
14264606989

Pull #339

github

web-flow
Merge d4d724630 into 5439661d3
Pull Request #339: Common writer for plink and ICF

778 of 789 new or added lines in 6 files covered. (98.61%)

1 existing line in 1 file now uncovered.

2572 of 2604 relevant lines covered (98.77%)

5.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.71
/bio2zarr/vcf2zarr/icf.py
1
import collections
6✔
2
import contextlib
6✔
3
import dataclasses
6✔
4
import json
6✔
5
import logging
6✔
6
import math
6✔
7
import pathlib
6✔
8
import pickle
6✔
9
import shutil
6✔
10
import sys
6✔
11
from functools import partial
6✔
12
from typing import Any
6✔
13

14
import numcodecs
6✔
15
import numpy as np
6✔
16

17
from bio2zarr import schema
6✔
18

19
from .. import constants, core, provenance, vcf_utils
6✔
20

21
logger = logging.getLogger(__name__)
6✔
22

23

24
@dataclasses.dataclass
6✔
25
class VcfFieldSummary(core.JsonDataclass):
6✔
26
    num_chunks: int = 0
6✔
27
    compressed_size: int = 0
6✔
28
    uncompressed_size: int = 0
6✔
29
    max_number: int = 0  # Corresponds to VCF Number field, depends on context
6✔
30
    # Only defined for numeric fields
31
    max_value: Any = -math.inf
6✔
32
    min_value: Any = math.inf
6✔
33

34
    def update(self, other):
6✔
35
        self.num_chunks += other.num_chunks
6✔
36
        self.compressed_size += other.compressed_size
6✔
37
        self.uncompressed_size += other.uncompressed_size
6✔
38
        self.max_number = max(self.max_number, other.max_number)
6✔
39
        self.min_value = min(self.min_value, other.min_value)
6✔
40
        self.max_value = max(self.max_value, other.max_value)
6✔
41

42
    @staticmethod
6✔
43
    def fromdict(d):
6✔
44
        return VcfFieldSummary(**d)
6✔
45

46

47
@dataclasses.dataclass(order=True)
6✔
48
class VcfField:
6✔
49
    category: str
6✔
50
    name: str
6✔
51
    vcf_number: str
6✔
52
    vcf_type: str
6✔
53
    description: str
6✔
54
    summary: VcfFieldSummary
6✔
55

56
    @staticmethod
6✔
57
    def from_header(definition):
6✔
58
        category = definition["HeaderType"]
6✔
59
        name = definition["ID"]
6✔
60
        vcf_number = definition["Number"]
6✔
61
        vcf_type = definition["Type"]
6✔
62
        return VcfField(
6✔
63
            category=category,
64
            name=name,
65
            vcf_number=vcf_number,
66
            vcf_type=vcf_type,
67
            description=definition["Description"].strip('"'),
68
            summary=VcfFieldSummary(),
69
        )
70

71
    @staticmethod
6✔
72
    def fromdict(d):
6✔
73
        f = VcfField(**d)
6✔
74
        f.summary = VcfFieldSummary(**d["summary"])
6✔
75
        return f
6✔
76

77
    @property
6✔
78
    def full_name(self):
6✔
79
        if self.category == "fixed":
6✔
80
            return self.name
6✔
81
        return f"{self.category}/{self.name}"
6✔
82

83
    def smallest_dtype(self):
6✔
84
        """
85
        Returns the smallest dtype suitable for this field based
86
        on type, and values.
87
        """
88
        s = self.summary
6✔
89
        if self.vcf_type == "Float":
6✔
90
            ret = "f4"
6✔
91
        elif self.vcf_type == "Integer":
6✔
92
            if not math.isfinite(s.max_value):
6✔
93
                # All missing values; use i1. Note we should have some API to
94
                # check more explicitly for missingness:
95
                # https://github.com/sgkit-dev/bio2zarr/issues/131
96
                ret = "i1"
6✔
97
            else:
98
                ret = core.min_int_dtype(s.min_value, s.max_value)
6✔
99
        elif self.vcf_type == "Flag":
6✔
100
            ret = "bool"
6✔
101
        elif self.vcf_type == "Character":
6✔
102
            ret = "U1"
6✔
103
        else:
104
            assert self.vcf_type == "String"
6✔
105
            ret = "O"
6✔
106
        return ret
6✔
107

108

109
@dataclasses.dataclass
6✔
110
class VcfPartition:
6✔
111
    vcf_path: str
6✔
112
    region: str
6✔
113
    num_records: int = -1
6✔
114

115

116
ICF_METADATA_FORMAT_VERSION = "0.4"
6✔
117
ICF_DEFAULT_COMPRESSOR = numcodecs.Blosc(
6✔
118
    cname="zstd", clevel=7, shuffle=numcodecs.Blosc.NOSHUFFLE
119
)
120

121

122
@dataclasses.dataclass
6✔
123
class IcfMetadata(core.JsonDataclass):
6✔
124
    samples: list
6✔
125
    contigs: list
6✔
126
    filters: list
6✔
127
    fields: list
6✔
128
    partitions: list = None
6✔
129
    format_version: str = None
6✔
130
    compressor: dict = None
6✔
131
    column_chunk_size: int = None
6✔
132
    provenance: dict = None
6✔
133
    num_records: int = -1
6✔
134

135
    @property
6✔
136
    def info_fields(self):
6✔
137
        fields = []
6✔
138
        for field in self.fields:
6✔
139
            if field.category == "INFO":
6✔
140
                fields.append(field)
6✔
141
        return fields
6✔
142

143
    @property
6✔
144
    def format_fields(self):
6✔
145
        fields = []
6✔
146
        for field in self.fields:
6✔
147
            if field.category == "FORMAT":
6✔
148
                fields.append(field)
6✔
149
        return fields
6✔
150

151
    @property
6✔
152
    def num_contigs(self):
6✔
153
        return len(self.contigs)
6✔
154

155
    @property
6✔
156
    def num_filters(self):
6✔
157
        return len(self.filters)
6✔
158

159
    @property
6✔
160
    def num_samples(self):
6✔
161
        return len(self.samples)
6✔
162

163
    @staticmethod
6✔
164
    def fromdict(d):
6✔
165
        if d["format_version"] != ICF_METADATA_FORMAT_VERSION:
6✔
166
            raise ValueError(
6✔
167
                "Intermediate columnar metadata format version mismatch: "
168
                f"{d['format_version']} != {ICF_METADATA_FORMAT_VERSION}"
169
            )
170
        partitions = [VcfPartition(**pd) for pd in d["partitions"]]
6✔
171
        for p in partitions:
6✔
172
            p.region = vcf_utils.Region(**p.region)
6✔
173
        d = d.copy()
6✔
174
        d["partitions"] = partitions
6✔
175
        d["fields"] = [VcfField.fromdict(fd) for fd in d["fields"]]
6✔
176
        d["samples"] = [schema.Sample(**sd) for sd in d["samples"]]
6✔
177
        d["filters"] = [schema.Filter(**fd) for fd in d["filters"]]
6✔
178
        d["contigs"] = [schema.Contig(**cd) for cd in d["contigs"]]
6✔
179
        return IcfMetadata(**d)
6✔
180

181
    def __eq__(self, other):
6✔
182
        if not isinstance(other, IcfMetadata):
6✔
183
            return NotImplemented
×
184
        return (
6✔
185
            self.samples == other.samples
186
            and self.contigs == other.contigs
187
            and self.filters == other.filters
188
            and sorted(self.fields) == sorted(other.fields)
189
        )
190

191

192
def fixed_vcf_field_definitions():
6✔
193
    def make_field_def(name, vcf_type, vcf_number):
6✔
194
        return VcfField(
6✔
195
            category="fixed",
196
            name=name,
197
            vcf_type=vcf_type,
198
            vcf_number=vcf_number,
199
            description="",
200
            summary=VcfFieldSummary(),
201
        )
202

203
    fields = [
6✔
204
        make_field_def("CHROM", "String", "1"),
205
        make_field_def("POS", "Integer", "1"),
206
        make_field_def("QUAL", "Float", "1"),
207
        make_field_def("ID", "String", "."),
208
        make_field_def("FILTERS", "String", "."),
209
        make_field_def("REF", "String", "1"),
210
        make_field_def("ALT", "String", "."),
211
        make_field_def("rlen", "Integer", "1"),  # computed field
212
    ]
213
    return fields
6✔
214

215

216
def scan_vcf(path, target_num_partitions):
6✔
217
    with vcf_utils.VcfFile(path) as vcf_file:
6✔
218
        vcf = vcf_file.vcf
6✔
219
        filters = []
6✔
220
        pass_index = -1
6✔
221
        for h in vcf.header_iter():
6✔
222
            if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str):
6✔
223
                try:
6✔
224
                    description = h["Description"].strip('"')
6✔
225
                except KeyError:
×
226
                    description = ""
×
227
                if h["ID"] == "PASS":
6✔
228
                    pass_index = len(filters)
6✔
229
                filters.append(schema.Filter(h["ID"], description))
6✔
230

231
        # Ensure PASS is the first filter if present
232
        if pass_index > 0:
6✔
233
            pass_filter = filters.pop(pass_index)
×
234
            filters.insert(0, pass_filter)
×
235

236
        fields = fixed_vcf_field_definitions()
6✔
237
        for h in vcf.header_iter():
6✔
238
            if h["HeaderType"] in ["INFO", "FORMAT"]:
6✔
239
                field = VcfField.from_header(h)
6✔
240
                if h["HeaderType"] == "FORMAT" and field.name == "GT":
6✔
241
                    field.vcf_type = "Integer"
6✔
242
                    field.vcf_number = "."
6✔
243
                fields.append(field)
6✔
244

245
        try:
6✔
246
            contig_lengths = vcf.seqlens
6✔
247
        except AttributeError:
6✔
248
            contig_lengths = [None for _ in vcf.seqnames]
6✔
249

250
        metadata = IcfMetadata(
6✔
251
            samples=[schema.Sample(sample_id) for sample_id in vcf.samples],
252
            contigs=[
253
                schema.Contig(contig_id, length)
254
                for contig_id, length in zip(vcf.seqnames, contig_lengths)
255
            ],
256
            filters=filters,
257
            fields=fields,
258
            partitions=[],
259
            num_records=sum(vcf_file.contig_record_counts().values()),
260
        )
261

262
        regions = vcf_file.partition_into_regions(num_parts=target_num_partitions)
6✔
263
        for region in regions:
6✔
264
            metadata.partitions.append(
6✔
265
                VcfPartition(
266
                    # TODO should this be fully resolving the path? Otherwise it's all
267
                    # relative to the original WD
268
                    vcf_path=str(path),
269
                    region=region,
270
                )
271
            )
272
        logger.info(
6✔
273
            f"Split {path} into {len(metadata.partitions)} "
274
            f"partitions target={target_num_partitions})"
275
        )
276
        core.update_progress(1)
6✔
277
        return metadata, vcf.raw_header
6✔
278

279

280
def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
6✔
281
    logger.info(
6✔
282
        f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
283
        f" partitions."
284
    )
285
    # An easy mistake to make is to pass the same file twice. Check this early on.
286
    for path, count in collections.Counter(paths).items():
6✔
287
        if not path.exists():  # NEEDS TEST
6✔
288
            raise FileNotFoundError(path)
×
289
        if count > 1:
6✔
290
            raise ValueError(f"Duplicate path provided: {path}")
6✔
291

292
    progress_config = core.ProgressConfig(
6✔
293
        total=len(paths),
294
        units="files",
295
        title="Scan",
296
        show=show_progress,
297
    )
298
    with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
6✔
299
        for path in paths:
6✔
300
            pwm.submit(
6✔
301
                scan_vcf,
302
                path,
303
                max(1, target_num_partitions // len(paths)),
304
            )
305
        results = list(pwm.results_as_completed())
6✔
306

307
    # Sort to make the ordering deterministic
308
    results.sort(key=lambda t: t[0].partitions[0].vcf_path)
6✔
309
    # We just take the first header, assuming the others
310
    # are compatible.
311
    all_partitions = []
6✔
312
    total_records = 0
6✔
313
    contigs = {}
6✔
314
    for metadata, _ in results:
6✔
315
        for partition in metadata.partitions:
6✔
316
            logger.debug(f"Scanned partition {partition}")
6✔
317
            all_partitions.append(partition)
6✔
318
        for contig in metadata.contigs:
6✔
319
            if contig.id in contigs:
6✔
320
                if contig != contigs[contig.id]:
6✔
UNCOV
321
                    raise ValueError(
4✔
322
                        "Incompatible contig definitions: "
323
                        f"{contig} != {contigs[contig.id]}"
324
                    )
325
            else:
326
                contigs[contig.id] = contig
6✔
327
        total_records += metadata.num_records
6✔
328
        metadata.num_records = 0
6✔
329
        metadata.partitions = []
6✔
330

331
    contig_union = list(contigs.values())
6✔
332
    for metadata, _ in results:
6✔
333
        metadata.contigs = contig_union
6✔
334

335
    icf_metadata, header = results[0]
6✔
336
    for metadata, _ in results[1:]:
6✔
337
        if metadata != icf_metadata:
6✔
338
            raise ValueError("Incompatible VCF chunks")
6✔
339

340
    # Note: this will be infinity here if any of the chunks has an index
341
    # that doesn't keep track of the number of records per-contig
342
    icf_metadata.num_records = total_records
6✔
343

344
    # Sort by contig (in the order they appear in the header) first,
345
    # then by start coordinate
346
    contig_index_map = {contig.id: j for j, contig in enumerate(metadata.contigs)}
6✔
347
    all_partitions.sort(
6✔
348
        key=lambda x: (contig_index_map[x.region.contig], x.region.start)
349
    )
350
    icf_metadata.partitions = all_partitions
6✔
351
    logger.info(f"Scan complete, resulting in {len(all_partitions)} partitions.")
6✔
352
    return icf_metadata, header
6✔
353

354

355
def sanitise_value_bool(shape, value):
6✔
356
    x = True
6✔
357
    if value is None:
6✔
358
        x = False
6✔
359
    return x
6✔
360

361

362
def sanitise_value_float_scalar(shape, value):
6✔
363
    x = value
6✔
364
    if value is None:
6✔
365
        x = [constants.FLOAT32_MISSING]
6✔
366
    return x[0]
6✔
367

368

369
def sanitise_value_int_scalar(shape, value):
6✔
370
    x = value
6✔
371
    if value is None:
6✔
372
        x = [constants.INT_MISSING]
6✔
373
    else:
374
        x = sanitise_int_array(value, ndmin=1, dtype=np.int32)
6✔
375
    return x[0]
6✔
376

377

378
def sanitise_value_string_scalar(shape, value):
6✔
379
    if value is None:
6✔
380
        return "."
6✔
381
    else:
382
        return value[0]
6✔
383

384

385
def sanitise_value_string_1d(shape, value):
6✔
386
    if value is None:
6✔
387
        return np.full(shape, ".", dtype="O")
6✔
388
    else:
389
        value = drop_empty_second_dim(value)
6✔
390
        result = np.full(shape, "", dtype=value.dtype)
6✔
391
        result[: value.shape[0]] = value
6✔
392
        return result
6✔
393

394

395
def sanitise_value_string_2d(shape, value):
6✔
396
    if value is None:
6✔
397
        return np.full(shape, ".", dtype="O")
6✔
398
    else:
399
        result = np.full(shape, "", dtype="O")
6✔
400
        if value.ndim == 2:
6✔
401
            result[: value.shape[0], : value.shape[1]] = value
6✔
402
        else:
403
            # Convert 1D array into 2D with appropriate shape
404
            for k, val in enumerate(value):
6✔
405
                result[k, : len(val)] = val
6✔
406
        return result
6✔
407

408

409
def drop_empty_second_dim(value):
6✔
410
    assert len(value.shape) == 1 or value.shape[1] == 1
6✔
411
    if len(value.shape) == 2 and value.shape[1] == 1:
6✔
412
        value = value[..., 0]
6✔
413
    return value
6✔
414

415

416
def sanitise_value_float_1d(shape, value):
6✔
417
    if value is None:
6✔
418
        return np.full(shape, constants.FLOAT32_MISSING)
6✔
419
    else:
420
        value = np.array(value, ndmin=1, dtype=np.float32, copy=True)
6✔
421
        # numpy will map None values to Nan, but we need a
422
        # specific NaN
423
        value[np.isnan(value)] = constants.FLOAT32_MISSING
6✔
424
        value = drop_empty_second_dim(value)
6✔
425
        result = np.full(shape, constants.FLOAT32_FILL, dtype=np.float32)
6✔
426
        result[: value.shape[0]] = value
6✔
427
        return result
6✔
428

429

430
def sanitise_value_float_2d(shape, value):
6✔
431
    if value is None:
6✔
432
        return np.full(shape, constants.FLOAT32_MISSING)
6✔
433
    else:
434
        value = np.array(value, ndmin=2, dtype=np.float32, copy=True)
6✔
435
        result = np.full(shape, constants.FLOAT32_FILL, dtype=np.float32)
6✔
436
        result[:, : value.shape[1]] = value
6✔
437
        return result
6✔
438

439

440
def sanitise_int_array(value, ndmin, dtype):
6✔
441
    if isinstance(value, tuple):
6✔
442
        value = [
×
443
            constants.VCF_INT_MISSING if x is None else x for x in value
444
        ]  # NEEDS TEST
445
    value = np.array(value, ndmin=ndmin, copy=True)
6✔
446
    value[value == constants.VCF_INT_MISSING] = -1
6✔
447
    value[value == constants.VCF_INT_FILL] = -2
6✔
448
    # TODO watch out for clipping here!
449
    return value.astype(dtype)
6✔
450

451

452
def sanitise_value_int_1d(shape, value):
6✔
453
    if value is None:
6✔
454
        return np.full(shape, -1)
6✔
455
    else:
456
        value = sanitise_int_array(value, 1, np.int32)
6✔
457
        value = drop_empty_second_dim(value)
6✔
458
        result = np.full(shape, -2, dtype=np.int32)
6✔
459
        result[: value.shape[0]] = value
6✔
460
        return result
6✔
461

462

463
def sanitise_value_int_2d(shape, value):
6✔
464
    if value is None:
6✔
465
        return np.full(shape, -1)
6✔
466
    else:
467
        value = sanitise_int_array(value, 2, np.int32)
6✔
468
        result = np.full(shape, -2, dtype=np.int32)
6✔
469
        result[:, : value.shape[1]] = value
6✔
470
        return result
6✔
471

472

473
missing_value_map = {
6✔
474
    "Integer": constants.INT_MISSING,
475
    "Float": constants.FLOAT32_MISSING,
476
    "String": constants.STR_MISSING,
477
    "Character": constants.STR_MISSING,
478
    "Flag": False,
479
}
480

481

482
class VcfValueTransformer:
6✔
483
    """
484
    Transform VCF values into the stored intermediate format used
485
    in the IntermediateColumnarFormat, and update field summaries.
486
    """
487

488
    def __init__(self, field, num_samples):
6✔
489
        self.field = field
6✔
490
        self.num_samples = num_samples
6✔
491
        self.dimension = 1
6✔
492
        if field.category == "FORMAT":
6✔
493
            self.dimension = 2
6✔
494
        self.missing = missing_value_map[field.vcf_type]
6✔
495

496
    @staticmethod
6✔
497
    def factory(field, num_samples):
6✔
498
        if field.vcf_type in ("Integer", "Flag"):
6✔
499
            return IntegerValueTransformer(field, num_samples)
6✔
500
        if field.vcf_type == "Float":
6✔
501
            return FloatValueTransformer(field, num_samples)
6✔
502
        if field.name in ["REF", "FILTERS", "ALT", "ID", "CHROM"]:
6✔
503
            return SplitStringValueTransformer(field, num_samples)
6✔
504
        return StringValueTransformer(field, num_samples)
6✔
505

506
    def transform(self, vcf_value):
6✔
507
        if isinstance(vcf_value, tuple):
6✔
508
            vcf_value = [self.missing if v is None else v for v in vcf_value]
6✔
509
        value = np.array(vcf_value, ndmin=self.dimension, copy=True)
6✔
510
        return value
6✔
511

512
    def transform_and_update_bounds(self, vcf_value):
6✔
513
        if vcf_value is None:
6✔
514
            return None
6✔
515
        # print(self, self.field.full_name, "T", vcf_value)
516
        value = self.transform(vcf_value)
6✔
517
        self.update_bounds(value)
6✔
518
        return value
6✔
519

520

521
class IntegerValueTransformer(VcfValueTransformer):
6✔
522
    def update_bounds(self, value):
6✔
523
        summary = self.field.summary
6✔
524
        # Mask out missing and fill values
525
        # print(value)
526
        a = value[value >= constants.MIN_INT_VALUE]
6✔
527
        if a.size > 0:
6✔
528
            summary.max_value = int(max(summary.max_value, np.max(a)))
6✔
529
            summary.min_value = int(min(summary.min_value, np.min(a)))
6✔
530
        number = value.shape[-1]
6✔
531
        summary.max_number = max(summary.max_number, number)
6✔
532

533

534
class FloatValueTransformer(VcfValueTransformer):
6✔
535
    def update_bounds(self, value):
6✔
536
        summary = self.field.summary
6✔
537
        summary.max_value = float(max(summary.max_value, np.max(value)))
6✔
538
        summary.min_value = float(min(summary.min_value, np.min(value)))
6✔
539
        number = value.shape[-1]
6✔
540
        summary.max_number = max(summary.max_number, number)
6✔
541

542

543
class StringValueTransformer(VcfValueTransformer):
6✔
544
    def update_bounds(self, value):
6✔
545
        summary = self.field.summary
6✔
546
        if self.field.category == "FORMAT":
6✔
547
            number = max(len(v) for v in value)
6✔
548
        else:
549
            number = value.shape[-1]
6✔
550
        # TODO would be nice to report string lengths, but not
551
        # really necessary.
552
        summary.max_number = max(summary.max_number, number)
6✔
553

554
    def transform(self, vcf_value):
6✔
555
        if self.dimension == 1:
6✔
556
            value = np.array(list(vcf_value.split(",")))
6✔
557
        else:
558
            # TODO can we make this faster??
559
            value = np.array([v.split(",") for v in vcf_value], dtype="O")
6✔
560
            # print("HERE", vcf_value, value)
561
            # for v in vcf_value:
562
            #     print("\t", type(v), len(v), v.split(","))
563
        # print("S: ", self.dimension, ":", value.shape, value)
564
        return value
6✔
565

566

567
class SplitStringValueTransformer(StringValueTransformer):
6✔
568
    def transform(self, vcf_value):
6✔
569
        if vcf_value is None:
6✔
570
            return self.missing_value  # NEEDS TEST
×
571
        assert self.dimension == 1
6✔
572
        return np.array(vcf_value, ndmin=1, dtype="str")
6✔
573

574

575
def get_vcf_field_path(base_path, vcf_field):
6✔
576
    if vcf_field.category == "fixed":
6✔
577
        return base_path / vcf_field.name
6✔
578
    return base_path / vcf_field.category / vcf_field.name
6✔
579

580

581
class IntermediateColumnarFormatField:
6✔
582
    def __init__(self, icf, vcf_field):
6✔
583
        self.vcf_field = vcf_field
6✔
584
        self.path = get_vcf_field_path(icf.path, vcf_field)
6✔
585
        self.compressor = icf.compressor
6✔
586
        self.num_partitions = icf.num_partitions
6✔
587
        self.num_records = icf.num_records
6✔
588
        self.partition_record_index = icf.partition_record_index
6✔
589
        # A map of partition id to the cumulative number of records
590
        # in chunks within that partition
591
        self._chunk_record_index = {}
6✔
592

593
    @property
6✔
594
    def name(self):
6✔
595
        return self.vcf_field.full_name
6✔
596

597
    def partition_path(self, partition_id):
6✔
598
        return self.path / f"p{partition_id}"
6✔
599

600
    def __repr__(self):
6✔
601
        partition_chunks = [self.num_chunks(j) for j in range(self.num_partitions)]
6✔
602
        return (
6✔
603
            f"IntermediateColumnarFormatField(name={self.name}, "
604
            f"partition_chunks={partition_chunks}, "
605
            f"path={self.path})"
606
        )
607

608
    def num_chunks(self, partition_id):
6✔
609
        return len(self.chunk_record_index(partition_id)) - 1
6✔
610

611
    def chunk_record_index(self, partition_id):
6✔
612
        if partition_id not in self._chunk_record_index:
6✔
613
            index_path = self.partition_path(partition_id) / "chunk_index"
6✔
614
            with open(index_path, "rb") as f:
6✔
615
                a = pickle.load(f)
6✔
616
            assert len(a) > 1
6✔
617
            assert a[0] == 0
6✔
618
            self._chunk_record_index[partition_id] = a
6✔
619
        return self._chunk_record_index[partition_id]
6✔
620

621
    def read_chunk(self, path):
6✔
622
        with open(path, "rb") as f:
6✔
623
            pkl = self.compressor.decode(f.read())
6✔
624
        return pickle.loads(pkl)
6✔
625

626
    def chunk_num_records(self, partition_id):
6✔
627
        return np.diff(self.chunk_record_index(partition_id))
6✔
628

629
    def chunks(self, partition_id, start_chunk=0):
6✔
630
        partition_path = self.partition_path(partition_id)
6✔
631
        chunk_cumulative_records = self.chunk_record_index(partition_id)
6✔
632
        chunk_num_records = np.diff(chunk_cumulative_records)
6✔
633
        for count, cumulative in zip(
6✔
634
            chunk_num_records[start_chunk:], chunk_cumulative_records[start_chunk + 1 :]
635
        ):
636
            path = partition_path / f"{cumulative}"
6✔
637
            chunk = self.read_chunk(path)
6✔
638
            if len(chunk) != count:
6✔
639
                raise ValueError(f"Corruption detected in chunk: {path}")
6✔
640
            yield chunk
6✔
641

642
    def iter_values(self, start=None, stop=None):
6✔
643
        start = 0 if start is None else start
6✔
644
        stop = self.num_records if stop is None else stop
6✔
645
        start_partition = (
6✔
646
            np.searchsorted(self.partition_record_index, start, side="right") - 1
647
        )
648
        offset = self.partition_record_index[start_partition]
6✔
649
        assert offset <= start
6✔
650
        chunk_offset = start - offset
6✔
651

652
        chunk_record_index = self.chunk_record_index(start_partition)
6✔
653
        start_chunk = (
6✔
654
            np.searchsorted(chunk_record_index, chunk_offset, side="right") - 1
655
        )
656
        record_id = offset + chunk_record_index[start_chunk]
6✔
657
        assert record_id <= start
6✔
658
        logger.debug(
6✔
659
            f"Read {self.vcf_field.full_name} slice [{start}:{stop}]:"
660
            f"p_start={start_partition}, c_start={start_chunk}, r_start={record_id}"
661
        )
662
        for chunk in self.chunks(start_partition, start_chunk):
6✔
663
            for record in chunk:
6✔
664
                if record_id == stop:
6✔
665
                    return
6✔
666
                if record_id >= start:
6✔
667
                    yield record
6✔
668
                record_id += 1
6✔
669
        assert record_id > start
6✔
670
        for partition_id in range(start_partition + 1, self.num_partitions):
6✔
671
            for chunk in self.chunks(partition_id):
6✔
672
                for record in chunk:
6✔
673
                    if record_id == stop:
6✔
674
                        return
6✔
675
                    yield record
6✔
676
                    record_id += 1
6✔
677

678
    # Note: this involves some computation so should arguably be a method,
679
    # but making a property for consistency with xarray etc
680
    @property
6✔
681
    def values(self):
6✔
682
        ret = [None] * self.num_records
6✔
683
        j = 0
6✔
684
        for partition_id in range(self.num_partitions):
6✔
685
            for chunk in self.chunks(partition_id):
6✔
686
                for record in chunk:
6✔
687
                    ret[j] = record
6✔
688
                    j += 1
6✔
689
        assert j == self.num_records
6✔
690
        return ret
6✔
691

692
    def sanitiser_factory(self, shape):
6✔
693
        assert len(shape) <= 2
6✔
694
        if self.vcf_field.vcf_type == "Flag":
6✔
695
            assert len(shape) == 0
6✔
696
            return partial(sanitise_value_bool, shape)
6✔
697
        elif self.vcf_field.vcf_type == "Float":
6✔
698
            if len(shape) == 0:
6✔
699
                return partial(sanitise_value_float_scalar, shape)
6✔
700
            elif len(shape) == 1:
6✔
701
                return partial(sanitise_value_float_1d, shape)
6✔
702
            else:
703
                return partial(sanitise_value_float_2d, shape)
6✔
704
        elif self.vcf_field.vcf_type == "Integer":
6✔
705
            if len(shape) == 0:
6✔
706
                return partial(sanitise_value_int_scalar, shape)
6✔
707
            elif len(shape) == 1:
6✔
708
                return partial(sanitise_value_int_1d, shape)
6✔
709
            else:
710
                return partial(sanitise_value_int_2d, shape)
6✔
711
        else:
712
            assert self.vcf_field.vcf_type in ("String", "Character")
6✔
713
            if len(shape) == 0:
6✔
714
                return partial(sanitise_value_string_scalar, shape)
6✔
715
            elif len(shape) == 1:
6✔
716
                return partial(sanitise_value_string_1d, shape)
6✔
717
            else:
718
                return partial(sanitise_value_string_2d, shape)
6✔
719

720

721
@dataclasses.dataclass
6✔
722
class IcfFieldWriter:
6✔
723
    vcf_field: VcfField
6✔
724
    path: pathlib.Path
6✔
725
    transformer: VcfValueTransformer
6✔
726
    compressor: Any
6✔
727
    max_buffered_bytes: int
6✔
728
    buff: list[Any] = dataclasses.field(default_factory=list)
6✔
729
    buffered_bytes: int = 0
6✔
730
    chunk_index: list[int] = dataclasses.field(default_factory=lambda: [0])
6✔
731
    num_records: int = 0
6✔
732

733
    def append(self, val):
6✔
734
        val = self.transformer.transform_and_update_bounds(val)
6✔
735
        assert val is None or isinstance(val, np.ndarray)
6✔
736
        self.buff.append(val)
6✔
737
        val_bytes = sys.getsizeof(val)
6✔
738
        self.buffered_bytes += val_bytes
6✔
739
        self.num_records += 1
6✔
740
        if self.buffered_bytes >= self.max_buffered_bytes:
6✔
741
            logger.debug(
6✔
742
                f"Flush {self.path} buffered={self.buffered_bytes} "
743
                f"max={self.max_buffered_bytes}"
744
            )
745
            self.write_chunk()
6✔
746
            self.buff.clear()
6✔
747
            self.buffered_bytes = 0
6✔
748

749
    def write_chunk(self):
6✔
750
        # Update index
751
        self.chunk_index.append(self.num_records)
6✔
752
        path = self.path / f"{self.num_records}"
6✔
753
        logger.debug(f"Start write: {path}")
6✔
754
        pkl = pickle.dumps(self.buff)
6✔
755
        compressed = self.compressor.encode(pkl)
6✔
756
        with open(path, "wb") as f:
6✔
757
            f.write(compressed)
6✔
758

759
        # Update the summary
760
        self.vcf_field.summary.num_chunks += 1
6✔
761
        self.vcf_field.summary.compressed_size += len(compressed)
6✔
762
        self.vcf_field.summary.uncompressed_size += self.buffered_bytes
6✔
763
        logger.debug(f"Finish write: {path}")
6✔
764

765
    def flush(self):
6✔
766
        logger.debug(
6✔
767
            f"Flush {self.path} records={len(self.buff)} buffered={self.buffered_bytes}"
768
        )
769
        if len(self.buff) > 0:
6✔
770
            self.write_chunk()
6✔
771
        with open(self.path / "chunk_index", "wb") as f:
6✔
772
            a = np.array(self.chunk_index, dtype=int)
6✔
773
            pickle.dump(a, f)
6✔
774

775

776
class IcfPartitionWriter(contextlib.AbstractContextManager):
6✔
777
    """
778
    Writes the data for a IntermediateColumnarFormat partition.
779
    """
780

781
    def __init__(
6✔
782
        self,
783
        icf_metadata,
784
        out_path,
785
        partition_index,
786
    ):
787
        self.partition_index = partition_index
6✔
788
        # chunk_size is in megabytes
789
        max_buffered_bytes = icf_metadata.column_chunk_size * 2**20
6✔
790
        assert max_buffered_bytes > 0
6✔
791
        compressor = numcodecs.get_codec(icf_metadata.compressor)
6✔
792

793
        self.field_writers = {}
6✔
794
        num_samples = len(icf_metadata.samples)
6✔
795
        for vcf_field in icf_metadata.fields:
6✔
796
            field_path = get_vcf_field_path(out_path, vcf_field)
6✔
797
            field_partition_path = field_path / f"p{partition_index}"
6✔
798
            # Should be robust to running explode_partition twice.
799
            field_partition_path.mkdir(exist_ok=True)
6✔
800
            transformer = VcfValueTransformer.factory(vcf_field, num_samples)
6✔
801
            self.field_writers[vcf_field.full_name] = IcfFieldWriter(
6✔
802
                vcf_field,
803
                field_partition_path,
804
                transformer,
805
                compressor,
806
                max_buffered_bytes,
807
            )
808

809
    @property
6✔
810
    def field_summaries(self):
6✔
811
        return {
6✔
812
            name: field.vcf_field.summary for name, field in self.field_writers.items()
813
        }
814

815
    def append(self, name, value):
6✔
816
        self.field_writers[name].append(value)
6✔
817

818
    def __exit__(self, exc_type, exc_val, exc_tb):
6✔
819
        if exc_type is None:
6✔
820
            for field in self.field_writers.values():
6✔
821
                field.flush()
6✔
822
        return False
6✔
823

824

825
class IntermediateColumnarFormat(collections.abc.Mapping):
6✔
826
    def __init__(self, path):
6✔
827
        self.path = pathlib.Path(path)
6✔
828
        # TODO raise a more informative error here telling people this
829
        # directory is either a WIP or the wrong format.
830
        with open(self.path / "metadata.json") as f:
6✔
831
            self.metadata = IcfMetadata.fromdict(json.load(f))
6✔
832
        with open(self.path / "header.txt") as f:
6✔
833
            self.vcf_header = f.read()
6✔
834
        self.compressor = numcodecs.get_codec(self.metadata.compressor)
6✔
835
        self.fields = {}
6✔
836
        partition_num_records = [
6✔
837
            partition.num_records for partition in self.metadata.partitions
838
        ]
839
        # Allow us to find which partition a given record is in
840
        self.partition_record_index = np.cumsum([0, *partition_num_records])
6✔
841
        for field in self.metadata.fields:
6✔
842
            self.fields[field.full_name] = IntermediateColumnarFormatField(self, field)
6✔
843
        logger.info(
6✔
844
            f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, "
845
            f"records={self.num_records}, fields={self.num_fields})"
846
        )
847

848
    def __repr__(self):
6✔
849
        return (
6✔
850
            f"IntermediateColumnarFormat(fields={len(self)}, "
851
            f"partitions={self.num_partitions}, "
852
            f"records={self.num_records}, path={self.path})"
853
        )
854

855
    def __getitem__(self, key):
6✔
856
        return self.fields[key]
6✔
857

858
    def __iter__(self):
6✔
859
        return iter(self.fields)
6✔
860

861
    def __len__(self):
6✔
862
        return len(self.fields)
6✔
863

864
    def summary_table(self):
6✔
865
        data = []
6✔
866
        for name, icf_field in self.fields.items():
6✔
867
            summary = icf_field.vcf_field.summary
6✔
868
            d = {
6✔
869
                "name": name,
870
                "type": icf_field.vcf_field.vcf_type,
871
                "chunks": summary.num_chunks,
872
                "size": core.display_size(summary.uncompressed_size),
873
                "compressed": core.display_size(summary.compressed_size),
874
                "max_n": summary.max_number,
875
                "min_val": core.display_number(summary.min_value),
876
                "max_val": core.display_number(summary.max_value),
877
            }
878

879
            data.append(d)
6✔
880
        return data
6✔
881

882
    @property
6✔
883
    def num_records(self):
6✔
884
        return self.metadata.num_records
6✔
885

886
    @property
6✔
887
    def num_partitions(self):
6✔
888
        return len(self.metadata.partitions)
6✔
889

890
    @property
6✔
891
    def samples(self):
6✔
892
        return [sample.id for sample in self.metadata.samples]
6✔
893

894
    @property
6✔
895
    def num_samples(self):
6✔
896
        return len(self.metadata.samples)
6✔
897

898
    @property
6✔
899
    def num_fields(self):
6✔
900
        return len(self.fields)
6✔
901

902
    @property
6✔
903
    def root_attrs(self):
6✔
904
        return {
6✔
905
            "vcf_header": self.vcf_header,
906
        }
907

908
    def iter_alleles(self, start, stop, num_alleles):
6✔
909
        ref_field = self.fields["REF"]
6✔
910
        alt_field = self.fields["ALT"]
6✔
911

912
        for ref, alt in zip(
6✔
913
            ref_field.iter_values(start, stop),
914
            alt_field.iter_values(start, stop),
915
        ):
916
            alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
6✔
917
            alleles[0] = ref[0]
6✔
918
            alleles[1 : 1 + len(alt)] = alt
6✔
919
            yield alleles
6✔
920

921
    def iter_id(self, start, stop):
6✔
922
        for value in self.fields["ID"].iter_values(start, stop):
6✔
923
            if value is not None:
6✔
924
                yield value[0]
6✔
925
            else:
926
                yield None
6✔
927

928
    def iter_filters(self, start, stop):
6✔
929
        source_field = self.fields["FILTERS"]
6✔
930
        lookup = {filt.id: index for index, filt in enumerate(self.metadata.filters)}
6✔
931

932
        for filter_values in source_field.iter_values(start, stop):
6✔
933
            filters = np.zeros(len(self.metadata.filters), dtype=bool)
6✔
934
            if filter_values is not None:
6✔
935
                for filter_id in filter_values:
6✔
936
                    try:
6✔
937
                        filters[lookup[filter_id]] = True
6✔
NEW
938
                    except KeyError:
×
NEW
939
                        raise ValueError(
×
940
                            f"Filter '{filter_id}' was not defined in the header."
941
                        ) from None
942
            yield filters
6✔
943

944
    def iter_contig(self, start, stop):
6✔
945
        source_field = self.fields["CHROM"]
6✔
946
        lookup = {
6✔
947
            contig.id: index for index, contig in enumerate(self.metadata.contigs)
948
        }
949

950
        for value in source_field.iter_values(start, stop):
6✔
951
            # Note: because we are using the indexes to define the lookups
952
            # and we always have an index, it seems that we the contig lookup
953
            # will always succeed. However, if anyone ever does hit a KeyError
954
            # here, please do open an issue with a reproducible example!
955
            yield lookup[value[0]]
6✔
956

957
    def iter_field(self, field_name, shape, start, stop):
6✔
958
        source_field = self.fields[field_name]
6✔
959
        sanitiser = source_field.sanitiser_factory(shape)
6✔
960
        for value in source_field.iter_values(start, stop):
6✔
961
            yield sanitiser(value)
6✔
962

963
    def iter_genotypes(self, shape, start, stop):
6✔
964
        source_field = self.fields["FORMAT/GT"]
6✔
965
        for value in source_field.iter_values(start, stop):
6✔
966
            genotypes = value[:, :-1] if value is not None else None
6✔
967
            phased = value[:, -1] if value is not None else None
6✔
968
            sanitised_genotypes = sanitise_value_int_2d(shape, genotypes)
6✔
969
            sanitised_phased = sanitise_value_int_1d(shape[:-1], phased)
6✔
970
            yield sanitised_genotypes, sanitised_phased
6✔
971

972
    def generate_schema(
6✔
973
        self, variants_chunk_size=None, samples_chunk_size=None, local_alleles=None
974
    ):
975
        # Import schema here to avoid circular import
976
        from bio2zarr import schema
6✔
977

978
        m = self.num_records
6✔
979
        n = self.num_samples
6✔
980
        if samples_chunk_size is None:
6✔
981
            samples_chunk_size = 10_000
6✔
982
        if variants_chunk_size is None:
6✔
983
            variants_chunk_size = 1000
6✔
984
        if local_alleles is None:
6✔
985
            local_alleles = False
6✔
986
        logger.info(
6✔
987
            f"Generating schema with chunks={variants_chunk_size, samples_chunk_size}"
988
        )
989

990
        def spec_from_field(field, array_name=None):
6✔
991
            return schema.ZarrArraySpec.from_field(
6✔
992
                field,
993
                num_samples=n,
994
                num_variants=m,
995
                samples_chunk_size=samples_chunk_size,
996
                variants_chunk_size=variants_chunk_size,
997
                array_name=array_name,
998
            )
999

1000
        def fixed_field_spec(
6✔
1001
            name,
1002
            dtype,
1003
            vcf_field=None,
1004
            shape=(m,),
1005
            dimensions=("variants",),
1006
            chunks=None,
1007
        ):
1008
            return schema.ZarrArraySpec.new(
6✔
1009
                vcf_field=vcf_field,
1010
                name=name,
1011
                dtype=dtype,
1012
                shape=shape,
1013
                description="",
1014
                dimensions=dimensions,
1015
                chunks=chunks or [variants_chunk_size],
1016
            )
1017

1018
        alt_field = self.fields["ALT"]
6✔
1019
        max_alleles = alt_field.vcf_field.summary.max_number + 1
6✔
1020

1021
        array_specs = [
6✔
1022
            fixed_field_spec(
1023
                name="variant_contig",
1024
                dtype=core.min_int_dtype(0, self.metadata.num_contigs),
1025
            ),
1026
            fixed_field_spec(
1027
                name="variant_filter",
1028
                dtype="bool",
1029
                shape=(m, self.metadata.num_filters),
1030
                dimensions=["variants", "filters"],
1031
                chunks=(variants_chunk_size, self.metadata.num_filters),
1032
            ),
1033
            fixed_field_spec(
1034
                name="variant_allele",
1035
                dtype="O",
1036
                shape=(m, max_alleles),
1037
                dimensions=["variants", "alleles"],
1038
                chunks=(variants_chunk_size, max_alleles),
1039
            ),
1040
            fixed_field_spec(
1041
                name="variant_id",
1042
                dtype="O",
1043
            ),
1044
            fixed_field_spec(
1045
                name="variant_id_mask",
1046
                dtype="bool",
1047
            ),
1048
        ]
1049
        name_map = {field.full_name: field for field in self.metadata.fields}
6✔
1050

1051
        # Only three of the fixed fields have a direct one-to-one mapping.
1052
        array_specs.extend(
6✔
1053
            [
1054
                spec_from_field(name_map["QUAL"], array_name="variant_quality"),
1055
                spec_from_field(name_map["POS"], array_name="variant_position"),
1056
                spec_from_field(name_map["rlen"], array_name="variant_length"),
1057
            ]
1058
        )
1059
        array_specs.extend(
6✔
1060
            [spec_from_field(field) for field in self.metadata.info_fields]
1061
        )
1062

1063
        gt_field = None
6✔
1064
        for field in self.metadata.format_fields:
6✔
1065
            if field.name == "GT":
6✔
1066
                gt_field = field
6✔
1067
                continue
6✔
1068
            array_specs.append(spec_from_field(field))
6✔
1069

1070
        if gt_field is not None and n > 0:
6✔
1071
            ploidy = max(gt_field.summary.max_number - 1, 1)
6✔
1072
            shape = [m, n]
6✔
1073
            chunks = [variants_chunk_size, samples_chunk_size]
6✔
1074
            dimensions = ["variants", "samples"]
6✔
1075
            array_specs.append(
6✔
1076
                schema.ZarrArraySpec.new(
1077
                    vcf_field=None,
1078
                    name="call_genotype_phased",
1079
                    dtype="bool",
1080
                    shape=list(shape),
1081
                    chunks=list(chunks),
1082
                    dimensions=list(dimensions),
1083
                    description="",
1084
                )
1085
            )
1086
            shape += [ploidy]
6✔
1087
            chunks += [ploidy]
6✔
1088
            dimensions += ["ploidy"]
6✔
1089
            array_specs.append(
6✔
1090
                schema.ZarrArraySpec.new(
1091
                    vcf_field=None,
1092
                    name="call_genotype",
1093
                    dtype=gt_field.smallest_dtype(),
1094
                    shape=list(shape),
1095
                    chunks=list(chunks),
1096
                    dimensions=list(dimensions),
1097
                    description="",
1098
                )
1099
            )
1100
            array_specs.append(
6✔
1101
                schema.ZarrArraySpec.new(
1102
                    vcf_field=None,
1103
                    name="call_genotype_mask",
1104
                    dtype="bool",
1105
                    shape=list(shape),
1106
                    chunks=list(chunks),
1107
                    dimensions=list(dimensions),
1108
                    description="",
1109
                )
1110
            )
1111

1112
        if local_alleles:
6✔
1113
            from bio2zarr.vcf2zarr.vcz import convert_local_allele_field_types
6✔
1114

1115
            array_specs = convert_local_allele_field_types(array_specs)
6✔
1116

1117
        return schema.VcfZarrSchema(
6✔
1118
            format_version=schema.ZARR_SCHEMA_FORMAT_VERSION,
1119
            samples_chunk_size=samples_chunk_size,
1120
            variants_chunk_size=variants_chunk_size,
1121
            fields=array_specs,
1122
            samples=self.metadata.samples,
1123
            contigs=self.metadata.contigs,
1124
            filters=self.metadata.filters,
1125
        )
1126

1127

1128
@dataclasses.dataclass
6✔
1129
class IcfPartitionMetadata(core.JsonDataclass):
6✔
1130
    num_records: int
6✔
1131
    last_position: int
6✔
1132
    field_summaries: dict
6✔
1133

1134
    @staticmethod
6✔
1135
    def fromdict(d):
6✔
1136
        md = IcfPartitionMetadata(**d)
6✔
1137
        for k, v in md.field_summaries.items():
6✔
1138
            md.field_summaries[k] = VcfFieldSummary.fromdict(v)
6✔
1139
        return md
6✔
1140

1141

1142
def check_overlapping_partitions(partitions):
6✔
1143
    for i in range(1, len(partitions)):
6✔
1144
        prev_region = partitions[i - 1].region
6✔
1145
        current_region = partitions[i].region
6✔
1146
        if prev_region.contig == current_region.contig:
6✔
1147
            assert prev_region.end is not None
6✔
1148
            # Regions are *inclusive*
1149
            if prev_region.end >= current_region.start:
6✔
1150
                raise ValueError(
6✔
1151
                    f"Overlapping VCF regions in partitions {i - 1} and {i}: "
1152
                    f"{prev_region} and {current_region}"
1153
                )
1154

1155

1156
def check_field_clobbering(icf_metadata):
6✔
1157
    info_field_names = set(field.name for field in icf_metadata.info_fields)
6✔
1158
    fixed_variant_fields = set(
6✔
1159
        ["contig", "id", "id_mask", "position", "allele", "filter", "quality"]
1160
    )
1161
    intersection = info_field_names & fixed_variant_fields
6✔
1162
    if len(intersection) > 0:
6✔
1163
        raise ValueError(
6✔
1164
            f"INFO field name(s) clashing with VCF Zarr spec: {intersection}"
1165
        )
1166

1167
    format_field_names = set(field.name for field in icf_metadata.format_fields)
6✔
1168
    fixed_variant_fields = set(["genotype", "genotype_phased", "genotype_mask"])
6✔
1169
    intersection = format_field_names & fixed_variant_fields
6✔
1170
    if len(intersection) > 0:
6✔
1171
        raise ValueError(
6✔
1172
            f"FORMAT field name(s) clashing with VCF Zarr spec: {intersection}"
1173
        )
1174

1175

1176
@dataclasses.dataclass
6✔
1177
class IcfWriteSummary(core.JsonDataclass):
6✔
1178
    num_partitions: int
6✔
1179
    num_samples: int
6✔
1180
    num_variants: int
6✔
1181

1182

1183
class IntermediateColumnarFormatWriter:
6✔
1184
    def __init__(self, path):
6✔
1185
        self.path = pathlib.Path(path)
6✔
1186
        self.wip_path = self.path / "wip"
6✔
1187
        self.metadata = None
6✔
1188

1189
    @property
6✔
1190
    def num_partitions(self):
6✔
1191
        return len(self.metadata.partitions)
6✔
1192

1193
    def init(
6✔
1194
        self,
1195
        vcfs,
1196
        *,
1197
        column_chunk_size=16,
1198
        worker_processes=1,
1199
        target_num_partitions=None,
1200
        show_progress=False,
1201
        compressor=None,
1202
    ):
1203
        if self.path.exists():
6✔
1204
            raise ValueError(f"ICF path already exists: {self.path}")
×
1205
        if compressor is None:
6✔
1206
            compressor = ICF_DEFAULT_COMPRESSOR
6✔
1207
        vcfs = [pathlib.Path(vcf) for vcf in vcfs]
6✔
1208
        target_num_partitions = max(target_num_partitions, len(vcfs))
6✔
1209

1210
        # TODO move scan_vcfs into this class
1211
        icf_metadata, header = scan_vcfs(
6✔
1212
            vcfs,
1213
            worker_processes=worker_processes,
1214
            show_progress=show_progress,
1215
            target_num_partitions=target_num_partitions,
1216
        )
1217
        check_field_clobbering(icf_metadata)
6✔
1218
        self.metadata = icf_metadata
6✔
1219
        self.metadata.format_version = ICF_METADATA_FORMAT_VERSION
6✔
1220
        self.metadata.compressor = compressor.get_config()
6✔
1221
        self.metadata.column_chunk_size = column_chunk_size
6✔
1222
        # Bare minimum here for provenance - would be nice to include versions of key
1223
        # dependencies as well.
1224
        self.metadata.provenance = {"source": f"bio2zarr-{provenance.__version__}"}
6✔
1225

1226
        self.mkdirs()
6✔
1227

1228
        # Note: this is needed for the current version of the vcfzarr spec, but it's
1229
        # probably going to be dropped.
1230
        # https://github.com/pystatgen/vcf-zarr-spec/issues/15
1231
        # May be useful to keep lying around still though?
1232
        logger.info("Writing VCF header")
6✔
1233
        with open(self.path / "header.txt", "w") as f:
6✔
1234
            f.write(header)
6✔
1235

1236
        logger.info("Writing WIP metadata")
6✔
1237
        with open(self.wip_path / "metadata.json", "w") as f:
6✔
1238
            json.dump(self.metadata.asdict(), f, indent=4)
6✔
1239
        return IcfWriteSummary(
6✔
1240
            num_partitions=self.num_partitions,
1241
            num_variants=icf_metadata.num_records,
1242
            num_samples=icf_metadata.num_samples,
1243
        )
1244

1245
    def mkdirs(self):
6✔
1246
        num_dirs = len(self.metadata.fields)
6✔
1247
        logger.info(f"Creating {num_dirs} field directories")
6✔
1248
        self.path.mkdir()
6✔
1249
        self.wip_path.mkdir()
6✔
1250
        for field in self.metadata.fields:
6✔
1251
            field_path = get_vcf_field_path(self.path, field)
6✔
1252
            field_path.mkdir(parents=True)
6✔
1253

1254
    def load_partition_summaries(self):
6✔
1255
        summaries = []
6✔
1256
        not_found = []
6✔
1257
        for j in range(self.num_partitions):
6✔
1258
            try:
6✔
1259
                with open(self.wip_path / f"p{j}.json") as f:
6✔
1260
                    summaries.append(IcfPartitionMetadata.fromdict(json.load(f)))
6✔
1261
            except FileNotFoundError:
6✔
1262
                not_found.append(j)
6✔
1263
        if len(not_found) > 0:
6✔
1264
            raise FileNotFoundError(
6✔
1265
                f"Partition metadata not found for {len(not_found)}"
1266
                f" partitions: {not_found}"
1267
            )
1268
        return summaries
6✔
1269

1270
    def load_metadata(self):
6✔
1271
        if self.metadata is None:
6✔
1272
            with open(self.wip_path / "metadata.json") as f:
6✔
1273
                self.metadata = IcfMetadata.fromdict(json.load(f))
6✔
1274

1275
    def process_partition(self, partition_index):
6✔
1276
        self.load_metadata()
6✔
1277
        summary_path = self.wip_path / f"p{partition_index}.json"
6✔
1278
        # If someone is rewriting a summary path (for whatever reason), make sure it
1279
        # doesn't look like it's already been completed.
1280
        # NOTE to do this properly we probably need to take a lock on this file - but
1281
        # this simple approach will catch the vast majority of problems.
1282
        if summary_path.exists():
6✔
1283
            summary_path.unlink()
6✔
1284

1285
        partition = self.metadata.partitions[partition_index]
6✔
1286
        logger.info(
6✔
1287
            f"Start p{partition_index} {partition.vcf_path}__{partition.region}"
1288
        )
1289
        info_fields = self.metadata.info_fields
6✔
1290
        format_fields = []
6✔
1291
        has_gt = False
6✔
1292
        for field in self.metadata.format_fields:
6✔
1293
            if field.name == "GT":
6✔
1294
                has_gt = True
6✔
1295
            else:
1296
                format_fields.append(field)
6✔
1297

1298
        last_position = None
6✔
1299
        with IcfPartitionWriter(
6✔
1300
            self.metadata,
1301
            self.path,
1302
            partition_index,
1303
        ) as tcw:
1304
            with vcf_utils.VcfFile(partition.vcf_path) as vcf:
6✔
1305
                num_records = 0
6✔
1306
                for variant in vcf.variants(partition.region):
6✔
1307
                    num_records += 1
6✔
1308
                    last_position = variant.POS
6✔
1309
                    tcw.append("CHROM", variant.CHROM)
6✔
1310
                    tcw.append("POS", variant.POS)
6✔
1311
                    tcw.append("QUAL", variant.QUAL)
6✔
1312
                    tcw.append("ID", variant.ID)
6✔
1313
                    tcw.append("FILTERS", variant.FILTERS)
6✔
1314
                    tcw.append("REF", variant.REF)
6✔
1315
                    tcw.append("ALT", variant.ALT)
6✔
1316
                    tcw.append("rlen", variant.end - variant.start)
6✔
1317
                    for field in info_fields:
6✔
1318
                        tcw.append(field.full_name, variant.INFO.get(field.name, None))
6✔
1319
                    if has_gt:
6✔
1320
                        val = None
6✔
1321
                        if "GT" in variant.FORMAT and variant.genotype is not None:
6✔
1322
                            val = variant.genotype.array()
6✔
1323
                        tcw.append("FORMAT/GT", val)
6✔
1324
                    for field in format_fields:
6✔
1325
                        val = variant.format(field.name)
6✔
1326
                        tcw.append(field.full_name, val)
6✔
1327

1328
                    # Note: an issue with updating the progress per variant here like
1329
                    # this is that we get a significant pause at the end of the counter
1330
                    # while all the "small" fields get flushed. Possibly not much to be
1331
                    # done about it.
1332
                    core.update_progress(1)
6✔
1333
            logger.info(
6✔
1334
                f"Finished reading VCF for partition {partition_index}, "
1335
                f"flushing buffers"
1336
            )
1337

1338
        partition_metadata = IcfPartitionMetadata(
6✔
1339
            num_records=num_records,
1340
            last_position=last_position,
1341
            field_summaries=tcw.field_summaries,
1342
        )
1343
        with open(summary_path, "w") as f:
6✔
1344
            f.write(partition_metadata.asjson())
6✔
1345
        logger.info(
6✔
1346
            f"Finish p{partition_index} {partition.vcf_path}__{partition.region} "
1347
            f"{num_records} records last_pos={last_position}"
1348
        )
1349

1350
    def explode(self, *, worker_processes=1, show_progress=False):
6✔
1351
        self.load_metadata()
6✔
1352
        num_records = self.metadata.num_records
6✔
1353
        if np.isinf(num_records):
6✔
1354
            logger.warning(
6✔
1355
                "Total records unknown, cannot show progress; "
1356
                "reindex VCFs with bcftools index to fix"
1357
            )
1358
            num_records = None
6✔
1359
        num_fields = len(self.metadata.fields)
6✔
1360
        num_samples = len(self.metadata.samples)
6✔
1361
        logger.info(
6✔
1362
            f"Exploding fields={num_fields} samples={num_samples}; "
1363
            f"partitions={self.num_partitions} "
1364
            f"variants={'unknown' if num_records is None else num_records}"
1365
        )
1366
        progress_config = core.ProgressConfig(
6✔
1367
            total=num_records,
1368
            units="vars",
1369
            title="Explode",
1370
            show=show_progress,
1371
        )
1372
        with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
6✔
1373
            for j in range(self.num_partitions):
6✔
1374
                pwm.submit(self.process_partition, j)
6✔
1375

1376
    def explode_partition(self, partition):
6✔
1377
        self.load_metadata()
6✔
1378
        if partition < 0 or partition >= self.num_partitions:
6✔
1379
            raise ValueError("Partition index not in the valid range")
6✔
1380
        self.process_partition(partition)
6✔
1381

1382
    def finalise(self):
6✔
1383
        self.load_metadata()
6✔
1384
        partition_summaries = self.load_partition_summaries()
6✔
1385
        total_records = 0
6✔
1386
        for index, summary in enumerate(partition_summaries):
6✔
1387
            partition_records = summary.num_records
6✔
1388
            self.metadata.partitions[index].num_records = partition_records
6✔
1389
            self.metadata.partitions[index].region.end = summary.last_position
6✔
1390
            total_records += partition_records
6✔
1391
        if not np.isinf(self.metadata.num_records):
6✔
1392
            # Note: this is just telling us that there's a bug in the
1393
            # index based record counting code, but it doesn't actually
1394
            # matter much. We may want to just make this a warning if
1395
            # we hit regular problems.
1396
            assert total_records == self.metadata.num_records
6✔
1397
        self.metadata.num_records = total_records
6✔
1398

1399
        check_overlapping_partitions(self.metadata.partitions)
6✔
1400

1401
        for field in self.metadata.fields:
6✔
1402
            for summary in partition_summaries:
6✔
1403
                field.summary.update(summary.field_summaries[field.full_name])
6✔
1404

1405
        logger.info("Finalising metadata")
6✔
1406
        with open(self.path / "metadata.json", "w") as f:
6✔
1407
            f.write(self.metadata.asjson())
6✔
1408

1409
        logger.debug("Removing WIP directory")
6✔
1410
        shutil.rmtree(self.wip_path)
6✔
1411

1412

1413
def explode(
6✔
1414
    icf_path,
1415
    vcfs,
1416
    *,
1417
    column_chunk_size=16,
1418
    worker_processes=1,
1419
    show_progress=False,
1420
    compressor=None,
1421
):
1422
    writer = IntermediateColumnarFormatWriter(icf_path)
6✔
1423
    writer.init(
6✔
1424
        vcfs,
1425
        # Heuristic to get reasonable worker utilisation with lumpy partition sizing
1426
        target_num_partitions=max(1, worker_processes * 4),
1427
        worker_processes=worker_processes,
1428
        show_progress=show_progress,
1429
        column_chunk_size=column_chunk_size,
1430
        compressor=compressor,
1431
    )
1432
    writer.explode(worker_processes=worker_processes, show_progress=show_progress)
6✔
1433
    writer.finalise()
6✔
1434
    return IntermediateColumnarFormat(icf_path)
6✔
1435

1436

1437
def explode_init(
6✔
1438
    icf_path,
1439
    vcfs,
1440
    *,
1441
    column_chunk_size=16,
1442
    target_num_partitions=1,
1443
    worker_processes=1,
1444
    show_progress=False,
1445
    compressor=None,
1446
):
1447
    writer = IntermediateColumnarFormatWriter(icf_path)
6✔
1448
    return writer.init(
6✔
1449
        vcfs,
1450
        target_num_partitions=target_num_partitions,
1451
        worker_processes=worker_processes,
1452
        show_progress=show_progress,
1453
        column_chunk_size=column_chunk_size,
1454
        compressor=compressor,
1455
    )
1456

1457

1458
def explode_partition(icf_path, partition):
6✔
1459
    writer = IntermediateColumnarFormatWriter(icf_path)
6✔
1460
    writer.explode_partition(partition)
6✔
1461

1462

1463
def explode_finalise(icf_path):
6✔
1464
    writer = IntermediateColumnarFormatWriter(icf_path)
6✔
1465
    writer.finalise()
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc