• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sgkit-dev / bio2zarr / 14369063087

09 Apr 2025 11:42PM UTC coverage: 98.409% (-0.4%) from 98.765%
14369063087

Pull #343

github

web-flow
Merge de7b6ce83 into eed60f05b
Pull Request #343: Remove schema data

67 of 77 new or added lines in 3 files covered. (87.01%)

8 existing lines in 1 file now uncovered.

2598 of 2640 relevant lines covered (98.41%)

5.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.85
/bio2zarr/vcz.py
1
import abc
6✔
2
import dataclasses
6✔
3
import json
6✔
4
import logging
6✔
5
import os
6✔
6
import pathlib
6✔
7
import shutil
6✔
8

9
import numcodecs
6✔
10
import numpy as np
6✔
11
import zarr
6✔
12

13
from bio2zarr import constants, core, provenance, zarr_utils
6✔
14

15
logger = logging.getLogger(__name__)
6✔
16

17
ZARR_SCHEMA_FORMAT_VERSION = "0.5"
6✔
18
DEFAULT_ZARR_COMPRESSOR = numcodecs.Blosc(cname="zstd", clevel=7)
6✔
19

20
_fixed_field_descriptions = {
6✔
21
    "variant_contig": "An identifier from the reference genome or an angle-bracketed ID"
22
    " string pointing to a contig in the assembly file",
23
    "variant_position": "The reference position",
24
    "variant_length": "The length of the variant measured in bases",
25
    "variant_id": "List of unique identifiers where applicable",
26
    "variant_allele": "List of the reference and alternate alleles",
27
    "variant_quality": "Phred-scaled quality score",
28
    "variant_filter": "Filter status of the variant",
29
}
30

31

32
class Source(abc.ABC):
6✔
33
    @property
6✔
34
    @abc.abstractmethod
6✔
35
    def path(self):
6✔
NEW
36
        pass
×
37

38
    @property
6✔
39
    @abc.abstractmethod
6✔
40
    def num_records(self):
6✔
NEW
41
        pass
×
42

43
    @property
6✔
44
    @abc.abstractmethod
6✔
45
    def num_samples(self):
6✔
NEW
46
        pass
×
47

48
    @property
6✔
49
    @abc.abstractmethod
6✔
50
    def samples(self):
6✔
NEW
51
        pass
×
52

53
    @property
6✔
54
    def contigs(self):
6✔
55
        return None
6✔
56

57
    @property
6✔
58
    def filters(self):
6✔
59
        return None
6✔
60

61
    @property
6✔
62
    def root_attrs(self):
6✔
63
        return {}
6✔
64

65
    @abc.abstractmethod
6✔
66
    def iter_alleles(self, start, stop, num_alleles):
6✔
NEW
67
        pass
×
68

69
    @abc.abstractmethod
6✔
70
    def iter_genotypes(self, start, stop, num_alleles):
6✔
NEW
71
        pass
×
72

73
    def iter_id(self, start, stop):
6✔
NEW
74
        return
×
75

76
    def iter_contig(self, start, stop):
6✔
NEW
77
        return
×
78

79
    @abc.abstractmethod
6✔
80
    def iter_field(self, field_name, shape, start, stop):
6✔
81
        """Iterate over values for the specified field from start to stop positions."""
NEW
82
        pass
×
83

84
    @abc.abstractmethod
6✔
85
    def generate_schema(self, variants_chunk_size, samples_chunk_size, local_alleles):
6✔
NEW
86
        pass
×
87

88

89
@dataclasses.dataclass
6✔
90
class ZarrArraySpec:
6✔
91
    name: str
6✔
92
    dtype: str
6✔
93
    shape: tuple
6✔
94
    chunks: tuple
6✔
95
    dimensions: tuple
6✔
96
    description: str
6✔
97
    vcf_field: str
6✔
98
    compressor: dict
6✔
99
    filters: list
6✔
100

101
    def __post_init__(self):
6✔
102
        if self.name in _fixed_field_descriptions:
6✔
103
            self.description = self.description or _fixed_field_descriptions[self.name]
6✔
104

105
        # Ensure these are tuples for ease of comparison and consistency
106
        self.shape = tuple(self.shape)
6✔
107
        self.chunks = tuple(self.chunks)
6✔
108
        self.dimensions = tuple(self.dimensions)
6✔
109
        self.filters = tuple(self.filters)
6✔
110

111
    @staticmethod
6✔
112
    def new(**kwargs):
6✔
113
        spec = ZarrArraySpec(
6✔
114
            **kwargs, compressor=DEFAULT_ZARR_COMPRESSOR.get_config(), filters=[]
115
        )
116
        spec._choose_compressor_settings()
6✔
117
        return spec
6✔
118

119
    @staticmethod
6✔
120
    def from_field(
6✔
121
        vcf_field,
122
        *,
123
        num_variants,
124
        num_samples,
125
        variants_chunk_size,
126
        samples_chunk_size,
127
        array_name=None,
128
    ):
129
        shape = [num_variants]
6✔
130
        prefix = "variant_"
6✔
131
        dimensions = ["variants"]
6✔
132
        chunks = [variants_chunk_size]
6✔
133
        if vcf_field.category == "FORMAT":
6✔
134
            prefix = "call_"
6✔
135
            shape.append(num_samples)
6✔
136
            chunks.append(samples_chunk_size)
6✔
137
            dimensions.append("samples")
6✔
138
        if array_name is None:
6✔
139
            array_name = prefix + vcf_field.name
6✔
140
        # TODO make an option to add in the empty extra dimension
141
        if vcf_field.summary.max_number > 1 or vcf_field.full_name == "FORMAT/LAA":
6✔
142
            shape.append(vcf_field.summary.max_number)
6✔
143
            chunks.append(vcf_field.summary.max_number)
6✔
144
            # TODO we should really be checking this to see if the named dimensions
145
            # are actually correct.
146
            if vcf_field.vcf_number == "R":
6✔
147
                dimensions.append("alleles")
6✔
148
            elif vcf_field.vcf_number == "A":
6✔
149
                dimensions.append("alt_alleles")
6✔
150
            elif vcf_field.vcf_number == "G":
6✔
151
                dimensions.append("genotypes")
6✔
152
            else:
153
                dimensions.append(f"{vcf_field.category}_{vcf_field.name}_dim")
6✔
154
        return ZarrArraySpec.new(
6✔
155
            vcf_field=vcf_field.full_name,
156
            name=array_name,
157
            dtype=vcf_field.smallest_dtype(),
158
            shape=shape,
159
            chunks=chunks,
160
            dimensions=dimensions,
161
            description=vcf_field.description,
162
        )
163

164
    def _choose_compressor_settings(self):
6✔
165
        """
166
        Choose compressor and filter settings based on the size and
167
        type of the array, plus some hueristics from observed properties
168
        of VCFs.
169

170
        See https://github.com/pystatgen/bio2zarr/discussions/74
171
        """
172
        # Default is to not shuffle, because autoshuffle isn't recognised
173
        # by many Zarr implementations, and shuffling can lead to worse
174
        # performance in some cases anyway. Turning on shuffle should be a
175
        # deliberate choice.
176
        shuffle = numcodecs.Blosc.NOSHUFFLE
6✔
177
        if self.name == "call_genotype" and self.dtype == "i1":
6✔
178
            # call_genotype gets BITSHUFFLE by default as it gets
179
            # significantly better compression (at a cost of slower
180
            # decoding)
181
            shuffle = numcodecs.Blosc.BITSHUFFLE
6✔
182
        elif self.dtype == "bool":
6✔
183
            shuffle = numcodecs.Blosc.BITSHUFFLE
6✔
184

185
        self.compressor["shuffle"] = shuffle
6✔
186

187
    @property
6✔
188
    def chunk_nbytes(self):
6✔
189
        """
190
        Returns the nbytes for a single chunk in this array.
191
        """
192
        items = 1
6✔
193
        dim = 0
6✔
194
        for chunk_size in self.chunks:
6✔
195
            size = min(chunk_size, self.shape[dim])
6✔
196
            items *= size
6✔
197
            dim += 1
6✔
198
        # Include sizes for extra dimensions.
199
        for size in self.shape[dim:]:
6✔
UNCOV
200
            items *= size
×
201
        dt = np.dtype(self.dtype)
6✔
202
        return items * dt.itemsize
6✔
203

204
    @property
6✔
205
    def variant_chunk_nbytes(self):
6✔
206
        """
207
        Returns the nbytes for a single variant chunk of this array.
208
        """
209
        chunk_items = self.chunks[0]
6✔
210
        for size in self.shape[1:]:
6✔
211
            chunk_items *= size
6✔
212
        dt = np.dtype(self.dtype)
6✔
213
        if dt.kind == "O" and "samples" in self.dimensions:
6✔
214
            logger.warning(
6✔
215
                f"Field {self.name} is a string; max memory usage may "
216
                "be a significant underestimate"
217
            )
218
        return chunk_items * dt.itemsize
6✔
219

220

221
@dataclasses.dataclass
6✔
222
class Contig:
6✔
223
    id: str
6✔
224
    length: int = None
6✔
225

226

227
@dataclasses.dataclass
6✔
228
class Sample:
6✔
229
    id: str
6✔
230

231

232
@dataclasses.dataclass
6✔
233
class Filter:
6✔
234
    id: str
6✔
235
    description: str = ""
6✔
236

237

238
@dataclasses.dataclass
6✔
239
class VcfZarrSchema(core.JsonDataclass):
6✔
240
    format_version: str
6✔
241
    samples_chunk_size: int
6✔
242
    variants_chunk_size: int
6✔
243
    fields: list
6✔
244

245
    def __init__(
6✔
246
        self,
247
        format_version: str,
248
        fields: list,
249
        variants_chunk_size: int = None,
250
        samples_chunk_size: int = None,
251
    ):
252
        self.format_version = format_version
6✔
253
        self.fields = fields
6✔
254
        if variants_chunk_size is None:
6✔
255
            variants_chunk_size = 1000
6✔
256
        self.variants_chunk_size = variants_chunk_size
6✔
257
        if samples_chunk_size is None:
6✔
258
            samples_chunk_size = 10_000
6✔
259
        self.samples_chunk_size = samples_chunk_size
6✔
260

261
    def validate(self):
6✔
262
        """
263
        Checks that the schema is well-formed and within required limits.
264
        """
265
        for field in self.fields:
6✔
266
            # This is the Blosc max buffer size
267
            if field.chunk_nbytes > 2147483647:
6✔
268
                # TODO add some links to documentation here advising how to
269
                # deal with PL values.
270
                raise ValueError(
6✔
271
                    f"Field {field.name} chunks are too large "
272
                    f"({field.chunk_nbytes} > 2**31 - 1 bytes). "
273
                    "Either generate a schema and drop this field (if you don't "
274
                    "need it) or reduce the variant or sample chunk sizes."
275
                )
276
            # TODO other checks? There must be lots of ways people could mess
277
            # up the schema leading to cryptic errors.
278

279
    def field_map(self):
6✔
280
        return {field.name: field for field in self.fields}
6✔
281

282
    @staticmethod
6✔
283
    def fromdict(d):
6✔
284
        if d["format_version"] != ZARR_SCHEMA_FORMAT_VERSION:
6✔
285
            raise ValueError(
6✔
286
                "Zarr schema format version mismatch: "
287
                f"{d['format_version']} != {ZARR_SCHEMA_FORMAT_VERSION}"
288
            )
289
        ret = VcfZarrSchema(**d)
6✔
290
        ret.fields = [ZarrArraySpec(**sd) for sd in d["fields"]]
6✔
291
        return ret
6✔
292

293
    @staticmethod
6✔
294
    def fromjson(s):
6✔
295
        return VcfZarrSchema.fromdict(json.loads(s))
6✔
296

297

298
def sanitise_int_array(value, ndmin, dtype):
6✔
299
    if isinstance(value, tuple):
6✔
UNCOV
300
        value = [
×
301
            constants.VCF_INT_MISSING if x is None else x for x in value
302
        ]  # NEEDS TEST
303
    value = np.array(value, ndmin=ndmin, copy=True)
6✔
304
    value[value == constants.VCF_INT_MISSING] = -1
6✔
305
    value[value == constants.VCF_INT_FILL] = -2
6✔
306
    # TODO watch out for clipping here!
307
    return value.astype(dtype)
6✔
308

309

310
def compute_la_field(genotypes):
6✔
311
    """
312
    Computes the value of the LA field for each sample given the genotypes
313
    for a variant. The LA field lists the unique alleles observed for
314
    each sample, including the REF.
315
    """
316
    v = 2**31 - 1
6✔
317
    if np.any(genotypes >= v):
6✔
318
        raise ValueError("Extreme allele value not supported")
6✔
319
    G = genotypes.astype(np.int32)
6✔
320
    if len(G) > 0:
6✔
321
        # Anything < 0 gets mapped to -2 (pad) in the output, which comes last.
322
        # So, to get this sorting correctly, we remap to the largest value for
323
        # sorting, then map back. We promote the genotypes up to 32 bit for convenience
324
        # here, assuming that we'll never have a allele of 2**31 - 1.
325
        assert np.all(G != v)
6✔
326
        G[G < 0] = v
6✔
327
        G.sort(axis=1)
6✔
328
        G[G[:, 0] == G[:, 1], 1] = -2
6✔
329
        # Equal values result in padding also
330
        G[G == v] = -2
6✔
331
    return G.astype(genotypes.dtype)
6✔
332

333

334
def compute_lad_field(ad, la):
6✔
335
    assert ad.shape[0] == la.shape[0]
6✔
336
    assert la.shape[1] == 2
6✔
337
    lad = np.full((ad.shape[0], 2), -2, dtype=ad.dtype)
6✔
338
    homs = np.where((la[:, 0] != -2) & (la[:, 1] == -2))
6✔
339
    lad[homs, 0] = ad[homs, la[homs, 0]]
6✔
340
    hets = np.where(la[:, 1] != -2)
6✔
341
    lad[hets, 0] = ad[hets, la[hets, 0]]
6✔
342
    lad[hets, 1] = ad[hets, la[hets, 1]]
6✔
343
    return lad
6✔
344

345

346
def pl_index(a, b):
6✔
347
    """
348
    Returns the PL index for alleles a and b.
349
    """
350
    return b * (b + 1) // 2 + a
6✔
351

352

353
def compute_lpl_field(pl, la):
6✔
354
    lpl = np.full((pl.shape[0], 3), -2, dtype=pl.dtype)
6✔
355

356
    homs = np.where((la[:, 0] != -2) & (la[:, 1] == -2))
6✔
357
    a = la[homs, 0]
6✔
358
    lpl[homs, 0] = pl[homs, pl_index(a, a)]
6✔
359

360
    hets = np.where(la[:, 1] != -2)[0]
6✔
361
    a = la[hets, 0]
6✔
362
    b = la[hets, 1]
6✔
363
    lpl[hets, 0] = pl[hets, pl_index(a, a)]
6✔
364
    lpl[hets, 1] = pl[hets, pl_index(a, b)]
6✔
365
    lpl[hets, 2] = pl[hets, pl_index(b, b)]
6✔
366

367
    return lpl
6✔
368

369

370
@dataclasses.dataclass
6✔
371
class LocalisableFieldDescriptor:
6✔
372
    array_name: str
6✔
373
    vcf_field: str
6✔
374
    sanitise: callable
6✔
375
    convert: callable
6✔
376

377

378
localisable_fields = [
6✔
379
    LocalisableFieldDescriptor(
380
        "call_LAD", "FORMAT/AD", sanitise_int_array, compute_lad_field
381
    ),
382
    LocalisableFieldDescriptor(
383
        "call_LPL", "FORMAT/PL", sanitise_int_array, compute_lpl_field
384
    ),
385
]
386

387

388
@dataclasses.dataclass
6✔
389
class VcfZarrPartition:
6✔
390
    start: int
6✔
391
    stop: int
6✔
392

393
    @staticmethod
6✔
394
    def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None):
6✔
395
        num_chunks = int(np.ceil(num_records / chunk_size))
6✔
396
        if max_chunks is not None:
6✔
397
            num_chunks = min(num_chunks, max_chunks)
6✔
398
        partitions = []
6✔
399
        splits = np.array_split(np.arange(num_chunks), min(num_partitions, num_chunks))
6✔
400
        for chunk_slice in splits:
6✔
401
            start_chunk = int(chunk_slice[0])
6✔
402
            stop_chunk = int(chunk_slice[-1]) + 1
6✔
403
            start_index = start_chunk * chunk_size
6✔
404
            stop_index = min(stop_chunk * chunk_size, num_records)
6✔
405
            partitions.append(VcfZarrPartition(start_index, stop_index))
6✔
406
        return partitions
6✔
407

408

409
VZW_METADATA_FORMAT_VERSION = "0.1"
6✔
410

411

412
@dataclasses.dataclass
6✔
413
class VcfZarrWriterMetadata(core.JsonDataclass):
6✔
414
    format_version: str
6✔
415
    source_path: str
6✔
416
    schema: VcfZarrSchema
6✔
417
    dimension_separator: str
6✔
418
    partitions: list
6✔
419
    provenance: dict
6✔
420

421
    @staticmethod
6✔
422
    def fromdict(d):
6✔
423
        if d["format_version"] != VZW_METADATA_FORMAT_VERSION:
6✔
424
            raise ValueError(
6✔
425
                "VcfZarrWriter format version mismatch: "
426
                f"{d['format_version']} != {VZW_METADATA_FORMAT_VERSION}"
427
            )
428
        ret = VcfZarrWriterMetadata(**d)
6✔
429
        ret.schema = VcfZarrSchema.fromdict(ret.schema)
6✔
430
        ret.partitions = [VcfZarrPartition(**p) for p in ret.partitions]
6✔
431
        return ret
6✔
432

433

434
@dataclasses.dataclass
6✔
435
class VcfZarrWriteSummary(core.JsonDataclass):
6✔
436
    num_partitions: int
6✔
437
    num_samples: int
6✔
438
    num_variants: int
6✔
439
    num_chunks: int
6✔
440
    max_encoding_memory: str
6✔
441

442

443
class VcfZarrWriter:
6✔
444
    def __init__(self, source_type, path):
6✔
445
        self.source_type = source_type
6✔
446
        self.path = pathlib.Path(path)
6✔
447
        self.wip_path = self.path / "wip"
6✔
448
        self.arrays_path = self.wip_path / "arrays"
6✔
449
        self.partitions_path = self.wip_path / "partitions"
6✔
450
        self.metadata = None
6✔
451
        self.source = None
6✔
452

453
    @property
6✔
454
    def schema(self):
6✔
455
        return self.metadata.schema
6✔
456

457
    @property
6✔
458
    def num_partitions(self):
6✔
459
        return len(self.metadata.partitions)
6✔
460

461
    def has_genotypes(self):
6✔
462
        for field in self.schema.fields:
6✔
463
            if field.name == "call_genotype":
6✔
464
                return True
6✔
465
        return False
6✔
466

467
    def has_local_alleles(self):
6✔
468
        for field in self.schema.fields:
6✔
469
            if field.name == "call_LA" and field.vcf_field is None:
6✔
470
                return True
6✔
471
        return False
6✔
472

473
    #######################
474
    # init
475
    #######################
476

477
    def init(
6✔
478
        self,
479
        source,
480
        *,
481
        target_num_partitions,
482
        schema,
483
        dimension_separator=None,
484
        max_variant_chunks=None,
485
    ):
486
        self.source = source
6✔
487
        if self.path.exists():
6✔
UNCOV
488
            raise ValueError("Zarr path already exists")  # NEEDS TEST
×
489
        schema.validate()
6✔
490
        partitions = VcfZarrPartition.generate_partitions(
6✔
491
            self.source.num_records,
492
            schema.variants_chunk_size,
493
            target_num_partitions,
494
            max_chunks=max_variant_chunks,
495
        )
496
        # Default to using nested directories following the Zarr v3 default.
497
        # This seems to require version 2.17+ to work properly
498
        dimension_separator = (
6✔
499
            "/" if dimension_separator is None else dimension_separator
500
        )
501
        self.metadata = VcfZarrWriterMetadata(
6✔
502
            format_version=VZW_METADATA_FORMAT_VERSION,
503
            source_path=str(self.source.path),
504
            schema=schema,
505
            dimension_separator=dimension_separator,
506
            partitions=partitions,
507
            # Bare minimum here for provenance - see comments above
508
            provenance={"source": f"bio2zarr-{provenance.__version__}"},
509
        )
510

511
        self.path.mkdir()
6✔
512
        root = zarr.open(store=self.path, mode="a", **zarr_utils.ZARR_FORMAT_KWARGS)
6✔
513
        root.attrs.update(
6✔
514
            {
515
                "vcf_zarr_version": "0.2",
516
                "source": f"bio2zarr-{provenance.__version__}",
517
            }
518
        )
519
        root.attrs.update(self.source.root_attrs)
6✔
520

521
        # Doing this synchronously - this is fine surely
522
        self.encode_samples(root)
6✔
523
        if self.source.filters is not None:
6✔
524
            self.encode_filter_id(root)
6✔
525
        if self.source.contigs is not None:
6✔
526
            self.encode_contigs(root)
6✔
527

528
        self.wip_path.mkdir()
6✔
529
        self.arrays_path.mkdir()
6✔
530
        self.partitions_path.mkdir()
6✔
531
        root = zarr.open(
6✔
532
            store=self.arrays_path, mode="a", **zarr_utils.ZARR_FORMAT_KWARGS
533
        )
534

535
        total_chunks = 0
6✔
536
        for field in self.schema.fields:
6✔
537
            a = self.init_array(root, field, partitions[-1].stop)
6✔
538
            total_chunks += a.nchunks
6✔
539

540
        logger.info("Writing WIP metadata")
6✔
541
        with open(self.wip_path / "metadata.json", "w") as f:
6✔
542
            json.dump(self.metadata.asdict(), f, indent=4)
6✔
543

544
        return VcfZarrWriteSummary(
6✔
545
            num_variants=self.source.num_records,
546
            num_samples=self.source.num_samples,
547
            num_partitions=self.num_partitions,
548
            num_chunks=total_chunks,
549
            max_encoding_memory=core.display_size(self.get_max_encoding_memory()),
550
        )
551

552
    def encode_samples(self, root):
6✔
553
        samples = self.source.samples
6✔
554
        array = root.array(
6✔
555
            "sample_id",
556
            data=[sample.id for sample in samples],
557
            shape=len(samples),
558
            dtype="str",
559
            compressor=DEFAULT_ZARR_COMPRESSOR,
560
            chunks=(self.schema.samples_chunk_size,),
561
        )
562
        array.attrs["_ARRAY_DIMENSIONS"] = ["samples"]
6✔
563
        logger.debug("Samples done")
6✔
564

565
    def encode_contigs(self, root):
6✔
566
        contigs = self.source.contigs
6✔
567
        array = root.array(
6✔
568
            "contig_id",
569
            data=[contig.id for contig in contigs],
570
            shape=len(contigs),
571
            dtype="str",
572
            compressor=DEFAULT_ZARR_COMPRESSOR,
573
        )
574
        array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
6✔
575
        if all(contig.length is not None for contig in contigs):
6✔
576
            array = root.array(
6✔
577
                "contig_length",
578
                data=[contig.length for contig in contigs],
579
                shape=len(contigs),
580
                dtype=np.int64,
581
                compressor=DEFAULT_ZARR_COMPRESSOR,
582
            )
583
            array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
6✔
584

585
    def encode_filter_id(self, root):
6✔
586
        # TODO need a way to store description also
587
        # https://github.com/sgkit-dev/vcf-zarr-spec/issues/19
588
        filters = self.source.filters
6✔
589
        array = root.array(
6✔
590
            "filter_id",
591
            data=[filt.id for filt in filters],
592
            shape=len(filters),
593
            dtype="str",
594
            compressor=DEFAULT_ZARR_COMPRESSOR,
595
        )
596
        array.attrs["_ARRAY_DIMENSIONS"] = ["filters"]
6✔
597

598
    def init_array(self, root, array_spec, variants_dim_size):
6✔
599
        kwargs = dict(zarr_utils.ZARR_FORMAT_KWARGS)
6✔
600
        filters = [numcodecs.get_codec(filt) for filt in array_spec.filters]
6✔
601
        if array_spec.dtype == "O":
6✔
602
            if zarr_utils.zarr_v3():
6✔
UNCOV
603
                filters = [*list(filters), numcodecs.VLenUTF8()]
×
604
            else:
605
                kwargs["object_codec"] = numcodecs.VLenUTF8()
6✔
606

607
        if not zarr_utils.zarr_v3():
6✔
608
            kwargs["dimension_separator"] = self.metadata.dimension_separator
6✔
609

610
        shape = list(array_spec.shape)
6✔
611
        # Truncate the variants dimension is max_variant_chunks was specified
612
        shape[0] = variants_dim_size
6✔
613
        a = root.empty(
6✔
614
            name=array_spec.name,
615
            shape=shape,
616
            chunks=array_spec.chunks,
617
            dtype=array_spec.dtype,
618
            compressor=numcodecs.get_codec(array_spec.compressor),
619
            filters=filters,
620
            **kwargs,
621
        )
622
        a.attrs.update(
6✔
623
            {
624
                "description": array_spec.description,
625
                # Dimension names are part of the spec in Zarr v3
626
                "_ARRAY_DIMENSIONS": array_spec.dimensions,
627
            }
628
        )
629
        logger.debug(f"Initialised {a}")
6✔
630
        return a
6✔
631

632
    #######################
633
    # encode_partition
634
    #######################
635

636
    def load_metadata(self):
6✔
637
        if self.metadata is None:
6✔
638
            with open(self.wip_path / "metadata.json") as f:
6✔
639
                self.metadata = VcfZarrWriterMetadata.fromdict(json.load(f))
6✔
640
            self.source = self.source_type(self.metadata.source_path)
6✔
641

642
    def partition_path(self, partition_index):
6✔
643
        return self.partitions_path / f"p{partition_index}"
6✔
644

645
    def wip_partition_path(self, partition_index):
6✔
646
        return self.partitions_path / f"wip_p{partition_index}"
6✔
647

648
    def wip_partition_array_path(self, partition_index, name):
6✔
649
        return self.wip_partition_path(partition_index) / name
6✔
650

651
    def partition_array_path(self, partition_index, name):
6✔
652
        return self.partition_path(partition_index) / name
6✔
653

654
    def encode_partition(self, partition_index):
6✔
655
        self.load_metadata()
6✔
656
        if partition_index < 0 or partition_index >= self.num_partitions:
6✔
657
            raise ValueError("Partition index not in the valid range")
6✔
658
        partition_path = self.wip_partition_path(partition_index)
6✔
659
        partition_path.mkdir(exist_ok=True)
6✔
660
        logger.info(f"Encoding partition {partition_index} to {partition_path}")
6✔
661

662
        all_field_names = [field.name for field in self.schema.fields]
6✔
663
        if "variant_id" in all_field_names:
6✔
664
            self.encode_id_partition(partition_index)
6✔
665
        if "variant_filter" in all_field_names:
6✔
666
            self.encode_filters_partition(partition_index)
6✔
667
        if "variant_contig" in all_field_names:
6✔
668
            self.encode_contig_partition(partition_index)
6✔
669
        self.encode_alleles_partition(partition_index)
6✔
670
        for array_spec in self.schema.fields:
6✔
671
            if array_spec.vcf_field is not None:
6✔
672
                self.encode_array_partition(array_spec, partition_index)
6✔
673
        if self.has_genotypes():
6✔
674
            self.encode_genotypes_partition(partition_index)
6✔
675
            self.encode_genotype_mask_partition(partition_index)
6✔
676
        if self.has_local_alleles():
6✔
677
            self.encode_local_alleles_partition(partition_index)
6✔
678
            self.encode_local_allele_fields_partition(partition_index)
6✔
679

680
        final_path = self.partition_path(partition_index)
6✔
681
        logger.info(f"Finalising {partition_index} at {final_path}")
6✔
682
        if final_path.exists():
6✔
683
            logger.warning(f"Removing existing partition at {final_path}")
6✔
684
            shutil.rmtree(final_path)
6✔
685
        os.rename(partition_path, final_path)
6✔
686

687
    def init_partition_array(self, partition_index, name):
6✔
688
        field_map = self.schema.field_map()
6✔
689
        array_spec = field_map[name]
6✔
690
        # Create an empty array like the definition
691
        src = self.arrays_path / array_spec.name
6✔
692
        # Overwrite any existing WIP files
693
        wip_path = self.wip_partition_array_path(partition_index, array_spec.name)
6✔
694
        shutil.copytree(src, wip_path, dirs_exist_ok=True)
6✔
695
        array = zarr.open_array(store=wip_path, mode="a")
6✔
696
        partition = self.metadata.partitions[partition_index]
6✔
697
        ba = core.BufferedArray(array, partition.start, name)
6✔
698
        logger.info(
6✔
699
            f"Start partition {partition_index} array {name} <{array.dtype}> "
700
            f"{array.shape} @ {wip_path}"
701
        )
702
        return ba
6✔
703

704
    def finalise_partition_array(self, partition_index, buffered_array):
6✔
705
        buffered_array.flush()
6✔
706
        logger.info(
6✔
707
            f"Completed partition {partition_index} array {buffered_array.name} "
708
            f"max_memory={core.display_size(buffered_array.max_buff_size)}"
709
        )
710

711
    def encode_array_partition(self, array_spec, partition_index):
6✔
712
        partition = self.metadata.partitions[partition_index]
6✔
713
        ba = self.init_partition_array(partition_index, array_spec.name)
6✔
714
        for value in self.source.iter_field(
6✔
715
            array_spec.vcf_field,
716
            ba.buff.shape[1:],
717
            partition.start,
718
            partition.stop,
719
        ):
720
            j = ba.next_buffer_row()
6✔
721
            ba.buff[j] = value
6✔
722

723
        self.finalise_partition_array(partition_index, ba)
6✔
724

725
    def encode_genotypes_partition(self, partition_index):
6✔
726
        partition = self.metadata.partitions[partition_index]
6✔
727
        gt = self.init_partition_array(partition_index, "call_genotype")
6✔
728
        gt_phased = self.init_partition_array(partition_index, "call_genotype_phased")
6✔
729

730
        for genotype, phased in self.source.iter_genotypes(
6✔
731
            gt.buff.shape[1:], partition.start, partition.stop
732
        ):
733
            j = gt.next_buffer_row()
6✔
734
            gt.buff[j] = genotype
6✔
735

736
            j_phased = gt_phased.next_buffer_row()
6✔
737
            gt_phased.buff[j_phased] = phased
6✔
738

739
        self.finalise_partition_array(partition_index, gt)
6✔
740
        self.finalise_partition_array(partition_index, gt_phased)
6✔
741

742
    def encode_genotype_mask_partition(self, partition_index):
6✔
743
        partition = self.metadata.partitions[partition_index]
6✔
744
        gt_mask = self.init_partition_array(partition_index, "call_genotype_mask")
6✔
745
        # Read back in the genotypes so we can compute the mask
746
        gt_array = zarr.open_array(
6✔
747
            store=self.wip_partition_array_path(partition_index, "call_genotype"),
748
            mode="r",
749
        )
750
        for genotypes in core.first_dim_slice_iter(
6✔
751
            gt_array, partition.start, partition.stop
752
        ):
753
            # TODO check is this the correct semantics when we are padding
754
            # with mixed ploidies?
755
            j = gt_mask.next_buffer_row()
6✔
756
            gt_mask.buff[j] = genotypes < 0
6✔
757
        self.finalise_partition_array(partition_index, gt_mask)
6✔
758

759
    def encode_local_alleles_partition(self, partition_index):
6✔
760
        partition = self.metadata.partitions[partition_index]
6✔
761
        call_LA = self.init_partition_array(partition_index, "call_LA")
6✔
762

763
        gt_array = zarr.open_array(
6✔
764
            store=self.wip_partition_array_path(partition_index, "call_genotype"),
765
            mode="r",
766
        )
767
        for genotypes in core.first_dim_slice_iter(
6✔
768
            gt_array, partition.start, partition.stop
769
        ):
770
            la = compute_la_field(genotypes)
6✔
771
            j = call_LA.next_buffer_row()
6✔
772
            call_LA.buff[j] = la
6✔
773
        self.finalise_partition_array(partition_index, call_LA)
6✔
774

775
    def encode_local_allele_fields_partition(self, partition_index):
6✔
776
        partition = self.metadata.partitions[partition_index]
6✔
777
        la_array = zarr.open_array(
6✔
778
            store=self.wip_partition_array_path(partition_index, "call_LA"),
779
            mode="r",
780
        )
781
        # We got through the localisable fields one-by-one so that we don't need to
782
        # keep several large arrays in memory at once for each partition.
783
        field_map = self.schema.field_map()
6✔
784
        for descriptor in localisable_fields:
6✔
785
            if descriptor.array_name not in field_map:
6✔
786
                continue
6✔
787
            assert field_map[descriptor.array_name].vcf_field is None
6✔
788

789
            buff = self.init_partition_array(partition_index, descriptor.array_name)
6✔
790
            source = self.source.fields[descriptor.vcf_field].iter_values(
6✔
791
                partition.start, partition.stop
792
            )
793
            for la in core.first_dim_slice_iter(
6✔
794
                la_array, partition.start, partition.stop
795
            ):
796
                raw_value = next(source)
6✔
797
                value = descriptor.sanitise(raw_value, 2, raw_value.dtype)
6✔
798
                j = buff.next_buffer_row()
6✔
799
                buff.buff[j] = descriptor.convert(value, la)
6✔
800
            self.finalise_partition_array(partition_index, buff)
6✔
801

802
    def encode_alleles_partition(self, partition_index):
6✔
803
        alleles = self.init_partition_array(partition_index, "variant_allele")
6✔
804
        partition = self.metadata.partitions[partition_index]
6✔
805

806
        for value in self.source.iter_alleles(
6✔
807
            partition.start, partition.stop, alleles.array.shape[1]
808
        ):
809
            j = alleles.next_buffer_row()
6✔
810
            alleles.buff[j] = value
6✔
811

812
        self.finalise_partition_array(partition_index, alleles)
6✔
813

814
    def encode_id_partition(self, partition_index):
6✔
815
        vid = self.init_partition_array(partition_index, "variant_id")
6✔
816
        vid_mask = self.init_partition_array(partition_index, "variant_id_mask")
6✔
817
        partition = self.metadata.partitions[partition_index]
6✔
818

819
        for value in self.source.iter_id(partition.start, partition.stop):
6✔
820
            j = vid.next_buffer_row()
6✔
821
            k = vid_mask.next_buffer_row()
6✔
822
            assert j == k
6✔
823
            if value is not None:
6✔
824
                vid.buff[j] = value
6✔
825
                vid_mask.buff[j] = False
6✔
826
            else:
827
                vid.buff[j] = constants.STR_MISSING
6✔
828
                vid_mask.buff[j] = True
6✔
829

830
        self.finalise_partition_array(partition_index, vid)
6✔
831
        self.finalise_partition_array(partition_index, vid_mask)
6✔
832

833
    def encode_filters_partition(self, partition_index):
6✔
834
        var_filter = self.init_partition_array(partition_index, "variant_filter")
6✔
835
        partition = self.metadata.partitions[partition_index]
6✔
836

837
        for filter_values in self.source.iter_filters(partition.start, partition.stop):
6✔
838
            j = var_filter.next_buffer_row()
6✔
839
            var_filter.buff[j] = filter_values
6✔
840

841
        self.finalise_partition_array(partition_index, var_filter)
6✔
842

843
    def encode_contig_partition(self, partition_index):
6✔
844
        contig = self.init_partition_array(partition_index, "variant_contig")
6✔
845
        partition = self.metadata.partitions[partition_index]
6✔
846

847
        for contig_index in self.source.iter_contig(partition.start, partition.stop):
6✔
848
            j = contig.next_buffer_row()
6✔
849
            contig.buff[j] = contig_index
6✔
850

851
        self.finalise_partition_array(partition_index, contig)
6✔
852

853
    #######################
854
    # finalise
855
    #######################
856

857
    def finalise_array(self, name):
6✔
858
        logger.info(f"Finalising {name}")
6✔
859
        final_path = self.path / name
6✔
860
        if final_path.exists():
6✔
861
            # NEEDS TEST
UNCOV
862
            raise ValueError(f"Array {name} already exists")
×
863
        for partition in range(self.num_partitions):
6✔
864
            # Move all the files in partition dir to dest dir
865
            src = self.partition_array_path(partition, name)
6✔
866
            if not src.exists():
6✔
867
                # Needs test
UNCOV
868
                raise ValueError(f"Partition {partition} of {name} does not exist")
×
869
            dest = self.arrays_path / name
6✔
870
            # This is Zarr v2 specific. Chunks in v3 with start with "c" prefix.
871
            chunk_files = [
6✔
872
                path for path in src.iterdir() if not path.name.startswith(".")
873
            ]
874
            # TODO check for a count of then number of files. If we require a
875
            # dimension_separator of "/" then we could make stronger assertions
876
            # here, as we'd always have num_variant_chunks
877
            logger.debug(
6✔
878
                f"Moving {len(chunk_files)} chunks for {name} partition {partition}"
879
            )
880
            for chunk_file in chunk_files:
6✔
881
                os.rename(chunk_file, dest / chunk_file.name)
6✔
882
        # Finally, once all the chunks have moved into the arrays dir,
883
        # we move it out of wip
884
        os.rename(self.arrays_path / name, self.path / name)
6✔
885
        core.update_progress(1)
6✔
886

887
    def finalise(self, show_progress=False):
6✔
888
        self.load_metadata()
6✔
889

890
        logger.info(f"Scanning {self.num_partitions} partitions")
6✔
891
        missing = []
6✔
892
        # TODO may need a progress bar here
893
        for partition_id in range(self.num_partitions):
6✔
894
            if not self.partition_path(partition_id).exists():
6✔
895
                missing.append(partition_id)
6✔
896
        if len(missing) > 0:
6✔
897
            raise FileNotFoundError(f"Partitions not encoded: {missing}")
6✔
898

899
        progress_config = core.ProgressConfig(
6✔
900
            total=len(self.schema.fields),
901
            title="Finalise",
902
            units="array",
903
            show=show_progress,
904
        )
905
        # NOTE: it's not clear that adding more workers will make this quicker,
906
        # as it's just going to be causing contention on the file system.
907
        # Something to check empirically in some deployments.
908
        # FIXME we're just using worker_processes=0 here to hook into the
909
        # SynchronousExecutor which is intended for testing purposes so
910
        # that we get test coverage. Should fix this either by allowing
911
        # for multiple workers, or making a standard wrapper for tqdm
912
        # that allows us to have a consistent look and feel.
913
        with core.ParallelWorkManager(0, progress_config) as pwm:
6✔
914
            for field in self.schema.fields:
6✔
915
                pwm.submit(self.finalise_array, field.name)
6✔
916
        logger.debug(f"Removing {self.wip_path}")
6✔
917
        shutil.rmtree(self.wip_path)
6✔
918
        logger.info("Consolidating Zarr metadata")
6✔
919
        zarr.consolidate_metadata(self.path)
6✔
920

921
    #######################
922
    # index
923
    #######################
924

925
    def create_index(self):
6✔
926
        """Create an index to support efficient region queries."""
927

928
        indexer = VcfZarrIndexer(self.path)
6✔
929
        indexer.create_index()
6✔
930

931
    ######################
932
    # encode_all_partitions
933
    ######################
934

935
    def get_max_encoding_memory(self):
6✔
936
        """
937
        Return the approximate maximum memory used to encode a variant chunk.
938
        """
939
        max_encoding_mem = 0
6✔
940
        for array_spec in self.schema.fields:
6✔
941
            max_encoding_mem = max(max_encoding_mem, array_spec.variant_chunk_nbytes)
6✔
942
        gt_mem = 0
6✔
943
        if self.has_genotypes:
6✔
944
            gt_mem = sum(
6✔
945
                field.variant_chunk_nbytes
946
                for field in self.schema.fields
947
                if field.name.startswith("call_genotype")
948
            )
949
        return max(max_encoding_mem, gt_mem)
6✔
950

951
    def encode_all_partitions(
6✔
952
        self, *, worker_processes=1, show_progress=False, max_memory=None
953
    ):
954
        max_memory = core.parse_max_memory(max_memory)
6✔
955
        self.load_metadata()
6✔
956
        num_partitions = self.num_partitions
6✔
957
        per_worker_memory = self.get_max_encoding_memory()
6✔
958
        logger.info(
6✔
959
            f"Encoding Zarr over {num_partitions} partitions with "
960
            f"{worker_processes} workers and {core.display_size(per_worker_memory)} "
961
            "per worker"
962
        )
963
        # Each partition requires per_worker_memory bytes, so to prevent more that
964
        # max_memory being used, we clamp the number of workers
965
        max_num_workers = max_memory // per_worker_memory
6✔
966
        if max_num_workers < worker_processes:
6✔
967
            logger.warning(
6✔
968
                f"Limiting number of workers to {max_num_workers} to "
969
                "keep within specified memory budget of "
970
                f"{core.display_size(max_memory)}"
971
            )
972
        if max_num_workers <= 0:
6✔
973
            raise ValueError(
6✔
974
                f"Insufficient memory to encode a partition:"
975
                f"{core.display_size(per_worker_memory)} > "
976
                f"{core.display_size(max_memory)}"
977
            )
978
        num_workers = min(max_num_workers, worker_processes)
6✔
979

980
        total_bytes = 0
6✔
981
        for array_spec in self.schema.fields:
6✔
982
            # Open the array definition to get the total size
983
            total_bytes += zarr.open(self.arrays_path / array_spec.name).nbytes
6✔
984

985
        progress_config = core.ProgressConfig(
6✔
986
            total=total_bytes,
987
            title="Encode",
988
            units="B",
989
            show=show_progress,
990
        )
991
        with core.ParallelWorkManager(num_workers, progress_config) as pwm:
6✔
992
            for partition_index in range(num_partitions):
6✔
993
                pwm.submit(self.encode_partition, partition_index)
6✔
994

995

996
class VcfZarr:
6✔
997
    def __init__(self, path):
6✔
998
        if not (path / ".zmetadata").exists():
6✔
UNCOV
999
            raise ValueError("Not in VcfZarr format")  # NEEDS TEST
×
1000
        self.path = path
6✔
1001
        self.root = zarr.open(path, mode="r")
6✔
1002

1003
    def summary_table(self):
6✔
1004
        data = []
6✔
1005
        arrays = [(core.du(self.path / a.basename), a) for _, a in self.root.arrays()]
6✔
1006
        arrays.sort(key=lambda x: x[0])
6✔
1007
        for stored, array in reversed(arrays):
6✔
1008
            d = {
6✔
1009
                "name": array.name,
1010
                "dtype": str(array.dtype),
1011
                "stored": core.display_size(stored),
1012
                "size": core.display_size(array.nbytes),
1013
                "ratio": core.display_number(array.nbytes / stored),
1014
                "nchunks": str(array.nchunks),
1015
                "chunk_size": core.display_size(array.nbytes / array.nchunks),
1016
                "avg_chunk_stored": core.display_size(int(stored / array.nchunks)),
1017
                "shape": str(array.shape),
1018
                "chunk_shape": str(array.chunks),
1019
                "compressor": str(array.compressor),
1020
                "filters": str(array.filters),
1021
            }
1022
            data.append(d)
6✔
1023
        return data
6✔
1024

1025

1026
class VcfZarrIndexer:
6✔
1027
    """
1028
    Creates an index for efficient region queries in a VCF Zarr dataset.
1029
    """
1030

1031
    def __init__(self, path):
6✔
1032
        self.path = pathlib.Path(path)
6✔
1033

1034
    def create_index(self):
6✔
1035
        """Create an index to support efficient region queries."""
1036
        root = zarr.open_group(store=self.path, mode="r+")
6✔
1037

1038
        if (
6✔
1039
            "variant_contig" not in root
1040
            or "variant_position" not in root
1041
            or "variant_length" not in root
1042
        ):
1043
            logger.warning("Cannot create index: required arrays not found")
×
UNCOV
1044
            return
×
1045

1046
        contig = root["variant_contig"]
6✔
1047
        pos = root["variant_position"]
6✔
1048
        length = root["variant_length"]
6✔
1049

1050
        assert contig.cdata_shape == pos.cdata_shape
6✔
1051

1052
        index = []
6✔
1053

1054
        logger.info("Creating region index")
6✔
1055
        for v_chunk in range(pos.cdata_shape[0]):
6✔
1056
            c = contig.blocks[v_chunk]
6✔
1057
            p = pos.blocks[v_chunk]
6✔
1058
            e = p + length.blocks[v_chunk] - 1
6✔
1059

1060
            # create a row for each contig in the chunk
1061
            d = np.diff(c, append=-1)
6✔
1062
            c_start_idx = 0
6✔
1063
            for c_end_idx in np.nonzero(d)[0]:
6✔
1064
                assert c[c_start_idx] == c[c_end_idx]
6✔
1065
                index.append(
6✔
1066
                    (
1067
                        v_chunk,  # chunk index
1068
                        c[c_start_idx],  # contig ID
1069
                        p[c_start_idx],  # start
1070
                        p[c_end_idx],  # end
1071
                        np.max(e[c_start_idx : c_end_idx + 1]),  # max end
1072
                        c_end_idx - c_start_idx + 1,  # num records
1073
                    )
1074
                )
1075
                c_start_idx = c_end_idx + 1
6✔
1076

1077
        index = np.array(index, dtype=pos.dtype)
6✔
1078
        kwargs = {}
6✔
1079
        if not zarr_utils.zarr_v3():
6✔
1080
            kwargs["dimension_separator"] = "/"
6✔
1081
        array = root.array(
6✔
1082
            "region_index",
1083
            data=index,
1084
            shape=index.shape,
1085
            chunks=index.shape,
1086
            dtype=index.dtype,
1087
            compressor=numcodecs.Blosc("zstd", clevel=9, shuffle=0),
1088
            fill_value=None,
1089
            **kwargs,
1090
        )
1091
        array.attrs["_ARRAY_DIMENSIONS"] = [
6✔
1092
            "region_index_values",
1093
            "region_index_fields",
1094
        ]
1095

1096
        logger.info("Consolidating Zarr metadata")
6✔
1097
        zarr.consolidate_metadata(self.path)
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc