• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Qiskit / qiskit / 23659987974

27 Mar 2026 05:50PM UTC coverage: 87.498% (+0.2%) from 87.267%
23659987974

Pull #15488

github

web-flow
Merge 7be547e15 into 9c6baac6f
Pull Request #15488: Support commutation check between Pauli-based gates and standard gates

281 of 294 new or added lines in 5 files covered. (95.58%)

1087 existing lines in 31 files now uncovered.

104317 of 119222 relevant lines covered (87.5%)

1020836.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

57.3
/qiskit/visualization/state_visualization.py
1
# This code is part of Qiskit.
2
#
3
# (C) Copyright IBM 2017, 2023.
4
#
5
# This code is licensed under the Apache License, Version 2.0. You may
6
# obtain a copy of this license in the LICENSE.txt file in the root directory
7
# of this source tree or at https://www.apache.org/licenses/LICENSE-2.0.
8
#
9
# Any modifications or derivative works of this code must retain this
10
# copyright notice, and modified files need to carry a notice indicating
11
# that they have been altered from the originals.
12

13

14
"""
15
Visualization functions for quantum states.
16
"""
17

18
import math
1✔
19

20
from functools import reduce
1✔
21
import colorsys
1✔
22

23
import numpy as np
1✔
24
from qiskit import user_config
1✔
25
from qiskit.quantum_info.states.statevector import Statevector
1✔
26
from qiskit.quantum_info.operators.operator import Operator
1✔
27
from qiskit.quantum_info.operators.symplectic import PauliList, SparsePauliOp
1✔
28
from qiskit.quantum_info.states.densitymatrix import DensityMatrix
1✔
29
from qiskit.utils import optionals as _optionals
1✔
30
from qiskit.circuit.tools.pi_check import pi_check
1✔
31

32
from .array import _num_to_latex, array_to_latex
1✔
33
from .utils import matplotlib_close_if_inline
1✔
34
from .exceptions import VisualizationError
1✔
35

36

37
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
38
def plot_state_hinton(state, title="", figsize=None, ax_real=None, ax_imag=None, *, filename=None):
1✔
39
    """Plot a hinton diagram for the density matrix of a quantum state.
40

41
    The hinton diagram represents the values of a matrix using
42
    squares, whose size indicate the magnitude of their corresponding value
43
    and their color, its sign. A white square means the value is positive and
44
    a black one means negative.
45

46
    Args:
47
        state (Statevector or DensityMatrix or ndarray): An N-qubit quantum state.
48
        title (str): a string that represents the plot title
49
        figsize (tuple): Figure size in inches.
50
        filename (str | None): The optional file path to save image to. If not specified
51
            no file is created for the visualization. If this is set the return
52
            from this function will be ``None``.
53
        ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
54
            the visualization output. If none is specified a new matplotlib
55
            Figure will be created and used. If this is specified without an
56
            ax_imag only the real component plot will be generated.
57
            Additionally, if specified there will be no returned Figure since
58
            it is redundant.
59
        ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
60
            the visualization output. If none is specified a new matplotlib
61
            Figure will be created and used. If this is specified without an
62
            ax_imag only the real component plot will be generated.
63
            Additionally, if specified there will be no returned Figure since
64
            it is redundant.
65

66
    Returns:
67
        :class:`matplotlib:matplotlib.figure.Figure` :
68
            The matplotlib.Figure of the visualization if
69
            neither ax_real or ax_imag is set.
70

71
    Raises:
72
        MissingOptionalLibraryError: Requires matplotlib.
73
        VisualizationError: Input is not a valid N-qubit state.
74

75
    Examples:
76
        .. plot::
77
           :alt: Output from the previous code.
78
           :include-source:
79

80
            import numpy as np
81
            from qiskit import QuantumCircuit
82
            from qiskit.quantum_info import DensityMatrix
83
            from qiskit.visualization import plot_state_hinton
84

85
            qc = QuantumCircuit(2)
86
            qc.h([0, 1])
87
            qc.cz(0,1)
88
            qc.ry(np.pi/3 , 0)
89
            qc.rx(np.pi/5, 1)
90

91
            state = DensityMatrix(qc)
92
            plot_state_hinton(state, title="New Hinton Plot")
93

94
    """
95
    from matplotlib import pyplot as plt
1✔
96

97
    # Figure data
98
    rho = DensityMatrix(state)
1✔
99
    num = rho.num_qubits
1✔
100
    if num is None:
1✔
UNCOV
101
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
102
    max_weight = 2 ** math.ceil(math.log2(np.abs(rho.data).max()))
1✔
103
    datareal = np.real(rho.data)
1✔
104
    dataimag = np.imag(rho.data)
1✔
105

106
    if figsize is None:
1✔
107
        figsize = (8, 5)
1✔
108
    if not ax_real and not ax_imag:
1✔
109
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
1✔
110
    else:
UNCOV
111
        if ax_real:
×
112
            fig = ax_real.get_figure()
×
113
        else:
114
            fig = ax_imag.get_figure()
×
UNCOV
115
        ax1 = ax_real
×
UNCOV
116
        ax2 = ax_imag
×
117
    # Reversal is to account for Qiskit's endianness.
118
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
1✔
119
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)][::-1]
1✔
120
    ly, lx = datareal.shape
1✔
121
    # Real
122
    if ax1:
1✔
123
        ax1.patch.set_facecolor("gray")
1✔
124
        ax1.set_aspect("equal", "box")
1✔
125
        ax1.xaxis.set_major_locator(plt.NullLocator())
1✔
126
        ax1.yaxis.set_major_locator(plt.NullLocator())
1✔
127

128
        for (x, y), w in np.ndenumerate(datareal):
1✔
129
            # Convert from matrix coordinates to plot coordinates.
130
            plot_x, plot_y = y, lx - x - 1
1✔
131
            color = "white" if w > 0 else "black"
1✔
132
            size = np.sqrt(np.abs(w) / max_weight)
1✔
133
            rect = plt.Rectangle(
1✔
134
                [0.5 + plot_x - size / 2, 0.5 + plot_y - size / 2],
135
                size,
136
                size,
137
                facecolor=color,
138
                edgecolor=color,
139
            )
140
            ax1.add_patch(rect)
1✔
141

142
        ax1.set_xticks(0.5 + np.arange(lx))
1✔
143
        ax1.set_yticks(0.5 + np.arange(ly))
1✔
144
        ax1.set_xlim([0, lx])
1✔
145
        ax1.set_ylim([0, ly])
1✔
146
        ax1.set_yticklabels(row_names, fontsize=14)
1✔
147
        ax1.set_xticklabels(column_names, fontsize=14, rotation=90)
1✔
148
        ax1.set_title("Re[$\\rho$]", fontsize=14)
1✔
149
    # Imaginary
150
    if ax2:
1✔
151
        ax2.patch.set_facecolor("gray")
1✔
152
        ax2.set_aspect("equal", "box")
1✔
153
        ax2.xaxis.set_major_locator(plt.NullLocator())
1✔
154
        ax2.yaxis.set_major_locator(plt.NullLocator())
1✔
155

156
        for (x, y), w in np.ndenumerate(dataimag):
1✔
157
            # Convert from matrix coordinates to plot coordinates.
158
            plot_x, plot_y = y, lx - x - 1
1✔
159
            color = "white" if w > 0 else "black"
1✔
160
            size = np.sqrt(np.abs(w) / max_weight)
1✔
161
            rect = plt.Rectangle(
1✔
162
                [0.5 + plot_x - size / 2, 0.5 + plot_y - size / 2],
163
                size,
164
                size,
165
                facecolor=color,
166
                edgecolor=color,
167
            )
168
            ax2.add_patch(rect)
1✔
169

170
        ax2.set_xticks(0.5 + np.arange(lx))
1✔
171
        ax2.set_yticks(0.5 + np.arange(ly))
1✔
172
        ax2.set_xlim([0, lx])
1✔
173
        ax2.set_ylim([0, ly])
1✔
174
        ax2.set_yticklabels(row_names, fontsize=14)
1✔
175
        ax2.set_xticklabels(column_names, fontsize=14, rotation=90)
1✔
176
        ax2.set_title("Im[$\\rho$]", fontsize=14)
1✔
177
    fig.tight_layout()
1✔
178
    if title:
1✔
UNCOV
179
        fig.suptitle(title, fontsize=16)
×
180
    if ax_real is None and ax_imag is None:
1✔
181
        matplotlib_close_if_inline(fig)
1✔
182
    if filename is None:
1✔
183
        return fig
1✔
184
    else:
UNCOV
185
        return fig.savefig(filename)
×
186

187

188
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
189
def plot_bloch_vector(
1✔
190
    bloch, title="", ax=None, figsize=None, coord_type="cartesian", font_size=None
191
):
192
    """Plot the Bloch sphere.
193

194
    Plot a Bloch sphere with the specified coordinates, that can be given in both
195
    cartesian and spherical systems.
196

197
    Args:
198
        bloch (list[double]): array of three elements where [<x>, <y>, <z>] (Cartesian)
199
            or [<r>, <theta>, <phi>] (spherical in radians)
200
            <theta> is inclination angle from +z direction
201
            <phi> is azimuth from +x direction
202
        title (str): a string that represents the plot title
203
        ax (matplotlib.axes.Axes): An Axes to use for rendering the bloch
204
            sphere
205
        figsize (tuple): Figure size in inches. Has no effect is passing ``ax``.
206
        coord_type (str): a string that specifies coordinate type for bloch
207
            (Cartesian or spherical), default is Cartesian
208
        font_size (float): Font size.
209

210
    Returns:
211
        :class:`matplotlib:matplotlib.figure.Figure` : A matplotlib figure instance if ``ax = None``.
212

213
    Raises:
214
        MissingOptionalLibraryError: Requires matplotlib.
215

216
    Examples:
217
        .. plot::
218
           :alt: Output from the previous code.
219
           :include-source:
220

221
           from qiskit.visualization import plot_bloch_vector
222

223
           plot_bloch_vector([0,1,0], title="New Bloch Sphere")
224

225
        .. plot::
226
           :alt: Output from the previous code.
227
           :include-source:
228

229
           import numpy as np
230
           from qiskit.visualization import plot_bloch_vector
231

232
           # You can use spherical coordinates instead of cartesian.
233

234
           plot_bloch_vector([1, np.pi/2, np.pi/3], coord_type='spherical')
235

236
    """
237
    from .bloch import Bloch
1✔
238

239
    if figsize is None:
1✔
240
        figsize = (5, 5)
1✔
241
    B = Bloch(axes=ax, font_size=font_size)
1✔
242
    if coord_type == "spherical":
1✔
243
        r, theta, phi = bloch[0], bloch[1], bloch[2]
×
244
        bloch[0] = r * np.sin(theta) * np.cos(phi)
×
UNCOV
245
        bloch[1] = r * np.sin(theta) * np.sin(phi)
×
UNCOV
246
        bloch[2] = r * np.cos(theta)
×
247
    B.add_vectors(bloch)
1✔
248
    B.render(title=title)
1✔
249
    if ax is None:
1✔
250
        fig = B.fig
×
251
        fig.set_size_inches(figsize[0], figsize[1])
×
UNCOV
252
        matplotlib_close_if_inline(fig)
×
UNCOV
253
        return fig
×
254
    return None
1✔
255

256

257
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
258
def plot_bloch_multivector(
1✔
259
    state,
260
    title="",
261
    figsize=None,
262
    *,
263
    reverse_bits=False,
264
    filename=None,
265
    font_size=None,
266
    title_font_size=None,
267
    title_pad=1,
268
):
269
    r"""Plot a Bloch sphere for each qubit.
270

271
    Each component :math:`(x,y,z)` of the Bloch sphere labeled as 'qubit i' represents the expected
272
    value of the corresponding Pauli operator acting only on that qubit, that is, the expected value
273
    of :math:`I_{N-1} \otimes\dotsb\otimes I_{i+1}\otimes P_i \otimes I_{i-1}\otimes\dotsb\otimes
274
    I_0`, where :math:`N` is the number of qubits, :math:`P\in \{X,Y,Z\}` and :math:`I` is the
275
    identity operator.
276

277
    Args:
278
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
279
        title (str): a string that represents the plot title
280
        figsize (tuple): size of each individual Bloch sphere figure, in inches.
281
        reverse_bits (bool): If True, plots qubits following Qiskit's convention [Default:False].
282
        font_size (float): Font size for the Bloch ball figures.
283
        title_font_size (float): Font size for the title.
284
        title_pad (float): Padding for the title (suptitle ``y`` position is ``0.98``
285
        and the image height will be extended by ``1 + title_pad/100``).
286
        filename (str | None): The optional file path to save image to. If not specified
287
            no file is created for the visualization. If this is set the return
288
            from this function will be ``None``.
289

290
    Returns:
291
        :class:`matplotlib:matplotlib.figure.Figure` :
292
            A matplotlib figure instance.
293

294
    Raises:
295
        MissingOptionalLibraryError: Requires matplotlib.
296
        VisualizationError: if input is not a valid N-qubit state.
297

298
    Examples:
299
        .. plot::
300
           :alt: Output from the previous code.
301
           :include-source:
302

303
            from qiskit import QuantumCircuit
304
            from qiskit.quantum_info import Statevector
305
            from qiskit.visualization import plot_bloch_multivector
306

307
            qc = QuantumCircuit(2)
308
            qc.h(0)
309
            qc.x(1)
310

311
            state = Statevector(qc)
312
            plot_bloch_multivector(state)
313

314
        .. plot::
315
           :alt: Output from the previous code.
316
           :include-source:
317

318
           from qiskit import QuantumCircuit
319
           from qiskit.quantum_info import Statevector
320
           from qiskit.visualization import plot_bloch_multivector
321

322
           qc = QuantumCircuit(2)
323
           qc.h(0)
324
           qc.x(1)
325

326
           # You can reverse the order of the qubits.
327

328
           from qiskit.quantum_info import DensityMatrix
329

330
           qc = QuantumCircuit(2)
331
           qc.h([0, 1])
332
           qc.t(1)
333
           qc.s(0)
334
           qc.cx(0,1)
335

336
           matrix = DensityMatrix(qc)
337
           plot_bloch_multivector(matrix, title='My Bloch Spheres', reverse_bits=True)
338

339
    """
340
    from matplotlib import pyplot as plt
1✔
341

342
    # Data
343
    bloch_data = (
1✔
344
        _bloch_multivector_data(state)[::-1] if reverse_bits else _bloch_multivector_data(state)
345
    )
346
    num = len(bloch_data)
1✔
347
    if figsize is not None:
1✔
348
        width, height = figsize
×
UNCOV
349
        width *= num
×
350
    else:
351
        width, height = plt.figaspect(1 / num)
1✔
352
    if len(title) > 0:
1✔
UNCOV
353
        height += 1 + title_pad / 100  # additional space for the title
×
354
    default_title_font_size = font_size if font_size is not None else 16
1✔
355
    title_font_size = title_font_size if title_font_size is not None else default_title_font_size
1✔
356
    fig = plt.figure(figsize=(width, height))
1✔
357
    for i in range(num):
1✔
358
        pos = num - 1 - i if reverse_bits else i
1✔
359
        ax = fig.add_subplot(1, num, i + 1, projection="3d")
1✔
360
        plot_bloch_vector(
1✔
361
            bloch_data[i], "qubit " + str(pos), ax=ax, figsize=figsize, font_size=font_size
362
        )
363
    fig.suptitle(title, fontsize=title_font_size, y=0.98)
1✔
364
    matplotlib_close_if_inline(fig)
1✔
365
    if filename is None:
1✔
366
        try:
1✔
367
            fig.tight_layout()
1✔
UNCOV
368
        except AttributeError:
×
UNCOV
369
            pass
×
370
        return fig
1✔
371
    else:
UNCOV
372
        return fig.savefig(filename)
×
373

374

375
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
376
def plot_state_city(
1✔
377
    state,
378
    title="",
379
    figsize=None,
380
    color=None,
381
    alpha=1,
382
    ax_real=None,
383
    ax_imag=None,
384
    *,
385
    filename=None,
386
):
387
    """Plot the cityscape of quantum state.
388

389
    Plot two 3d bar graphs (two dimensional) of the real and imaginary
390
    part of the density matrix rho.
391

392
    Args:
393
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
394
        title (str): a string that represents the plot title
395
        figsize (tuple): Figure size in inches.
396
        color (list): A list of len=2 giving colors for real and
397
            imaginary components of matrix elements.
398
        alpha (float): Transparency value for bars
399
        ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
400
            the visualization output. If none is specified a new matplotlib
401
            Figure will be created and used. If this is specified without an
402
            ax_imag only the real component plot will be generated.
403
            Additionally, if specified there will be no returned Figure since
404
            it is redundant.
405
        ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
406
            the visualization output. If none is specified a new matplotlib
407
            Figure will be created and used. If this is specified without an
408
            ax_real only the imaginary component plot will be generated.
409
            Additionally, if specified there will be no returned Figure since
410
            it is redundant.
411
        filename (str | None): The optional file path to save image to. If not specified
412
            no file is created for the visualization. If this is set the return
413
            from this function will be ``None``.
414

415
    Returns:
416
        :class:`matplotlib:matplotlib.figure.Figure` :
417
            The matplotlib.Figure of the visualization if the
418
            ``ax_real`` and ``ax_imag`` kwargs are not set
419

420
    Raises:
421
        MissingOptionalLibraryError: Requires matplotlib.
422
        ValueError: When 'color' is not a list of len=2.
423
        VisualizationError: if input is not a valid N-qubit state.
424

425
    Examples:
426
        .. plot::
427
           :alt: Output from the previous code.
428
           :include-source:
429

430
           # You can choose different colors for the real and imaginary parts of the density matrix.
431

432
           from qiskit import QuantumCircuit
433
           from qiskit.quantum_info import DensityMatrix
434
           from qiskit.visualization import plot_state_city
435

436
           qc = QuantumCircuit(2)
437
           qc.h(0)
438
           qc.cx(0, 1)
439

440
           state = DensityMatrix(qc)
441
           plot_state_city(state, color=['midnightblue', 'crimson'], title="New State City")
442

443
        .. plot::
444
           :alt: Output from the previous code.
445
           :include-source:
446

447
           # You can make the bars more transparent to better see the ones that are behind
448
           # if they overlap.
449

450
           import numpy as np
451
           from qiskit.quantum_info import Statevector
452
           from qiskit.visualization import plot_state_city
453
           from qiskit import QuantumCircuit
454

455
           qc = QuantumCircuit(2)
456
           qc.h(0)
457
           qc.cx(0, 1)
458

459

460
           qc = QuantumCircuit(2)
461
           qc.h([0, 1])
462
           qc.cz(0,1)
463
           qc.ry(np.pi/3, 0)
464
           qc.rx(np.pi/5, 1)
465

466
           state = Statevector(qc)
467
           plot_state_city(state, alpha=0.6)
468

469
    """
UNCOV
470
    import matplotlib.colors as mcolors
×
UNCOV
471
    from matplotlib import pyplot as plt
×
472
    from mpl_toolkits.mplot3d.art3d import Poly3DCollection
×
473

UNCOV
474
    rho = DensityMatrix(state)
×
UNCOV
475
    num = rho.num_qubits
×
476
    if num is None:
×
477
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
478

479
    # get the real and imag parts of rho
480
    datareal = np.real(rho.data)
×
481
    dataimag = np.imag(rho.data)
×
482

483
    # get the labels
484
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
×
485
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
×
486

UNCOV
487
    ly, lx = datareal.shape[:2]
×
488
    xpos = np.arange(0, lx, 1)  # Set up a mesh of positions
×
489
    ypos = np.arange(0, ly, 1)
×
490
    xpos, ypos = np.meshgrid(xpos + 0.25, ypos + 0.25)
×
491

UNCOV
492
    xpos = xpos.flatten()
×
493
    ypos = ypos.flatten()
×
494
    zpos = np.zeros(lx * ly)
×
495

496
    dx = 0.5 * np.ones_like(zpos)  # width of bars
×
497
    dy = dx.copy()
×
498
    dzr = datareal.flatten()
×
499
    dzi = dataimag.flatten()
×
500

UNCOV
501
    if color is None:
×
502
        real_color, imag_color = "#648fff", "#648fff"
×
503
    else:
UNCOV
504
        if len(color) != 2:
×
505
            raise ValueError("'color' must be a list of len=2.")
×
506
        real_color = "#648fff" if color[0] is None else color[0]
×
507
        imag_color = "#648fff" if color[1] is None else color[1]
×
UNCOV
508
    if ax_real is None and ax_imag is None:
×
509
        # set default figure size
510
        if figsize is None:
×
511
            figsize = (16, 8)
×
512

UNCOV
513
        fig = plt.figure(figsize=figsize, facecolor="w")
×
514
        ax1 = fig.add_subplot(1, 2, 1, projection="3d", computed_zorder=False)
×
515
        ax2 = fig.add_subplot(1, 2, 2, projection="3d", computed_zorder=False)
×
516

UNCOV
517
    elif ax_real is not None:
×
518
        fig = ax_real.get_figure()
×
UNCOV
519
        ax1 = ax_real
×
520
        ax2 = ax_imag
×
521
    else:
UNCOV
522
        fig = ax_imag.get_figure()
×
UNCOV
523
        ax1 = None
×
524
        ax2 = ax_imag
×
525

526
    fig.tight_layout()
×
527

UNCOV
528
    max_dzr = np.max(dzr)
×
529
    max_dzi = np.max(dzi)
×
530

531
    # Figure scaling variables since fig.tight_layout won't work
UNCOV
532
    fig_width, fig_height = fig.get_size_inches()
×
UNCOV
533
    max_plot_size = min(fig_width / 2.25, fig_height)
×
534
    max_font_size = int(3 * max_plot_size)
×
535
    max_zoom = 10 / (10 + np.sqrt(max_plot_size))
×
536

537
    for ax, dz, col, zlabel in (
×
538
        (ax1, dzr, real_color, "Real"),
539
        (ax2, dzi, imag_color, "Imaginary"),
540
    ):
541

UNCOV
542
        if ax is None:
×
543
            continue
×
544

545
        max_dz = np.max(dz)
×
UNCOV
546
        min_dz = np.min(dz)
×
547

548
        if isinstance(col, str) and col.startswith("#"):
×
UNCOV
549
            col = mcolors.to_rgba_array(col)
×
550

UNCOV
551
        dzn = dz < 0
×
UNCOV
552
        if np.any(dzn):
×
UNCOV
553
            fc = generate_facecolors(
×
554
                xpos[dzn], ypos[dzn], zpos[dzn], dx[dzn], dy[dzn], dz[dzn], col
555
            )
UNCOV
556
            negative_bars = ax.bar3d(
×
557
                xpos[dzn],
558
                ypos[dzn],
559
                zpos[dzn],
560
                dx[dzn],
561
                dy[dzn],
562
                dz[dzn],
563
                alpha=alpha,
564
                zorder=0.625,
565
            )
UNCOV
566
            negative_bars.set_facecolor(fc)
×
567

568
        if min_dz < 0 < max_dz:
×
569
            xlim, ylim = [0, lx], [0, ly]
×
UNCOV
570
            verts = [list(zip(xlim + xlim[::-1], np.repeat(ylim, 2), [0] * 4))]
×
UNCOV
571
            plane = Poly3DCollection(verts, alpha=0.25, facecolor="k", linewidths=1)
×
572
            plane.set_zorder(0.75)
×
UNCOV
573
            ax.add_collection3d(plane)
×
574

UNCOV
575
        dzp = dz >= 0
×
UNCOV
576
        if np.any(dzp):
×
UNCOV
577
            fc = generate_facecolors(
×
578
                xpos[dzp], ypos[dzp], zpos[dzp], dx[dzp], dy[dzp], dz[dzp], col
579
            )
UNCOV
580
            positive_bars = ax.bar3d(
×
581
                xpos[dzp],
582
                ypos[dzp],
583
                zpos[dzp],
584
                dx[dzp],
585
                dy[dzp],
586
                dz[dzp],
587
                alpha=alpha,
588
                zorder=0.875,
589
            )
590
            positive_bars.set_facecolor(fc)
×
591

UNCOV
592
        ax.set_title(f"{zlabel} Amplitude (ρ)", fontsize=max_font_size)
×
593

594
        ax.set_xticks(np.arange(0.5, lx + 0.5, 1))
×
UNCOV
595
        ax.set_yticks(np.arange(0.5, ly + 0.5, 1))
×
596
        if max_dz != min_dz:
×
UNCOV
597
            ax.axes.set_zlim3d(min_dz, max(max_dzr + 1e-9, max_dzi))
×
UNCOV
598
        elif min_dz == 0:
×
599
            ax.axes.set_zlim3d(min_dz, max(max_dzr + 1e-9, max_dzi))
×
600
        else:
UNCOV
601
            ax.axes.set_zlim3d(auto=True)
×
UNCOV
602
        ax.get_autoscalez_on()
×
603

604
        ax.xaxis.set_ticklabels(
×
605
            row_names, fontsize=max_font_size, rotation=45, ha="right", va="top"
606
        )
UNCOV
607
        ax.yaxis.set_ticklabels(
×
608
            column_names, fontsize=max_font_size, rotation=-22.5, ha="left", va="center"
609
        )
610

UNCOV
611
        for tick in ax.zaxis.get_major_ticks():
×
612
            tick.label1.set_fontsize(max_font_size)
×
613
            tick.label1.set_horizontalalignment("left")
×
614
            tick.label1.set_verticalalignment("bottom")
×
615

616
        ax.set_box_aspect(aspect=(4, 4, 4), zoom=max_zoom)
×
617
        ax.set_xmargin(0)
×
UNCOV
618
        ax.set_ymargin(0)
×
619

UNCOV
620
    fig.suptitle(title, fontsize=max_font_size * 1.25)
×
UNCOV
621
    fig.subplots_adjust(top=0.9, bottom=0, left=0, right=1, hspace=0, wspace=0)
×
UNCOV
622
    if ax_real is None and ax_imag is None:
×
UNCOV
623
        matplotlib_close_if_inline(fig)
×
UNCOV
624
    if filename is None:
×
UNCOV
625
        return fig
×
626
    else:
UNCOV
627
        return fig.savefig(filename)
×
628

629

630
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
631
def plot_state_paulivec(state, title="", figsize=None, color=None, ax=None, *, filename=None):
1✔
632
    r"""Plot the Pauli-vector representation of a quantum state as bar graph.
633

634
    The Pauli-vector of a density matrix :math:`\rho` is defined by the expectation of each
635
    possible tensor product of single-qubit Pauli operators (including the identity), that is
636

637
    .. math ::
638

639
        \rho = \frac{1}{2^n} \sum_{\sigma \in \{I, X, Y, Z\}^{\otimes n}}
640
               \mathrm{Tr}(\sigma \rho) \sigma.
641

642
    This function plots the coefficients :math:`\mathrm{Tr}(\sigma\rho)` as bar graph.
643

644
    Args:
645
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
646
        title (str): a string that represents the plot title
647
        figsize (tuple): Figure size in inches.
648
        color (list or str): Color of the coefficient value bars.
649
        ax (matplotlib.axes.Axes): An optional Axes object to be used for
650
            the visualization output. If none is specified a new matplotlib
651
            Figure will be created and used. Additionally, if specified there
652
            will be no returned Figure since it is redundant.
653
        filename (str | None): The optional file path to save image to. If not specified
654
            no file is created for the visualization. If this is set the return
655
            from this function will be ``None``.
656

657
    Returns:
658
         :class:`matplotlib:matplotlib.figure.Figure` :
659
            The matplotlib.Figure of the visualization if the
660
            ``ax`` kwarg is not set
661

662
    Raises:
663
        MissingOptionalLibraryError: Requires matplotlib.
664
        VisualizationError: if input is not a valid N-qubit state.
665

666
    Examples:
667
        .. plot::
668
           :alt: Output from the previous code.
669
           :include-source:
670

671
           # You can set a color for all the bars.
672

673
           from qiskit import QuantumCircuit
674
           from qiskit.quantum_info import Statevector
675
           from qiskit.visualization import plot_state_paulivec
676

677
           qc = QuantumCircuit(2)
678
           qc.h(0)
679
           qc.cx(0, 1)
680

681
           state = Statevector(qc)
682
           plot_state_paulivec(state, color='midnightblue', title="New PauliVec plot")
683

684
        .. plot::
685
           :alt: Output from the previous code.
686
           :include-source:
687

688
           # If you introduce a list with less colors than bars, the color of the bars will
689
           # alternate following the sequence from the list.
690

691
           import numpy as np
692
           from qiskit.quantum_info import DensityMatrix
693
           from qiskit import QuantumCircuit
694
           from qiskit.visualization import plot_state_paulivec
695

696
           qc = QuantumCircuit(2)
697
           qc.h(0)
698
           qc.cx(0, 1)
699

700
           qc = QuantumCircuit(2)
701
           qc.h([0, 1])
702
           qc.cz(0, 1)
703
           qc.ry(np.pi/3, 0)
704
           qc.rx(np.pi/5, 1)
705

706
           matrix = DensityMatrix(qc)
707
           plot_state_paulivec(matrix, color=['crimson', 'midnightblue', 'seagreen'])
708
    """
709
    from matplotlib import pyplot as plt
×
710

711
    labels, values = _paulivec_data(state)
×
712
    numelem = len(values)
×
713

714
    if figsize is None:
×
715
        figsize = (7, 5)
×
716
    if color is None:
×
717
        color = "#648fff"
×
718

UNCOV
719
    ind = np.arange(numelem)  # the x locations for the groups
×
720
    width = 0.5  # the width of the bars
×
721
    if ax is None:
×
722
        return_fig = True
×
723
        fig, ax = plt.subplots(figsize=figsize)
×
724
    else:
725
        return_fig = False
×
726
        fig = ax.get_figure()
×
727
    ax.grid(zorder=0, linewidth=1, linestyle="--")
×
728
    ax.bar(ind, values, width, color=color, zorder=2)
×
729
    ax.axhline(linewidth=1, color="k")
×
730
    # add some text for labels, title, and axes ticks
731
    ax.set_ylabel("Coefficients", fontsize=14)
×
732
    ax.set_xticks(ind)
×
733
    ax.set_yticks([-1, -0.5, 0, 0.5, 1])
×
734
    ax.set_xticklabels(labels, fontsize=14, rotation=70)
×
735
    ax.set_xlabel("Pauli", fontsize=14)
×
736
    ax.set_ylim([-1, 1])
×
737
    ax.set_facecolor("#eeeeee")
×
UNCOV
738
    for tick in ax.xaxis.get_major_ticks() + ax.yaxis.get_major_ticks():
×
739
        tick.label1.set_fontsize(14)
×
UNCOV
740
    ax.set_title(title, fontsize=16)
×
UNCOV
741
    if return_fig:
×
UNCOV
742
        matplotlib_close_if_inline(fig)
×
UNCOV
743
    if filename is None:
×
UNCOV
744
        try:
×
UNCOV
745
            fig.tight_layout()
×
UNCOV
746
        except AttributeError:
×
UNCOV
747
            pass
×
UNCOV
748
        return fig
×
749
    else:
UNCOV
750
        return fig.savefig(filename)
×
751

752

753
def n_choose_k(n, k):
1✔
754
    """Return the number of combinations for n choose k.
755

756
    Args:
757
        n (int): the total number of options .
758
        k (int): The number of elements.
759

760
    Returns:
761
        int: returns the binomial coefficient
762
    """
763
    if n == 0:
1✔
764
        return 0
1✔
765
    return reduce(lambda x, y: x * y[0] / y[1], zip(range(n - k + 1, n + 1), range(1, k + 1)), 1)
1✔
766

767

768
def lex_index(n, k, lst):
1✔
769
    """Return  the lex index of a combination..
770

771
    Args:
772
        n (int): the total number of options .
773
        k (int): The number of elements.
774
        lst (list): list
775

776
    Returns:
777
        int: returns int index for lex order
778

779
    Raises:
780
        VisualizationError: if length of list is not equal to k
781
    """
782
    if len(lst) != k:
1✔
783
        raise VisualizationError("list should have length k")
×
784
    comb = [n - 1 - x for x in lst]
1✔
785
    dualm = sum(n_choose_k(comb[k - 1 - i], i + 1) for i in range(k))
1✔
786
    return int(dualm)
1✔
787

788

789
def bit_string_index(s):
1✔
790
    """Return the index of a string of 0s and 1s."""
791
    n = len(s)
1✔
792
    k = s.count("1")
1✔
793
    if s.count("0") != n - k:
1✔
UNCOV
794
        raise VisualizationError("s must be a string of 0 and 1")
×
795
    ones = [pos for pos, char in enumerate(s) if char == "1"]
1✔
796
    return lex_index(n, k, ones)
1✔
797

798

799
def phase_to_rgb(complex_number):
1✔
800
    """Map a phase of a complex number to a color in (r,g,b).
801

802
    complex_number is phase is first mapped to angle in the range
803
    [0, 2pi] and then to the HSL color wheel
804
    """
805
    angles = (np.angle(complex_number) + (np.pi * 5 / 4)) % (np.pi * 2)
1✔
806
    rgb = colorsys.hls_to_rgb(angles / (np.pi * 2), 0.5, 0.5)
1✔
807
    return rgb
1✔
808

809

810
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
811
@_optionals.HAS_SEABORN.require_in_call
1✔
812
def plot_state_qsphere(
1✔
813
    state,
814
    figsize=None,
815
    ax=None,
816
    show_state_labels=True,
817
    show_state_phases=False,
818
    use_degrees=False,
819
    *,
820
    filename=None,
821
):
822
    """Plot the qsphere representation of a quantum state.
823
    Here, the size of the points is proportional to the probability
824
    of the corresponding term in the state and the color represents
825
    the phase.
826

827
    Args:
828
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
829
        figsize (tuple): Figure size in inches.
830
        ax (matplotlib.axes.Axes): An optional Axes object to be used for
831
            the visualization output. If none is specified a new matplotlib
832
            Figure will be created and used. Additionally, if specified there
833
            will be no returned Figure since it is redundant.
834
        show_state_labels (bool): An optional boolean indicating whether to
835
            show labels for each basis state.
836
        show_state_phases (bool): An optional boolean indicating whether to
837
            show the phase for each basis state.
838
        use_degrees (bool): An optional boolean indicating whether to use
839
            radians or degrees for the phase values in the plot.
840
        filename (str | None): The optional file path to save image to. If not specified
841
            no file is created for the visualization. If this is set the return
842
            from this function will be ``None``.
843

844

845
    Returns:
846
        :class:`matplotlib:matplotlib.figure.Figure` :
847
            A matplotlib figure instance if the ``ax`` kwarg is not set
848

849
    Raises:
850
        MissingOptionalLibraryError: Requires matplotlib.
851
        VisualizationError: Input is not a valid N-qubit state.
852

853
        QiskitError: Input statevector does not have valid dimensions.
854

855
    Examples:
856
        .. plot::
857
           :alt: Output from the previous code.
858
           :include-source:
859

860
           from qiskit import QuantumCircuit
861
           from qiskit.quantum_info import Statevector
862
           from qiskit.visualization import plot_state_qsphere
863

864
           qc = QuantumCircuit(2)
865
           qc.h(0)
866
           qc.cx(0, 1)
867

868
           state = Statevector(qc)
869
           plot_state_qsphere(state)
870

871
        .. plot::
872
           :alt: Output from the previous code.
873
           :include-source:
874

875
           # You can show the phase of each state and use
876
           # degrees instead of radians
877

878
           from qiskit.quantum_info import DensityMatrix
879
           import numpy as np
880
           from qiskit import QuantumCircuit
881
           from qiskit.visualization import plot_state_qsphere
882

883
           qc = QuantumCircuit(2)
884
           qc.h([0, 1])
885
           qc.cz(0,1)
886
           qc.ry(np.pi/3, 0)
887
           qc.rx(np.pi/5, 1)
888
           qc.z(1)
889

890
           matrix = DensityMatrix(qc)
891
           plot_state_qsphere(matrix,
892
                show_state_phases = True, use_degrees = True)
893
    """
894
    from matplotlib import gridspec
1✔
895
    from matplotlib import pyplot as plt
1✔
896
    from matplotlib.patches import Circle
1✔
897
    import seaborn as sns
1✔
898
    from scipy import linalg
1✔
899
    from .bloch import Arrow3D
1✔
900

901
    rho = DensityMatrix(state)
1✔
902
    num = rho.num_qubits
1✔
903
    if num is None:
1✔
UNCOV
904
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
905
    # get the eigenvectors and eigenvalues
906
    eigvals, eigvecs = linalg.eigh(rho.data)
1✔
907

908
    if figsize is None:
1✔
909
        figsize = (7, 7)
1✔
910

911
    if ax is None:
1✔
912
        return_fig = True
1✔
913
        fig = plt.figure(figsize=figsize)
1✔
914
    else:
UNCOV
915
        return_fig = False
×
UNCOV
916
        fig = ax.get_figure()
×
917

918
    gs = gridspec.GridSpec(nrows=3, ncols=3)
1✔
919

920
    ax = fig.add_subplot(gs[0:3, 0:3], projection="3d")
1✔
921
    ax.axes.set_xlim3d(-1.0, 1.0)
1✔
922
    ax.axes.set_ylim3d(-1.0, 1.0)
1✔
923
    ax.axes.set_zlim3d(-1.0, 1.0)
1✔
924
    ax.axes.grid(False)
1✔
925
    ax.view_init(elev=5, azim=275)
1✔
926

927
    # Force aspect ratio
928
    # MPL 3.2 or previous do not have set_box_aspect
929
    if hasattr(ax.axes, "set_box_aspect"):
1✔
930
        ax.axes.set_box_aspect((1, 1, 1))
1✔
931

932
    # start the plotting
933
    # Plot semi-transparent sphere
934
    u = np.linspace(0, 2 * np.pi, 25)
1✔
935
    v = np.linspace(0, np.pi, 25)
1✔
936
    x = np.outer(np.cos(u), np.sin(v))
1✔
937
    y = np.outer(np.sin(u), np.sin(v))
1✔
938
    z = np.outer(np.ones(np.size(u)), np.cos(v))
1✔
939
    ax.plot_surface(
1✔
940
        x, y, z, rstride=1, cstride=1, color=plt.rcParams["grid.color"], alpha=0.2, linewidth=0
941
    )
942

943
    # Get rid of the panes
944
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
945
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
946
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
947

948
    # Get rid of the spines
949
    ax.xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
950
    ax.yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
951
    ax.zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
952

953
    # Get rid of the ticks
954
    ax.set_xticks([])
1✔
955
    ax.set_yticks([])
1✔
956
    ax.set_zticks([])
1✔
957

958
    # traversing the eigvals/vecs backward as sorted low->high
959
    for idx in range(eigvals.shape[0] - 1, -1, -1):
1✔
960
        if eigvals[idx] > 0.001:
1✔
961
            # get the max eigenvalue
962
            state = eigvecs[:, idx]
1✔
963
            # Rounding to 13 decimals ignores machine epsilon noise (~1e-16)
964
            # from the solver, ensuring 'argmax' finds the true analytical winner.
965
            loc = np.round(np.absolute(state), decimals=13).argmax()
1✔
966
            # remove the global phase from max element
967
            angles = (np.angle(state[loc]) + 2 * np.pi) % (2 * np.pi)
1✔
968
            angleset = np.exp(-1j * angles)
1✔
969
            state = angleset * state
1✔
970

971
            d = num
1✔
972
            for i in range(2**num):
1✔
973
                # get x,y,z points
974
                element = bin(i)[2:].zfill(num)
1✔
975
                weight = element.count("1")
1✔
976
                zvalue = -2 * weight / d + 1
1✔
977
                number_of_divisions = n_choose_k(d, weight)
1✔
978
                weight_order = bit_string_index(element)
1✔
979
                angle = (float(weight) / d) * (np.pi * 2) + (
1✔
980
                    weight_order * 2 * (np.pi / number_of_divisions)
981
                )
982

983
                if (weight > d / 2) or (
1✔
984
                    (weight == d / 2) and (weight_order >= number_of_divisions / 2)
985
                ):
986
                    angle = np.pi - angle - (2 * np.pi / number_of_divisions)
1✔
987

988
                xvalue = np.sqrt(1 - zvalue**2) * np.cos(angle)
1✔
989
                yvalue = np.sqrt(1 - zvalue**2) * np.sin(angle)
1✔
990

991
                # get prob and angle - prob will be shade and angle color
992
                prob = np.real(np.dot(state[i], state[i].conj()))
1✔
993
                prob = min(prob, 1)  # See https://github.com/Qiskit/qiskit-terra/issues/4666
1✔
994
                colorstate = phase_to_rgb(state[i])
1✔
995

996
                alfa = 1
1✔
997
                if yvalue >= 0.1:
1✔
998
                    alfa = 1.0 - yvalue
1✔
999

1000
                if not np.isclose(prob, 0) and show_state_labels:
1✔
1001
                    rprime = 1.3
1✔
1002
                    angle_theta = np.arctan2(np.sqrt(1 - zvalue**2), zvalue)
1✔
1003
                    xvalue_text = rprime * np.sin(angle_theta) * np.cos(angle)
1✔
1004
                    yvalue_text = rprime * np.sin(angle_theta) * np.sin(angle)
1✔
1005
                    zvalue_text = rprime * np.cos(angle_theta)
1✔
1006
                    element_text = "$\\vert" + element + "\\rangle$"
1✔
1007
                    if show_state_phases:
1✔
UNCOV
1008
                        element_angle = (np.angle(state[i]) + (np.pi * 4)) % (np.pi * 2)
×
UNCOV
1009
                        if use_degrees:
×
UNCOV
1010
                            element_text += f"\n${element_angle * 180 / np.pi:.1f}^\\circ$"
×
1011
                        else:
UNCOV
1012
                            element_angle = pi_check(element_angle, ndigits=3).replace("pi", "\\pi")
×
UNCOV
1013
                            element_text += f"\n${element_angle}$"
×
1014
                    ax.text(
1✔
1015
                        xvalue_text,
1016
                        yvalue_text,
1017
                        zvalue_text,
1018
                        element_text,
1019
                        ha="center",
1020
                        va="center",
1021
                        size=12,
1022
                    )
1023

1024
                ax.plot(
1✔
1025
                    [xvalue],
1026
                    [yvalue],
1027
                    [zvalue],
1028
                    markerfacecolor=colorstate,
1029
                    markeredgecolor=colorstate,
1030
                    marker="o",
1031
                    markersize=np.sqrt(prob) * 30,
1032
                    alpha=alfa,
1033
                )
1034

1035
                a = Arrow3D(
1✔
1036
                    [0, xvalue],
1037
                    [0, yvalue],
1038
                    [0, zvalue],
1039
                    mutation_scale=20,
1040
                    alpha=prob,
1041
                    arrowstyle="-",
1042
                    color=colorstate,
1043
                    lw=2,
1044
                )
1045
                ax.add_artist(a)
1✔
1046

1047
            # add weight lines
1048
            for weight in range(d + 1):
1✔
1049
                theta = np.linspace(-2 * np.pi, 2 * np.pi, 100)
1✔
1050
                z = -2 * weight / d + 1
1✔
1051
                r = np.sqrt(1 - z**2)
1✔
1052
                x = r * np.cos(theta)
1✔
1053
                y = r * np.sin(theta)
1✔
1054
                ax.plot(x, y, z, color=(0.5, 0.5, 0.5), lw=1, ls=":", alpha=0.5)
1✔
1055

1056
            # add center point
1057
            ax.plot(
1✔
1058
                [0],
1059
                [0],
1060
                [0],
1061
                markerfacecolor=(0.5, 0.5, 0.5),
1062
                markeredgecolor=(0.5, 0.5, 0.5),
1063
                marker="o",
1064
                markersize=3,
1065
                alpha=1,
1066
            )
1067
        else:
1068
            break
1✔
1069

1070
    n = 64
1✔
1071
    theta = np.ones(n)
1✔
1072
    colors = sns.hls_palette(n)
1✔
1073

1074
    ax2 = fig.add_subplot(gs[2:, 2:])
1✔
1075
    ax2.pie(theta, colors=colors[5 * n // 8 :] + colors[: 5 * n // 8], radius=0.75)
1✔
1076
    ax2.add_artist(Circle((0, 0), 0.5, color="white", zorder=1))
1✔
1077
    offset = 0.95  # since radius of sphere is one.
1✔
1078

1079
    if use_degrees:
1✔
UNCOV
1080
        labels = ["Phase\n(Deg)", "0", "90", "180   ", "270"]
×
1081
    else:
1082
        labels = ["Phase", "$0$", "$\\pi/2$", "$\\pi$", "$3\\pi/2$"]
1✔
1083

1084
    ax2.text(0, 0, labels[0], horizontalalignment="center", verticalalignment="center", fontsize=14)
1✔
1085
    ax2.text(
1✔
1086
        offset, 0, labels[1], horizontalalignment="center", verticalalignment="center", fontsize=14
1087
    )
1088
    ax2.text(
1✔
1089
        0, offset, labels[2], horizontalalignment="center", verticalalignment="center", fontsize=14
1090
    )
1091
    ax2.text(
1✔
1092
        -offset, 0, labels[3], horizontalalignment="center", verticalalignment="center", fontsize=14
1093
    )
1094
    ax2.text(
1✔
1095
        0, -offset, labels[4], horizontalalignment="center", verticalalignment="center", fontsize=14
1096
    )
1097

1098
    if return_fig:
1✔
1099
        matplotlib_close_if_inline(fig)
1✔
1100
    if filename is None:
1✔
1101
        return fig
1✔
1102
    else:
UNCOV
1103
        return fig.savefig(filename)
×
1104

1105

1106
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
1107
def generate_facecolors(x, y, z, dx, dy, dz, color):
1✔
1108
    """Generates shaded facecolors for shaded bars.
1109

1110
    This is here to work around a Matplotlib bug
1111
    where alpha does not work in Bar3D.
1112

1113
    Args:
1114
        x (array_like): The x- coordinates of the anchor point of the bars.
1115
        y (array_like): The y- coordinates of the anchor point of the bars.
1116
        z (array_like): The z- coordinates of the anchor point of the bars.
1117
        dx (array_like): Width of bars.
1118
        dy (array_like): Depth of bars.
1119
        dz (array_like): Height of bars.
1120
        color (array_like): sequence of valid color specifications, optional
1121
    Returns:
1122
        list: Shaded colors for bars.
1123
    Raises:
1124
        MissingOptionalLibraryError: If matplotlib is not installed
1125
    """
UNCOV
1126
    import matplotlib.colors as mcolors
×
1127

UNCOV
1128
    cuboid = np.array(
×
1129
        [
1130
            # -z
1131
            (
1132
                (0, 0, 0),
1133
                (0, 1, 0),
1134
                (1, 1, 0),
1135
                (1, 0, 0),
1136
            ),
1137
            # +z
1138
            (
1139
                (0, 0, 1),
1140
                (1, 0, 1),
1141
                (1, 1, 1),
1142
                (0, 1, 1),
1143
            ),
1144
            # -y
1145
            (
1146
                (0, 0, 0),
1147
                (1, 0, 0),
1148
                (1, 0, 1),
1149
                (0, 0, 1),
1150
            ),
1151
            # +y
1152
            (
1153
                (0, 1, 0),
1154
                (0, 1, 1),
1155
                (1, 1, 1),
1156
                (1, 1, 0),
1157
            ),
1158
            # -x
1159
            (
1160
                (0, 0, 0),
1161
                (0, 0, 1),
1162
                (0, 1, 1),
1163
                (0, 1, 0),
1164
            ),
1165
            # +x
1166
            (
1167
                (1, 0, 0),
1168
                (1, 1, 0),
1169
                (1, 1, 1),
1170
                (1, 0, 1),
1171
            ),
1172
        ]
1173
    )
1174

1175
    # indexed by [bar, face, vertex, coord]
UNCOV
1176
    polys = np.empty(x.shape + cuboid.shape)
×
1177
    # handle each coordinate separately
1178
    for i, p, dp in [(0, x, dx), (1, y, dy), (2, z, dz)]:
×
1179
        p = p[..., np.newaxis, np.newaxis]
×
1180
        dp = dp[..., np.newaxis, np.newaxis]
×
UNCOV
1181
        polys[..., i] = p + dp * cuboid[..., i]
×
1182

1183
    # collapse the first two axes
UNCOV
1184
    polys = polys.reshape((-1,) + polys.shape[2:])
×
1185

UNCOV
1186
    facecolors = []
×
UNCOV
1187
    if len(color) == len(x):
×
1188
        # bar colors specified, need to expand to number of faces
UNCOV
1189
        for c in color:
×
UNCOV
1190
            facecolors.extend([c] * 6)
×
1191
    else:
1192
        # a single color specified, or face colors specified explicitly
UNCOV
1193
        facecolors = list(mcolors.to_rgba_array(color))
×
UNCOV
1194
        if len(facecolors) < len(x):
×
UNCOV
1195
            facecolors *= 6 * len(x)
×
1196

UNCOV
1197
    normals = _generate_normals(polys)
×
UNCOV
1198
    return _shade_colors(facecolors, normals)
×
1199

1200

1201
def _generate_normals(polygons):
1✔
1202
    """Takes a list of polygons and return an array of their normals.
1203

1204
    Normals point towards the viewer for a face with its vertices in
1205
    counterclockwise order, following the right hand rule.
1206
    Uses three points equally spaced around the polygon.
1207
    This normal of course might not make sense for polygons with more than
1208
    three points not lying in a plane, but it's a plausible and fast
1209
    approximation.
1210

1211
    Args:
1212
        polygons (list): list of (M_i, 3) array_like, or (..., M, 3) array_like
1213
            A sequence of polygons to compute normals for, which can have
1214
            varying numbers of vertices. If the polygons all have the same
1215
            number of vertices and array is passed, then the operation will
1216
            be vectorized.
1217
    Returns:
1218
        normals: (..., 3) array_like
1219
            A normal vector estimated for the polygon.
1220
    """
1221
    if isinstance(polygons, np.ndarray):
×
1222
        # optimization: polygons all have the same number of points, so can
1223
        # vectorize
UNCOV
1224
        n = polygons.shape[-2]
×
UNCOV
1225
        i1, i2, i3 = 0, n // 3, 2 * n // 3
×
UNCOV
1226
        v1 = polygons[..., i1, :] - polygons[..., i2, :]
×
UNCOV
1227
        v2 = polygons[..., i2, :] - polygons[..., i3, :]
×
1228
    else:
1229
        # The subtraction doesn't vectorize because polygons is jagged.
UNCOV
1230
        v1 = np.empty((len(polygons), 3))
×
1231
        v2 = np.empty((len(polygons), 3))
×
1232
        for poly_i, ps in enumerate(polygons):
×
UNCOV
1233
            n = len(ps)
×
1234
            i1, i2, i3 = 0, n // 3, 2 * n // 3
×
UNCOV
1235
            v1[poly_i, :] = ps[i1, :] - ps[i2, :]
×
1236
            v2[poly_i, :] = ps[i2, :] - ps[i3, :]
×
1237

1238
    return np.cross(v1, v2)
×
1239

1240

1241
def _shade_colors(color, normals, lightsource=None):
1✔
1242
    """
1243
    Shade *color* using normal vectors given by *normals*.
1244
    *color* can also be an array of the same length as *normals*.
1245
    """
1246
    from matplotlib.colors import Normalize, LightSource
×
1247
    import matplotlib.colors as mcolors
×
1248

1249
    if lightsource is None:
×
1250
        # chosen for backwards-compatibility
UNCOV
1251
        lightsource = LightSource(azdeg=225, altdeg=19.4712)
×
1252

1253
    def mod(v):
×
1254
        return np.sqrt(v[0] ** 2 + v[1] ** 2 + v[2] ** 2)
×
1255

UNCOV
1256
    shade = np.array(
×
1257
        [np.dot(n / mod(n), lightsource.direction) if mod(n) else np.nan for n in normals]
1258
    )
1259
    mask = ~np.isnan(shade)
×
1260

UNCOV
1261
    if mask.any():
×
UNCOV
1262
        norm = Normalize(min(shade[mask]), max(shade[mask]))
×
UNCOV
1263
        shade[~mask] = min(shade[mask])
×
UNCOV
1264
        color = mcolors.to_rgba_array(color)
×
1265
        # shape of color should be (M, 4) (where M is number of faces)
1266
        # shape of shade should be (M,)
1267
        # colors should have final shape of (M, 4)
UNCOV
1268
        alpha = color[:, 3]
×
UNCOV
1269
        colors = (0.5 + norm(shade)[:, np.newaxis] * 0.5) * color
×
UNCOV
1270
        colors[:, 3] = alpha
×
1271
    else:
UNCOV
1272
        colors = np.asanyarray(color).copy()
×
1273

UNCOV
1274
    return colors
×
1275

1276

1277
def state_to_latex(
1✔
1278
    state: Statevector | DensityMatrix, dims: bool | None = None, convention: str = "ket", **args
1279
) -> str:
1280
    """Return a Latex representation of a state. Wrapper function
1281
    for `qiskit.visualization.array_to_latex` for convention 'vector'.
1282
    Adds dims if necessary.
1283
    Intended for use within `state_drawer`.
1284

1285
    Args:
1286
        state: State to be drawn
1287
        dims (bool): Whether to display the state's `dims`
1288
        convention (str): Either 'vector' or 'ket'. For 'ket' plot the state in the ket-notation.
1289
                Otherwise plot as a vector
1290
        **args: Arguments to be passed directly to `array_to_latex` for convention 'ket'
1291

1292
    Returns:
1293
        Latex representation of the state
1294
        MissingOptionalLibrary: If SymPy isn't installed and ``'latex'`` or
1295
            ``'latex_source'`` is selected for ``output``.
1296

1297
    """
1298
    if dims is None:  # show dims if state is not only qubits
1✔
1299
        if set(state.dims()) == {2}:
1✔
1300
            dims = False
1✔
1301
        else:
UNCOV
1302
            dims = True
×
1303

1304
    prefix = ""
1✔
1305
    suffix = ""
1✔
1306
    if dims:
1✔
UNCOV
1307
        prefix = "\\begin{align}\n"
×
UNCOV
1308
        dims_str = state._op_shape.dims_l()
×
UNCOV
1309
        suffix = f"\\\\\n\\text{{dims={dims_str}}}\n\\end{{align}}"
×
1310

1311
    operator_shape = state._op_shape
1✔
1312
    # we only use the ket convention for qubit statevectors
1313
    # this means the operator shape should have no input dimensions and all output dimensions equal to 2
1314
    is_qubit_statevector = len(operator_shape.dims_r()) == 0 and set(operator_shape.dims_l()) == {2}
1✔
1315
    if convention == "ket" and is_qubit_statevector:
1✔
1316
        latex_str = _state_to_latex_ket(state._data, **args)
1✔
1317
    else:
1318
        latex_str = array_to_latex(state._data, source=True, **args)
1✔
1319
    return prefix + latex_str + suffix
1✔
1320

1321

1322
def _numbers_to_latex_terms(numbers: list[complex], decimals: int = 10) -> list[str]:
1✔
1323
    """Convert a list of numbers to latex formatted terms
1324

1325
    The first non-zero term is treated differently. For this term a leading + is suppressed.
1326

1327
    Args:
1328
        numbers: List of numbers to format
1329
        decimals: Number of decimal places to round to (default: 10).
1330
    Returns:
1331
        List of formatted terms
1332
    """
1333
    first_term = True
1✔
1334
    terms = []
1✔
1335
    for number in numbers:
1✔
1336
        term = _num_to_latex(number, decimals=decimals, first_term=first_term, coefficient=True)
1✔
1337
        terms.append(term)
1✔
1338
        first_term = False
1✔
1339
    return terms
1✔
1340

1341

1342
def _state_to_latex_ket(
1✔
1343
    data: list[complex], max_size: int = 12, prefix: str = "", decimals: int = 10
1344
) -> str:
1345
    """Convert state vector to latex representation
1346

1347
    Args:
1348
        data: State vector
1349
        max_size: Maximum number of non-zero terms in the expression. If the number of
1350
                 non-zero terms is larger than the max_size, then the representation is truncated.
1351
        prefix: Latex string to be prepended to the latex, intended for labels.
1352
        decimals: Number of decimal places to round to (default: 10).
1353

1354
    Returns:
1355
        String with LaTeX representation of the state vector
1356
    """
1357
    num = int(math.log2(len(data)))
1✔
1358

1359
    def ket_name(i):
1✔
1360
        return bin(i)[2:].zfill(num)
1✔
1361

1362
    data = np.around(data, decimals)
1✔
1363
    nonzero_indices = np.where(data != 0)[0].tolist()
1✔
1364
    if len(nonzero_indices) > max_size:
1✔
1365
        nonzero_indices = (
1✔
1366
            nonzero_indices[: max_size // 2] + [0] + nonzero_indices[-max_size // 2 + 1 :]
1367
        )
1368
        latex_terms = _numbers_to_latex_terms(data[nonzero_indices], decimals)
1✔
1369
        nonzero_indices[max_size // 2] = None
1✔
1370
    else:
1371
        latex_terms = _numbers_to_latex_terms(data[nonzero_indices], decimals)
1✔
1372

1373
    latex_str = ""
1✔
1374
    for idx, ket_idx in enumerate(nonzero_indices):
1✔
1375
        if ket_idx is None:
1✔
1376
            latex_str += r" + \ldots "
1✔
1377
        else:
1378
            term = latex_terms[idx]
1✔
1379
            ket = ket_name(ket_idx)
1✔
1380
            latex_str += f"{term} |{ket}\\rangle"
1✔
1381
    return prefix + latex_str
1✔
1382

1383

1384
class TextMatrix:
1✔
1385
    """Text representation of an array, with `__str__` method so it
1386
    displays nicely in Jupyter notebooks"""
1387

1388
    def __init__(self, state, max_size=8, dims=None, prefix="", suffix=""):
1✔
1389
        self.state = state
1✔
1390
        self.max_size = max_size
1✔
1391
        if dims is None:  # show dims if state is not only qubits
1✔
1392
            if (isinstance(state, (Statevector, DensityMatrix)) and set(state.dims()) == {2}) or (
1✔
1393
                isinstance(state, Operator)
1394
                and len(state.input_dims()) == len(state.output_dims())
1395
                and set(state.input_dims()) == set(state.output_dims()) == {2}
1396
            ):
1397
                dims = False
1✔
1398
            else:
1399
                dims = True
×
1400
        self.dims = dims
1✔
1401
        self.prefix = prefix
1✔
1402
        self.suffix = suffix
1✔
1403
        if isinstance(max_size, int):
1✔
1404
            self.max_size = max_size
1✔
1405
        elif isinstance(state, DensityMatrix):
×
1406
            # density matrices are square, so threshold for
1407
            # summarization is shortest side squared
UNCOV
1408
            self.max_size = min(max_size) ** 2
×
1409
        else:
1410
            self.max_size = max_size[0]
×
1411

1412
    def __str__(self):
1✔
UNCOV
1413
        threshold = self.max_size
×
UNCOV
1414
        data = np.array2string(
×
1415
            self.state._data, prefix=self.prefix, threshold=threshold, separator=","
1416
        )
UNCOV
1417
        dimstr = ""
×
UNCOV
1418
        if self.dims:
×
UNCOV
1419
            data += ",\n"
×
UNCOV
1420
            dimstr += " " * len(self.prefix)
×
UNCOV
1421
            if isinstance(self.state, (Statevector, DensityMatrix)):
×
UNCOV
1422
                dimstr += f"dims={self.state._op_shape.dims_l()}"
×
1423
            else:
UNCOV
1424
                dimstr += f"input_dims={self.state.input_dims()}, "
×
UNCOV
1425
                dimstr += f"output_dims={self.state.output_dims()}"
×
1426

UNCOV
1427
        return self.prefix + data + dimstr + self.suffix
×
1428

1429
    def __repr__(self):
1430
        return self.__str__()
1431

1432

1433
def state_drawer(state, output=None, **drawer_args):
1✔
1434
    """Returns a visualization of the state.
1435

1436
    **repr**: ASCII TextMatrix of the state's ``_repr_``.
1437

1438
    **text**: ASCII TextMatrix that can be printed in the console.
1439

1440
    **latex**: An IPython Latex object for displaying in Jupyter Notebooks.
1441

1442
    **latex_source**: Raw, uncompiled ASCII source to generate array using LaTeX.
1443

1444
    **qsphere**: Matplotlib figure, rendering of statevector using `plot_state_qsphere()`.
1445

1446
    **hinton**: Matplotlib figure, rendering of statevector using `plot_state_hinton()`.
1447

1448
    **bloch**: Matplotlib figure, rendering of statevector using `plot_bloch_multivector()`.
1449

1450
    **city**: Matplotlib figure, rendering of statevector using `plot_state_city()`.
1451

1452
    **paulivec**: Matplotlib figure, rendering of statevector using `plot_state_paulivec()`.
1453

1454
    Args:
1455
        state: State to be drawn
1456
        output (str): Select the output method to use for drawing the
1457
            circuit. Valid choices are ``text``, ``latex``, ``latex_source``,
1458
            ``qsphere``, ``hinton``, ``bloch``, ``city`` or ``paulivec``.
1459
            Default is `'text`'.
1460
        drawer_args: Arguments to be passed to the relevant drawer. For
1461
            'latex' and 'latex_source' see ``array_to_latex``
1462

1463
    Returns:
1464
        :class:`matplotlib.figure` or :class:`str` or
1465
        :class:`TextMatrix` or :class:`IPython.display.Latex`:
1466
        Drawing of the state.
1467

1468
    Raises:
1469
        MissingOptionalLibraryError: when `output` is `latex` and IPython is not installed.
1470
            or if SymPy isn't installed and ``'latex'`` or ``'latex_source'`` is selected for
1471
            ``output``.
1472

1473
        ValueError: when `output` is not a valid selection.
1474
    """
1475
    config = user_config.get_config()
1✔
1476
    # Get default 'output' from config file else use 'repr'
1477
    default_output = "repr"
1✔
1478
    if output is None:
1✔
UNCOV
1479
        if config:
×
UNCOV
1480
            default_output = config.get("state_drawer", "repr")
×
UNCOV
1481
        output = default_output
×
1482
    output = output.lower()
1✔
1483

1484
    # Choose drawing backend:
1485
    drawers = {
1✔
1486
        "text": TextMatrix,
1487
        "latex_source": state_to_latex,
1488
        "qsphere": plot_state_qsphere,
1489
        "hinton": plot_state_hinton,
1490
        "bloch": plot_bloch_multivector,
1491
        "city": plot_state_city,
1492
        "paulivec": plot_state_paulivec,
1493
    }
1494
    if output == "latex":
1✔
1495
        _optionals.HAS_IPYTHON.require_now("state_drawer")
1✔
1496
        from IPython.display import Latex
1✔
1497

1498
        draw_func = drawers["latex_source"]
1✔
1499
        return Latex(f"$${draw_func(state, **drawer_args)}$$")
1✔
1500

1501
    if output == "repr":
1✔
1502
        return state.__repr__()
1✔
1503

1504
    try:
1✔
1505
        draw_func = drawers[output]
1✔
1506
        return draw_func(state, **drawer_args)
1✔
UNCOV
1507
    except KeyError as err:
×
UNCOV
1508
        raise ValueError(
×
1509
            f"""'{output}' is not a valid option for drawing {type(state).__name__}
1510
             objects. Please choose from:
1511
            'text', 'latex', 'latex_source', 'qsphere', 'hinton',
1512
            'bloch', 'city' or 'paulivec'."""
1513
        ) from err
1514

1515

1516
def _bloch_multivector_data(state):
1✔
1517
    """Return list of Bloch vectors for each qubit
1518

1519
    Args:
1520
        state (DensityMatrix or Statevector): an N-qubit state.
1521

1522
    Returns:
1523
        list: list of Bloch vectors (x, y, z) for each qubit.
1524

1525
    Raises:
1526
        VisualizationError: if input is not an N-qubit state.
1527
    """
1528
    rho = DensityMatrix(state)
1✔
1529
    num = rho.num_qubits
1✔
1530
    if num is None:
1✔
UNCOV
1531
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
1532
    pauli_singles = PauliList(["X", "Y", "Z"])
1✔
1533
    bloch_data = []
1✔
1534
    for i in range(num):
1✔
1535
        if num > 1:
1✔
1536
            paulis = PauliList.from_symplectic(
1✔
1537
                np.zeros((3, (num - 1)), dtype=bool), np.zeros((3, (num - 1)), dtype=bool)
1538
            ).insert(i, pauli_singles, qubit=True)
1539
        else:
UNCOV
1540
            paulis = pauli_singles
×
1541
        bloch_state = [np.real(np.trace(np.dot(mat, rho.data))) for mat in paulis.matrix_iter()]
1✔
1542
        bloch_data.append(bloch_state)
1✔
1543
    return bloch_data
1✔
1544

1545

1546
def _paulivec_data(state):
1✔
1547
    """Return paulivec data for plotting.
1548

1549
    Args:
1550
        state (DensityMatrix or Statevector): an N-qubit state.
1551

1552
    Returns:
1553
        tuple: (labels, values) for Pauli vector.
1554

1555
    Raises:
1556
        VisualizationError: if input is not an N-qubit state.
1557
    """
1558
    rho = SparsePauliOp.from_operator(DensityMatrix(state))
1✔
1559
    if rho.num_qubits is None:
1✔
UNCOV
1560
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
1561
    return rho.paulis.to_labels(), np.real(rho.coeffs * 2**rho.num_qubits)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc