• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sgkit-dev / bio2zarr / 14242614654

03 Apr 2025 12:10PM UTC coverage: 98.771% (-0.1%) from 98.867%
14242614654

Pull #339

github

web-flow
Merge 965501277 into 5439661d3
Pull Request #339: Common writer for plink and ICF

771 of 782 new or added lines in 6 files covered. (98.59%)

12 existing lines in 2 files now uncovered.

2571 of 2603 relevant lines covered (98.77%)

5.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.17
/bio2zarr/writer.py
1
import dataclasses
6✔
2
import json
6✔
3
import logging
6✔
4
import os
6✔
5
import pathlib
6✔
6
import shutil
6✔
7

8
import numcodecs
6✔
9
import numpy as np
6✔
10
import zarr
6✔
11

12
from bio2zarr import constants, core, provenance, schema, zarr_utils
6✔
13

14
logger = logging.getLogger(__name__)
6✔
15

16

17
def sanitise_int_array(value, ndmin, dtype):
6✔
18
    if isinstance(value, tuple):
6✔
NEW
19
        value = [
×
20
            constants.VCF_INT_MISSING if x is None else x for x in value
21
        ]  # NEEDS TEST
22
    value = np.array(value, ndmin=ndmin, copy=True)
6✔
23
    value[value == constants.VCF_INT_MISSING] = -1
6✔
24
    value[value == constants.VCF_INT_FILL] = -2
6✔
25
    # TODO watch out for clipping here!
26
    return value.astype(dtype)
6✔
27

28

29
def compute_la_field(genotypes):
6✔
30
    """
31
    Computes the value of the LA field for each sample given the genotypes
32
    for a variant. The LA field lists the unique alleles observed for
33
    each sample, including the REF.
34
    """
35
    v = 2**31 - 1
6✔
36
    if np.any(genotypes >= v):
6✔
37
        raise ValueError("Extreme allele value not supported")
6✔
38
    G = genotypes.astype(np.int32)
6✔
39
    if len(G) > 0:
6✔
40
        # Anything < 0 gets mapped to -2 (pad) in the output, which comes last.
41
        # So, to get this sorting correctly, we remap to the largest value for
42
        # sorting, then map back. We promote the genotypes up to 32 bit for convenience
43
        # here, assuming that we'll never have a allele of 2**31 - 1.
44
        assert np.all(G != v)
6✔
45
        G[G < 0] = v
6✔
46
        G.sort(axis=1)
6✔
47
        G[G[:, 0] == G[:, 1], 1] = -2
6✔
48
        # Equal values result in padding also
49
        G[G == v] = -2
6✔
50
    return G.astype(genotypes.dtype)
6✔
51

52

53
def compute_lad_field(ad, la):
6✔
54
    assert ad.shape[0] == la.shape[0]
6✔
55
    assert la.shape[1] == 2
6✔
56
    lad = np.full((ad.shape[0], 2), -2, dtype=ad.dtype)
6✔
57
    homs = np.where((la[:, 0] != -2) & (la[:, 1] == -2))
6✔
58
    lad[homs, 0] = ad[homs, la[homs, 0]]
6✔
59
    hets = np.where(la[:, 1] != -2)
6✔
60
    lad[hets, 0] = ad[hets, la[hets, 0]]
6✔
61
    lad[hets, 1] = ad[hets, la[hets, 1]]
6✔
62
    return lad
6✔
63

64

65
def pl_index(a, b):
6✔
66
    """
67
    Returns the PL index for alleles a and b.
68
    """
69
    return b * (b + 1) // 2 + a
6✔
70

71

72
def compute_lpl_field(pl, la):
6✔
73
    lpl = np.full((pl.shape[0], 3), -2, dtype=pl.dtype)
6✔
74

75
    homs = np.where((la[:, 0] != -2) & (la[:, 1] == -2))
6✔
76
    a = la[homs, 0]
6✔
77
    lpl[homs, 0] = pl[homs, pl_index(a, a)]
6✔
78

79
    hets = np.where(la[:, 1] != -2)[0]
6✔
80
    a = la[hets, 0]
6✔
81
    b = la[hets, 1]
6✔
82
    lpl[hets, 0] = pl[hets, pl_index(a, a)]
6✔
83
    lpl[hets, 1] = pl[hets, pl_index(a, b)]
6✔
84
    lpl[hets, 2] = pl[hets, pl_index(b, b)]
6✔
85

86
    return lpl
6✔
87

88

89
@dataclasses.dataclass
6✔
90
class LocalisableFieldDescriptor:
6✔
91
    array_name: str
6✔
92
    vcf_field: str
6✔
93
    sanitise: callable
6✔
94
    convert: callable
6✔
95

96

97
localisable_fields = [
6✔
98
    LocalisableFieldDescriptor(
99
        "call_LAD", "FORMAT/AD", sanitise_int_array, compute_lad_field
100
    ),
101
    LocalisableFieldDescriptor(
102
        "call_LPL", "FORMAT/PL", sanitise_int_array, compute_lpl_field
103
    ),
104
]
105

106

107
@dataclasses.dataclass
6✔
108
class VcfZarrPartition:
6✔
109
    start: int
6✔
110
    stop: int
6✔
111

112
    @staticmethod
6✔
113
    def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None):
6✔
114
        num_chunks = int(np.ceil(num_records / chunk_size))
6✔
115
        if max_chunks is not None:
6✔
116
            num_chunks = min(num_chunks, max_chunks)
6✔
117
        partitions = []
6✔
118
        splits = np.array_split(np.arange(num_chunks), min(num_partitions, num_chunks))
6✔
119
        for chunk_slice in splits:
6✔
120
            start_chunk = int(chunk_slice[0])
6✔
121
            stop_chunk = int(chunk_slice[-1]) + 1
6✔
122
            start_index = start_chunk * chunk_size
6✔
123
            stop_index = min(stop_chunk * chunk_size, num_records)
6✔
124
            partitions.append(VcfZarrPartition(start_index, stop_index))
6✔
125
        return partitions
6✔
126

127

128
VZW_METADATA_FORMAT_VERSION = "0.1"
6✔
129

130

131
@dataclasses.dataclass
6✔
132
class VcfZarrWriterMetadata(core.JsonDataclass):
6✔
133
    format_version: str
6✔
134
    source_path: str
6✔
135
    schema: schema.VcfZarrSchema
6✔
136
    dimension_separator: str
6✔
137
    partitions: list
6✔
138
    provenance: dict
6✔
139

140
    @staticmethod
6✔
141
    def fromdict(d):
6✔
142
        if d["format_version"] != VZW_METADATA_FORMAT_VERSION:
6✔
143
            raise ValueError(
6✔
144
                "VcfZarrWriter format version mismatch: "
145
                f"{d['format_version']} != {VZW_METADATA_FORMAT_VERSION}"
146
            )
147
        ret = VcfZarrWriterMetadata(**d)
6✔
148
        ret.schema = schema.VcfZarrSchema.fromdict(ret.schema)
6✔
149
        ret.partitions = [VcfZarrPartition(**p) for p in ret.partitions]
6✔
150
        return ret
6✔
151

152

153
@dataclasses.dataclass
6✔
154
class VcfZarrWriteSummary(core.JsonDataclass):
6✔
155
    num_partitions: int
6✔
156
    num_samples: int
6✔
157
    num_variants: int
6✔
158
    num_chunks: int
6✔
159
    max_encoding_memory: str
6✔
160

161

162
class VcfZarrWriter:
6✔
163
    def __init__(self, source_type, path):
6✔
164
        self.source_type = source_type
6✔
165
        self.path = pathlib.Path(path)
6✔
166
        self.wip_path = self.path / "wip"
6✔
167
        self.arrays_path = self.wip_path / "arrays"
6✔
168
        self.partitions_path = self.wip_path / "partitions"
6✔
169
        self.metadata = None
6✔
170
        self.source = None
6✔
171

172
    @property
6✔
173
    def schema(self):
6✔
174
        return self.metadata.schema
6✔
175

176
    @property
6✔
177
    def num_partitions(self):
6✔
178
        return len(self.metadata.partitions)
6✔
179

180
    def has_genotypes(self):
6✔
181
        for field in self.schema.fields:
6✔
182
            if field.name == "call_genotype":
6✔
183
                return True
6✔
184
        return False
6✔
185

186
    def has_local_alleles(self):
6✔
187
        for field in self.schema.fields:
6✔
188
            if field.name == "call_LA" and field.vcf_field is None:
6✔
189
                return True
6✔
190
        return False
6✔
191

192
    #######################
193
    # init
194
    #######################
195

196
    def init(
6✔
197
        self,
198
        source,
199
        *,
200
        target_num_partitions,
201
        schema,
202
        dimension_separator=None,
203
        max_variant_chunks=None,
204
    ):
205
        self.source = source
6✔
206
        if self.path.exists():
6✔
NEW
207
            raise ValueError("Zarr path already exists")  # NEEDS TEST
×
208
        schema.validate()
6✔
209
        partitions = VcfZarrPartition.generate_partitions(
6✔
210
            self.source.num_records,
211
            schema.variants_chunk_size,
212
            target_num_partitions,
213
            max_chunks=max_variant_chunks,
214
        )
215
        # Default to using nested directories following the Zarr v3 default.
216
        # This seems to require version 2.17+ to work properly
217
        dimension_separator = (
6✔
218
            "/" if dimension_separator is None else dimension_separator
219
        )
220
        self.metadata = VcfZarrWriterMetadata(
6✔
221
            format_version=VZW_METADATA_FORMAT_VERSION,
222
            source_path=str(self.source.path),
223
            schema=schema,
224
            dimension_separator=dimension_separator,
225
            partitions=partitions,
226
            # Bare minimum here for provenance - see comments above
227
            provenance={"source": f"bio2zarr-{provenance.__version__}"},
228
        )
229

230
        self.path.mkdir()
6✔
231
        root = zarr.open(store=self.path, mode="a", **zarr_utils.ZARR_FORMAT_KWARGS)
6✔
232
        root.attrs.update(
6✔
233
            {
234
                "vcf_zarr_version": "0.2",
235
                "source": f"bio2zarr-{provenance.__version__}",
236
            }
237
        )
238
        root.attrs.update(self.source.root_attrs)
6✔
239

240
        # Doing this synchronously - this is fine surely
241
        self.encode_samples(root)
6✔
242
        self.encode_filter_id(root)
6✔
243
        self.encode_contig_id(root)
6✔
244

245
        self.wip_path.mkdir()
6✔
246
        self.arrays_path.mkdir()
6✔
247
        self.partitions_path.mkdir()
6✔
248
        root = zarr.open(
6✔
249
            store=self.arrays_path, mode="a", **zarr_utils.ZARR_FORMAT_KWARGS
250
        )
251

252
        total_chunks = 0
6✔
253
        for field in self.schema.fields:
6✔
254
            a = self.init_array(root, field, partitions[-1].stop)
6✔
255
            total_chunks += a.nchunks
6✔
256

257
        logger.info("Writing WIP metadata")
6✔
258
        with open(self.wip_path / "metadata.json", "w") as f:
6✔
259
            json.dump(self.metadata.asdict(), f, indent=4)
6✔
260

261
        return VcfZarrWriteSummary(
6✔
262
            num_variants=self.source.num_records,
263
            num_samples=self.source.num_samples,
264
            num_partitions=self.num_partitions,
265
            num_chunks=total_chunks,
266
            max_encoding_memory=core.display_size(self.get_max_encoding_memory()),
267
        )
268

269
    def encode_samples(self, root):
6✔
270
        if [s.id for s in self.schema.samples] != self.source.samples:
6✔
271
            raise ValueError("Subsetting or reordering samples not supported currently")
6✔
272
        array = root.array(
6✔
273
            "sample_id",
274
            data=[sample.id for sample in self.schema.samples],
275
            shape=len(self.schema.samples),
276
            dtype="str",
277
            compressor=schema.DEFAULT_ZARR_COMPRESSOR,
278
            chunks=(self.schema.samples_chunk_size,),
279
        )
280
        array.attrs["_ARRAY_DIMENSIONS"] = ["samples"]
6✔
281
        logger.debug("Samples done")
6✔
282

283
    def encode_contig_id(self, root):
6✔
284
        array = root.array(
6✔
285
            "contig_id",
286
            data=[contig.id for contig in self.schema.contigs],
287
            shape=len(self.schema.contigs),
288
            dtype="str",
289
            compressor=schema.DEFAULT_ZARR_COMPRESSOR,
290
        )
291
        array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
6✔
292
        if all(contig.length is not None for contig in self.schema.contigs):
6✔
293
            array = root.array(
6✔
294
                "contig_length",
295
                data=[contig.length for contig in self.schema.contigs],
296
                shape=len(self.schema.contigs),
297
                dtype=np.int64,
298
                compressor=schema.DEFAULT_ZARR_COMPRESSOR,
299
            )
300
            array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
6✔
301

302
    def encode_filter_id(self, root):
6✔
303
        # TODO need a way to store description also
304
        # https://github.com/sgkit-dev/vcf-zarr-spec/issues/19
305
        array = root.array(
6✔
306
            "filter_id",
307
            data=[filt.id for filt in self.schema.filters],
308
            shape=len(self.schema.filters),
309
            dtype="str",
310
            compressor=schema.DEFAULT_ZARR_COMPRESSOR,
311
        )
312
        array.attrs["_ARRAY_DIMENSIONS"] = ["filters"]
6✔
313

314
    def init_array(self, root, array_spec, variants_dim_size):
6✔
315
        kwargs = dict(zarr_utils.ZARR_FORMAT_KWARGS)
6✔
316
        filters = [numcodecs.get_codec(filt) for filt in array_spec.filters]
6✔
317
        if array_spec.dtype == "O":
6✔
318
            if zarr_utils.zarr_v3():
6✔
NEW
319
                filters = [*list(filters), numcodecs.VLenUTF8()]
×
320
            else:
321
                kwargs["object_codec"] = numcodecs.VLenUTF8()
6✔
322

323
        if not zarr_utils.zarr_v3():
6✔
324
            kwargs["dimension_separator"] = self.metadata.dimension_separator
6✔
325

326
        shape = list(array_spec.shape)
6✔
327
        # Truncate the variants dimension is max_variant_chunks was specified
328
        shape[0] = variants_dim_size
6✔
329
        a = root.empty(
6✔
330
            name=array_spec.name,
331
            shape=shape,
332
            chunks=array_spec.chunks,
333
            dtype=array_spec.dtype,
334
            compressor=numcodecs.get_codec(array_spec.compressor),
335
            filters=filters,
336
            **kwargs,
337
        )
338
        a.attrs.update(
6✔
339
            {
340
                "description": array_spec.description,
341
                # Dimension names are part of the spec in Zarr v3
342
                "_ARRAY_DIMENSIONS": array_spec.dimensions,
343
            }
344
        )
345
        logger.debug(f"Initialised {a}")
6✔
346
        return a
6✔
347

348
    #######################
349
    # encode_partition
350
    #######################
351

352
    def load_metadata(self):
6✔
353
        if self.metadata is None:
6✔
354
            with open(self.wip_path / "metadata.json") as f:
6✔
355
                self.metadata = VcfZarrWriterMetadata.fromdict(json.load(f))
6✔
356
            self.source = self.source_type(self.metadata.source_path)
6✔
357

358
    def partition_path(self, partition_index):
6✔
359
        return self.partitions_path / f"p{partition_index}"
6✔
360

361
    def wip_partition_path(self, partition_index):
6✔
362
        return self.partitions_path / f"wip_p{partition_index}"
6✔
363

364
    def wip_partition_array_path(self, partition_index, name):
6✔
365
        return self.wip_partition_path(partition_index) / name
6✔
366

367
    def partition_array_path(self, partition_index, name):
6✔
368
        return self.partition_path(partition_index) / name
6✔
369

370
    def encode_partition(self, partition_index):
6✔
371
        self.load_metadata()
6✔
372
        if partition_index < 0 or partition_index >= self.num_partitions:
6✔
373
            raise ValueError("Partition index not in the valid range")
6✔
374
        partition_path = self.wip_partition_path(partition_index)
6✔
375
        partition_path.mkdir(exist_ok=True)
6✔
376
        logger.info(f"Encoding partition {partition_index} to {partition_path}")
6✔
377

378
        all_field_names = [field.name for field in self.schema.fields]
6✔
379
        if "variant_id" in all_field_names:
6✔
380
            self.encode_id_partition(partition_index)
6✔
381
        if "variant_filter" in all_field_names:
6✔
382
            self.encode_filters_partition(partition_index)
6✔
383
        if "variant_contig" in all_field_names:
6✔
384
            self.encode_contig_partition(partition_index)
6✔
385
        self.encode_alleles_partition(partition_index)
6✔
386
        for array_spec in self.schema.fields:
6✔
387
            if array_spec.vcf_field is not None:
6✔
388
                self.encode_array_partition(array_spec, partition_index)
6✔
389
        if self.has_genotypes():
6✔
390
            self.encode_genotypes_partition(partition_index)
6✔
391
            self.encode_genotype_mask_partition(partition_index)
6✔
392
        if self.has_local_alleles():
6✔
393
            self.encode_local_alleles_partition(partition_index)
6✔
394
            self.encode_local_allele_fields_partition(partition_index)
6✔
395

396
        final_path = self.partition_path(partition_index)
6✔
397
        logger.info(f"Finalising {partition_index} at {final_path}")
6✔
398
        if final_path.exists():
6✔
399
            logger.warning(f"Removing existing partition at {final_path}")
6✔
400
            shutil.rmtree(final_path)
6✔
401
        os.rename(partition_path, final_path)
6✔
402

403
    def init_partition_array(self, partition_index, name):
6✔
404
        field_map = self.schema.field_map()
6✔
405
        array_spec = field_map[name]
6✔
406
        # Create an empty array like the definition
407
        src = self.arrays_path / array_spec.name
6✔
408
        # Overwrite any existing WIP files
409
        wip_path = self.wip_partition_array_path(partition_index, array_spec.name)
6✔
410
        shutil.copytree(src, wip_path, dirs_exist_ok=True)
6✔
411
        array = zarr.open_array(store=wip_path, mode="a")
6✔
412
        partition = self.metadata.partitions[partition_index]
6✔
413
        ba = core.BufferedArray(array, partition.start, name)
6✔
414
        logger.info(
6✔
415
            f"Start partition {partition_index} array {name} <{array.dtype}> "
416
            f"{array.shape} @ {wip_path}"
417
        )
418
        return ba
6✔
419

420
    def finalise_partition_array(self, partition_index, buffered_array):
6✔
421
        buffered_array.flush()
6✔
422
        logger.info(
6✔
423
            f"Completed partition {partition_index} array {buffered_array.name} "
424
            f"max_memory={core.display_size(buffered_array.max_buff_size)}"
425
        )
426

427
    def encode_array_partition(self, array_spec, partition_index):
6✔
428
        partition = self.metadata.partitions[partition_index]
6✔
429
        ba = self.init_partition_array(partition_index, array_spec.name)
6✔
430
        for value in self.source.iter_field(
6✔
431
            array_spec.vcf_field,
432
            ba.buff.shape[1:],
433
            partition.start,
434
            partition.stop,
435
        ):
436
            j = ba.next_buffer_row()
6✔
437
            ba.buff[j] = value
6✔
438

439
        self.finalise_partition_array(partition_index, ba)
6✔
440

441
    def encode_genotypes_partition(self, partition_index):
6✔
442
        partition = self.metadata.partitions[partition_index]
6✔
443
        gt = self.init_partition_array(partition_index, "call_genotype")
6✔
444
        gt_phased = self.init_partition_array(partition_index, "call_genotype_phased")
6✔
445

446
        for genotype, phased in self.source.iter_genotypes(
6✔
447
            gt.buff.shape[1:], partition.start, partition.stop
448
        ):
449
            j = gt.next_buffer_row()
6✔
450
            gt.buff[j] = genotype
6✔
451

452
            j_phased = gt_phased.next_buffer_row()
6✔
453
            gt_phased.buff[j_phased] = phased
6✔
454

455
        self.finalise_partition_array(partition_index, gt)
6✔
456
        self.finalise_partition_array(partition_index, gt_phased)
6✔
457

458
    def encode_genotype_mask_partition(self, partition_index):
6✔
459
        partition = self.metadata.partitions[partition_index]
6✔
460
        gt_mask = self.init_partition_array(partition_index, "call_genotype_mask")
6✔
461
        # Read back in the genotypes so we can compute the mask
462
        gt_array = zarr.open_array(
6✔
463
            store=self.wip_partition_array_path(partition_index, "call_genotype"),
464
            mode="r",
465
        )
466
        for genotypes in core.first_dim_slice_iter(
6✔
467
            gt_array, partition.start, partition.stop
468
        ):
469
            # TODO check is this the correct semantics when we are padding
470
            # with mixed ploidies?
471
            j = gt_mask.next_buffer_row()
6✔
472
            gt_mask.buff[j] = genotypes < 0
6✔
473
        self.finalise_partition_array(partition_index, gt_mask)
6✔
474

475
    def encode_local_alleles_partition(self, partition_index):
6✔
476
        partition = self.metadata.partitions[partition_index]
6✔
477
        call_LA = self.init_partition_array(partition_index, "call_LA")
6✔
478

479
        gt_array = zarr.open_array(
6✔
480
            store=self.wip_partition_array_path(partition_index, "call_genotype"),
481
            mode="r",
482
        )
483
        for genotypes in core.first_dim_slice_iter(
6✔
484
            gt_array, partition.start, partition.stop
485
        ):
486
            la = compute_la_field(genotypes)
6✔
487
            j = call_LA.next_buffer_row()
6✔
488
            call_LA.buff[j] = la
6✔
489
        self.finalise_partition_array(partition_index, call_LA)
6✔
490

491
    def encode_local_allele_fields_partition(self, partition_index):
6✔
492
        partition = self.metadata.partitions[partition_index]
6✔
493
        la_array = zarr.open_array(
6✔
494
            store=self.wip_partition_array_path(partition_index, "call_LA"),
495
            mode="r",
496
        )
497
        # We got through the localisable fields one-by-one so that we don't need to
498
        # keep several large arrays in memory at once for each partition.
499
        field_map = self.schema.field_map()
6✔
500
        for descriptor in localisable_fields:
6✔
501
            if descriptor.array_name not in field_map:
6✔
502
                continue
6✔
503
            assert field_map[descriptor.array_name].vcf_field is None
6✔
504

505
            buff = self.init_partition_array(partition_index, descriptor.array_name)
6✔
506
            source = self.source.fields[descriptor.vcf_field].iter_values(
6✔
507
                partition.start, partition.stop
508
            )
509
            for la in core.first_dim_slice_iter(
6✔
510
                la_array, partition.start, partition.stop
511
            ):
512
                raw_value = next(source)
6✔
513
                value = descriptor.sanitise(raw_value, 2, raw_value.dtype)
6✔
514
                j = buff.next_buffer_row()
6✔
515
                buff.buff[j] = descriptor.convert(value, la)
6✔
516
            self.finalise_partition_array(partition_index, buff)
6✔
517

518
    def encode_alleles_partition(self, partition_index):
6✔
519
        alleles = self.init_partition_array(partition_index, "variant_allele")
6✔
520
        partition = self.metadata.partitions[partition_index]
6✔
521

522
        for value in self.source.iter_alleles(
6✔
523
            partition.start, partition.stop, alleles.array.shape[1]
524
        ):
525
            j = alleles.next_buffer_row()
6✔
526
            alleles.buff[j] = value
6✔
527

528
        self.finalise_partition_array(partition_index, alleles)
6✔
529

530
    def encode_id_partition(self, partition_index):
6✔
531
        vid = self.init_partition_array(partition_index, "variant_id")
6✔
532
        vid_mask = self.init_partition_array(partition_index, "variant_id_mask")
6✔
533
        partition = self.metadata.partitions[partition_index]
6✔
534

535
        for value in self.source.iter_id(partition.start, partition.stop):
6✔
536
            j = vid.next_buffer_row()
6✔
537
            k = vid_mask.next_buffer_row()
6✔
538
            assert j == k
6✔
539
            if value is not None:
6✔
540
                vid.buff[j] = value
6✔
541
                vid_mask.buff[j] = False
6✔
542
            else:
543
                vid.buff[j] = constants.STR_MISSING
6✔
544
                vid_mask.buff[j] = True
6✔
545

546
        self.finalise_partition_array(partition_index, vid)
6✔
547
        self.finalise_partition_array(partition_index, vid_mask)
6✔
548

549
    def encode_filters_partition(self, partition_index):
6✔
550
        var_filter = self.init_partition_array(partition_index, "variant_filter")
6✔
551
        partition = self.metadata.partitions[partition_index]
6✔
552

553
        for filter_values in self.source.iter_filters(partition.start, partition.stop):
6✔
554
            j = var_filter.next_buffer_row()
6✔
555
            var_filter.buff[j] = filter_values
6✔
556

557
        self.finalise_partition_array(partition_index, var_filter)
6✔
558

559
    def encode_contig_partition(self, partition_index):
6✔
560
        contig = self.init_partition_array(partition_index, "variant_contig")
6✔
561
        partition = self.metadata.partitions[partition_index]
6✔
562

563
        for contig_index in self.source.iter_contig(partition.start, partition.stop):
6✔
564
            j = contig.next_buffer_row()
6✔
565
            contig.buff[j] = contig_index
6✔
566

567
        self.finalise_partition_array(partition_index, contig)
6✔
568

569
    #######################
570
    # finalise
571
    #######################
572

573
    def finalise_array(self, name):
6✔
574
        logger.info(f"Finalising {name}")
6✔
575
        final_path = self.path / name
6✔
576
        if final_path.exists():
6✔
577
            # NEEDS TEST
NEW
578
            raise ValueError(f"Array {name} already exists")
×
579
        for partition in range(self.num_partitions):
6✔
580
            # Move all the files in partition dir to dest dir
581
            src = self.partition_array_path(partition, name)
6✔
582
            if not src.exists():
6✔
583
                # Needs test
NEW
584
                raise ValueError(f"Partition {partition} of {name} does not exist")
×
585
            dest = self.arrays_path / name
6✔
586
            # This is Zarr v2 specific. Chunks in v3 with start with "c" prefix.
587
            chunk_files = [
6✔
588
                path for path in src.iterdir() if not path.name.startswith(".")
589
            ]
590
            # TODO check for a count of then number of files. If we require a
591
            # dimension_separator of "/" then we could make stronger assertions
592
            # here, as we'd always have num_variant_chunks
593
            logger.debug(
6✔
594
                f"Moving {len(chunk_files)} chunks for {name} partition {partition}"
595
            )
596
            for chunk_file in chunk_files:
6✔
597
                os.rename(chunk_file, dest / chunk_file.name)
6✔
598
        # Finally, once all the chunks have moved into the arrays dir,
599
        # we move it out of wip
600
        os.rename(self.arrays_path / name, self.path / name)
6✔
601
        core.update_progress(1)
6✔
602

603
    def finalise(self, show_progress=False):
6✔
604
        self.load_metadata()
6✔
605

606
        logger.info(f"Scanning {self.num_partitions} partitions")
6✔
607
        missing = []
6✔
608
        # TODO may need a progress bar here
609
        for partition_id in range(self.num_partitions):
6✔
610
            if not self.partition_path(partition_id).exists():
6✔
611
                missing.append(partition_id)
6✔
612
        if len(missing) > 0:
6✔
613
            raise FileNotFoundError(f"Partitions not encoded: {missing}")
6✔
614

615
        progress_config = core.ProgressConfig(
6✔
616
            total=len(self.schema.fields),
617
            title="Finalise",
618
            units="array",
619
            show=show_progress,
620
        )
621
        # NOTE: it's not clear that adding more workers will make this quicker,
622
        # as it's just going to be causing contention on the file system.
623
        # Something to check empirically in some deployments.
624
        # FIXME we're just using worker_processes=0 here to hook into the
625
        # SynchronousExecutor which is intended for testing purposes so
626
        # that we get test coverage. Should fix this either by allowing
627
        # for multiple workers, or making a standard wrapper for tqdm
628
        # that allows us to have a consistent look and feel.
629
        with core.ParallelWorkManager(0, progress_config) as pwm:
6✔
630
            for field in self.schema.fields:
6✔
631
                pwm.submit(self.finalise_array, field.name)
6✔
632
        logger.debug(f"Removing {self.wip_path}")
6✔
633
        shutil.rmtree(self.wip_path)
6✔
634
        logger.info("Consolidating Zarr metadata")
6✔
635
        zarr.consolidate_metadata(self.path)
6✔
636

637
    #######################
638
    # index
639
    #######################
640

641
    def create_index(self):
6✔
642
        """Create an index to support efficient region queries."""
643

644
        indexer = VcfZarrIndexer(self.path)
6✔
645
        indexer.create_index()
6✔
646

647
    ######################
648
    # encode_all_partitions
649
    ######################
650

651
    def get_max_encoding_memory(self):
6✔
652
        """
653
        Return the approximate maximum memory used to encode a variant chunk.
654
        """
655
        max_encoding_mem = 0
6✔
656
        for array_spec in self.schema.fields:
6✔
657
            max_encoding_mem = max(max_encoding_mem, array_spec.variant_chunk_nbytes)
6✔
658
        gt_mem = 0
6✔
659
        if self.has_genotypes:
6✔
660
            gt_mem = sum(
6✔
661
                field.variant_chunk_nbytes
662
                for field in self.schema.fields
663
                if field.name.startswith("call_genotype")
664
            )
665
        return max(max_encoding_mem, gt_mem)
6✔
666

667
    def encode_all_partitions(
6✔
668
        self, *, worker_processes=1, show_progress=False, max_memory=None
669
    ):
670
        max_memory = core.parse_max_memory(max_memory)
6✔
671
        self.load_metadata()
6✔
672
        num_partitions = self.num_partitions
6✔
673
        per_worker_memory = self.get_max_encoding_memory()
6✔
674
        logger.info(
6✔
675
            f"Encoding Zarr over {num_partitions} partitions with "
676
            f"{worker_processes} workers and {core.display_size(per_worker_memory)} "
677
            "per worker"
678
        )
679
        # Each partition requires per_worker_memory bytes, so to prevent more that
680
        # max_memory being used, we clamp the number of workers
681
        max_num_workers = max_memory // per_worker_memory
6✔
682
        if max_num_workers < worker_processes:
6✔
683
            logger.warning(
6✔
684
                f"Limiting number of workers to {max_num_workers} to "
685
                "keep within specified memory budget of "
686
                f"{core.display_size(max_memory)}"
687
            )
688
        if max_num_workers <= 0:
6✔
689
            raise ValueError(
6✔
690
                f"Insufficient memory to encode a partition:"
691
                f"{core.display_size(per_worker_memory)} > "
692
                f"{core.display_size(max_memory)}"
693
            )
694
        num_workers = min(max_num_workers, worker_processes)
6✔
695

696
        total_bytes = 0
6✔
697
        for array_spec in self.schema.fields:
6✔
698
            # Open the array definition to get the total size
699
            total_bytes += zarr.open(self.arrays_path / array_spec.name).nbytes
6✔
700

701
        progress_config = core.ProgressConfig(
6✔
702
            total=total_bytes,
703
            title="Encode",
704
            units="B",
705
            show=show_progress,
706
        )
707
        with core.ParallelWorkManager(num_workers, progress_config) as pwm:
6✔
708
            for partition_index in range(num_partitions):
6✔
709
                pwm.submit(self.encode_partition, partition_index)
6✔
710

711

712
class VcfZarr:
6✔
713
    def __init__(self, path):
6✔
714
        if not (path / ".zmetadata").exists():
6✔
NEW
715
            raise ValueError("Not in VcfZarr format")  # NEEDS TEST
×
716
        self.path = path
6✔
717
        self.root = zarr.open(path, mode="r")
6✔
718

719
    def summary_table(self):
6✔
720
        data = []
6✔
721
        arrays = [(core.du(self.path / a.basename), a) for _, a in self.root.arrays()]
6✔
722
        arrays.sort(key=lambda x: x[0])
6✔
723
        for stored, array in reversed(arrays):
6✔
724
            d = {
6✔
725
                "name": array.name,
726
                "dtype": str(array.dtype),
727
                "stored": core.display_size(stored),
728
                "size": core.display_size(array.nbytes),
729
                "ratio": core.display_number(array.nbytes / stored),
730
                "nchunks": str(array.nchunks),
731
                "chunk_size": core.display_size(array.nbytes / array.nchunks),
732
                "avg_chunk_stored": core.display_size(int(stored / array.nchunks)),
733
                "shape": str(array.shape),
734
                "chunk_shape": str(array.chunks),
735
                "compressor": str(array.compressor),
736
                "filters": str(array.filters),
737
            }
738
            data.append(d)
6✔
739
        return data
6✔
740

741

742
class VcfZarrIndexer:
6✔
743
    """
744
    Creates an index for efficient region queries in a VCF Zarr dataset.
745
    """
746

747
    def __init__(self, path):
6✔
748
        self.path = pathlib.Path(path)
6✔
749

750
    def create_index(self):
6✔
751
        """Create an index to support efficient region queries."""
752
        root = zarr.open_group(store=self.path, mode="r+")
6✔
753

754
        if (
6✔
755
            "variant_contig" not in root
756
            or "variant_position" not in root
757
            or "variant_length" not in root
758
        ):
NEW
759
            logger.warning("Cannot create index: required arrays not found")
×
NEW
760
            return
×
761

762
        contig = root["variant_contig"]
6✔
763
        pos = root["variant_position"]
6✔
764
        length = root["variant_length"]
6✔
765

766
        assert contig.cdata_shape == pos.cdata_shape
6✔
767

768
        index = []
6✔
769

770
        logger.info("Creating region index")
6✔
771
        for v_chunk in range(pos.cdata_shape[0]):
6✔
772
            c = contig.blocks[v_chunk]
6✔
773
            p = pos.blocks[v_chunk]
6✔
774
            e = p + length.blocks[v_chunk] - 1
6✔
775

776
            # create a row for each contig in the chunk
777
            d = np.diff(c, append=-1)
6✔
778
            c_start_idx = 0
6✔
779
            for c_end_idx in np.nonzero(d)[0]:
6✔
780
                assert c[c_start_idx] == c[c_end_idx]
6✔
781
                index.append(
6✔
782
                    (
783
                        v_chunk,  # chunk index
784
                        c[c_start_idx],  # contig ID
785
                        p[c_start_idx],  # start
786
                        p[c_end_idx],  # end
787
                        np.max(e[c_start_idx : c_end_idx + 1]),  # max end
788
                        c_end_idx - c_start_idx + 1,  # num records
789
                    )
790
                )
791
                c_start_idx = c_end_idx + 1
6✔
792

793
        index = np.array(index, dtype=pos.dtype)
6✔
794
        kwargs = {}
6✔
795
        if not zarr_utils.zarr_v3():
6✔
796
            kwargs["dimension_separator"] = "/"
6✔
797
        array = root.array(
6✔
798
            "region_index",
799
            data=index,
800
            shape=index.shape,
801
            chunks=index.shape,
802
            dtype=index.dtype,
803
            compressor=numcodecs.Blosc("zstd", clevel=9, shuffle=0),
804
            fill_value=None,
805
            **kwargs,
806
        )
807
        array.attrs["_ARRAY_DIMENSIONS"] = [
6✔
808
            "region_index_values",
809
            "region_index_fields",
810
        ]
811

812
        logger.info("Consolidating Zarr metadata")
6✔
813
        zarr.consolidate_metadata(self.path)
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc