• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

sgkit-dev / bio2zarr / 18940786519

30 Oct 2025 12:34PM UTC coverage: 98.361% (+0.07%) from 98.292%
18940786519

Pull #422

github

web-flow
Merge b33cd4646 into 3f955cef0
Pull Request #422: Update intro.md with VCF Zarr conversion info

2880 of 2928 relevant lines covered (98.36%)

3.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

100.0
/bio2zarr/tskit.py
1
import logging
4✔
2
import pathlib
4✔
3

4
import numpy as np
4✔
5

6
from bio2zarr import constants, core, vcz
4✔
7

8
logger = logging.getLogger(__name__)
4✔
9

10

11
class TskitFormat(vcz.Source):
4✔
12
    @core.requires_optional_dependency("tskit", "tskit")
4✔
13
    def __init__(
4✔
14
        self,
15
        ts,
16
        *,
17
        model_mapping=None,
18
        contig_id=None,
19
        isolated_as_missing=False,
20
    ):
21
        import tskit
4✔
22

23
        self._path = None
4✔
24
        # Future versions here will need to deal with the complexities of
25
        # having lists of tree sequences for multiple chromosomes.
26
        if isinstance(ts, tskit.TreeSequence):
4✔
27
            self.ts = ts
4✔
28
        else:
29
            # input 'ts' is a path.
30
            self._path = ts
4✔
31
            logger.info(f"Loading from {ts}")
4✔
32
            self.ts = tskit.load(ts)
4✔
33
        logger.info(
4✔
34
            f"Input has {self.ts.num_individuals} individuals and "
35
            f"{self.ts.num_sites} sites"
36
        )
37

38
        self.contig_id = contig_id if contig_id is not None else "1"
4✔
39
        self.isolated_as_missing = isolated_as_missing
4✔
40

41
        self.positions = self.ts.sites_position
4✔
42

43
        if model_mapping is None:
4✔
44
            model_mapping = self.ts.map_to_vcf_model()
4✔
45

46
        individuals_nodes = model_mapping.individuals_nodes
4✔
47
        sample_ids = model_mapping.individuals_name
4✔
48

49
        self._num_samples = individuals_nodes.shape[0]
4✔
50
        logger.info(f"Converting for {self._num_samples} samples")
4✔
51
        if self._num_samples < 1:
4✔
52
            raise ValueError("individuals_nodes must have at least one sample")
4✔
53
        self.max_ploidy = individuals_nodes.shape[1]
4✔
54
        if len(sample_ids) != self._num_samples:
4✔
55
            raise ValueError(
4✔
56
                f"Length of sample_ids ({len(sample_ids)}) does not match "
57
                f"number of samples ({self._num_samples})"
58
            )
59

60
        self._samples = [vcz.Sample(id=sample_id) for sample_id in sample_ids]
4✔
61

62
        self.tskit_samples = np.unique(individuals_nodes[individuals_nodes >= 0])
4✔
63
        if len(self.tskit_samples) < 1:
4✔
64
            raise ValueError("individuals_nodes must have at least one valid sample")
4✔
65
        node_id_to_index = {node_id: i for i, node_id in enumerate(self.tskit_samples)}
4✔
66
        valid_mask = individuals_nodes >= 0
4✔
67
        self.sample_indices, self.ploidy_indices = np.where(valid_mask)
4✔
68
        self.genotype_indices = np.array(
4✔
69
            [node_id_to_index[node_id] for node_id in individuals_nodes[valid_mask]]
70
        )
71

72
    @property
4✔
73
    def path(self):
4✔
74
        return self._path
4✔
75

76
    @property
4✔
77
    def num_records(self):
4✔
78
        return self.ts.num_sites
4✔
79

80
    @property
4✔
81
    def num_samples(self):
4✔
82
        return self._num_samples
4✔
83

84
    @property
4✔
85
    def samples(self):
4✔
86
        return self._samples
4✔
87

88
    @property
4✔
89
    def root_attrs(self):
4✔
90
        return {}
4✔
91

92
    @property
4✔
93
    def contigs(self):
4✔
94
        return [vcz.Contig(id=self.contig_id)]
4✔
95

96
    def iter_contig(self, start, stop):
4✔
97
        yield from (0 for _ in range(start, stop))
4✔
98

99
    def iter_field(self, field_name, shape, start, stop):
4✔
100
        if field_name == "position":
4✔
101
            for pos in self.ts.sites_position[start:stop]:
4✔
102
                yield int(pos)
4✔
103
        else:
104
            raise ValueError(f"Unknown field {field_name}")
4✔
105

106
    def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
4✔
107
        # All genotypes in tskit are considered phased
108
        phased = np.ones(shape[:-1], dtype=bool)
4✔
109
        logger.debug(f"Getting genotpes start={start} stop={stop}")
4✔
110

111
        for variant in self.ts.variants(
4✔
112
            isolated_as_missing=self.isolated_as_missing,
113
            left=self.positions[start],
114
            right=self.positions[stop] if stop < self.num_records else None,
115
            samples=self.tskit_samples,
116
            copy=False,
117
        ):
118
            gt = np.full(shape, constants.INT_FILL, dtype=np.int8)
4✔
119
            alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
4✔
120
            # length is the length of the REF allele unless other fields
121
            # are included.
122
            variant_length = len(variant.alleles[0])
4✔
123
            for i, allele in enumerate(variant.alleles):
4✔
124
                # None is returned by tskit in the case of a missing allele
125
                if allele is None:
4✔
126
                    continue
4✔
127
                assert i < num_alleles
4✔
128
                alleles[i] = allele
4✔
129
            gt[self.sample_indices, self.ploidy_indices] = variant.genotypes[
4✔
130
                self.genotype_indices
131
            ]
132

133
            yield vcz.VariantData(variant_length, alleles, gt, phased)
4✔
134

135
    def generate_schema(
4✔
136
        self,
137
        variants_chunk_size=None,
138
        samples_chunk_size=None,
139
    ):
140
        n = self.num_samples
4✔
141
        m = self.ts.num_sites
4✔
142

143
        # Determine max number of alleles
144
        max_alleles = 0
4✔
145
        for site in self.ts.sites():
4✔
146
            states = {site.ancestral_state}
4✔
147
            for mut in site.mutations:
4✔
148
                states.add(mut.derived_state)
4✔
149
            max_alleles = max(len(states), max_alleles)
4✔
150

151
        logging.info(f"Scanned tskit with {n} samples and {m} variants")
4✔
152
        logging.info(
4✔
153
            f"Maximum ploidy: {self.max_ploidy}, maximum alleles: {max_alleles}"
154
        )
155
        dimensions = vcz.standard_dimensions(
4✔
156
            variants_size=m,
157
            variants_chunk_size=variants_chunk_size,
158
            samples_size=n,
159
            samples_chunk_size=samples_chunk_size,
160
            ploidy_size=self.max_ploidy,
161
            alleles_size=max_alleles,
162
        )
163
        schema_instance = vcz.VcfZarrSchema(
4✔
164
            format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION,
165
            dimensions=dimensions,
166
            fields=[],
167
        )
168

169
        logger.info(
4✔
170
            "Generating schema with chunks="
171
            f"{schema_instance.dimensions['variants'].chunk_size}, "
172
            f"{schema_instance.dimensions['samples'].chunk_size}"
173
        )
174

175
        # Check if positions will fit in i4 (max ~2.1 billion)
176
        min_position = 0
4✔
177
        max_position = 0
4✔
178
        if self.ts.num_sites > 0:
4✔
179
            min_position = np.min(self.ts.sites_position)
4✔
180
            max_position = np.max(self.ts.sites_position)
4✔
181

182
        tables = self.ts.tables
4✔
183
        ancestral_state_offsets = tables.sites.ancestral_state_offset
4✔
184
        derived_state_offsets = tables.mutations.derived_state_offset
4✔
185
        ancestral_lengths = ancestral_state_offsets[1:] - ancestral_state_offsets[:-1]
4✔
186
        derived_lengths = derived_state_offsets[1:] - derived_state_offsets[:-1]
4✔
187
        max_variant_length = max(
4✔
188
            np.max(ancestral_lengths) if len(ancestral_lengths) > 0 else 0,
189
            np.max(derived_lengths) if len(derived_lengths) > 0 else 0,
190
        )
191

192
        array_specs = [
4✔
193
            vcz.ZarrArraySpec(
194
                source="position",
195
                name="variant_position",
196
                dtype=core.min_int_dtype(min_position, max_position),
197
                dimensions=["variants"],
198
                description="Position of each variant",
199
            ),
200
            vcz.ZarrArraySpec(
201
                source=None,
202
                name="variant_allele",
203
                dtype="O",
204
                dimensions=["variants", "alleles"],
205
                description="Alleles for each variant",
206
            ),
207
            vcz.ZarrArraySpec(
208
                source=None,
209
                name="variant_length",
210
                dtype=core.min_int_dtype(0, max_variant_length),
211
                dimensions=["variants"],
212
                description="Length of each variant",
213
            ),
214
            vcz.ZarrArraySpec(
215
                source=None,
216
                name="variant_contig",
217
                dtype=core.min_int_dtype(0, len(self.contigs)),
218
                dimensions=["variants"],
219
                description="Contig/chromosome index for each variant",
220
            ),
221
            vcz.ZarrArraySpec(
222
                source=None,
223
                name="call_genotype_phased",
224
                dtype="bool",
225
                dimensions=["variants", "samples"],
226
                description="Whether the genotype is phased",
227
                compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
228
            ),
229
            vcz.ZarrArraySpec(
230
                source=None,
231
                name="call_genotype",
232
                dtype=core.min_int_dtype(constants.INT_FILL, max_alleles - 1),
233
                dimensions=["variants", "samples", "ploidy"],
234
                description="Genotype for each variant and sample",
235
                compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(),
236
            ),
237
            vcz.ZarrArraySpec(
238
                source=None,
239
                name="call_genotype_mask",
240
                dtype="bool",
241
                dimensions=["variants", "samples", "ploidy"],
242
                description="Mask for each genotype call",
243
                compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
244
            ),
245
        ]
246
        schema_instance.fields = array_specs
4✔
247
        return schema_instance
4✔
248

249

250
def convert(
4✔
251
    ts_or_path,
252
    vcz_path,
253
    *,
254
    model_mapping=None,
255
    contig_id=None,
256
    isolated_as_missing=False,
257
    variants_chunk_size=None,
258
    samples_chunk_size=None,
259
    worker_processes=core.DEFAULT_WORKER_PROCESSES,
260
    show_progress=False,
261
):
262
    """
263
    Convert a :class:`tskit.TreeSequence` (or path to a tree sequence
264
    file) to VCF Zarr format stored at the specified path.
265

266
    .. todo:: Document parameters
267
    """
268
    # FIXME there's some tricky details here in how we're handling
269
    # parallelism that we'll need to tackle properly, and maybe
270
    # review the current structures a bit. Basically, it looks like
271
    # we're pickling/unpickling the format object when we have
272
    # multiple workers, and this results in several copies of the
273
    # tree sequence object being pass around. This is fine most
274
    # of the time, but results in lots of memory being used when
275
    # we're dealing with really massive files.
276
    # See https://github.com/sgkit-dev/bio2zarr/issues/403
277
    tskit_format = TskitFormat(
4✔
278
        ts_or_path,
279
        model_mapping=model_mapping,
280
        contig_id=contig_id,
281
        isolated_as_missing=isolated_as_missing,
282
    )
283
    schema_instance = tskit_format.generate_schema(
4✔
284
        variants_chunk_size=variants_chunk_size,
285
        samples_chunk_size=samples_chunk_size,
286
    )
287
    zarr_path = pathlib.Path(vcz_path)
4✔
288
    vzw = vcz.VcfZarrWriter(TskitFormat, zarr_path)
4✔
289
    # Rough heuristic to split work up enough to keep utilisation high
290
    target_num_partitions = max(1, worker_processes * 4)
4✔
291
    vzw.init(
4✔
292
        tskit_format,
293
        target_num_partitions=target_num_partitions,
294
        schema=schema_instance,
295
    )
296
    vzw.encode_all_partitions(
4✔
297
        worker_processes=worker_processes,
298
        show_progress=show_progress,
299
    )
300
    vzw.finalise(show_progress)
4✔
301
    vzw.create_index()
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc