• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

rmcar17 / cogent3 / 16431585733

16 Jul 2025 07:02AM UTC coverage: 90.819% (+0.004%) from 90.815%
16431585733

push

github

web-flow
Merge pull request #2403 from GavinHuttley/develop

DEV: bump version to 2025.7.10a3

1 of 1 new or added line in 1 file covered. (100.0%)

498 existing lines in 32 files now uncovered.

30123 of 33168 relevant lines covered (90.82%)

5.45 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.65
/src/cogent3/evolve/coevolution.py
1
import enum
6✔
2
import itertools
6✔
3
import typing
6✔
4

5
import numba
6✔
6
import numpy
6✔
7
from numpy import (
6✔
8
    nan,
9
)
10

11
from cogent3.core import alphabet as c3_alphabet
6✔
12
from cogent3.core.moltype import IUPAC_gap, IUPAC_missing
6✔
13
from cogent3.util import dict_array
6✔
14
from cogent3.util import parallel as PAR
6✔
15
from cogent3.util import progress_display as UI
6✔
16

17
DEFAULT_EXCLUDES = f"{IUPAC_gap}{IUPAC_missing}"
6✔
18
DEFAULT_NULL_VALUE = nan
6✔
19

20

21
class MI_METHODS(enum.Enum):
6✔
22
    mi = "mi"
6✔
23
    nmi = "nmi"
6✔
24
    rmi = "rmi"
6✔
25

26

27
# Comments on design
28
# the revised mutual information calculations are based on a sequence alignment
29
# represented as a numpy uint8 array.
30

31
# The resampled mutual information calculation uses a cache of all entropy terms
32
# from the independent and joint position, produced by _calc_entropy_components().
33
# For each combination of alternate possible states, there are two values in the
34
# entropy terms that need to be modified. These calculations are done by
35
# _calc_updated_entropy() and _calc_temp_entropy(). This caching, plus the
36
# numba.jit compilation, makes rmi competitive performance-wise with the other
37
# methods.
38

39

40
@numba.jit
41
def _count_states(
42
    state_vector: numpy.ndarray,
43
    num_states: int,
44
    counts: numpy.ndarray | None = None,
45
) -> numpy.ndarray:  # pragma: no cover
46
    """computes counts from a single vector of states"""
47
    if counts is None:
48
        counts = numpy.empty(num_states, dtype=numpy.int64)
49

50
    counts.fill(0)
51
    for state in state_vector:
52
        if state < num_states:
53
            counts[state] += 1
54

55
    return counts
56

57

58
@numba.jit
59
def _vector_entropy(counts: numpy.ndarray) -> float:  # pragma: no cover
60
    """computes entropy for a single vector of integers"""
61
    total = counts.sum()
62
    if total <= 1:
63
        return 0.0 if total else numpy.nan
64

65
    entropy = 0.0
66
    for count in counts:
67
        if count > 0:
68
            prob = count / total
69
            entropy += prob * -numpy.log2(prob)
70

71
    return entropy
72

73

74
@numba.jit
75
def _count_joint_states(
76
    joint_states: numpy.ndarray,
77
    num_states: int,
78
    counts: numpy.ndarray | None = None,
79
) -> numpy.ndarray:  # pragma: no cover
80
    if counts is None:
81
        counts = numpy.empty((num_states, num_states), dtype=numpy.int64)
82

83
    counts.fill(0)
84
    for joint_state in joint_states:
85
        i, j = joint_state
86
        if i >= num_states or j >= num_states:
87
            continue
88

89
        counts[i, j] += 1
90
    return counts
91

92

93
@numba.jit
94
def _calc_joint_entropy(counts: numpy.ndarray) -> float:  # pragma: no cover
95
    entropy = 0.0
96
    total_counts = counts.sum()
97
    for count in counts.flatten():
98
        if count > 0:
99
            prob = count / total_counts
100
            entropy += prob * -numpy.log2(prob)
101
    return entropy
102

103

104
@numba.jit
105
def _calc_column_entropies(
106
    columns: numpy.ndarray,
107
    num_states: int,
108
) -> numpy.ndarray:  # pragma: no cover
109
    """
110
    Calculate the entropy for each column in the input array.
111

112
    Parameters:
113
    array (numpy.ndarray): Input array of unsigned 8-bit integers.
114

115
    Returns:
116
    numpy.ndarray: Array of entropy values for each column.
117
    """
118
    n_cols = columns.shape[1]
119
    entropies = numpy.zeros(n_cols, dtype=numpy.float64)
120
    counts = numpy.empty(num_states, dtype=numpy.int64)
121
    for col in range(n_cols):
122
        counts = _count_states(columns[:, col], num_states, counts)
123
        entropies[col] = _vector_entropy(counts)
124

125
    return entropies
126

127

128
@numba.jit
129
def _make_weights(
130
    counts: numpy.ndarray,
131
    weights: numpy.ndarray | None = None,
132
) -> numpy.ndarray:  # pragma: no cover
133
    """Return the weights for replacement states for each possible character.
134
    We compute the weight as the normalized frequency of the replacement state
135
    divided by 2*n."""
136
    zeroes = counts == 0
137
    total = counts.sum()
138
    char_prob = counts.astype(numpy.float64) / total
139
    if weights is None:
140
        weights = numpy.empty((counts.shape[0], counts.shape[0]), dtype=numpy.float64)
141

142
    weights.fill(0.0)
143
    denom = 2 * total
144
    for i in range(counts.shape[0]):
145
        if zeroes[i]:
146
            continue
147
        diag = char_prob[i]
148
        weights[i, :] = char_prob / (1 - diag) / denom
149
        weights[i, i] = 0.0
150

151
    return weights
152

153

154
@numba.jit
155
def _calc_entropy_components(
156
    counts: numpy.ndarray,
157
    total: int,
158
) -> tuple[float, numpy.ndarray, numpy.ndarray]:  # pragma: no cover
159
    """Return the entropy and arrays of entropy components and non-zero status"""
160
    non_zero = counts != 0
161
    assert counts.ndim == 1, "designed for 1D arrays"
162
    freqs = counts.astype(numpy.float64) / total
163
    log2 = numpy.zeros(counts.shape, dtype=numpy.float64)
164
    et = numpy.zeros(counts.shape, dtype=numpy.float64)
165
    log2[non_zero] = numpy.log2(freqs[non_zero])
166
    et[non_zero] = (
167
        -log2[non_zero] * freqs[non_zero]
168
    )  # the terms in the entropy equation
169
    h = numpy.sum(et)  # entropy of the original data
170
    return h, et, non_zero
171

172

173
@numba.jit
174
def _calc_temp_entropy(
175
    entropy: float,
176
    counts: numpy.ndarray,
177
    entropy_terms: numpy.ndarray,
178
    total: int,
179
    index: int,
180
) -> float:  # pragma: no cover
181
    # compute the intermediate column 1 entropy term
182
    new_count = counts[index] - 1
183
    orig_term = entropy_terms[index]
184
    if new_count > 0:
185
        freq = new_count / total
186
        log2 = -numpy.log2(freq)
187
        new_term = freq * log2
188
    else:
189
        new_term = 0.0
190

191
    return entropy - orig_term + new_term
192

193

194
@numba.jit
195
def _calc_updated_entropy(
196
    temp_entropy: float,
197
    counts: numpy.ndarray,
198
    entropy_terms: numpy.ndarray,
199
    total: int,
200
    index: int,
201
) -> float:  # pragma: no cover
202
    new_count = counts[index] + 1
203
    orig_term = entropy_terms[index]
204
    freq = new_count / total
205
    log2 = -numpy.log2(freq)
206
    new_term = freq * log2
207
    return temp_entropy - orig_term + new_term
208

209

210
@numba.jit
211
def _calc_pair_scale(
212
    counts_12: numpy.ndarray,
213
    counts_1: numpy.ndarray,
214
    counts_2: numpy.ndarray,
215
    weights_1: numpy.ndarray,
216
    weights_2: numpy.ndarray,
217
    states_12: numpy.ndarray,
218
    states_1: numpy.ndarray,
219
    states_2: numpy.ndarray,
220
    coeffs: numpy.ndarray,
221
) -> tuple[float, numpy.ndarray, numpy.ndarray]:  # pragma: no cover
222
    """Return entropies and weights for comparable alignment.
223
    A comparable alignment is one in which, for each paired state ij, all
224
    alternate observable paired symbols are created. For instance, let the
225
    symbols {A,C} be observed at position i and {A,C} at position j. If we
226
    observe the paired types {AC, AA}. A comparable alignment would involve
227
    replacing an AC pair with a CC pair."""
228

229
    # break down the joint entropy into individual terms so we can easily adjust
230
    # for the different combinations
231
    total = counts_1.sum()
232
    counts_12 = counts_12.flatten()
233
    je_orig, jet, j_nz = _calc_entropy_components(counts_12, total)
234

235
    # individual entropy components for column 1
236
    orig_e_1, et_1, nz_1 = _calc_entropy_components(counts_1, total)
237

238
    # individual entropy components for column 2
239
    orig_e_2, et_2, nz_2 = _calc_entropy_components(counts_2, total)
240
    orig_mi = orig_e_1 + orig_e_2 - je_orig
241

242
    num_scales = j_nz.sum() * nz_1.sum() + j_nz.sum() * nz_2.sum()
243
    scales = numpy.zeros((num_scales, 2), dtype=numpy.float64)
244
    pairs = numpy.zeros((num_scales, 2), dtype=numpy.uint64)
245
    new_coord = numpy.zeros(2, dtype=numpy.int64)
246
    if orig_e_1 == 0.0 or orig_e_2 == 0.0 or je_orig == 0.0:
247
        return 0.0, pairs, scales
248

249
    n = 0
250
    for pair in states_12.flatten()[j_nz]:
251
        i, j = c3_alphabet.index_to_coord(pair, coeffs=coeffs)
252
        if counts_12[pair] == 0:
253
            continue
254

255
        # compute the intermediate column 1 entropy term
256
        tmp_e_1 = _calc_temp_entropy(orig_e_1, counts_1, et_1, total, i)
257

258
        # compute the intermediate joint entropy term
259
        tmp_je = _calc_temp_entropy(je_orig, counts_12, jet, total, pair)
260

261
        for k in states_1:
262
            if k == i or counts_1[k] == 0:
263
                continue
264
            # compute the new entropy for column 1
265
            new_e_1 = _calc_updated_entropy(tmp_e_1, counts_1, et_1, total, k)
266

267
            # compute the new joint-entropy
268
            new_coord[:] = k, j
269
            new_index = c3_alphabet.coord_to_index(new_coord, coeffs=coeffs)
270
            n_je = _calc_updated_entropy(tmp_je, counts_12, jet, total, new_index)
271

272
            # the weight
273
            w = weights_1[i, k]
274
            new_mi = new_e_1 + orig_e_2 - n_je
275
            scales[n][0] = new_mi
276
            scales[n][1] = w
277
            pairs[n][0] = i
278
            pairs[n][1] = j
279
            n += 1
280

281
        # compute the intermediate column 2 entropy term
282
        tmp_e_2 = _calc_temp_entropy(orig_e_2, counts_2, et_2, total, j)
283
        for k in states_2:
284
            if k == j or counts_2[k] == 0:
285
                continue
286

287
            # compute the new entropy for column 1
288
            new_e_2 = _calc_updated_entropy(tmp_e_2, counts_2, et_2, total, k)
289

290
            # compute the new joint-entropy
291
            new_coord[:] = i, k
292
            new_index = c3_alphabet.coord_to_index(new_coord, coeffs=coeffs)
293
            n_je = _calc_updated_entropy(tmp_je, counts_12, jet, total, new_index)
294

295
            # the weight
296
            w = weights_2[j, k]
297
            new_mi = orig_e_1 + new_e_2 - n_je
298
            scales[n][0] = new_mi
299
            scales[n][1] = w
300
            pairs[n][0] = i
301
            pairs[n][1] = j
302
            n += 1
303

304
    return orig_mi, pairs, scales
305

306

307
@numba.jit
308
def _count_col_joint(
309
    pos_12: numpy.ndarray,
310
    counts_12: numpy.ndarray,
311
    counts_1: numpy.ndarray,
312
    counts_2: numpy.ndarray,
313
    num_states: int,
314
) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:  # pragma: no cover
315
    counts_12.fill(0)
316
    counts_1.fill(0)
317
    counts_2.fill(0)
318
    for pos_1, pos_2 in pos_12:
319
        if pos_1 >= num_states or pos_2 >= num_states:
320
            continue
321
        counts_12[pos_1, pos_2] += 1
322
        counts_1[pos_1] += 1
323
        counts_2[pos_2] += 1
324

325
    return counts_12, counts_1, counts_2
326

327

328
@numba.jit
329
def _rmi_calc(
330
    positions: numpy.ndarray,
331
    alignment: numpy.ndarray,
332
    num_states: numpy.uint8,
333
) -> tuple[numpy.ndarray, numpy.ndarray]:  # pragma: no cover
334
    coeffs = c3_alphabet.coord_conversion_coeffs(num_states, 2, dtype=numpy.int64)
335
    weights_1 = numpy.empty((num_states, num_states), dtype=float)
336
    weights_2 = numpy.empty((num_states, num_states), dtype=float)
337

338
    n_seqs = alignment.shape[0]
339
    stats = numpy.empty(len(positions), dtype=numpy.float64)
340
    stats.fill(numpy.nan)
341

342
    counts_1 = numpy.empty(num_states, dtype=numpy.int64)
343
    counts_2 = numpy.empty(num_states, dtype=numpy.int64)
344
    counts_12 = numpy.zeros((num_states, num_states), dtype=numpy.int64)
345
    joint_states = numpy.empty((n_seqs, 2), dtype=numpy.uint8)
346

347
    # the state indices
348
    states_12 = numpy.arange(counts_12.size).reshape(counts_12.shape)
349
    states_1 = numpy.arange(counts_1.size)
350
    states_2 = numpy.arange(counts_2.size)
351

352
    for pair in range(len(positions)):
353
        i, j = positions[pair]
354
        joint_states[:, 0] = alignment[:, i]
355
        joint_states[:, 1] = alignment[:, j]
356

357
        counts_12, counts_1, counts_2 = _count_col_joint(
358
            joint_states,
359
            counts_12,
360
            counts_1,
361
            counts_2,
362
            num_states,
363
        )
364

365
        weights_1 = _make_weights(counts_1, weights_1)
366
        weights_2 = _make_weights(counts_2, weights_2)
367
        entropy, pairs, scales = _calc_pair_scale(
368
            counts_12,
369
            counts_1,
370
            counts_2,
371
            weights_1,
372
            weights_2,
373
            states_12,
374
            states_1,
375
            states_2,
376
            coeffs,
377
        )
378
        if entropy == 0.0:
379
            stats[pair] = 0.0
380
        else:
381
            stat = 0.0
382
            for i in range(scales.shape[0]):
383
                e, w = scales[i]
384
                # we round the revised entropy to avoid floating point errors
385
                # this is in effect a more stringent condition
386
                if entropy > numpy.round(e, 10):
387
                    continue
388

389
                p1, p2 = pairs[i]
390
                stat += w * counts_12[p1, p2]
391

392
            stats[pair] = 1 - stat
393

394
    return positions, stats
395

396

397
@numba.jit
398
def _calc_all_entropies(
399
    joint_states: numpy.ndarray,
400
    joint_counts: numpy.ndarray,
401
    counts_1: numpy.ndarray,
402
    counts_2: numpy.ndarray,
403
    num_states: int,
404
) -> tuple[float, float, float]:  # pragma: no cover
405
    joint_counts, counts_1, counts_2 = _count_col_joint(
406
        joint_states,
407
        joint_counts,
408
        counts_1,
409
        counts_2,
410
        num_states,
411
    )
412
    entropy_1 = _vector_entropy(counts_1)
413
    entropy_2 = _vector_entropy(counts_2)
414
    if entropy_1 == 0.0 or entropy_2 == 0.0:
415
        return entropy_1, entropy_2, 0.0
416

417
    joint_entropy = _calc_joint_entropy(joint_counts)
418
    return entropy_1, entropy_2, joint_entropy
419

420

421
@numba.jit
422
def _general_mi_calc(
423
    positions: numpy.ndarray,
424
    canonical_pos: numpy.ndarray,
425
    alignment: numpy.ndarray,
426
    entropies: numpy.ndarray,
427
    num_states: numpy.uint8,
428
    metric_id: int = MI_METHODS.mi,
429
) -> tuple[numpy.ndarray, numpy.ndarray]:  # pragma: no cover
430
    n_seqs = alignment.shape[0]
431
    stats = numpy.empty(len(positions), dtype=numpy.float64)
432
    stats.fill(numpy.nan)
433

434
    counts_1 = numpy.empty(num_states, dtype=numpy.int64)
435
    counts_2 = numpy.empty(num_states, dtype=numpy.int64)
436
    joint_counts = numpy.zeros((num_states, num_states), dtype=numpy.int64)
437
    joint_states = numpy.empty((n_seqs, 2), dtype=numpy.uint8)
438
    for pair in range(len(positions)):
439
        i, j = positions[pair]
440
        joint_states[:, 0] = alignment[:, i]
441
        joint_states[:, 1] = alignment[:, j]
442
        if canonical_pos[i] and canonical_pos[j]:
443
            entropy_i = entropies[i]
444
            entropy_j = entropies[j]
445
            joint_counts = _count_joint_states(joint_states, num_states, joint_counts)
446
            joint_entropy = _calc_joint_entropy(
447
                joint_counts,
448
            )
449
        else:
450
            # we need to compute all entropies
451
            # cases where either position has a non-canonical
452
            # state are omitted
453
            entropy_i, entropy_j, joint_entropy = _calc_all_entropies(
454
                joint_states,
455
                joint_counts,
456
                counts_1,
457
                counts_2,
458
                num_states,
459
            )
460

461
        # MI
462
        stat = entropy_i + entropy_j - joint_entropy
463
        if metric_id == 2 and joint_entropy != 0.0:
464
            # normalised MI
465
            stat /= joint_entropy
466

467
        stats[pair] = stat
468
    return positions, stats
469

470

471
def _gen_combinations(num_pos: int, chunk_size: int) -> typing.Iterator[numpy.ndarray]:
6✔
472
    combs = itertools.combinations(range(num_pos), 2)
6✔
473

474
    while True:
6✔
475
        if chunk := list(itertools.islice(combs, chunk_size)):
6✔
476
            yield numpy.array(chunk)
6✔
477
        else:
478
            break
6✔
479

480

481
class calc_mi:
6✔
482
    """calculator for mutual information or normalised mutual information
483

484
    callable with positions to calculate the statistic for.
485
    """
486

487
    def __init__(
6✔
488
        self,
489
        data: numpy.ndarray,
490
        num_states: int,
491
        metric_id: enum.Enum,
492
    ) -> None:
493
        """
494
        Parameters
495
        ----------
496
        data
497
            a 2D numpy array of uint8 values representing a multiple sequence
498
            alignment where sequences are the first dimension and positions are
499
            the second dimension.
500
        num_states
501
            the number of canonical states in the moltype. Sequence elements that
502
            exceed this value are not included in the calculation.
503
        metric_id
504
            either 1 (for MI) or 2 (for NMI)
505
        """
506
        self._data = data
6✔
507
        self._num_states = num_states
6✔
508
        self._metric_id = metric_id
6✔
509
        self._canonical_pos = numpy.all(self._data < self._num_states, axis=0)
6✔
510
        self._entropies = _calc_column_entropies(self._data, self._num_states)
6✔
511

512
    def __call__(self, positions: numpy.ndarray) -> tuple[numpy.ndarray, numpy.ndarray]:
6✔
513
        return _general_mi_calc(
6✔
514
            positions,
515
            self._canonical_pos,
516
            self._data,
517
            self._entropies,
518
            self._num_states,
519
            metric_id=self._metric_id,
520
        )
521

522

523
class calc_rmi:
6✔
524
    """calculator for resampled mutual information
525

526
    Callable with positions to calculate the statistic for.
527
    When called, it returns the positions and their corresponding statistic.
528
    """
529

530
    def __init__(self, data: numpy.ndarray, num_states: int) -> None:
6✔
531
        """
532
        Parameters
533
        ----------
534
        data
535
            a 2D numpy array of uint8 values representing a multiple sequence
536
            alignment where sequences are the first dimension and positions are
537
            the second dimension.
538
        num_states
539
            the number of canonical states in the moltype. Sequence elements that
540
            exceed this value are not included in the calculation.
541
        """
542
        self._data = data
6✔
543
        self._num_states = num_states
6✔
544

545
    def __call__(self, positions: numpy.ndarray) -> tuple[numpy.ndarray, numpy.ndarray]:
6✔
546
        return _rmi_calc(positions, self._data, self._num_states)
6✔
547

548

549
@UI.display_wrap
6✔
550
def coevolution_matrix(
6✔
551
    *,
552
    alignment: "Alignment",
553
    positions: list[int] | None = None,
554
    stat: str = "nmi",
555
    parallel: bool = False,
556
    par_kw: dict | None = None,
557
    show_progress: bool = False,
558
    ui=None,
559
) -> dict_array.DictArray:
560
    """measure pairwise coevolution
561

562
    Parameters
563
    ----------
564
    aln
565
        sequence alignment
566
    stat
567
        either 'mi' (mutual information), 'nmi' (normalised MI) or 'rmi' (resampled MI)
568
    parallel
569
        run in parallel on your machine
570
    par_kw
571
        providing {'max_workers': 6} defines the number of workers to use, see
572
        arguments for cogent3.util.parallel.as_completed()
573
    show_progress
574
        displays a progress bar
575

576
    Returns
577
    -------
578
    Returns a DictArray with the pairwise coevolution values as a lower triangle. The other
579
    values are nan.
580
    """
581
    stat = {"mi": 1, "nmi": 2, "rmi": 3}[MI_METHODS(stat).name]
6✔
582
    num_states = len(alignment.moltype.alphabet)
6✔
583

584
    data = numpy.array(alignment)
6✔
585

586
    if positions:
6✔
UNCOV
587
        positions = list(itertools.chain(*positions))
×
UNCOV
588
        data = data[:, tuple(positions)]
×
589
    else:
590
        positions = range(data.shape[1])
6✔
591

592
    num_pos = data.shape[1]
6✔
593

594
    calc = calc_rmi(data, num_states) if stat == 3 else calc_mi(data, num_states, stat)
6✔
595

596
    mutual_info = numpy.empty((num_pos, num_pos), dtype=numpy.float64)
6✔
597
    mutual_info.fill(numpy.nan)
6✔
598

599
    # we generate the positions as a numpy.array of tuples
600
    chunk_size = 10_000
6✔
601
    num_chunks = num_pos * (num_pos - 1) // 2 // chunk_size
6✔
602

603
    position_combinations = _gen_combinations(num_pos, chunk_size)
6✔
604
    if parallel:
6✔
UNCOV
605
        par_kw = par_kw or {}
×
UNCOV
606
        to_do = PAR.as_completed(calc, position_combinations, **par_kw)
×
607
    else:
608
        to_do = map(calc, position_combinations)
6✔
609

610
    for pos_pairs, stats in ui.series(
6✔
611
        to_do,
612
        noun="Sets of pairwise positions",
613
        count=num_chunks + 1,
614
    ):
615
        indices = numpy.ravel_multi_index(pos_pairs.T[::-1], (num_pos, num_pos))
6✔
616
        mutual_info.put(indices, stats)
6✔
617

618
    positions = list(positions)
6✔
619
    return dict_array.DictArray.from_array_names(mutual_info, positions, positions)
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc