• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sgkit-dev / bio2zarr / 14242614654

03 Apr 2025 12:10PM UTC coverage: 98.771% (-0.1%) from 98.867%
14242614654

Pull #339

github

web-flow
Merge 965501277 into 5439661d3
Pull Request #339: Common writer for plink and ICF

771 of 782 new or added lines in 6 files covered. (98.59%)

12 existing lines in 2 files now uncovered.

2571 of 2603 relevant lines covered (98.77%)

5.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.65
/bio2zarr/vcf2zarr/icf.py
1
import collections
6✔
2
import contextlib
6✔
3
import dataclasses
6✔
4
import json
6✔
5
import logging
6✔
6
import math
6✔
7
import pathlib
6✔
8
import pickle
6✔
9
import shutil
6✔
10
import sys
6✔
11
from functools import partial
6✔
12
from typing import Any
6✔
13

14
import numcodecs
6✔
15
import numpy as np
6✔
16

17
from bio2zarr import schema
6✔
18

19
from .. import constants, core, provenance, vcf_utils
6✔
20

21
logger = logging.getLogger(__name__)
6✔
22

23

24
@dataclasses.dataclass
6✔
25
class VcfFieldSummary(core.JsonDataclass):
6✔
26
    num_chunks: int = 0
6✔
27
    compressed_size: int = 0
6✔
28
    uncompressed_size: int = 0
6✔
29
    max_number: int = 0  # Corresponds to VCF Number field, depends on context
6✔
30
    # Only defined for numeric fields
31
    max_value: Any = -math.inf
6✔
32
    min_value: Any = math.inf
6✔
33

34
    def update(self, other):
6✔
35
        self.num_chunks += other.num_chunks
6✔
36
        self.compressed_size += other.compressed_size
6✔
37
        self.uncompressed_size += other.uncompressed_size
6✔
38
        self.max_number = max(self.max_number, other.max_number)
6✔
39
        self.min_value = min(self.min_value, other.min_value)
6✔
40
        self.max_value = max(self.max_value, other.max_value)
6✔
41

42
    @staticmethod
6✔
43
    def fromdict(d):
6✔
44
        return VcfFieldSummary(**d)
6✔
45

46

47
@dataclasses.dataclass(order=True)
6✔
48
class VcfField:
6✔
49
    category: str
6✔
50
    name: str
6✔
51
    vcf_number: str
6✔
52
    vcf_type: str
6✔
53
    description: str
6✔
54
    summary: VcfFieldSummary
6✔
55

56
    @staticmethod
6✔
57
    def from_header(definition):
6✔
58
        category = definition["HeaderType"]
6✔
59
        name = definition["ID"]
6✔
60
        vcf_number = definition["Number"]
6✔
61
        vcf_type = definition["Type"]
6✔
62
        return VcfField(
6✔
63
            category=category,
64
            name=name,
65
            vcf_number=vcf_number,
66
            vcf_type=vcf_type,
67
            description=definition["Description"].strip('"'),
68
            summary=VcfFieldSummary(),
69
        )
70

71
    @staticmethod
6✔
72
    def fromdict(d):
6✔
73
        f = VcfField(**d)
6✔
74
        f.summary = VcfFieldSummary(**d["summary"])
6✔
75
        return f
6✔
76

77
    @property
6✔
78
    def full_name(self):
6✔
79
        if self.category == "fixed":
6✔
80
            return self.name
6✔
81
        return f"{self.category}/{self.name}"
6✔
82

83
    def smallest_dtype(self):
6✔
84
        """
85
        Returns the smallest dtype suitable for this field based
86
        on type, and values.
87
        """
88
        s = self.summary
6✔
89
        if self.vcf_type == "Float":
6✔
90
            ret = "f4"
6✔
91
        elif self.vcf_type == "Integer":
6✔
92
            if not math.isfinite(s.max_value):
6✔
93
                # All missing values; use i1. Note we should have some API to
94
                # check more explicitly for missingness:
95
                # https://github.com/sgkit-dev/bio2zarr/issues/131
96
                ret = "i1"
6✔
97
            else:
98
                ret = core.min_int_dtype(s.min_value, s.max_value)
6✔
99
        elif self.vcf_type == "Flag":
6✔
100
            ret = "bool"
6✔
101
        elif self.vcf_type == "Character":
6✔
102
            ret = "U1"
6✔
103
        else:
104
            assert self.vcf_type == "String"
6✔
105
            ret = "O"
6✔
106
        return ret
6✔
107

108

109
@dataclasses.dataclass
6✔
110
class VcfPartition:
6✔
111
    vcf_path: str
6✔
112
    region: str
6✔
113
    num_records: int = -1
6✔
114

115

116
ICF_METADATA_FORMAT_VERSION = "0.4"
6✔
117
ICF_DEFAULT_COMPRESSOR = numcodecs.Blosc(
6✔
118
    cname="zstd", clevel=7, shuffle=numcodecs.Blosc.NOSHUFFLE
119
)
120

121

122
@dataclasses.dataclass
6✔
123
class IcfMetadata(core.JsonDataclass):
6✔
124
    samples: list
6✔
125
    contigs: list
6✔
126
    filters: list
6✔
127
    fields: list
6✔
128
    partitions: list = None
6✔
129
    format_version: str = None
6✔
130
    compressor: dict = None
6✔
131
    column_chunk_size: int = None
6✔
132
    provenance: dict = None
6✔
133
    num_records: int = -1
6✔
134

135
    @property
6✔
136
    def info_fields(self):
6✔
137
        fields = []
6✔
138
        for field in self.fields:
6✔
139
            if field.category == "INFO":
6✔
140
                fields.append(field)
6✔
141
        return fields
6✔
142

143
    @property
6✔
144
    def format_fields(self):
6✔
145
        fields = []
6✔
146
        for field in self.fields:
6✔
147
            if field.category == "FORMAT":
6✔
148
                fields.append(field)
6✔
149
        return fields
6✔
150

151
    @property
6✔
152
    def num_contigs(self):
6✔
153
        return len(self.contigs)
6✔
154

155
    @property
6✔
156
    def num_filters(self):
6✔
157
        return len(self.filters)
6✔
158

159
    @property
6✔
160
    def num_samples(self):
6✔
161
        return len(self.samples)
6✔
162

163
    @staticmethod
6✔
164
    def fromdict(d):
6✔
165
        if d["format_version"] != ICF_METADATA_FORMAT_VERSION:
6✔
166
            raise ValueError(
6✔
167
                "Intermediate columnar metadata format version mismatch: "
168
                f"{d['format_version']} != {ICF_METADATA_FORMAT_VERSION}"
169
            )
170
        partitions = [VcfPartition(**pd) for pd in d["partitions"]]
6✔
171
        for p in partitions:
6✔
172
            p.region = vcf_utils.Region(**p.region)
6✔
173
        d = d.copy()
6✔
174
        d["partitions"] = partitions
6✔
175
        d["fields"] = [VcfField.fromdict(fd) for fd in d["fields"]]
6✔
176
        d["samples"] = [schema.Sample(**sd) for sd in d["samples"]]
6✔
177
        d["filters"] = [schema.Filter(**fd) for fd in d["filters"]]
6✔
178
        d["contigs"] = [schema.Contig(**cd) for cd in d["contigs"]]
6✔
179
        return IcfMetadata(**d)
6✔
180

181
    def __eq__(self, other):
6✔
182
        if not isinstance(other, IcfMetadata):
6✔
UNCOV
183
            return NotImplemented
×
184
        return (
6✔
185
            self.samples == other.samples
186
            and self.contigs == other.contigs
187
            and self.filters == other.filters
188
            and sorted(self.fields) == sorted(other.fields)
189
        )
190

191

192
def fixed_vcf_field_definitions():
6✔
193
    def make_field_def(name, vcf_type, vcf_number):
6✔
194
        return VcfField(
6✔
195
            category="fixed",
196
            name=name,
197
            vcf_type=vcf_type,
198
            vcf_number=vcf_number,
199
            description="",
200
            summary=VcfFieldSummary(),
201
        )
202

203
    fields = [
6✔
204
        make_field_def("CHROM", "String", "1"),
205
        make_field_def("POS", "Integer", "1"),
206
        make_field_def("QUAL", "Float", "1"),
207
        make_field_def("ID", "String", "."),
208
        make_field_def("FILTERS", "String", "."),
209
        make_field_def("REF", "String", "1"),
210
        make_field_def("ALT", "String", "."),
211
        make_field_def("rlen", "Integer", "1"),  # computed field
212
    ]
213
    return fields
6✔
214

215

216
def scan_vcf(path, target_num_partitions):
6✔
217
    with vcf_utils.VcfFile(path) as vcf_file:
6✔
218
        vcf = vcf_file.vcf
6✔
219
        filters = []
6✔
220
        pass_index = -1
6✔
221
        for h in vcf.header_iter():
6✔
222
            if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str):
6✔
223
                try:
6✔
224
                    description = h["Description"].strip('"')
6✔
UNCOV
225
                except KeyError:
×
UNCOV
226
                    description = ""
×
227
                if h["ID"] == "PASS":
6✔
228
                    pass_index = len(filters)
6✔
229
                filters.append(schema.Filter(h["ID"], description))
6✔
230

231
        # Ensure PASS is the first filter if present
232
        if pass_index > 0:
6✔
UNCOV
233
            pass_filter = filters.pop(pass_index)
×
UNCOV
234
            filters.insert(0, pass_filter)
×
235

236
        fields = fixed_vcf_field_definitions()
6✔
237
        for h in vcf.header_iter():
6✔
238
            if h["HeaderType"] in ["INFO", "FORMAT"]:
6✔
239
                field = VcfField.from_header(h)
6✔
240
                if h["HeaderType"] == "FORMAT" and field.name == "GT":
6✔
241
                    field.vcf_type = "Integer"
6✔
242
                    field.vcf_number = "."
6✔
243
                fields.append(field)
6✔
244

245
        try:
6✔
246
            contig_lengths = vcf.seqlens
6✔
247
        except AttributeError:
6✔
248
            contig_lengths = [None for _ in vcf.seqnames]
6✔
249

250
        metadata = IcfMetadata(
6✔
251
            samples=[schema.Sample(sample_id) for sample_id in vcf.samples],
252
            contigs=[
253
                schema.Contig(contig_id, length)
254
                for contig_id, length in zip(vcf.seqnames, contig_lengths)
255
            ],
256
            filters=filters,
257
            fields=fields,
258
            partitions=[],
259
            num_records=sum(vcf_file.contig_record_counts().values()),
260
        )
261

262
        regions = vcf_file.partition_into_regions(num_parts=target_num_partitions)
6✔
263
        for region in regions:
6✔
264
            metadata.partitions.append(
6✔
265
                VcfPartition(
266
                    # TODO should this be fully resolving the path? Otherwise it's all
267
                    # relative to the original WD
268
                    vcf_path=str(path),
269
                    region=region,
270
                )
271
            )
272
        logger.info(
6✔
273
            f"Split {path} into {len(metadata.partitions)} "
274
            f"partitions target={target_num_partitions})"
275
        )
276
        core.update_progress(1)
6✔
277
        return metadata, vcf.raw_header
6✔
278

279

280
def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
6✔
281
    logger.info(
6✔
282
        f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
283
        f" partitions."
284
    )
285
    # An easy mistake to make is to pass the same file twice. Check this early on.
286
    for path, count in collections.Counter(paths).items():
6✔
287
        if not path.exists():  # NEEDS TEST
6✔
UNCOV
288
            raise FileNotFoundError(path)
×
289
        if count > 1:
6✔
290
            raise ValueError(f"Duplicate path provided: {path}")
6✔
291

292
    progress_config = core.ProgressConfig(
6✔
293
        total=len(paths),
294
        units="files",
295
        title="Scan",
296
        show=show_progress,
297
    )
298
    with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
6✔
299
        for path in paths:
6✔
300
            pwm.submit(
6✔
301
                scan_vcf,
302
                path,
303
                max(1, target_num_partitions // len(paths)),
304
            )
305
        results = list(pwm.results_as_completed())
6✔
306

307
    # Sort to make the ordering deterministic
308
    results.sort(key=lambda t: t[0].partitions[0].vcf_path)
6✔
309
    # We just take the first header, assuming the others
310
    # are compatible.
311
    all_partitions = []
6✔
312
    total_records = 0
6✔
313
    contigs = {}
6✔
314
    for metadata, _ in results:
6✔
315
        for partition in metadata.partitions:
6✔
316
            logger.debug(f"Scanned partition {partition}")
6✔
317
            all_partitions.append(partition)
6✔
318
        for contig in metadata.contigs:
6✔
319
            if contig.id in contigs:
6✔
320
                if contig != contigs[contig.id]:
6✔
321
                    raise ValueError(
4✔
322
                        "Incompatible contig definitions: "
323
                        f"{contig} != {contigs[contig.id]}"
324
                    )
325
            else:
326
                contigs[contig.id] = contig
6✔
327
        total_records += metadata.num_records
6✔
328
        metadata.num_records = 0
6✔
329
        metadata.partitions = []
6✔
330

331
    contig_union = list(contigs.values())
6✔
332
    for metadata, _ in results:
6✔
333
        metadata.contigs = contig_union
6✔
334

335
    icf_metadata, header = results[0]
6✔
336
    for metadata, _ in results[1:]:
6✔
337
        if metadata != icf_metadata:
6✔
338
            raise ValueError("Incompatible VCF chunks")
6✔
339

340
    # Note: this will be infinity here if any of the chunks has an index
341
    # that doesn't keep track of the number of records per-contig
342
    icf_metadata.num_records = total_records
6✔
343

344
    # Sort by contig (in the order they appear in the header) first,
345
    # then by start coordinate
346
    contig_index_map = {contig.id: j for j, contig in enumerate(metadata.contigs)}
6✔
347
    all_partitions.sort(
6✔
348
        key=lambda x: (contig_index_map[x.region.contig], x.region.start)
349
    )
350
    icf_metadata.partitions = all_partitions
6✔
351
    logger.info(f"Scan complete, resulting in {len(all_partitions)} partitions.")
6✔
352
    return icf_metadata, header
6✔
353

354

355
def sanitise_value_bool(shape, value):
6✔
356
    x = True
6✔
357
    if value is None:
6✔
358
        x = False
6✔
359
    return x
6✔
360

361

362
def sanitise_value_float_scalar(shape, value):
6✔
363
    x = value
6✔
364
    if value is None:
6✔
365
        x = [constants.FLOAT32_MISSING]
6✔
366
    return x[0]
6✔
367

368

369
def sanitise_value_int_scalar(shape, value):
6✔
370
    x = value
6✔
371
    if value is None:
6✔
372
        x = [constants.INT_MISSING]
6✔
373
    else:
374
        x = sanitise_int_array(value, ndmin=1, dtype=np.int32)
6✔
375
    return x[0]
6✔
376

377

378
def sanitise_value_string_scalar(shape, value):
6✔
379
    if value is None:
6✔
380
        return "."
6✔
381
    else:
382
        return value[0]
6✔
383

384

385
def sanitise_value_string_1d(shape, value):
6✔
386
    if value is None:
6✔
387
        return np.full(shape, ".", dtype="O")
6✔
388
    else:
389
        value = drop_empty_second_dim(value)
6✔
390
        result = np.full(shape, "", dtype=value.dtype)
6✔
391
        result[: value.shape[0]] = value
6✔
392
        return result
6✔
393

394

395
def sanitise_value_string_2d(shape, value):
6✔
396
    if value is None:
6✔
397
        return np.full(shape, ".", dtype="O")
6✔
398
    else:
399
        result = np.full(shape, "", dtype="O")
6✔
400
        if value.ndim == 2:
6✔
401
            result[: value.shape[0], : value.shape[1]] = value
6✔
402
        else:
403
            # Convert 1D array into 2D with appropriate shape
404
            for k, val in enumerate(value):
6✔
405
                result[k, : len(val)] = val
6✔
406
        return result
6✔
407

408

409
def drop_empty_second_dim(value):
6✔
410
    assert len(value.shape) == 1 or value.shape[1] == 1
6✔
411
    if len(value.shape) == 2 and value.shape[1] == 1:
6✔
412
        value = value[..., 0]
6✔
413
    return value
6✔
414

415

416
def sanitise_value_float_1d(shape, value):
6✔
417
    if value is None:
6✔
418
        return np.full(shape, constants.FLOAT32_MISSING)
6✔
419
    else:
420
        value = np.array(value, ndmin=1, dtype=np.float32, copy=True)
6✔
421
        # numpy will map None values to Nan, but we need a
422
        # specific NaN
423
        value[np.isnan(value)] = constants.FLOAT32_MISSING
6✔
424
        value = drop_empty_second_dim(value)
6✔
425
        result = np.full(shape, constants.FLOAT32_FILL, dtype=np.float32)
6✔
426
        result[: value.shape[0]] = value
6✔
427
        return result
6✔
428

429

430
def sanitise_value_float_2d(shape, value):
6✔
431
    if value is None:
6✔
432
        return np.full(shape, constants.FLOAT32_MISSING)
6✔
433
    else:
434
        value = np.array(value, ndmin=2, dtype=np.float32, copy=True)
6✔
435
        result = np.full(shape, constants.FLOAT32_FILL, dtype=np.float32)
6✔
436
        result[:, : value.shape[1]] = value
6✔
437
        return result
6✔
438

439

440
def sanitise_int_array(value, ndmin, dtype):
6✔
441
    if isinstance(value, tuple):
6✔
UNCOV
442
        value = [
×
443
            constants.VCF_INT_MISSING if x is None else x for x in value
444
        ]  # NEEDS TEST
445
    value = np.array(value, ndmin=ndmin, copy=True)
6✔
446
    value[value == constants.VCF_INT_MISSING] = -1
6✔
447
    value[value == constants.VCF_INT_FILL] = -2
6✔
448
    # TODO watch out for clipping here!
449
    return value.astype(dtype)
6✔
450

451

452
def sanitise_value_int_1d(shape, value):
6✔
453
    if value is None:
6✔
454
        return np.full(shape, -1)
6✔
455
    else:
456
        value = sanitise_int_array(value, 1, np.int32)
6✔
457
        value = drop_empty_second_dim(value)
6✔
458
        result = np.full(shape, -2, dtype=np.int32)
6✔
459
        result[: value.shape[0]] = value
6✔
460
        return result
6✔
461

462

463
def sanitise_value_int_2d(shape, value):
6✔
464
    if value is None:
6✔
465
        return np.full(shape, -1)
6✔
466
    else:
467
        value = sanitise_int_array(value, 2, np.int32)
6✔
468
        result = np.full(shape, -2, dtype=np.int32)
6✔
469
        result[:, : value.shape[1]] = value
6✔
470
        return result
6✔
471

472

473
missing_value_map = {
6✔
474
    "Integer": constants.INT_MISSING,
475
    "Float": constants.FLOAT32_MISSING,
476
    "String": constants.STR_MISSING,
477
    "Character": constants.STR_MISSING,
478
    "Flag": False,
479
}
480

481

482
class VcfValueTransformer:
6✔
483
    """
484
    Transform VCF values into the stored intermediate format used
485
    in the IntermediateColumnarFormat, and update field summaries.
486
    """
487

488
    def __init__(self, field, num_samples):
6✔
489
        self.field = field
6✔
490
        self.num_samples = num_samples
6✔
491
        self.dimension = 1
6✔
492
        if field.category == "FORMAT":
6✔
493
            self.dimension = 2
6✔
494
        self.missing = missing_value_map[field.vcf_type]
6✔
495

496
    @staticmethod
6✔
497
    def factory(field, num_samples):
6✔
498
        if field.vcf_type in ("Integer", "Flag"):
6✔
499
            return IntegerValueTransformer(field, num_samples)
6✔
500
        if field.vcf_type == "Float":
6✔
501
            return FloatValueTransformer(field, num_samples)
6✔
502
        if field.name in ["REF", "FILTERS", "ALT", "ID", "CHROM"]:
6✔
503
            return SplitStringValueTransformer(field, num_samples)
6✔
504
        return StringValueTransformer(field, num_samples)
6✔
505

506
    def transform(self, vcf_value):
6✔
507
        if isinstance(vcf_value, tuple):
6✔
508
            vcf_value = [self.missing if v is None else v for v in vcf_value]
6✔
509
        value = np.array(vcf_value, ndmin=self.dimension, copy=True)
6✔
510
        return value
6✔
511

512
    def transform_and_update_bounds(self, vcf_value):
6✔
513
        if vcf_value is None:
6✔
514
            return None
6✔
515
        # print(self, self.field.full_name, "T", vcf_value)
516
        value = self.transform(vcf_value)
6✔
517
        self.update_bounds(value)
6✔
518
        return value
6✔
519

520

521
class IntegerValueTransformer(VcfValueTransformer):
6✔
522
    def update_bounds(self, value):
6✔
523
        summary = self.field.summary
6✔
524
        # Mask out missing and fill values
525
        # print(value)
526
        a = value[value >= constants.MIN_INT_VALUE]
6✔
527
        if a.size > 0:
6✔
528
            summary.max_value = int(max(summary.max_value, np.max(a)))
6✔
529
            summary.min_value = int(min(summary.min_value, np.min(a)))
6✔
530
        number = value.shape[-1]
6✔
531
        summary.max_number = max(summary.max_number, number)
6✔
532

533

534
class FloatValueTransformer(VcfValueTransformer):
6✔
535
    def update_bounds(self, value):
6✔
536
        summary = self.field.summary
6✔
537
        summary.max_value = float(max(summary.max_value, np.max(value)))
6✔
538
        summary.min_value = float(min(summary.min_value, np.min(value)))
6✔
539
        number = value.shape[-1]
6✔
540
        summary.max_number = max(summary.max_number, number)
6✔
541

542

543
class StringValueTransformer(VcfValueTransformer):
6✔
544
    def update_bounds(self, value):
6✔
545
        summary = self.field.summary
6✔
546
        if self.field.category == "FORMAT":
6✔
547
            number = max(len(v) for v in value)
6✔
548
        else:
549
            number = value.shape[-1]
6✔
550
        # TODO would be nice to report string lengths, but not
551
        # really necessary.
552
        summary.max_number = max(summary.max_number, number)
6✔
553

554
    def transform(self, vcf_value):
6✔
555
        if self.dimension == 1:
6✔
556
            value = np.array(list(vcf_value.split(",")))
6✔
557
        else:
558
            # TODO can we make this faster??
559
            value = np.array([v.split(",") for v in vcf_value], dtype="O")
6✔
560
            # print("HERE", vcf_value, value)
561
            # for v in vcf_value:
562
            #     print("\t", type(v), len(v), v.split(","))
563
        # print("S: ", self.dimension, ":", value.shape, value)
564
        return value
6✔
565

566

567
class SplitStringValueTransformer(StringValueTransformer):
6✔
568
    def transform(self, vcf_value):
6✔
569
        if vcf_value is None:
6✔
UNCOV
570
            return self.missing_value  # NEEDS TEST
×
571
        assert self.dimension == 1
6✔
572
        return np.array(vcf_value, ndmin=1, dtype="str")
6✔
573

574

575
def get_vcf_field_path(base_path, vcf_field):
6✔
576
    if vcf_field.category == "fixed":
6✔
577
        return base_path / vcf_field.name
6✔
578
    return base_path / vcf_field.category / vcf_field.name
6✔
579

580

581
class IntermediateColumnarFormatField:
6✔
582
    def __init__(self, icf, vcf_field):
6✔
583
        self.vcf_field = vcf_field
6✔
584
        self.path = get_vcf_field_path(icf.path, vcf_field)
6✔
585
        self.compressor = icf.compressor
6✔
586
        self.num_partitions = icf.num_partitions
6✔
587
        self.num_records = icf.num_records
6✔
588
        self.partition_record_index = icf.partition_record_index
6✔
589
        # A map of partition id to the cumulative number of records
590
        # in chunks within that partition
591
        self._chunk_record_index = {}
6✔
592

593
    @property
6✔
594
    def name(self):
6✔
595
        return self.vcf_field.full_name
6✔
596

597
    def partition_path(self, partition_id):
6✔
598
        return self.path / f"p{partition_id}"
6✔
599

600
    def __repr__(self):
6✔
601
        partition_chunks = [self.num_chunks(j) for j in range(self.num_partitions)]
6✔
602
        return (
6✔
603
            f"IntermediateColumnarFormatField(name={self.name}, "
604
            f"partition_chunks={partition_chunks}, "
605
            f"path={self.path})"
606
        )
607

608
    def num_chunks(self, partition_id):
6✔
609
        return len(self.chunk_record_index(partition_id)) - 1
6✔
610

611
    def chunk_record_index(self, partition_id):
6✔
612
        if partition_id not in self._chunk_record_index:
6✔
613
            index_path = self.partition_path(partition_id) / "chunk_index"
6✔
614
            with open(index_path, "rb") as f:
6✔
615
                a = pickle.load(f)
6✔
616
            assert len(a) > 1
6✔
617
            assert a[0] == 0
6✔
618
            self._chunk_record_index[partition_id] = a
6✔
619
        return self._chunk_record_index[partition_id]
6✔
620

621
    def read_chunk(self, path):
6✔
622
        with open(path, "rb") as f:
6✔
623
            pkl = self.compressor.decode(f.read())
6✔
624
        return pickle.loads(pkl)
6✔
625

626
    def chunk_num_records(self, partition_id):
6✔
627
        return np.diff(self.chunk_record_index(partition_id))
6✔
628

629
    def chunks(self, partition_id, start_chunk=0):
6✔
630
        partition_path = self.partition_path(partition_id)
6✔
631
        chunk_cumulative_records = self.chunk_record_index(partition_id)
6✔
632
        chunk_num_records = np.diff(chunk_cumulative_records)
6✔
633
        for count, cumulative in zip(
6✔
634
            chunk_num_records[start_chunk:], chunk_cumulative_records[start_chunk + 1 :]
635
        ):
636
            path = partition_path / f"{cumulative}"
6✔
637
            chunk = self.read_chunk(path)
6✔
638
            if len(chunk) != count:
6✔
639
                raise ValueError(f"Corruption detected in chunk: {path}")
6✔
640
            yield chunk
6✔
641

642
    def iter_values(self, start=None, stop=None):
6✔
643
        start = 0 if start is None else start
6✔
644
        stop = self.num_records if stop is None else stop
6✔
645
        start_partition = (
6✔
646
            np.searchsorted(self.partition_record_index, start, side="right") - 1
647
        )
648
        offset = self.partition_record_index[start_partition]
6✔
649
        assert offset <= start
6✔
650
        chunk_offset = start - offset
6✔
651

652
        chunk_record_index = self.chunk_record_index(start_partition)
6✔
653
        start_chunk = (
6✔
654
            np.searchsorted(chunk_record_index, chunk_offset, side="right") - 1
655
        )
656
        record_id = offset + chunk_record_index[start_chunk]
6✔
657
        assert record_id <= start
6✔
658
        logger.debug(
6✔
659
            f"Read {self.vcf_field.full_name} slice [{start}:{stop}]:"
660
            f"p_start={start_partition}, c_start={start_chunk}, r_start={record_id}"
661
        )
662
        for chunk in self.chunks(start_partition, start_chunk):
6✔
663
            for record in chunk:
6✔
664
                if record_id == stop:
6✔
665
                    return
6✔
666
                if record_id >= start:
6✔
667
                    yield record
6✔
668
                record_id += 1
6✔
669
        assert record_id > start
6✔
670
        for partition_id in range(start_partition + 1, self.num_partitions):
6✔
671
            for chunk in self.chunks(partition_id):
6✔
672
                for record in chunk:
6✔
673
                    if record_id == stop:
6✔
674
                        return
6✔
675
                    yield record
6✔
676
                    record_id += 1
6✔
677

678
    # Note: this involves some computation so should arguably be a method,
679
    # but making a property for consistency with xarray etc
680
    @property
6✔
681
    def values(self):
6✔
682
        ret = [None] * self.num_records
6✔
683
        j = 0
6✔
684
        for partition_id in range(self.num_partitions):
6✔
685
            for chunk in self.chunks(partition_id):
6✔
686
                for record in chunk:
6✔
687
                    ret[j] = record
6✔
688
                    j += 1
6✔
689
        assert j == self.num_records
6✔
690
        return ret
6✔
691

692
    def sanitiser_factory(self, shape):
6✔
693
        assert len(shape) <= 2
6✔
694
        if self.vcf_field.vcf_type == "Flag":
6✔
695
            assert len(shape) == 0
6✔
696
            return partial(sanitise_value_bool, shape)
6✔
697
        elif self.vcf_field.vcf_type == "Float":
6✔
698
            if len(shape) == 0:
6✔
699
                return partial(sanitise_value_float_scalar, shape)
6✔
700
            elif len(shape) == 1:
6✔
701
                return partial(sanitise_value_float_1d, shape)
6✔
702
            else:
703
                return partial(sanitise_value_float_2d, shape)
6✔
704
        elif self.vcf_field.vcf_type == "Integer":
6✔
705
            if len(shape) == 0:
6✔
706
                return partial(sanitise_value_int_scalar, shape)
6✔
707
            elif len(shape) == 1:
6✔
708
                return partial(sanitise_value_int_1d, shape)
6✔
709
            else:
710
                return partial(sanitise_value_int_2d, shape)
6✔
711
        else:
712
            assert self.vcf_field.vcf_type in ("String", "Character")
6✔
713
            if len(shape) == 0:
6✔
714
                return partial(sanitise_value_string_scalar, shape)
6✔
715
            elif len(shape) == 1:
6✔
716
                return partial(sanitise_value_string_1d, shape)
6✔
717
            else:
718
                return partial(sanitise_value_string_2d, shape)
6✔
719

720

721
@dataclasses.dataclass
6✔
722
class IcfFieldWriter:
6✔
723
    vcf_field: VcfField
6✔
724
    path: pathlib.Path
6✔
725
    transformer: VcfValueTransformer
6✔
726
    compressor: Any
6✔
727
    max_buffered_bytes: int
6✔
728
    buff: list[Any] = dataclasses.field(default_factory=list)
6✔
729
    buffered_bytes: int = 0
6✔
730
    chunk_index: list[int] = dataclasses.field(default_factory=lambda: [0])
6✔
731
    num_records: int = 0
6✔
732

733
    def append(self, val):
6✔
734
        val = self.transformer.transform_and_update_bounds(val)
6✔
735
        assert val is None or isinstance(val, np.ndarray)
6✔
736
        self.buff.append(val)
6✔
737
        val_bytes = sys.getsizeof(val)
6✔
738
        self.buffered_bytes += val_bytes
6✔
739
        self.num_records += 1
6✔
740
        if self.buffered_bytes >= self.max_buffered_bytes:
6✔
741
            logger.debug(
6✔
742
                f"Flush {self.path} buffered={self.buffered_bytes} "
743
                f"max={self.max_buffered_bytes}"
744
            )
745
            self.write_chunk()
6✔
746
            self.buff.clear()
6✔
747
            self.buffered_bytes = 0
6✔
748

749
    def write_chunk(self):
6✔
750
        # Update index
751
        self.chunk_index.append(self.num_records)
6✔
752
        path = self.path / f"{self.num_records}"
6✔
753
        logger.debug(f"Start write: {path}")
6✔
754
        pkl = pickle.dumps(self.buff)
6✔
755
        compressed = self.compressor.encode(pkl)
6✔
756
        with open(path, "wb") as f:
6✔
757
            f.write(compressed)
6✔
758

759
        # Update the summary
760
        self.vcf_field.summary.num_chunks += 1
6✔
761
        self.vcf_field.summary.compressed_size += len(compressed)
6✔
762
        self.vcf_field.summary.uncompressed_size += self.buffered_bytes
6✔
763
        logger.debug(f"Finish write: {path}")
6✔
764

765
    def flush(self):
6✔
766
        logger.debug(
6✔
767
            f"Flush {self.path} records={len(self.buff)} buffered={self.buffered_bytes}"
768
        )
769
        if len(self.buff) > 0:
6✔
770
            self.write_chunk()
6✔
771
        with open(self.path / "chunk_index", "wb") as f:
6✔
772
            a = np.array(self.chunk_index, dtype=int)
6✔
773
            pickle.dump(a, f)
6✔
774

775

776
class IcfPartitionWriter(contextlib.AbstractContextManager):
6✔
777
    """
778
    Writes the data for a IntermediateColumnarFormat partition.
779
    """
780

781
    def __init__(
6✔
782
        self,
783
        icf_metadata,
784
        out_path,
785
        partition_index,
786
    ):
787
        self.partition_index = partition_index
6✔
788
        # chunk_size is in megabytes
789
        max_buffered_bytes = icf_metadata.column_chunk_size * 2**20
6✔
790
        assert max_buffered_bytes > 0
6✔
791
        compressor = numcodecs.get_codec(icf_metadata.compressor)
6✔
792

793
        self.field_writers = {}
6✔
794
        num_samples = len(icf_metadata.samples)
6✔
795
        for vcf_field in icf_metadata.fields:
6✔
796
            field_path = get_vcf_field_path(out_path, vcf_field)
6✔
797
            field_partition_path = field_path / f"p{partition_index}"
6✔
798
            # Should be robust to running explode_partition twice.
799
            field_partition_path.mkdir(exist_ok=True)
6✔
800
            transformer = VcfValueTransformer.factory(vcf_field, num_samples)
6✔
801
            self.field_writers[vcf_field.full_name] = IcfFieldWriter(
6✔
802
                vcf_field,
803
                field_partition_path,
804
                transformer,
805
                compressor,
806
                max_buffered_bytes,
807
            )
808

809
    @property
6✔
810
    def field_summaries(self):
6✔
811
        return {
6✔
812
            name: field.vcf_field.summary for name, field in self.field_writers.items()
813
        }
814

815
    def append(self, name, value):
6✔
816
        self.field_writers[name].append(value)
6✔
817

818
    def __exit__(self, exc_type, exc_val, exc_tb):
6✔
819
        if exc_type is None:
6✔
820
            for field in self.field_writers.values():
6✔
821
                field.flush()
6✔
822
        return False
6✔
823

824

825
class IntermediateColumnarFormat(collections.abc.Mapping):
6✔
826
    def __init__(self, path):
6✔
827
        self.path = pathlib.Path(path)
6✔
828
        # TODO raise a more informative error here telling people this
829
        # directory is either a WIP or the wrong format.
830
        with open(self.path / "metadata.json") as f:
6✔
831
            self.metadata = IcfMetadata.fromdict(json.load(f))
6✔
832
        with open(self.path / "header.txt") as f:
6✔
833
            self.vcf_header = f.read()
6✔
834
        self.compressor = numcodecs.get_codec(self.metadata.compressor)
6✔
835
        self.fields = {}
6✔
836
        partition_num_records = [
6✔
837
            partition.num_records for partition in self.metadata.partitions
838
        ]
839
        # Allow us to find which partition a given record is in
840
        self.partition_record_index = np.cumsum([0, *partition_num_records])
6✔
841
        for field in self.metadata.fields:
6✔
842
            self.fields[field.full_name] = IntermediateColumnarFormatField(self, field)
6✔
843
        logger.info(
6✔
844
            f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, "
845
            f"records={self.num_records}, fields={self.num_fields})"
846
        )
847

848
    def __repr__(self):
6✔
849
        return (
6✔
850
            f"IntermediateColumnarFormat(fields={len(self)}, "
851
            f"partitions={self.num_partitions}, "
852
            f"records={self.num_records}, path={self.path})"
853
        )
854

855
    def __getitem__(self, key):
6✔
856
        return self.fields[key]
6✔
857

858
    def __iter__(self):
6✔
859
        return iter(self.fields)
6✔
860

861
    def __len__(self):
6✔
862
        return len(self.fields)
6✔
863

864
    def summary_table(self):
6✔
865
        data = []
6✔
866
        for name, icf_field in self.fields.items():
6✔
867
            summary = icf_field.vcf_field.summary
6✔
868
            d = {
6✔
869
                "name": name,
870
                "type": icf_field.vcf_field.vcf_type,
871
                "chunks": summary.num_chunks,
872
                "size": core.display_size(summary.uncompressed_size),
873
                "compressed": core.display_size(summary.compressed_size),
874
                "max_n": summary.max_number,
875
                "min_val": core.display_number(summary.min_value),
876
                "max_val": core.display_number(summary.max_value),
877
            }
878

879
            data.append(d)
6✔
880
        return data
6✔
881

882
    @property
6✔
883
    def num_records(self):
6✔
884
        return self.metadata.num_records
6✔
885

886
    @property
6✔
887
    def num_partitions(self):
6✔
888
        return len(self.metadata.partitions)
6✔
889

890
    @property
6✔
891
    def samples(self):
6✔
892
        return [sample.id for sample in self.metadata.samples]
6✔
893

894
    @property
6✔
895
    def num_samples(self):
6✔
896
        return len(self.metadata.samples)
6✔
897

898
    @property
6✔
899
    def num_fields(self):
6✔
900
        return len(self.fields)
6✔
901

902
    @property
6✔
903
    def root_attrs(self):
6✔
904
        return {
6✔
905
            "vcf_header": self.vcf_header,
906
        }
907

908
    def iter_alleles(self, start, stop, num_alleles):
6✔
909
        ref_field = self.fields["REF"]
6✔
910
        alt_field = self.fields["ALT"]
6✔
911

912
        for ref, alt in zip(
6✔
913
            ref_field.iter_values(start, stop),
914
            alt_field.iter_values(start, stop),
915
        ):
916
            alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
6✔
917
            alleles[0] = ref[0]
6✔
918
            alleles[1 : 1 + len(alt)] = alt
6✔
919
            yield alleles
6✔
920

921
    def iter_id(self, start, stop):
6✔
922
        for value in self.fields["ID"].iter_values(start, stop):
6✔
923
            if value is not None:
6✔
924
                yield value[0]
6✔
925
            else:
926
                yield None
6✔
927

928
    def iter_filters(self, start, stop):
6✔
929
        source_field = self.fields["FILTERS"]
6✔
930
        lookup = {filt.id: index for index, filt in enumerate(self.metadata.filters)}
6✔
931

932
        for filter_values in source_field.iter_values(start, stop):
6✔
933
            filters = np.zeros(len(self.metadata.filters), dtype=bool)
6✔
934
            if filter_values is not None:
6✔
935
                for filter_id in filter_values:
6✔
936
                    try:
6✔
937
                        filters[lookup[filter_id]] = True
6✔
NEW
UNCOV
938
                    except KeyError:
×
NEW
UNCOV
939
                        raise ValueError(
×
940
                            f"Filter '{filter_id}' was not defined in the header."
941
                        ) from None
942
            yield filters
6✔
943

944
    def iter_contig(self, start, stop):
6✔
945
        source_field = self.fields["CHROM"]
6✔
946
        lookup = {
6✔
947
            contig.id: index for index, contig in enumerate(self.metadata.contigs)
948
        }
949

950
        for value in source_field.iter_values(start, stop):
6✔
951
            # Note: because we are using the indexes to define the lookups
952
            # and we always have an index, it seems that we the contig lookup
953
            # will always succeed. However, if anyone ever does hit a KeyError
954
            # here, please do open an issue with a reproducible example!
955
            yield lookup[value[0]]
6✔
956

957
    def iter_field(self, field_name, shape, start, stop):
6✔
958
        source_field = self.fields[field_name]
6✔
959
        sanitiser = source_field.sanitiser_factory(shape)
6✔
960
        for value in source_field.iter_values(start, stop):
6✔
961
            yield sanitiser(value)
6✔
962

963
    def iter_genotypes(self, shape, start, stop):
6✔
964
        source_field = self.fields["FORMAT/GT"]
6✔
965
        for value in source_field.iter_values(start, stop):
6✔
966
            genotypes = value[:, :-1] if value is not None else None
6✔
967
            phased = value[:, -1] if value is not None else None
6✔
968
            sanitised_genotypes = sanitise_value_int_2d(shape, genotypes)
6✔
969
            sanitised_phased = sanitise_value_int_1d(shape[:-1], phased)
6✔
970
            yield sanitised_genotypes, sanitised_phased
6✔
971

972

973
@dataclasses.dataclass
6✔
974
class IcfPartitionMetadata(core.JsonDataclass):
6✔
975
    num_records: int
6✔
976
    last_position: int
6✔
977
    field_summaries: dict
6✔
978

979
    @staticmethod
6✔
980
    def fromdict(d):
6✔
981
        md = IcfPartitionMetadata(**d)
6✔
982
        for k, v in md.field_summaries.items():
6✔
983
            md.field_summaries[k] = VcfFieldSummary.fromdict(v)
6✔
984
        return md
6✔
985

986

987
def check_overlapping_partitions(partitions):
6✔
988
    for i in range(1, len(partitions)):
6✔
989
        prev_region = partitions[i - 1].region
6✔
990
        current_region = partitions[i].region
6✔
991
        if prev_region.contig == current_region.contig:
6✔
992
            assert prev_region.end is not None
6✔
993
            # Regions are *inclusive*
994
            if prev_region.end >= current_region.start:
6✔
995
                raise ValueError(
6✔
996
                    f"Overlapping VCF regions in partitions {i - 1} and {i}: "
997
                    f"{prev_region} and {current_region}"
998
                )
999

1000

1001
def check_field_clobbering(icf_metadata):
6✔
1002
    info_field_names = set(field.name for field in icf_metadata.info_fields)
6✔
1003
    fixed_variant_fields = set(
6✔
1004
        ["contig", "id", "id_mask", "position", "allele", "filter", "quality"]
1005
    )
1006
    intersection = info_field_names & fixed_variant_fields
6✔
1007
    if len(intersection) > 0:
6✔
1008
        raise ValueError(
6✔
1009
            f"INFO field name(s) clashing with VCF Zarr spec: {intersection}"
1010
        )
1011

1012
    format_field_names = set(field.name for field in icf_metadata.format_fields)
6✔
1013
    fixed_variant_fields = set(["genotype", "genotype_phased", "genotype_mask"])
6✔
1014
    intersection = format_field_names & fixed_variant_fields
6✔
1015
    if len(intersection) > 0:
6✔
1016
        raise ValueError(
6✔
1017
            f"FORMAT field name(s) clashing with VCF Zarr spec: {intersection}"
1018
        )
1019

1020

1021
@dataclasses.dataclass
6✔
1022
class IcfWriteSummary(core.JsonDataclass):
6✔
1023
    num_partitions: int
6✔
1024
    num_samples: int
6✔
1025
    num_variants: int
6✔
1026

1027

1028
class IntermediateColumnarFormatWriter:
6✔
1029
    def __init__(self, path):
6✔
1030
        self.path = pathlib.Path(path)
6✔
1031
        self.wip_path = self.path / "wip"
6✔
1032
        self.metadata = None
6✔
1033

1034
    @property
6✔
1035
    def num_partitions(self):
6✔
1036
        return len(self.metadata.partitions)
6✔
1037

1038
    def init(
6✔
1039
        self,
1040
        vcfs,
1041
        *,
1042
        column_chunk_size=16,
1043
        worker_processes=1,
1044
        target_num_partitions=None,
1045
        show_progress=False,
1046
        compressor=None,
1047
    ):
1048
        if self.path.exists():
6✔
UNCOV
1049
            raise ValueError(f"ICF path already exists: {self.path}")
×
1050
        if compressor is None:
6✔
1051
            compressor = ICF_DEFAULT_COMPRESSOR
6✔
1052
        vcfs = [pathlib.Path(vcf) for vcf in vcfs]
6✔
1053
        target_num_partitions = max(target_num_partitions, len(vcfs))
6✔
1054

1055
        # TODO move scan_vcfs into this class
1056
        icf_metadata, header = scan_vcfs(
6✔
1057
            vcfs,
1058
            worker_processes=worker_processes,
1059
            show_progress=show_progress,
1060
            target_num_partitions=target_num_partitions,
1061
        )
1062
        check_field_clobbering(icf_metadata)
6✔
1063
        self.metadata = icf_metadata
6✔
1064
        self.metadata.format_version = ICF_METADATA_FORMAT_VERSION
6✔
1065
        self.metadata.compressor = compressor.get_config()
6✔
1066
        self.metadata.column_chunk_size = column_chunk_size
6✔
1067
        # Bare minimum here for provenance - would be nice to include versions of key
1068
        # dependencies as well.
1069
        self.metadata.provenance = {"source": f"bio2zarr-{provenance.__version__}"}
6✔
1070

1071
        self.mkdirs()
6✔
1072

1073
        # Note: this is needed for the current version of the vcfzarr spec, but it's
1074
        # probably going to be dropped.
1075
        # https://github.com/pystatgen/vcf-zarr-spec/issues/15
1076
        # May be useful to keep lying around still though?
1077
        logger.info("Writing VCF header")
6✔
1078
        with open(self.path / "header.txt", "w") as f:
6✔
1079
            f.write(header)
6✔
1080

1081
        logger.info("Writing WIP metadata")
6✔
1082
        with open(self.wip_path / "metadata.json", "w") as f:
6✔
1083
            json.dump(self.metadata.asdict(), f, indent=4)
6✔
1084
        return IcfWriteSummary(
6✔
1085
            num_partitions=self.num_partitions,
1086
            num_variants=icf_metadata.num_records,
1087
            num_samples=icf_metadata.num_samples,
1088
        )
1089

1090
    def mkdirs(self):
6✔
1091
        num_dirs = len(self.metadata.fields)
6✔
1092
        logger.info(f"Creating {num_dirs} field directories")
6✔
1093
        self.path.mkdir()
6✔
1094
        self.wip_path.mkdir()
6✔
1095
        for field in self.metadata.fields:
6✔
1096
            field_path = get_vcf_field_path(self.path, field)
6✔
1097
            field_path.mkdir(parents=True)
6✔
1098

1099
    def load_partition_summaries(self):
6✔
1100
        summaries = []
6✔
1101
        not_found = []
6✔
1102
        for j in range(self.num_partitions):
6✔
1103
            try:
6✔
1104
                with open(self.wip_path / f"p{j}.json") as f:
6✔
1105
                    summaries.append(IcfPartitionMetadata.fromdict(json.load(f)))
6✔
1106
            except FileNotFoundError:
6✔
1107
                not_found.append(j)
6✔
1108
        if len(not_found) > 0:
6✔
1109
            raise FileNotFoundError(
6✔
1110
                f"Partition metadata not found for {len(not_found)}"
1111
                f" partitions: {not_found}"
1112
            )
1113
        return summaries
6✔
1114

1115
    def load_metadata(self):
6✔
1116
        if self.metadata is None:
6✔
1117
            with open(self.wip_path / "metadata.json") as f:
6✔
1118
                self.metadata = IcfMetadata.fromdict(json.load(f))
6✔
1119

1120
    def process_partition(self, partition_index):
6✔
1121
        self.load_metadata()
6✔
1122
        summary_path = self.wip_path / f"p{partition_index}.json"
6✔
1123
        # If someone is rewriting a summary path (for whatever reason), make sure it
1124
        # doesn't look like it's already been completed.
1125
        # NOTE to do this properly we probably need to take a lock on this file - but
1126
        # this simple approach will catch the vast majority of problems.
1127
        if summary_path.exists():
6✔
1128
            summary_path.unlink()
6✔
1129

1130
        partition = self.metadata.partitions[partition_index]
6✔
1131
        logger.info(
6✔
1132
            f"Start p{partition_index} {partition.vcf_path}__{partition.region}"
1133
        )
1134
        info_fields = self.metadata.info_fields
6✔
1135
        format_fields = []
6✔
1136
        has_gt = False
6✔
1137
        for field in self.metadata.format_fields:
6✔
1138
            if field.name == "GT":
6✔
1139
                has_gt = True
6✔
1140
            else:
1141
                format_fields.append(field)
6✔
1142

1143
        last_position = None
6✔
1144
        with IcfPartitionWriter(
6✔
1145
            self.metadata,
1146
            self.path,
1147
            partition_index,
1148
        ) as tcw:
1149
            with vcf_utils.VcfFile(partition.vcf_path) as vcf:
6✔
1150
                num_records = 0
6✔
1151
                for variant in vcf.variants(partition.region):
6✔
1152
                    num_records += 1
6✔
1153
                    last_position = variant.POS
6✔
1154
                    tcw.append("CHROM", variant.CHROM)
6✔
1155
                    tcw.append("POS", variant.POS)
6✔
1156
                    tcw.append("QUAL", variant.QUAL)
6✔
1157
                    tcw.append("ID", variant.ID)
6✔
1158
                    tcw.append("FILTERS", variant.FILTERS)
6✔
1159
                    tcw.append("REF", variant.REF)
6✔
1160
                    tcw.append("ALT", variant.ALT)
6✔
1161
                    tcw.append("rlen", variant.end - variant.start)
6✔
1162
                    for field in info_fields:
6✔
1163
                        tcw.append(field.full_name, variant.INFO.get(field.name, None))
6✔
1164
                    if has_gt:
6✔
1165
                        val = None
6✔
1166
                        if "GT" in variant.FORMAT and variant.genotype is not None:
6✔
1167
                            val = variant.genotype.array()
6✔
1168
                        tcw.append("FORMAT/GT", val)
6✔
1169
                    for field in format_fields:
6✔
1170
                        val = variant.format(field.name)
6✔
1171
                        tcw.append(field.full_name, val)
6✔
1172

1173
                    # Note: an issue with updating the progress per variant here like
1174
                    # this is that we get a significant pause at the end of the counter
1175
                    # while all the "small" fields get flushed. Possibly not much to be
1176
                    # done about it.
1177
                    core.update_progress(1)
6✔
1178
            logger.info(
6✔
1179
                f"Finished reading VCF for partition {partition_index}, "
1180
                f"flushing buffers"
1181
            )
1182

1183
        partition_metadata = IcfPartitionMetadata(
6✔
1184
            num_records=num_records,
1185
            last_position=last_position,
1186
            field_summaries=tcw.field_summaries,
1187
        )
1188
        with open(summary_path, "w") as f:
6✔
1189
            f.write(partition_metadata.asjson())
6✔
1190
        logger.info(
6✔
1191
            f"Finish p{partition_index} {partition.vcf_path}__{partition.region} "
1192
            f"{num_records} records last_pos={last_position}"
1193
        )
1194

1195
    def explode(self, *, worker_processes=1, show_progress=False):
6✔
1196
        self.load_metadata()
6✔
1197
        num_records = self.metadata.num_records
6✔
1198
        if np.isinf(num_records):
6✔
1199
            logger.warning(
6✔
1200
                "Total records unknown, cannot show progress; "
1201
                "reindex VCFs with bcftools index to fix"
1202
            )
1203
            num_records = None
6✔
1204
        num_fields = len(self.metadata.fields)
6✔
1205
        num_samples = len(self.metadata.samples)
6✔
1206
        logger.info(
6✔
1207
            f"Exploding fields={num_fields} samples={num_samples}; "
1208
            f"partitions={self.num_partitions} "
1209
            f"variants={'unknown' if num_records is None else num_records}"
1210
        )
1211
        progress_config = core.ProgressConfig(
6✔
1212
            total=num_records,
1213
            units="vars",
1214
            title="Explode",
1215
            show=show_progress,
1216
        )
1217
        with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
6✔
1218
            for j in range(self.num_partitions):
6✔
1219
                pwm.submit(self.process_partition, j)
6✔
1220

1221
    def explode_partition(self, partition):
6✔
1222
        self.load_metadata()
6✔
1223
        if partition < 0 or partition >= self.num_partitions:
6✔
1224
            raise ValueError("Partition index not in the valid range")
6✔
1225
        self.process_partition(partition)
6✔
1226

1227
    def finalise(self):
6✔
1228
        self.load_metadata()
6✔
1229
        partition_summaries = self.load_partition_summaries()
6✔
1230
        total_records = 0
6✔
1231
        for index, summary in enumerate(partition_summaries):
6✔
1232
            partition_records = summary.num_records
6✔
1233
            self.metadata.partitions[index].num_records = partition_records
6✔
1234
            self.metadata.partitions[index].region.end = summary.last_position
6✔
1235
            total_records += partition_records
6✔
1236
        if not np.isinf(self.metadata.num_records):
6✔
1237
            # Note: this is just telling us that there's a bug in the
1238
            # index based record counting code, but it doesn't actually
1239
            # matter much. We may want to just make this a warning if
1240
            # we hit regular problems.
1241
            assert total_records == self.metadata.num_records
6✔
1242
        self.metadata.num_records = total_records
6✔
1243

1244
        check_overlapping_partitions(self.metadata.partitions)
6✔
1245

1246
        for field in self.metadata.fields:
6✔
1247
            for summary in partition_summaries:
6✔
1248
                field.summary.update(summary.field_summaries[field.full_name])
6✔
1249

1250
        logger.info("Finalising metadata")
6✔
1251
        with open(self.path / "metadata.json", "w") as f:
6✔
1252
            f.write(self.metadata.asjson())
6✔
1253

1254
        logger.debug("Removing WIP directory")
6✔
1255
        shutil.rmtree(self.wip_path)
6✔
1256

1257

1258
def explode(
6✔
1259
    icf_path,
1260
    vcfs,
1261
    *,
1262
    column_chunk_size=16,
1263
    worker_processes=1,
1264
    show_progress=False,
1265
    compressor=None,
1266
):
1267
    writer = IntermediateColumnarFormatWriter(icf_path)
6✔
1268
    writer.init(
6✔
1269
        vcfs,
1270
        # Heuristic to get reasonable worker utilisation with lumpy partition sizing
1271
        target_num_partitions=max(1, worker_processes * 4),
1272
        worker_processes=worker_processes,
1273
        show_progress=show_progress,
1274
        column_chunk_size=column_chunk_size,
1275
        compressor=compressor,
1276
    )
1277
    writer.explode(worker_processes=worker_processes, show_progress=show_progress)
6✔
1278
    writer.finalise()
6✔
1279
    return IntermediateColumnarFormat(icf_path)
6✔
1280

1281

1282
def explode_init(
6✔
1283
    icf_path,
1284
    vcfs,
1285
    *,
1286
    column_chunk_size=16,
1287
    target_num_partitions=1,
1288
    worker_processes=1,
1289
    show_progress=False,
1290
    compressor=None,
1291
):
1292
    writer = IntermediateColumnarFormatWriter(icf_path)
6✔
1293
    return writer.init(
6✔
1294
        vcfs,
1295
        target_num_partitions=target_num_partitions,
1296
        worker_processes=worker_processes,
1297
        show_progress=show_progress,
1298
        column_chunk_size=column_chunk_size,
1299
        compressor=compressor,
1300
    )
1301

1302

1303
def explode_partition(icf_path, partition):
6✔
1304
    writer = IntermediateColumnarFormatWriter(icf_path)
6✔
1305
    writer.explode_partition(partition)
6✔
1306

1307

1308
def explode_finalise(icf_path):
6✔
1309
    writer = IntermediateColumnarFormatWriter(icf_path)
6✔
1310
    writer.finalise()
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc