• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sgkit-dev / bio2zarr / 12312346477

13 Dec 2024 08:41AM UTC coverage: 98.25% (-0.7%) from 98.91%
12312346477

Pull #281

github

web-flow
Merge a398a9196 into 883a37e81
Pull Request #281: Draft bed2zarr code

138 of 154 new or added lines in 3 files covered. (89.61%)

13 existing lines in 4 files now uncovered.

2583 of 2629 relevant lines covered (98.25%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.05
/bio2zarr/vcf2zarr/icf.py
1
import collections
1✔
2
import contextlib
1✔
3
import dataclasses
1✔
4
import json
1✔
5
import logging
1✔
6
import math
1✔
7
import pathlib
1✔
8
import pickle
1✔
9
import shutil
1✔
10
import sys
1✔
11
from typing import Any
1✔
12

13
import numcodecs
1✔
14
import numpy as np
1✔
15

16
from .. import constants, core, provenance, vcf_utils
1✔
17

18
logger = logging.getLogger(__name__)
1✔
19

20

21
@dataclasses.dataclass
1✔
22
class VcfFieldSummary(core.JsonDataclass):
1✔
23
    num_chunks: int = 0
1✔
24
    compressed_size: int = 0
1✔
25
    uncompressed_size: int = 0
1✔
26
    max_number: int = 0  # Corresponds to VCF Number field, depends on context
1✔
27
    # Only defined for numeric fields
28
    max_value: Any = -math.inf
1✔
29
    min_value: Any = math.inf
1✔
30

31
    def update(self, other):
1✔
32
        self.num_chunks += other.num_chunks
1✔
33
        self.compressed_size += other.compressed_size
1✔
34
        self.uncompressed_size += other.uncompressed_size
1✔
35
        self.max_number = max(self.max_number, other.max_number)
1✔
36
        self.min_value = min(self.min_value, other.min_value)
1✔
37
        self.max_value = max(self.max_value, other.max_value)
1✔
38

39
    @staticmethod
1✔
40
    def fromdict(d):
1✔
41
        return VcfFieldSummary(**d)
1✔
42

43

44
@dataclasses.dataclass
1✔
45
class VcfField:
1✔
46
    category: str
1✔
47
    name: str
1✔
48
    vcf_number: str
1✔
49
    vcf_type: str
1✔
50
    description: str
1✔
51
    summary: VcfFieldSummary
1✔
52

53
    @staticmethod
1✔
54
    def from_header(definition):
1✔
55
        category = definition["HeaderType"]
1✔
56
        name = definition["ID"]
1✔
57
        vcf_number = definition["Number"]
1✔
58
        vcf_type = definition["Type"]
1✔
59
        return VcfField(
1✔
60
            category=category,
61
            name=name,
62
            vcf_number=vcf_number,
63
            vcf_type=vcf_type,
64
            description=definition["Description"].strip('"'),
65
            summary=VcfFieldSummary(),
66
        )
67

68
    @staticmethod
1✔
69
    def fromdict(d):
1✔
70
        f = VcfField(**d)
1✔
71
        f.summary = VcfFieldSummary(**d["summary"])
1✔
72
        return f
1✔
73

74
    @property
1✔
75
    def full_name(self):
1✔
76
        if self.category == "fixed":
1✔
77
            return self.name
1✔
78
        return f"{self.category}/{self.name}"
1✔
79

80
    def smallest_dtype(self):
1✔
81
        """
82
        Returns the smallest dtype suitable for this field based
83
        on type, and values.
84
        """
85
        s = self.summary
1✔
86
        if self.vcf_type == "Float":
1✔
87
            ret = "f4"
1✔
88
        elif self.vcf_type == "Integer":
1✔
89
            if not math.isfinite(s.max_value):
1✔
90
                # All missing values; use i1. Note we should have some API to
91
                # check more explicitly for missingness:
92
                # https://github.com/sgkit-dev/bio2zarr/issues/131
93
                ret = "i1"
1✔
94
            else:
95
                ret = core.min_int_dtype(s.min_value, s.max_value)
1✔
96
        elif self.vcf_type == "Flag":
1✔
97
            ret = "bool"
1✔
98
        elif self.vcf_type == "Character":
1✔
99
            ret = "U1"
1✔
100
        else:
101
            assert self.vcf_type == "String"
1✔
102
            ret = "O"
1✔
103
        return ret
1✔
104

105

106
@dataclasses.dataclass
1✔
107
class VcfPartition:
1✔
108
    vcf_path: str
1✔
109
    region: str
1✔
110
    num_records: int = -1
1✔
111

112

113
ICF_METADATA_FORMAT_VERSION = "0.4"
1✔
114
ICF_DEFAULT_COMPRESSOR = numcodecs.Blosc(
1✔
115
    cname="zstd", clevel=7, shuffle=numcodecs.Blosc.NOSHUFFLE
116
)
117

118

119
@dataclasses.dataclass
1✔
120
class Contig:
1✔
121
    id: str
1✔
122
    length: int = None
1✔
123

124

125
@dataclasses.dataclass
1✔
126
class Sample:
1✔
127
    id: str
1✔
128

129

130
@dataclasses.dataclass
1✔
131
class Filter:
1✔
132
    id: str
1✔
133
    description: str = ""
1✔
134

135

136
@dataclasses.dataclass
1✔
137
class IcfMetadata(core.JsonDataclass):
1✔
138
    samples: list
1✔
139
    contigs: list
1✔
140
    filters: list
1✔
141
    fields: list
1✔
142
    partitions: list = None
1✔
143
    format_version: str = None
1✔
144
    compressor: dict = None
1✔
145
    column_chunk_size: int = None
1✔
146
    provenance: dict = None
1✔
147
    num_records: int = -1
1✔
148

149
    @property
1✔
150
    def info_fields(self):
1✔
151
        fields = []
1✔
152
        for field in self.fields:
1✔
153
            if field.category == "INFO":
1✔
154
                fields.append(field)
1✔
155
        return fields
1✔
156

157
    @property
1✔
158
    def format_fields(self):
1✔
159
        fields = []
1✔
160
        for field in self.fields:
1✔
161
            if field.category == "FORMAT":
1✔
162
                fields.append(field)
1✔
163
        return fields
1✔
164

165
    @property
1✔
166
    def num_contigs(self):
1✔
167
        return len(self.contigs)
1✔
168

169
    @property
1✔
170
    def num_filters(self):
1✔
171
        return len(self.filters)
1✔
172

173
    @property
1✔
174
    def num_samples(self):
1✔
175
        return len(self.samples)
1✔
176

177
    @staticmethod
1✔
178
    def fromdict(d):
1✔
179
        if d["format_version"] != ICF_METADATA_FORMAT_VERSION:
1✔
180
            raise ValueError(
1✔
181
                "Intermediate columnar metadata format version mismatch: "
182
                f"{d['format_version']} != {ICF_METADATA_FORMAT_VERSION}"
183
            )
184
        partitions = [VcfPartition(**pd) for pd in d["partitions"]]
1✔
185
        for p in partitions:
1✔
186
            p.region = vcf_utils.Region(**p.region)
1✔
187
        d = d.copy()
1✔
188
        d["partitions"] = partitions
1✔
189
        d["fields"] = [VcfField.fromdict(fd) for fd in d["fields"]]
1✔
190
        d["samples"] = [Sample(**sd) for sd in d["samples"]]
1✔
191
        d["filters"] = [Filter(**fd) for fd in d["filters"]]
1✔
192
        d["contigs"] = [Contig(**cd) for cd in d["contigs"]]
1✔
193
        return IcfMetadata(**d)
1✔
194

195

196
def fixed_vcf_field_definitions():
1✔
197
    def make_field_def(name, vcf_type, vcf_number):
1✔
198
        return VcfField(
1✔
199
            category="fixed",
200
            name=name,
201
            vcf_type=vcf_type,
202
            vcf_number=vcf_number,
203
            description="",
204
            summary=VcfFieldSummary(),
205
        )
206

207
    fields = [
1✔
208
        make_field_def("CHROM", "String", "1"),
209
        make_field_def("POS", "Integer", "1"),
210
        make_field_def("QUAL", "Float", "1"),
211
        make_field_def("ID", "String", "."),
212
        make_field_def("FILTERS", "String", "."),
213
        make_field_def("REF", "String", "1"),
214
        make_field_def("ALT", "String", "."),
215
        make_field_def("rlen", "Integer", "1"),  # computed field
216
    ]
217
    return fields
1✔
218

219

220
def scan_vcf(path, target_num_partitions, *, local_alleles):
1✔
221
    with vcf_utils.IndexedVcf(path) as indexed_vcf:
1✔
222
        vcf = indexed_vcf.vcf
1✔
223
        filters = []
1✔
224
        pass_index = -1
1✔
225
        for h in vcf.header_iter():
1✔
226
            if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str):
1✔
227
                try:
1✔
228
                    description = h["Description"].strip('"')
1✔
229
                except KeyError:
×
230
                    description = ""
×
231
                if h["ID"] == "PASS":
1✔
232
                    pass_index = len(filters)
1✔
233
                filters.append(Filter(h["ID"], description))
1✔
234

235
        # Ensure PASS is the first filter if present
236
        if pass_index > 0:
1✔
237
            pass_filter = filters.pop(pass_index)
×
238
            filters.insert(0, pass_filter)
×
239

240
        # Indicates whether vcf2zarr can introduce local alleles
241
        can_localize = False
1✔
242
        should_add_laa_field = True
1✔
243
        should_add_lpl_field = True
1✔
244
        fields = fixed_vcf_field_definitions()
1✔
245
        for h in vcf.header_iter():
1✔
246
            if h["HeaderType"] in ["INFO", "FORMAT"]:
1✔
247
                field = VcfField.from_header(h)
1✔
248
                if h["HeaderType"] == "FORMAT" and field.name == "GT":
1✔
249
                    field.vcf_type = "Integer"
1✔
250
                    field.vcf_number = "."
1✔
251
                fields.append(field)
1✔
252
                if field.category == "FORMAT":
1✔
253
                    if field.name == "PL":
1✔
254
                        can_localize = True
1✔
255
                    if field.name == "LAA":
1✔
256
                        should_add_laa_field = False
1✔
257
                    if field.name == "LPL":
1✔
258
                        should_add_lpl_field = False
1✔
259

260
        if local_alleles and can_localize:
1✔
261
            if should_add_laa_field:
1✔
262
                laa_field = VcfField(
1✔
263
                    category="FORMAT",
264
                    name="LAA",
265
                    vcf_type="Integer",
266
                    vcf_number=".",
267
                    description="1-based indices into ALT, indicating which alleles"
268
                    " are relevant (local) for the current sample",
269
                    summary=VcfFieldSummary(),
270
                )
271
                fields.append(laa_field)
1✔
272
            if should_add_lpl_field:
1✔
273
                lpl_field = VcfField(
1✔
274
                    category="FORMAT",
275
                    name="LPL",
276
                    vcf_type="Integer",
277
                    vcf_number="LG",
278
                    description="Local-allele representation of PL",
279
                    summary=VcfFieldSummary(),
280
                )
281
                fields.append(lpl_field)
1✔
282

283
        try:
1✔
284
            contig_lengths = vcf.seqlens
1✔
285
        except AttributeError:
1✔
286
            contig_lengths = [None for _ in vcf.seqnames]
1✔
287

288
        metadata = IcfMetadata(
1✔
289
            samples=[Sample(sample_id) for sample_id in vcf.samples],
290
            contigs=[
291
                Contig(contig_id, length)
292
                for contig_id, length in zip(vcf.seqnames, contig_lengths)
293
            ],
294
            filters=filters,
295
            fields=fields,
296
            partitions=[],
297
            num_records=sum(indexed_vcf.contig_record_counts().values()),
298
        )
299

300
        regions = indexed_vcf.partition_into_regions(num_parts=target_num_partitions)
1✔
301
        for region in regions:
1✔
302
            metadata.partitions.append(
1✔
303
                VcfPartition(
304
                    # TODO should this be fully resolving the path? Otherwise it's all
305
                    # relative to the original WD
306
                    vcf_path=str(path),
307
                    region=region,
308
                )
309
            )
310
        logger.info(
1✔
311
            f"Split {path} into {len(metadata.partitions)} "
312
            f"partitions target={target_num_partitions})"
313
        )
314
        core.update_progress(1)
1✔
315
        return metadata, vcf.raw_header
1✔
316

317

318
def scan_vcfs(
1✔
319
    paths,
320
    show_progress,
321
    target_num_partitions,
322
    worker_processes=1,
323
    *,
324
    local_alleles,
325
):
326
    logger.info(
1✔
327
        f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
328
        f" partitions."
329
    )
330
    # An easy mistake to make is to pass the same file twice. Check this early on.
331
    for path, count in collections.Counter(paths).items():
1✔
332
        if not path.exists():  # NEEDS TEST
1✔
333
            raise FileNotFoundError(path)
×
334
        if count > 1:
1✔
335
            raise ValueError(f"Duplicate path provided: {path}")
1✔
336

337
    progress_config = core.ProgressConfig(
1✔
338
        total=len(paths),
339
        units="files",
340
        title="Scan",
341
        show=show_progress,
342
    )
343
    with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1✔
344
        for path in paths:
1✔
345
            pwm.submit(
1✔
346
                scan_vcf,
347
                path,
348
                max(1, target_num_partitions // len(paths)),
349
                local_alleles=local_alleles,
350
            )
351
        results = list(pwm.results_as_completed())
1✔
352

353
    # Sort to make the ordering deterministic
354
    results.sort(key=lambda t: t[0].partitions[0].vcf_path)
1✔
355
    # We just take the first header, assuming the others
356
    # are compatible.
357
    all_partitions = []
1✔
358
    total_records = 0
1✔
359
    for metadata, _ in results:
1✔
360
        for partition in metadata.partitions:
1✔
361
            logger.debug(f"Scanned partition {partition}")
1✔
362
            all_partitions.append(partition)
1✔
363
        total_records += metadata.num_records
1✔
364
        metadata.num_records = 0
1✔
365
        metadata.partitions = []
1✔
366

367
    icf_metadata, header = results[0]
1✔
368
    for metadata, _ in results[1:]:
1✔
369
        if metadata != icf_metadata:
1✔
370
            raise ValueError("Incompatible VCF chunks")
1✔
371

372
    # Note: this will be infinity here if any of the chunks has an index
373
    # that doesn't keep track of the number of records per-contig
374
    icf_metadata.num_records = total_records
1✔
375

376
    # Sort by contig (in the order they appear in the header) first,
377
    # then by start coordinate
378
    contig_index_map = {contig.id: j for j, contig in enumerate(metadata.contigs)}
1✔
379
    all_partitions.sort(
1✔
380
        key=lambda x: (contig_index_map[x.region.contig], x.region.start)
381
    )
382
    icf_metadata.partitions = all_partitions
1✔
383
    logger.info(f"Scan complete, resulting in {len(all_partitions)} partitions.")
1✔
384
    return icf_metadata, header
1✔
385

386

387
def sanitise_value_bool(buff, j, value):
1✔
388
    x = True
1✔
389
    if value is None:
1✔
390
        x = False
1✔
391
    buff[j] = x
1✔
392

393

394
def sanitise_value_float_scalar(buff, j, value):
1✔
395
    x = value
1✔
396
    if value is None:
1✔
397
        x = [constants.FLOAT32_MISSING]
1✔
398
    buff[j] = x[0]
1✔
399

400

401
def sanitise_value_int_scalar(buff, j, value):
1✔
402
    x = value
1✔
403
    if value is None:
1✔
404
        # print("MISSING", INT_MISSING, INT_FILL)
405
        x = [constants.INT_MISSING]
1✔
406
    else:
407
        x = sanitise_int_array(value, ndmin=1, dtype=np.int32)
1✔
408
    buff[j] = x[0]
1✔
409

410

411
def sanitise_value_string_scalar(buff, j, value):
1✔
412
    if value is None:
1✔
413
        buff[j] = "."
1✔
414
    else:
415
        buff[j] = value[0]
1✔
416

417

418
def sanitise_value_string_1d(buff, j, value):
1✔
419
    if value is None:
1✔
420
        buff[j] = "."
1✔
421
    else:
422
        # value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
423
        # FIXME failure isn't coming from here, it seems to be from an
424
        # incorrectly detected dimension in the zarr array
425
        # The dimesions look all wrong, and the dtype should be Object
426
        # not str
427
        value = drop_empty_second_dim(value)
1✔
428
        buff[j] = ""
1✔
429
        buff[j, : value.shape[0]] = value
1✔
430

431

432
def sanitise_value_string_2d(buff, j, value):
1✔
433
    if value is None:
1✔
434
        buff[j] = "."
1✔
435
    else:
436
        # print(buff.shape, value.dtype, value)
437
        # assert value.ndim == 2
438
        buff[j] = ""
1✔
439
        if value.ndim == 2:
1✔
440
            buff[j, :, : value.shape[1]] = value
1✔
441
        else:
442
            # TODO check if this is still necessary
443
            for k, val in enumerate(value):
1✔
444
                buff[j, k, : len(val)] = val
1✔
445

446

447
def drop_empty_second_dim(value):
1✔
448
    assert len(value.shape) == 1 or value.shape[1] == 1
1✔
449
    if len(value.shape) == 2 and value.shape[1] == 1:
1✔
450
        value = value[..., 0]
1✔
451
    return value
1✔
452

453

454
def sanitise_value_float_1d(buff, j, value):
1✔
455
    if value is None:
1✔
456
        buff[j] = constants.FLOAT32_MISSING
1✔
457
    else:
458
        value = np.array(value, ndmin=1, dtype=buff.dtype, copy=True)
1✔
459
        # numpy will map None values to Nan, but we need a
460
        # specific NaN
461
        value[np.isnan(value)] = constants.FLOAT32_MISSING
1✔
462
        value = drop_empty_second_dim(value)
1✔
463
        buff[j] = constants.FLOAT32_FILL
1✔
464
        buff[j, : value.shape[0]] = value
1✔
465

466

467
def sanitise_value_float_2d(buff, j, value):
1✔
468
    if value is None:
1✔
469
        buff[j] = constants.FLOAT32_MISSING
1✔
470
    else:
471
        # print("value = ", value)
472
        value = np.array(value, ndmin=2, dtype=buff.dtype, copy=True)
1✔
473
        buff[j] = constants.FLOAT32_FILL
1✔
474
        buff[j, :, : value.shape[1]] = value
1✔
475

476

477
def sanitise_int_array(value, ndmin, dtype):
1✔
478
    if isinstance(value, tuple):
1✔
479
        value = [
×
480
            constants.VCF_INT_MISSING if x is None else x for x in value
481
        ]  # NEEDS TEST
482
    value = np.array(value, ndmin=ndmin, copy=True)
1✔
483
    value[value == constants.VCF_INT_MISSING] = -1
1✔
484
    value[value == constants.VCF_INT_FILL] = -2
1✔
485
    # TODO watch out for clipping here!
486
    return value.astype(dtype)
1✔
487

488

489
def sanitise_value_int_1d(buff, j, value):
1✔
490
    if value is None:
1✔
491
        buff[j] = -1
1✔
492
    else:
493
        value = sanitise_int_array(value, 1, buff.dtype)
1✔
494
        value = drop_empty_second_dim(value)
1✔
495
        buff[j] = -2
1✔
496
        buff[j, : value.shape[0]] = value
1✔
497

498

499
def sanitise_value_int_2d(buff, j, value):
1✔
500
    if value is None:
1✔
501
        buff[j] = -1
1✔
502
    else:
503
        value = sanitise_int_array(value, 2, buff.dtype)
1✔
504
        buff[j] = -2
1✔
505
        buff[j, :, : value.shape[1]] = value
1✔
506

507

508
def compute_laa_field(variant) -> np.ndarray:
1✔
509
    """
510
    Computes the value of the LAA field for each sample given a variant.
511

512
    The LAA field is a list of one-based indices into the ALT alleles
513
    that indicates which alternate alleles are observed in the sample.
514

515
    This method infers which alleles are observed from the GT field.
516
    """
517
    sample_count = variant.num_called + variant.num_unknown
1✔
518
    alt_allele_count = len(variant.ALT)
1✔
519
    allele_count = alt_allele_count + 1
1✔
520
    allele_counts = np.zeros((sample_count, allele_count), dtype=int)
1✔
521

522
    if "GT" in variant.FORMAT:
1✔
523
        # The last element of each sample's genotype indicates the phasing
524
        # and is not an allele.
525
        genotypes = variant.genotype.array()[:, :-1]
1✔
526
        genotypes.clip(0, None, out=genotypes)
1✔
527
        genotype_allele_counts = np.apply_along_axis(
1✔
528
            np.bincount, axis=1, arr=genotypes, minlength=allele_count
529
        )
530
        allele_counts += genotype_allele_counts
1✔
531

532
    allele_counts[:, 0] = 0  # We don't count the reference allele
1✔
533
    max_row_length = 1
1✔
534

535
    def nonzero_pad(arr: np.ndarray, *, length: int):
1✔
536
        nonlocal max_row_length
537
        alleles = arr.nonzero()[0]
1✔
538
        max_row_length = max(max_row_length, len(alleles))
1✔
539
        pad_length = length - len(alleles)
1✔
540
        return np.pad(
1✔
541
            alleles,
542
            (0, pad_length),
543
            mode="constant",
544
            constant_values=constants.INT_FILL,
545
        )
546

547
    alleles = np.apply_along_axis(
1✔
548
        nonzero_pad, axis=1, arr=allele_counts, length=max(1, alt_allele_count)
549
    )
550
    alleles = alleles[:, :max_row_length]
1✔
551

552
    return alleles
1✔
553

554

555
def compute_lpl_field(variant, laa_val: np.ndarray) -> np.ndarray:
1✔
556
    assert laa_val is not None
1✔
557

558
    la_val = np.zeros((laa_val.shape[0], laa_val.shape[1] + 1), dtype=laa_val.dtype)
1✔
559
    la_val[:, 1:] = laa_val
1✔
560
    ploidy = variant.ploidy
1✔
561

562
    if "PL" not in variant.FORMAT:
1✔
563
        sample_count = variant.num_called + variant.num_unknown
1✔
564
        local_allele_count = la_val.shape[1]
1✔
565

566
        if ploidy == 1:
1✔
567
            local_genotype_count = local_allele_count
1✔
568
        elif ploidy == 2:
1✔
569
            local_genotype_count = local_allele_count * (local_allele_count + 1) // 2
1✔
570
        else:
571
            raise ValueError(f"Cannot handle ploidy = {ploidy}")
1✔
572

573
        return np.full((sample_count, local_genotype_count), constants.INT_MISSING)
1✔
574

575
    # Compute a and b
576
    if ploidy == 1:
1✔
577
        a = la_val
1✔
578
        b = np.zeros_like(la_val)
1✔
579
    elif ploidy == 2:
1✔
580
        repeats = np.arange(1, la_val.shape[1] + 1)
1✔
581
        b = np.repeat(la_val, repeats, axis=1)
1✔
582
        arange_tile = np.tile(np.arange(la_val.shape[1]), (la_val.shape[1], 1))
1✔
583
        tril_indices = np.tril_indices_from(arange_tile)
1✔
584
        a_index = np.tile(arange_tile[tril_indices], (b.shape[0], 1))
1✔
585
        row_index = np.arange(la_val.shape[0]).reshape(-1, 1)
1✔
586
        a = la_val[row_index, a_index]
1✔
587
    else:
588
        raise ValueError(f"Cannot handle ploidy = {ploidy}")
1✔
589

590
    # Compute n, the local indices of the PL field
591
    n = (b * (b + 1) / 2 + a).astype(int)
1✔
592

593
    pl_val = variant.format("PL")
1✔
594
    pl_val[pl_val == constants.VCF_INT_MISSING] = constants.INT_MISSING
1✔
595
    # When the PL value is missing in all samples, pl_val has shape (sample_count, 1).
596
    # In that case, we need to broadcast the PL value.
597
    if pl_val.shape[1] < n.shape[1]:
1✔
598
        pl_val = np.broadcast_to(pl_val, n.shape)
1✔
599
    row_index = np.arange(pl_val.shape[0]).reshape(-1, 1)
1✔
600
    lpl_val = pl_val[row_index, n]
1✔
601
    lpl_val[b == constants.INT_FILL] = constants.INT_FILL
1✔
602

603
    return lpl_val
1✔
604

605

606
missing_value_map = {
1✔
607
    "Integer": constants.INT_MISSING,
608
    "Float": constants.FLOAT32_MISSING,
609
    "String": constants.STR_MISSING,
610
    "Character": constants.STR_MISSING,
611
    "Flag": False,
612
}
613

614

615
class VcfValueTransformer:
1✔
616
    """
617
    Transform VCF values into the stored intermediate format used
618
    in the IntermediateColumnarFormat, and update field summaries.
619
    """
620

621
    def __init__(self, field, num_samples):
1✔
622
        self.field = field
1✔
623
        self.num_samples = num_samples
1✔
624
        self.dimension = 1
1✔
625
        if field.category == "FORMAT":
1✔
626
            self.dimension = 2
1✔
627
        self.missing = missing_value_map[field.vcf_type]
1✔
628

629
    @staticmethod
1✔
630
    def factory(field, num_samples):
1✔
631
        if field.vcf_type in ("Integer", "Flag"):
1✔
632
            return IntegerValueTransformer(field, num_samples)
1✔
633
        if field.vcf_type == "Float":
1✔
634
            return FloatValueTransformer(field, num_samples)
1✔
635
        if field.name in ["REF", "FILTERS", "ALT", "ID", "CHROM"]:
1✔
636
            return SplitStringValueTransformer(field, num_samples)
1✔
637
        return StringValueTransformer(field, num_samples)
1✔
638

639
    def transform(self, vcf_value):
1✔
640
        if isinstance(vcf_value, tuple):
1✔
641
            vcf_value = [self.missing if v is None else v for v in vcf_value]
1✔
642
        value = np.array(vcf_value, ndmin=self.dimension, copy=True)
1✔
643
        return value
1✔
644

645
    def transform_and_update_bounds(self, vcf_value):
1✔
646
        if vcf_value is None:
1✔
647
            return None
1✔
648
        # print(self, self.field.full_name, "T", vcf_value)
649
        value = self.transform(vcf_value)
1✔
650
        self.update_bounds(value)
1✔
651
        return value
1✔
652

653

654
class IntegerValueTransformer(VcfValueTransformer):
1✔
655
    def update_bounds(self, value):
1✔
656
        summary = self.field.summary
1✔
657
        # Mask out missing and fill values
658
        # print(value)
659
        a = value[value >= constants.MIN_INT_VALUE]
1✔
660
        if a.size > 0:
1✔
661
            summary.max_value = int(max(summary.max_value, np.max(a)))
1✔
662
            summary.min_value = int(min(summary.min_value, np.min(a)))
1✔
663
        number = value.shape[-1]
1✔
664
        summary.max_number = max(summary.max_number, number)
1✔
665

666

667
class FloatValueTransformer(VcfValueTransformer):
1✔
668
    def update_bounds(self, value):
1✔
669
        summary = self.field.summary
1✔
670
        summary.max_value = float(max(summary.max_value, np.max(value)))
1✔
671
        summary.min_value = float(min(summary.min_value, np.min(value)))
1✔
672
        number = value.shape[-1]
1✔
673
        summary.max_number = max(summary.max_number, number)
1✔
674

675

676
class StringValueTransformer(VcfValueTransformer):
1✔
677
    def update_bounds(self, value):
1✔
678
        summary = self.field.summary
1✔
679
        if self.field.category == "FORMAT":
1✔
680
            number = max(len(v) for v in value)
1✔
681
        else:
682
            number = value.shape[-1]
1✔
683
        # TODO would be nice to report string lengths, but not
684
        # really necessary.
685
        summary.max_number = max(summary.max_number, number)
1✔
686

687
    def transform(self, vcf_value):
1✔
688
        if self.dimension == 1:
1✔
689
            value = np.array(list(vcf_value.split(",")))
1✔
690
        else:
691
            # TODO can we make this faster??
692
            value = np.array([v.split(",") for v in vcf_value], dtype="O")
1✔
693
            # print("HERE", vcf_value, value)
694
            # for v in vcf_value:
695
            #     print("\t", type(v), len(v), v.split(","))
696
        # print("S: ", self.dimension, ":", value.shape, value)
697
        return value
1✔
698

699

700
class SplitStringValueTransformer(StringValueTransformer):
1✔
701
    def transform(self, vcf_value):
1✔
702
        if vcf_value is None:
1✔
UNCOV
703
            return self.missing_value  # NEEDS TEST
×
704
        assert self.dimension == 1
1✔
705
        return np.array(vcf_value, ndmin=1, dtype="str")
1✔
706

707

708
def get_vcf_field_path(base_path, vcf_field):
1✔
709
    if vcf_field.category == "fixed":
1✔
710
        return base_path / vcf_field.name
1✔
711
    return base_path / vcf_field.category / vcf_field.name
1✔
712

713

714
class IntermediateColumnarFormatField:
1✔
715
    def __init__(self, icf, vcf_field):
1✔
716
        self.vcf_field = vcf_field
1✔
717
        self.path = get_vcf_field_path(icf.path, vcf_field)
1✔
718
        self.compressor = icf.compressor
1✔
719
        self.num_partitions = icf.num_partitions
1✔
720
        self.num_records = icf.num_records
1✔
721
        self.partition_record_index = icf.partition_record_index
1✔
722
        # A map of partition id to the cumulative number of records
723
        # in chunks within that partition
724
        self._chunk_record_index = {}
1✔
725

726
    @property
1✔
727
    def name(self):
1✔
728
        return self.vcf_field.full_name
1✔
729

730
    def partition_path(self, partition_id):
1✔
731
        return self.path / f"p{partition_id}"
1✔
732

733
    def __repr__(self):
1✔
734
        partition_chunks = [self.num_chunks(j) for j in range(self.num_partitions)]
1✔
735
        return (
1✔
736
            f"IntermediateColumnarFormatField(name={self.name}, "
737
            f"partition_chunks={partition_chunks}, "
738
            f"path={self.path})"
739
        )
740

741
    def num_chunks(self, partition_id):
1✔
742
        return len(self.chunk_record_index(partition_id)) - 1
1✔
743

744
    def chunk_record_index(self, partition_id):
1✔
745
        if partition_id not in self._chunk_record_index:
1✔
746
            index_path = self.partition_path(partition_id) / "chunk_index"
1✔
747
            with open(index_path, "rb") as f:
1✔
748
                a = pickle.load(f)
1✔
749
            assert len(a) > 1
1✔
750
            assert a[0] == 0
1✔
751
            self._chunk_record_index[partition_id] = a
1✔
752
        return self._chunk_record_index[partition_id]
1✔
753

754
    def read_chunk(self, path):
1✔
755
        with open(path, "rb") as f:
1✔
756
            pkl = self.compressor.decode(f.read())
1✔
757
        return pickle.loads(pkl)
1✔
758

759
    def chunk_num_records(self, partition_id):
1✔
760
        return np.diff(self.chunk_record_index(partition_id))
1✔
761

762
    def chunks(self, partition_id, start_chunk=0):
1✔
763
        partition_path = self.partition_path(partition_id)
1✔
764
        chunk_cumulative_records = self.chunk_record_index(partition_id)
1✔
765
        chunk_num_records = np.diff(chunk_cumulative_records)
1✔
766
        for count, cumulative in zip(
1✔
767
            chunk_num_records[start_chunk:], chunk_cumulative_records[start_chunk + 1 :]
768
        ):
769
            path = partition_path / f"{cumulative}"
1✔
770
            chunk = self.read_chunk(path)
1✔
771
            if len(chunk) != count:
1✔
772
                raise ValueError(f"Corruption detected in chunk: {path}")
1✔
773
            yield chunk
1✔
774

775
    def iter_values(self, start=None, stop=None):
1✔
776
        start = 0 if start is None else start
1✔
777
        stop = self.num_records if stop is None else stop
1✔
778
        start_partition = (
1✔
779
            np.searchsorted(self.partition_record_index, start, side="right") - 1
780
        )
781
        offset = self.partition_record_index[start_partition]
1✔
782
        assert offset <= start
1✔
783
        chunk_offset = start - offset
1✔
784

785
        chunk_record_index = self.chunk_record_index(start_partition)
1✔
786
        start_chunk = (
1✔
787
            np.searchsorted(chunk_record_index, chunk_offset, side="right") - 1
788
        )
789
        record_id = offset + chunk_record_index[start_chunk]
1✔
790
        assert record_id <= start
1✔
791
        logger.debug(
1✔
792
            f"Read {self.vcf_field.full_name} slice [{start}:{stop}]:"
793
            f"p_start={start_partition}, c_start={start_chunk}, r_start={record_id}"
794
        )
795
        for chunk in self.chunks(start_partition, start_chunk):
1✔
796
            for record in chunk:
1✔
797
                if record_id == stop:
1✔
798
                    return
1✔
799
                if record_id >= start:
1✔
800
                    yield record
1✔
801
                record_id += 1
1✔
802
        assert record_id > start
1✔
803
        for partition_id in range(start_partition + 1, self.num_partitions):
1✔
804
            for chunk in self.chunks(partition_id):
1✔
805
                for record in chunk:
1✔
806
                    if record_id == stop:
1✔
807
                        return
1✔
808
                    yield record
1✔
809
                    record_id += 1
1✔
810

811
    # Note: this involves some computation so should arguably be a method,
812
    # but making a property for consistency with xarray etc
813
    @property
1✔
814
    def values(self):
1✔
815
        ret = [None] * self.num_records
1✔
816
        j = 0
1✔
817
        for partition_id in range(self.num_partitions):
1✔
818
            for chunk in self.chunks(partition_id):
1✔
819
                for record in chunk:
1✔
820
                    ret[j] = record
1✔
821
                    j += 1
1✔
822
        assert j == self.num_records
1✔
823
        return ret
1✔
824

825
    def sanitiser_factory(self, shape):
1✔
826
        """
827
        Return a function that sanitised values from this column
828
        and writes into a buffer of the specified shape.
829
        """
830
        assert len(shape) <= 3
1✔
831
        if self.vcf_field.vcf_type == "Flag":
1✔
832
            assert len(shape) == 1
1✔
833
            return sanitise_value_bool
1✔
834
        elif self.vcf_field.vcf_type == "Float":
1✔
835
            if len(shape) == 1:
1✔
836
                return sanitise_value_float_scalar
1✔
837
            elif len(shape) == 2:
1✔
838
                return sanitise_value_float_1d
1✔
839
            else:
840
                return sanitise_value_float_2d
1✔
841
        elif self.vcf_field.vcf_type == "Integer":
1✔
842
            if len(shape) == 1:
1✔
843
                return sanitise_value_int_scalar
1✔
844
            elif len(shape) == 2:
1✔
845
                return sanitise_value_int_1d
1✔
846
            else:
847
                return sanitise_value_int_2d
1✔
848
        else:
849
            assert self.vcf_field.vcf_type in ("String", "Character")
1✔
850
            if len(shape) == 1:
1✔
851
                return sanitise_value_string_scalar
1✔
852
            elif len(shape) == 2:
1✔
853
                return sanitise_value_string_1d
1✔
854
            else:
855
                return sanitise_value_string_2d
1✔
856

857

858
@dataclasses.dataclass
1✔
859
class IcfFieldWriter:
1✔
860
    vcf_field: VcfField
1✔
861
    path: pathlib.Path
1✔
862
    transformer: VcfValueTransformer
1✔
863
    compressor: Any
1✔
864
    max_buffered_bytes: int
1✔
865
    buff: list[Any] = dataclasses.field(default_factory=list)
1✔
866
    buffered_bytes: int = 0
1✔
867
    chunk_index: list[int] = dataclasses.field(default_factory=lambda: [0])
1✔
868
    num_records: int = 0
1✔
869

870
    def append(self, val):
1✔
871
        val = self.transformer.transform_and_update_bounds(val)
1✔
872
        assert val is None or isinstance(val, np.ndarray)
1✔
873
        self.buff.append(val)
1✔
874
        val_bytes = sys.getsizeof(val)
1✔
875
        self.buffered_bytes += val_bytes
1✔
876
        self.num_records += 1
1✔
877
        if self.buffered_bytes >= self.max_buffered_bytes:
1✔
878
            logger.debug(
1✔
879
                f"Flush {self.path} buffered={self.buffered_bytes} "
880
                f"max={self.max_buffered_bytes}"
881
            )
882
            self.write_chunk()
1✔
883
            self.buff.clear()
1✔
884
            self.buffered_bytes = 0
1✔
885

886
    def write_chunk(self):
1✔
887
        # Update index
888
        self.chunk_index.append(self.num_records)
1✔
889
        path = self.path / f"{self.num_records}"
1✔
890
        logger.debug(f"Start write: {path}")
1✔
891
        pkl = pickle.dumps(self.buff)
1✔
892
        compressed = self.compressor.encode(pkl)
1✔
893
        with open(path, "wb") as f:
1✔
894
            f.write(compressed)
1✔
895

896
        # Update the summary
897
        self.vcf_field.summary.num_chunks += 1
1✔
898
        self.vcf_field.summary.compressed_size += len(compressed)
1✔
899
        self.vcf_field.summary.uncompressed_size += self.buffered_bytes
1✔
900
        logger.debug(f"Finish write: {path}")
1✔
901

902
    def flush(self):
1✔
903
        logger.debug(
1✔
904
            f"Flush {self.path} records={len(self.buff)} buffered={self.buffered_bytes}"
905
        )
906
        if len(self.buff) > 0:
1✔
907
            self.write_chunk()
1✔
908
        with open(self.path / "chunk_index", "wb") as f:
1✔
909
            a = np.array(self.chunk_index, dtype=int)
1✔
910
            pickle.dump(a, f)
1✔
911

912

913
class IcfPartitionWriter(contextlib.AbstractContextManager):
1✔
914
    """
915
    Writes the data for a IntermediateColumnarFormat partition.
916
    """
917

918
    def __init__(
1✔
919
        self,
920
        icf_metadata,
921
        out_path,
922
        partition_index,
923
    ):
924
        self.partition_index = partition_index
1✔
925
        # chunk_size is in megabytes
926
        max_buffered_bytes = icf_metadata.column_chunk_size * 2**20
1✔
927
        assert max_buffered_bytes > 0
1✔
928
        compressor = numcodecs.get_codec(icf_metadata.compressor)
1✔
929

930
        self.field_writers = {}
1✔
931
        num_samples = len(icf_metadata.samples)
1✔
932
        for vcf_field in icf_metadata.fields:
1✔
933
            field_path = get_vcf_field_path(out_path, vcf_field)
1✔
934
            field_partition_path = field_path / f"p{partition_index}"
1✔
935
            # Should be robust to running explode_partition twice.
936
            field_partition_path.mkdir(exist_ok=True)
1✔
937
            transformer = VcfValueTransformer.factory(vcf_field, num_samples)
1✔
938
            self.field_writers[vcf_field.full_name] = IcfFieldWriter(
1✔
939
                vcf_field,
940
                field_partition_path,
941
                transformer,
942
                compressor,
943
                max_buffered_bytes,
944
            )
945

946
    @property
1✔
947
    def field_summaries(self):
1✔
948
        return {
1✔
949
            name: field.vcf_field.summary for name, field in self.field_writers.items()
950
        }
951

952
    def append(self, name, value):
1✔
953
        self.field_writers[name].append(value)
1✔
954

955
    def __exit__(self, exc_type, exc_val, exc_tb):
1✔
956
        if exc_type is None:
1✔
957
            for field in self.field_writers.values():
1✔
958
                field.flush()
1✔
959
        return False
1✔
960

961

962
class IntermediateColumnarFormat(collections.abc.Mapping):
1✔
963
    def __init__(self, path):
1✔
964
        self.path = pathlib.Path(path)
1✔
965
        # TODO raise a more informative error here telling people this
966
        # directory is either a WIP or the wrong format.
967
        with open(self.path / "metadata.json") as f:
1✔
968
            self.metadata = IcfMetadata.fromdict(json.load(f))
1✔
969
        with open(self.path / "header.txt") as f:
1✔
970
            self.vcf_header = f.read()
1✔
971
        self.compressor = numcodecs.get_codec(self.metadata.compressor)
1✔
972
        self.fields = {}
1✔
973
        partition_num_records = [
1✔
974
            partition.num_records for partition in self.metadata.partitions
975
        ]
976
        # Allow us to find which partition a given record is in
977
        self.partition_record_index = np.cumsum([0, *partition_num_records])
1✔
978
        for field in self.metadata.fields:
1✔
979
            self.fields[field.full_name] = IntermediateColumnarFormatField(self, field)
1✔
980
        logger.info(
1✔
981
            f"Loaded IntermediateColumnarFormat(partitions={self.num_partitions}, "
982
            f"records={self.num_records}, fields={self.num_fields})"
983
        )
984

985
    def __repr__(self):
1✔
986
        return (
1✔
987
            f"IntermediateColumnarFormat(fields={len(self)}, "
988
            f"partitions={self.num_partitions}, "
989
            f"records={self.num_records}, path={self.path})"
990
        )
991

992
    def __getitem__(self, key):
1✔
993
        return self.fields[key]
1✔
994

995
    def __iter__(self):
1✔
996
        return iter(self.fields)
1✔
997

998
    def __len__(self):
1✔
999
        return len(self.fields)
1✔
1000

1001
    def summary_table(self):
1✔
1002
        data = []
1✔
1003
        for name, icf_field in self.fields.items():
1✔
1004
            summary = icf_field.vcf_field.summary
1✔
1005
            d = {
1✔
1006
                "name": name,
1007
                "type": icf_field.vcf_field.vcf_type,
1008
                "chunks": summary.num_chunks,
1009
                "size": core.display_size(summary.uncompressed_size),
1010
                "compressed": core.display_size(summary.compressed_size),
1011
                "max_n": summary.max_number,
1012
                "min_val": core.display_number(summary.min_value),
1013
                "max_val": core.display_number(summary.max_value),
1014
            }
1015

1016
            data.append(d)
1✔
1017
        return data
1✔
1018

1019
    @property
1✔
1020
    def num_records(self):
1✔
1021
        return self.metadata.num_records
1✔
1022

1023
    @property
1✔
1024
    def num_partitions(self):
1✔
1025
        return len(self.metadata.partitions)
1✔
1026

1027
    @property
1✔
1028
    def num_samples(self):
1✔
1029
        return len(self.metadata.samples)
1✔
1030

1031
    @property
1✔
1032
    def num_fields(self):
1✔
1033
        return len(self.fields)
1✔
1034

1035

1036
@dataclasses.dataclass
1✔
1037
class IcfPartitionMetadata(core.JsonDataclass):
1✔
1038
    num_records: int
1✔
1039
    last_position: int
1✔
1040
    field_summaries: dict
1✔
1041

1042
    @staticmethod
1✔
1043
    def fromdict(d):
1✔
1044
        md = IcfPartitionMetadata(**d)
1✔
1045
        for k, v in md.field_summaries.items():
1✔
1046
            md.field_summaries[k] = VcfFieldSummary.fromdict(v)
1✔
1047
        return md
1✔
1048

1049

1050
def check_overlapping_partitions(partitions):
1✔
1051
    for i in range(1, len(partitions)):
1✔
1052
        prev_region = partitions[i - 1].region
1✔
1053
        current_region = partitions[i].region
1✔
1054
        if prev_region.contig == current_region.contig:
1✔
1055
            assert prev_region.end is not None
1✔
1056
            # Regions are *inclusive*
1057
            if prev_region.end >= current_region.start:
1✔
1058
                raise ValueError(
1✔
1059
                    f"Overlapping VCF regions in partitions {i - 1} and {i}: "
1060
                    f"{prev_region} and {current_region}"
1061
                )
1062

1063

1064
def check_field_clobbering(icf_metadata):
1✔
1065
    info_field_names = set(field.name for field in icf_metadata.info_fields)
1✔
1066
    fixed_variant_fields = set(
1✔
1067
        ["contig", "id", "id_mask", "position", "allele", "filter", "quality"]
1068
    )
1069
    intersection = info_field_names & fixed_variant_fields
1✔
1070
    if len(intersection) > 0:
1✔
1071
        raise ValueError(
1✔
1072
            f"INFO field name(s) clashing with VCF Zarr spec: {intersection}"
1073
        )
1074

1075
    format_field_names = set(field.name for field in icf_metadata.format_fields)
1✔
1076
    fixed_variant_fields = set(["genotype", "genotype_phased", "genotype_mask"])
1✔
1077
    intersection = format_field_names & fixed_variant_fields
1✔
1078
    if len(intersection) > 0:
1✔
1079
        raise ValueError(
1✔
1080
            f"FORMAT field name(s) clashing with VCF Zarr spec: {intersection}"
1081
        )
1082

1083

1084
@dataclasses.dataclass
1✔
1085
class IcfWriteSummary(core.JsonDataclass):
1✔
1086
    num_partitions: int
1✔
1087
    num_samples: int
1✔
1088
    num_variants: int
1✔
1089

1090

1091
class IntermediateColumnarFormatWriter:
1✔
1092
    def __init__(self, path):
1✔
1093
        self.path = pathlib.Path(path)
1✔
1094
        self.wip_path = self.path / "wip"
1✔
1095
        self.metadata = None
1✔
1096

1097
    @property
1✔
1098
    def num_partitions(self):
1✔
1099
        return len(self.metadata.partitions)
1✔
1100

1101
    def init(
1✔
1102
        self,
1103
        vcfs,
1104
        *,
1105
        column_chunk_size=16,
1106
        worker_processes=1,
1107
        target_num_partitions=None,
1108
        show_progress=False,
1109
        compressor=None,
1110
        local_alleles=None,
1111
    ):
1112
        if self.path.exists():
1✔
UNCOV
1113
            raise ValueError(f"ICF path already exists: {self.path}")
×
1114
        if compressor is None:
1✔
1115
            compressor = ICF_DEFAULT_COMPRESSOR
1✔
1116
        if local_alleles is None:
1✔
1117
            local_alleles = False
1✔
1118
        vcfs = [pathlib.Path(vcf) for vcf in vcfs]
1✔
1119
        target_num_partitions = max(target_num_partitions, len(vcfs))
1✔
1120

1121
        # TODO move scan_vcfs into this class
1122
        icf_metadata, header = scan_vcfs(
1✔
1123
            vcfs,
1124
            worker_processes=worker_processes,
1125
            show_progress=show_progress,
1126
            target_num_partitions=target_num_partitions,
1127
            local_alleles=local_alleles,
1128
        )
1129
        check_field_clobbering(icf_metadata)
1✔
1130
        self.metadata = icf_metadata
1✔
1131
        self.metadata.format_version = ICF_METADATA_FORMAT_VERSION
1✔
1132
        self.metadata.compressor = compressor.get_config()
1✔
1133
        self.metadata.column_chunk_size = column_chunk_size
1✔
1134
        # Bare minimum here for provenance - would be nice to include versions of key
1135
        # dependencies as well.
1136
        self.metadata.provenance = {"source": f"bio2zarr-{provenance.__version__}"}
1✔
1137

1138
        self.mkdirs()
1✔
1139

1140
        # Note: this is needed for the current version of the vcfzarr spec, but it's
1141
        # probably going to be dropped.
1142
        # https://github.com/pystatgen/vcf-zarr-spec/issues/15
1143
        # May be useful to keep lying around still though?
1144
        logger.info("Writing VCF header")
1✔
1145
        with open(self.path / "header.txt", "w") as f:
1✔
1146
            f.write(header)
1✔
1147

1148
        logger.info("Writing WIP metadata")
1✔
1149
        with open(self.wip_path / "metadata.json", "w") as f:
1✔
1150
            json.dump(self.metadata.asdict(), f, indent=4)
1✔
1151
        return IcfWriteSummary(
1✔
1152
            num_partitions=self.num_partitions,
1153
            num_variants=icf_metadata.num_records,
1154
            num_samples=icf_metadata.num_samples,
1155
        )
1156

1157
    def mkdirs(self):
1✔
1158
        num_dirs = len(self.metadata.fields)
1✔
1159
        logger.info(f"Creating {num_dirs} field directories")
1✔
1160
        self.path.mkdir()
1✔
1161
        self.wip_path.mkdir()
1✔
1162
        for field in self.metadata.fields:
1✔
1163
            field_path = get_vcf_field_path(self.path, field)
1✔
1164
            field_path.mkdir(parents=True)
1✔
1165

1166
    def load_partition_summaries(self):
1✔
1167
        summaries = []
1✔
1168
        not_found = []
1✔
1169
        for j in range(self.num_partitions):
1✔
1170
            try:
1✔
1171
                with open(self.wip_path / f"p{j}.json") as f:
1✔
1172
                    summaries.append(IcfPartitionMetadata.fromdict(json.load(f)))
1✔
1173
            except FileNotFoundError:
1✔
1174
                not_found.append(j)
1✔
1175
        if len(not_found) > 0:
1✔
1176
            raise FileNotFoundError(
1✔
1177
                f"Partition metadata not found for {len(not_found)}"
1178
                f" partitions: {not_found}"
1179
            )
1180
        return summaries
1✔
1181

1182
    def load_metadata(self):
1✔
1183
        if self.metadata is None:
1✔
1184
            with open(self.wip_path / "metadata.json") as f:
1✔
1185
                self.metadata = IcfMetadata.fromdict(json.load(f))
1✔
1186

1187
    def process_partition(self, partition_index):
1✔
1188
        self.load_metadata()
1✔
1189
        summary_path = self.wip_path / f"p{partition_index}.json"
1✔
1190
        # If someone is rewriting a summary path (for whatever reason), make sure it
1191
        # doesn't look like it's already been completed.
1192
        # NOTE to do this properly we probably need to take a lock on this file - but
1193
        # this simple approach will catch the vast majority of problems.
1194
        if summary_path.exists():
1✔
1195
            summary_path.unlink()
1✔
1196

1197
        partition = self.metadata.partitions[partition_index]
1✔
1198
        logger.info(
1✔
1199
            f"Start p{partition_index} {partition.vcf_path}__{partition.region}"
1200
        )
1201
        info_fields = self.metadata.info_fields
1✔
1202
        format_fields = []
1✔
1203
        has_gt = False
1✔
1204
        for field in self.metadata.format_fields:
1✔
1205
            if field.name == "GT":
1✔
1206
                has_gt = True
1✔
1207
            else:
1208
                format_fields.append(field)
1✔
1209

1210
        format_field_names = [format_field.name for format_field in format_fields]
1✔
1211
        if "LAA" in format_field_names and "LPL" in format_field_names:
1✔
1212
            laa_index = format_field_names.index("LAA")
1✔
1213
            lpl_index = format_field_names.index("LPL")
1✔
1214
            # LAA needs to come before LPL
1215
            if lpl_index < laa_index:
1✔
1216
                format_fields[laa_index], format_fields[lpl_index] = (
1✔
1217
                    format_fields[lpl_index],
1218
                    format_fields[laa_index],
1219
                )
1220

1221
        last_position = None
1✔
1222
        with IcfPartitionWriter(
1✔
1223
            self.metadata,
1224
            self.path,
1225
            partition_index,
1226
        ) as tcw:
1227
            with vcf_utils.IndexedVcf(partition.vcf_path) as ivcf:
1✔
1228
                num_records = 0
1✔
1229
                for variant in ivcf.variants(partition.region):
1✔
1230
                    num_records += 1
1✔
1231
                    last_position = variant.POS
1✔
1232
                    tcw.append("CHROM", variant.CHROM)
1✔
1233
                    tcw.append("POS", variant.POS)
1✔
1234
                    tcw.append("QUAL", variant.QUAL)
1✔
1235
                    tcw.append("ID", variant.ID)
1✔
1236
                    tcw.append("FILTERS", variant.FILTERS)
1✔
1237
                    tcw.append("REF", variant.REF)
1✔
1238
                    tcw.append("ALT", variant.ALT)
1✔
1239
                    tcw.append("rlen", variant.end - variant.start)
1✔
1240
                    for field in info_fields:
1✔
1241
                        tcw.append(field.full_name, variant.INFO.get(field.name, None))
1✔
1242
                    if has_gt:
1✔
1243
                        if variant.genotype is None:
1✔
1244
                            val = None
1✔
1245
                        else:
1246
                            val = variant.genotype.array()
1✔
1247
                        tcw.append("FORMAT/GT", val)
1✔
1248
                    laa_val = None
1✔
1249
                    for field in format_fields:
1✔
1250
                        if field.name == "LAA":
1✔
1251
                            if "LAA" not in variant.FORMAT:
1✔
1252
                                laa_val = compute_laa_field(variant)
1✔
1253
                            else:
1254
                                laa_val = variant.format("LAA")
1✔
1255
                            val = laa_val
1✔
1256
                        elif field.name == "LPL" and "LPL" not in variant.FORMAT:
1✔
1257
                            val = compute_lpl_field(variant, laa_val)
1✔
1258
                        else:
1259
                            val = variant.format(field.name)
1✔
1260
                        tcw.append(field.full_name, val)
1✔
1261

1262
                    # Note: an issue with updating the progress per variant here like
1263
                    # this is that we get a significant pause at the end of the counter
1264
                    # while all the "small" fields get flushed. Possibly not much to be
1265
                    # done about it.
1266
                    core.update_progress(1)
1✔
1267
            logger.info(
1✔
1268
                f"Finished reading VCF for partition {partition_index}, "
1269
                f"flushing buffers"
1270
            )
1271

1272
        partition_metadata = IcfPartitionMetadata(
1✔
1273
            num_records=num_records,
1274
            last_position=last_position,
1275
            field_summaries=tcw.field_summaries,
1276
        )
1277
        with open(summary_path, "w") as f:
1✔
1278
            f.write(partition_metadata.asjson())
1✔
1279
        logger.info(
1✔
1280
            f"Finish p{partition_index} {partition.vcf_path}__{partition.region} "
1281
            f"{num_records} records last_pos={last_position}"
1282
        )
1283

1284
    def explode(self, *, worker_processes=1, show_progress=False):
1✔
1285
        self.load_metadata()
1✔
1286
        num_records = self.metadata.num_records
1✔
1287
        if np.isinf(num_records):
1✔
1288
            logger.warning(
1✔
1289
                "Total records unknown, cannot show progress; "
1290
                "reindex VCFs with bcftools index to fix"
1291
            )
1292
            num_records = None
1✔
1293
        num_fields = len(self.metadata.fields)
1✔
1294
        num_samples = len(self.metadata.samples)
1✔
1295
        logger.info(
1✔
1296
            f"Exploding fields={num_fields} samples={num_samples}; "
1297
            f"partitions={self.num_partitions} "
1298
            f"variants={'unknown' if num_records is None else num_records}"
1299
        )
1300
        progress_config = core.ProgressConfig(
1✔
1301
            total=num_records,
1302
            units="vars",
1303
            title="Explode",
1304
            show=show_progress,
1305
        )
1306
        with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1✔
1307
            for j in range(self.num_partitions):
1✔
1308
                pwm.submit(self.process_partition, j)
1✔
1309

1310
    def explode_partition(self, partition):
1✔
1311
        self.load_metadata()
1✔
1312
        if partition < 0 or partition >= self.num_partitions:
1✔
1313
            raise ValueError("Partition index not in the valid range")
1✔
1314
        self.process_partition(partition)
1✔
1315

1316
    def finalise(self):
1✔
1317
        self.load_metadata()
1✔
1318
        partition_summaries = self.load_partition_summaries()
1✔
1319
        total_records = 0
1✔
1320
        for index, summary in enumerate(partition_summaries):
1✔
1321
            partition_records = summary.num_records
1✔
1322
            self.metadata.partitions[index].num_records = partition_records
1✔
1323
            self.metadata.partitions[index].region.end = summary.last_position
1✔
1324
            total_records += partition_records
1✔
1325
        if not np.isinf(self.metadata.num_records):
1✔
1326
            # Note: this is just telling us that there's a bug in the
1327
            # index based record counting code, but it doesn't actually
1328
            # matter much. We may want to just make this a warning if
1329
            # we hit regular problems.
1330
            assert total_records == self.metadata.num_records
1✔
1331
        self.metadata.num_records = total_records
1✔
1332

1333
        check_overlapping_partitions(self.metadata.partitions)
1✔
1334

1335
        for field in self.metadata.fields:
1✔
1336
            for summary in partition_summaries:
1✔
1337
                field.summary.update(summary.field_summaries[field.full_name])
1✔
1338

1339
        logger.info("Finalising metadata")
1✔
1340
        with open(self.path / "metadata.json", "w") as f:
1✔
1341
            f.write(self.metadata.asjson())
1✔
1342

1343
        logger.debug("Removing WIP directory")
1✔
1344
        shutil.rmtree(self.wip_path)
1✔
1345

1346

1347
def explode(
1✔
1348
    icf_path,
1349
    vcfs,
1350
    *,
1351
    column_chunk_size=16,
1352
    worker_processes=1,
1353
    show_progress=False,
1354
    compressor=None,
1355
    local_alleles=None,
1356
):
1357
    writer = IntermediateColumnarFormatWriter(icf_path)
1✔
1358
    writer.init(
1✔
1359
        vcfs,
1360
        # Heuristic to get reasonable worker utilisation with lumpy partition sizing
1361
        target_num_partitions=max(1, worker_processes * 4),
1362
        worker_processes=worker_processes,
1363
        show_progress=show_progress,
1364
        column_chunk_size=column_chunk_size,
1365
        compressor=compressor,
1366
        local_alleles=local_alleles,
1367
    )
1368
    writer.explode(worker_processes=worker_processes, show_progress=show_progress)
1✔
1369
    writer.finalise()
1✔
1370
    return IntermediateColumnarFormat(icf_path)
1✔
1371

1372

1373
def explode_init(
1✔
1374
    icf_path,
1375
    vcfs,
1376
    *,
1377
    column_chunk_size=16,
1378
    target_num_partitions=1,
1379
    worker_processes=1,
1380
    show_progress=False,
1381
    compressor=None,
1382
    local_alleles=None,
1383
):
1384
    writer = IntermediateColumnarFormatWriter(icf_path)
1✔
1385
    return writer.init(
1✔
1386
        vcfs,
1387
        target_num_partitions=target_num_partitions,
1388
        worker_processes=worker_processes,
1389
        show_progress=show_progress,
1390
        column_chunk_size=column_chunk_size,
1391
        compressor=compressor,
1392
        local_alleles=local_alleles,
1393
    )
1394

1395

1396
def explode_partition(icf_path, partition):
1✔
1397
    writer = IntermediateColumnarFormatWriter(icf_path)
1✔
1398
    writer.explode_partition(partition)
1✔
1399

1400

1401
def explode_finalise(icf_path):
1✔
1402
    writer = IntermediateColumnarFormatWriter(icf_path)
1✔
1403
    writer.finalise()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc