• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

qiskit-community / qiskit-machine-learning / 23496395751

24 Mar 2026 03:01PM UTC coverage: 89.643% (-1.0%) from 90.603%
23496395751

push

github

web-flow
Add step to upload coverage to Coveralls (#1025)

* Add step to upload coverage to Coveralls

* Update Coveralls upload step in workflow

* Remove redundant lines in Coveralls action

* Modify CI workflow for coverage and dependencies

Updated pip install command to include toml support and improved coverage reporting steps.

---------

Co-authored-by: M. Emre Sahin <40424147+OkuyanBoga@users.noreply.github.com>

5167 of 5764 relevant lines covered (89.64%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.86
/qiskit_machine_learning/algorithms/inference/qbayesian.py
1
# This code is part of a Qiskit project.
2
#
3
# (C) Copyright IBM 2023, 2025.
4
#
5
# This code is licensed under the Apache License, Version 2.0. You may
6
# obtain a copy of this license in the LICENSE.txt file in the root directory
7
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
8
#
9
# Any modifications or derivative works of this code must retain this
10
# copyright notice, and modified files need to carry a notice indicating
11
# that they have been altered from the originals.
12
"""Quantum Bayesian Inference"""
13

14
from __future__ import annotations
1✔
15

16
import copy
1✔
17
from typing import Set
1✔
18

19
from qiskit import ClassicalRegister, QuantumCircuit
1✔
20
from qiskit.circuit import Qubit
1✔
21
from qiskit.circuit.library import grover_operator
1✔
22
from qiskit.primitives import BaseSamplerV2
1✔
23
from qiskit.quantum_info import Statevector
1✔
24
from qiskit.transpiler.passmanager import BasePassManager
1✔
25

26
from qiskit_machine_learning.primitives import QMLSampler as Sampler
1✔
27

28

29
class QBayesian:
1✔
30
    r"""
31
    Implements a quantum Bayesian inference (QBI) algorithm that has been developed in [1]. The
32
    Bayesian network must be based on binary random variables (0/1) and represented by a quantum
33
    circuit. The quantum circuit can be passed in various forms as long as it represents the joint
34
    probability distribution of the network.
35

36
    For Bayesian networks with random variables that have more than two states, see for example [2].
37

38
    Note that ``QBayesian`` defines an order for the qubits in the circuit. The last qubit in the
39
    circuit will correspond to the most significant bit in the joint probability distribution. For
40
    example, if the random variables A, B, and C are entered into the circuit in this order with
41
    (A=1, B=0 and C=0), the probability is represented by the probability amplitude of quantum
42
    state 001.
43

44
    **Example**
45

46
    .. code-block:: python
47

48
        qc = QuantumCircuit(...)
49

50
        qb = QBayesian(qc)
51
        result = qb.inference(query={...}, evidence={...})
52
        print("Probability of query given evidence: ", result)
53

54
    **References**
55
        [1]: Low, Guang Hao, Theodore J. Yoder, and Isaac L. Chuang. "Quantum inference on Bayesian
56
        networks", Physical Review A 89.6 (2014): 062315.
57

58
        [2]: Borujeni, Sima E., et al. "Quantum circuit representation of Bayesian networks."
59
        Expert Systems with Applications 176 (2021): 114768.
60
    """
61

62
    # Discrete quantum Bayesian network
63
    def __init__(
1✔
64
        self,
65
        circuit: QuantumCircuit,
66
        *,
67
        limit: int = 10,
68
        threshold: float = 0.9,
69
        sampler: BaseSamplerV2 | None = None,
70
        pass_manager: BasePassManager | None = None,
71
    ):
72
        """
73
        Args:
74
            circuit: The quantum circuit that represents the Bayesian network. Each random variable
75
                should be assigned to exactly one register of one qubit. The last qubit in the
76
                circuit corresponds to the most significant bit in the binary string, which
77
                represents the measured quantum state.
78
            limit: The maximum number of times the Grover operator is integrated (2^limit).
79
            threshold (float): The threshold to accept the evidence. For example, if set to 0.9,
80
                this means that each evidence qubit must be equal to the value of the evidence
81
                variable at least 90% of the measurements.
82
            sampler: The sampler primitive used to compute the Bayesian inference.
83
                If ``None`` is given, a default instance of the reference sampler defined
84
                by :class:`~qiskit.primitives.Sampler` will be used.
85
        Raises:
86
            ValueError: If any register in the circuit is not mapped to exactly one qubit.
87
        """
88
        # Test valid input
89
        for qrg in circuit.qregs:
1✔
90
            if qrg.size > 1:
1✔
91
                raise ValueError("Every register needs to be mapped to exactly one unique qubit.")
1✔
92

93
        # Initialize parameter
94
        self._circ = circuit
1✔
95
        self._limit = limit
1✔
96
        self._threshold = threshold
1✔
97
        if sampler is None:
1✔
98
            sampler = Sampler()
1✔
99

100
        self._sampler = sampler
1✔
101

102
        if hasattr(circuit.layout, "_input_qubit_count"):
1✔
103
            self.num_virtual_qubits = circuit.layout._input_qubit_count
×
104
        else:
105
            if pass_manager is None:
1✔
106
                self.num_virtual_qubits = circuit.num_qubits
1✔
107
            else:
108
                circuit = pass_manager.run(circuit)
1✔
109
                self.num_virtual_qubits = circuit.layout._input_qubit_count
1✔
110
        self._pass_manager = pass_manager
1✔
111

112
        # Label of register mapped to its qubit
113
        self._label2qubit = {qrg.name: qrg[0] for qrg in self._circ.qregs}
1✔
114
        # Label of register mapped to its qubit index bottom up in significance
115
        self._label2qidx = {
1✔
116
            qrg.name: self._circ.num_qubits - idx - 1 for idx, qrg in enumerate(self._circ.qregs)
117
        }
118
        # Distribution of samples from rejection sampling
119
        self._samples: dict[str, float] = {}
1✔
120
        # True if rejection sampling converged after limit
121
        self._converged = bool()
1✔
122

123
    def _get_grover_op(self, evidence: dict[str, int]) -> QuantumCircuit:
1✔
124
        """
125
        Constructs a Grover operator based on the provided evidence. The evidence is used to
126
        determine the "good states" that the Grover operator will amplify.
127

128
        Args:
129
            evidence: A dictionary representing the evidence with keys as variable labels
130
                and values as states.
131
        Returns:
132
            GroverOperator: The constructed Grover operator.
133
        """
134
        # Evidence to reversed qubit index sorted by index
135
        num_qubits = self._circ.num_qubits
1✔
136
        e2idx = sorted(
1✔
137
            [(self._label2qidx[e_key], e_val) for e_key, e_val in evidence.items()],
138
            key=lambda x: x[0],
139
        )
140
        # Binary format of good states
141
        num_evd = len(e2idx)
1✔
142
        bin_str = [
1✔
143
            format(i, f"0{(num_qubits - num_evd)}b") for i in range(2 ** (num_qubits - num_evd))
144
        ]
145
        # Get good states
146
        good_states = []
1✔
147
        for b in bin_str:
1✔
148
            for e_idx, e_val in e2idx:
1✔
149
                b = b[:e_idx] + str(e_val) + b[e_idx:]
1✔
150
            good_states.append(b)
1✔
151
        # Get statevector by transform good states w.r.t its index to 1 and o/w to 0
152
        oracle = Statevector(
1✔
153
            [int(format(i, f"0{num_qubits}b") in good_states) for i in range(2**num_qubits)]
154
        )
155
        return grover_operator(oracle, state_preparation=self._circ)
1✔
156

157
    def _run_circuit(self, circuit: QuantumCircuit) -> dict[str, float]:
1✔
158
        """Run the quantum circuit with the sampler and return P(bitstring) with fixed width."""
159
        if self._pass_manager is not None:
1✔
160
            circuit = self._pass_manager.run(circuit)
1✔
161

162
        job = self._sampler.run([circuit])
1✔
163
        res = job.result()
1✔
164
        pub = res[0]
1✔
165

166
        # Default
167
        bit_counts = {}
1✔
168

169
        # 1) Prefer robust, register-agnostic access (no try/except: guards only)
170
        join_data = getattr(pub, "join_data", None)
1✔
171
        if callable(join_data):
1✔
172
            joined = join_data()
1✔
173
            get_counts = getattr(joined, "get_counts", None)
1✔
174
            if callable(get_counts):
1✔
175
                bit_counts = get_counts()
1✔
176

177
        # 2) Fallback: first available register deterministically
178
        if not bit_counts:
1✔
179
            data = getattr(pub, "data", None)
×
180
            if data is not None:
×
181
                # dict-like (fast + deterministic)
182
                if isinstance(data, dict):
×
183
                    for reg_name in sorted(data):
×
184
                        reg = data[reg_name]
×
185
                        gc = getattr(reg, "get_counts", None)
×
186
                        if callable(gc):
×
187
                            bit_counts = gc()
×
188
                            break
×
189

190
                # object-like container (deterministic by sorted dir)
191
                else:
192
                    for reg_name in sorted(n for n in dir(data) if not n.startswith("_")):
×
193
                        reg = getattr(data, reg_name, None)
×
194
                        gc = getattr(reg, "get_counts", None)
×
195
                        if callable(gc):
×
196
                            bit_counts = gc()
×
197
                            break
×
198

199
        total = sum(bit_counts.values())
1✔
200
        if total == 0:
1✔
201
            return {}
×
202

203
        width = circuit.num_clbits  # number of measured classical bits in this circuit instance
1✔
204

205
        out: dict[str, float] = {}
1✔
206

207
        def _to_bin_key(k) -> str:
1✔
208
            if isinstance(k, (int,)):
1✔
209
                return format(int(k), f"0{width}b")
×
210
            ks = str(k).replace(" ", "")
1✔
211
            if ks.startswith(("0b", "0B")):
1✔
212
                return format(int(ks, 2), f"0{width}b")
×
213
            if ks.startswith(("0x", "0X")):
1✔
214
                return format(int(ks, 16), f"0{width}b")
×
215
            if set(ks) <= {"0", "1"} and len(ks) <= width:
1✔
216
                return ks.zfill(width)
1✔
217
            # decimal string
218
            return format(int(ks), f"0{width}b")
×
219

220
        for k, v in bit_counts.items():
1✔
221
            out[_to_bin_key(k)] = out.get(_to_bin_key(k), 0.0) + v / total
1✔
222

223
        return out
1✔
224

225
    def __power_grover(
1✔
226
        self, grover_op: QuantumCircuit, evidence: dict[str, int], k: int
227
    ) -> tuple[QuantumCircuit, Set[tuple[Qubit, int]]]:
228
        """
229
        Applies the Grover operator to the quantum circuit 2^k times, measures the evidence qubits,
230
        and returns a tuple containing the updated quantum circuit and a set of the measured
231
        evidence qubits.
232

233
        Args:
234
            grover_op: The Grover operator to be applied.
235
            evidence: A dictionary representing the evidence.
236
            k: The power to which the Grover operator is raised.
237
        Returns:
238
            tuple: A tuple containing the updated quantum circuit and a set of the measured evidence
239
                qubits.
240
        """
241
        # Create circuit
242
        qc = QuantumCircuit(*self._circ.qregs)
1✔
243
        qc.append(self._circ, self._circ.qregs)
1✔
244
        # Apply Grover operator 2^k times
245
        qc_grover = QuantumCircuit(*self._circ.qregs)
1✔
246
        qc_grover.append(grover_op, self._circ.qregs)
1✔
247
        qc_grover = qc_grover.power(2**k)
1✔
248
        qc.append(qc_grover, self._circ.qregs)
1✔
249
        # Add quantum circuit for measuring
250
        qc_measure = QuantumCircuit(*self._circ.qregs)
1✔
251
        qc_measure.append(qc, self._circ.qregs)
1✔
252
        # Create a classical register with the size of the evidence
253
        measurement_ecr = ClassicalRegister(len(evidence))
1✔
254
        qc_measure.add_register(measurement_ecr)
1✔
255
        # Map the evidence qubits to the classical bits and measure them
256
        evidence_qubits = [self._label2qubit[e_key] for e_key in evidence]
1✔
257
        qc_measure.measure(evidence_qubits, measurement_ecr)
1✔
258
        # Run the circuit with the Grover operator and measurements
259
        e_samples = self._run_circuit(qc_measure)
1✔
260
        e_count = {self._label2qubit[e]: 0.0 for e in evidence}
1✔
261
        for e_sample_key, e_sample_val in e_samples.items():
1✔
262
            # Go through reverse binary that matches order of qubits
263
            for i, char in enumerate(e_sample_key[::-1]):
1✔
264
                if int(char) == 1:
1✔
265
                    e_count[evidence_qubits[i]] += e_sample_val
1✔
266
        # Assign to every evidence qubit if it is measured with high probability (th) 1 o/w 0
267
        e_meas = {
1✔
268
            (e_count_key, int(e_count_val >= self._threshold))
269
            for e_count_key, e_count_val in e_count.items()
270
        }
271
        return qc, e_meas
1✔
272

273
    def _format_samples(self, samples: dict[str, float], evidence: list[str]) -> dict[str, float]:
1✔
274
        """Transforms samples keys back to their variables names."""
275
        f_samples: dict[str, float] = {}
1✔
276
        for smpl_key, smpl_val in samples.items():
1✔
277
            q_str, e_str = "", ""
1✔
278
            for var_name, var_idx in sorted(self._label2qidx.items(), key=lambda x: -x[1]):
1✔
279
                if var_name in evidence:
1✔
280
                    e_str += f"{var_name}={smpl_key[var_idx]},"
1✔
281
                else:
282
                    q_str += f"{var_name}={smpl_key[var_idx]},"
1✔
283
            if evidence:
1✔
284
                f_samples[f"P({q_str[:-1]}|{e_str[:-1]})"] = smpl_val
1✔
285
            else:
286
                f_samples[f"P({q_str[:-1]})"] = smpl_val
1✔
287
        return f_samples
1✔
288

289
    def rejection_sampling(
1✔
290
        self, evidence: dict[str, int], format_res: bool = False
291
    ) -> dict[str, float]:
292
        """
293
        Performs quantum rejection sampling given the evidence.
294

295
        Args:
296
            evidence: The keys of the dictionary are the evidence variables that are linked to the
297
                corresponding quantum register with their names and values (0/1). If evidence is
298
                empty, it measures all qubits. If evidence is given, it uses the Grover operator for
299
                amplitude amplification and repeats until the evidence matches or limit is reached.
300
            format_res: If true, maps the output back to variable names. For example, the output
301
                {'100': 0.23} with evidence A=0, B=0 will be mapped to {'P(C=1|A=0,B=0)': 0.23}.
302
        Returns:
303
            A dictionary with the probability distribution of the samples given the evidence, where
304
            the keys are the sequential values of the variables. Note that the last variable value
305
            appears as the first character for the key. If format_res is true, the output will be
306
            mapped back to the variable names, for example {'P(C=1|A=0,B=0)': 0.23}.
307
        """
308
        # If evidence is empty
309
        if len(evidence) == 0:
1✔
310
            # Create circuit
311
            qc = QuantumCircuit(*self._circ.qregs)
1✔
312
            qc.append(self._circ, self._circ.qregs)
1✔
313
            # Measure
314
            qc.measure_all()
1✔
315
            # Run circuit
316
            self._samples = self._run_circuit(qc)
1✔
317
        else:
318
            # Get Grover operator if evidence not empty
319
            grover_op = self._get_grover_op(evidence)
1✔
320
            # Amplitude amplification
321
            true_e = {(self._label2qubit[e_key], e_val) for e_key, e_val in evidence.items()}
1✔
322
            meas_e: Set[tuple[str, int]] = set()
1✔
323
            best_qc, best_inter = QuantumCircuit(), -1
1✔
324
            self._converged = False
1✔
325
            k = -1
1✔
326
            # If the measurement of the evidence qubits matches the evidence stop
327
            while (true_e != meas_e) and (k < self._limit):
1✔
328
                # Increment power
329
                k += 1
1✔
330
                # Create circuit with 2^k times Grover operator
331
                qc, meas_e = self.__power_grover(grover_op=grover_op, evidence=evidence, k=k)
1✔
332
                # Test number of
333
                if len(true_e.intersection(meas_e)) > best_inter:
1✔
334
                    best_qc = qc
1✔
335
            if true_e == meas_e:
1✔
336
                self._converged = True
1✔
337
            # Create a classical register with the size of the evidence
338
            best_qc_meas = QuantumCircuit(*self._circ.qregs)
1✔
339
            best_qc_meas.append(best_qc, self._circ.qregs)
1✔
340
            measurement_qcr = ClassicalRegister(self._circ.num_qubits - len(evidence))
1✔
341
            best_qc_meas.add_register(measurement_qcr)
1✔
342
            # Map the query qubits to the classical bits and measure them
343
            query_qubits = [
1✔
344
                (label, self._label2qidx[label], qubit)
345
                for label, qubit in self._label2qubit.items()
346
                if label not in evidence
347
            ]
348
            query_qubits_sorted = sorted(query_qubits, key=lambda x: x[1], reverse=True)
1✔
349
            # Measure query variables and return their count
350
            best_qc_meas.measure([q[2] for q in query_qubits_sorted], measurement_qcr)
1✔
351
            # Run circuit
352
            counts = self._run_circuit(best_qc_meas)
1✔
353
            # Build default string with evidence
354
            query_string = ""
1✔
355
            var_idx_sorted = [
1✔
356
                label for label, _ in sorted(self._label2qidx.items(), key=lambda x: x[1])
357
            ]
358
            for var in var_idx_sorted:
1✔
359
                if var in evidence:
1✔
360
                    query_string += str(evidence[var])
1✔
361
                else:
362
                    query_string += "q"
1✔
363
            # Retrieve valid samples
364
            self._samples = {}
1✔
365
            # Replace placeholder q with query variables from samples
366
            for key, val in counts.items():
1✔
367
                query = query_string
1✔
368
                for char in key:
1✔
369
                    query = query.replace("q", char, 1)
1✔
370
                self._samples[query] = val
1✔
371
        if not format_res:
1✔
372
            return copy.deepcopy(self._samples)
1✔
373
        else:
374
            return self._format_samples(self._samples, list(evidence.keys()))
1✔
375

376
    def inference(
1✔
377
        self,
378
        query: dict[str, int],
379
        evidence: dict[str, int] = None,
380
    ) -> float:
381
        """
382
        Performs quantum inference for the query variables given the evidence. It uses quantum
383
        rejection sampling if evidence is given and calculates the probability of the query.
384

385
        Args:
386
            query: The keys of the dictionary are the query variables that are linked to the
387
                corresponding quantum registers with their names and values (0/1). If the query
388
                variables are a real subset of all variables without the evidence, the query will be
389
                marginalized.
390
            evidence: The evidence variables. If evidence is a dictionary, the rejection sampling is
391
                executed with the keys representing the variables linked to the corresponding
392
                quantum register by their names and values (0/1). If evidence is ``None``, the
393
                default, then samples from the previous rejection sampling are used.
394
        Returns:
395
            The probability of the query given the evidence.
396
        Raises:
397
            ValueError: If evidence is required for rejection sampling and ``None`` is given.
398
        """
399
        if evidence is not None:
1✔
400
            self.rejection_sampling(evidence)
1✔
401
        elif not self._samples:
1✔
402
            raise ValueError("Provide evidence or indicate no evidence with an empty dictionary")
1✔
403
        # Get sorted indices of query qubits
404
        query_indices_rev = [(self._label2qidx[q_key], q_val) for q_key, q_val in query.items()]
1✔
405
        # Get probability of query
406
        res = 0.0
1✔
407
        for sample_key, sample_val in self._samples.items():
1✔
408
            add = True
1✔
409
            for q_idx, q_val in query_indices_rev:
1✔
410
                if int(sample_key[q_idx]) != q_val:
1✔
411
                    add = False
1✔
412
                    break
1✔
413
            if add:
1✔
414
                res += sample_val
1✔
415
        return res
1✔
416

417
    @property
1✔
418
    def converged(self) -> bool:
1✔
419
        """Returns ``True`` if a solution for the evidence with the given threshold was found
420
        without reaching the maximum number of times the Grover operator was applied (2^limit)."""
421
        return self._converged
1✔
422

423
    @property
1✔
424
    def samples(self) -> dict[str, float]:
1✔
425
        """Returns the samples generated from the rejection sampling."""
426
        return self._samples
1✔
427

428
    @property
1✔
429
    def limit(self) -> int:
1✔
430
        """Returns the maximum number of times the Grover operator can be applied (2^limit)."""
431
        return self._limit
1✔
432

433
    @limit.setter
1✔
434
    def limit(self, limit: int):
1✔
435
        """Set the maximum number of times the Grover operator can be applied (2^limit)."""
436
        self._limit = limit
1✔
437

438
    @property
1✔
439
    def sampler(self) -> BaseSamplerV2:
1✔
440
        """Returns the sampler primitive used to compute the samples."""
441
        return self._sampler
1✔
442

443
    @sampler.setter
1✔
444
    def sampler(self, sampler: BaseSamplerV2):
1✔
445
        """Set the sampler primitive used to compute the samples."""
446
        self._sampler = sampler
1✔
447

448
    @property
1✔
449
    def threshold(self) -> float:
1✔
450
        """Returns the threshold to accept the evidence."""
451
        return self._threshold
1✔
452

453
    @threshold.setter
1✔
454
    def threshold(self, threshold: float):
1✔
455
        """Set the threshold to accept the evidence."""
456
        self._threshold = threshold
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc