• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Qiskit / qiskit / 14034596400

24 Mar 2025 11:24AM UTC coverage: 88.076% (+0.008%) from 88.068%
14034596400

push

github

web-flow
Fixes to various graphs in the docs (#14055)

* Fixes to various graphs in the docs

* Fix a bug caused by `tight_layout` crashing on some images

* Apply the fix to the timeline plotter

* Remove release note and fix bloch sphere visualization

* Additional bloch fix

* Slighly relax image compairson for bloch sphere due to small discrepancies between the server and local versions.

6 of 22 new or added lines in 4 files covered. (27.27%)

4 existing lines in 3 files now uncovered.

72629 of 82462 relevant lines covered (88.08%)

371867.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

57.37
/qiskit/visualization/state_visualization.py
1
# This code is part of Qiskit.
2
#
3
# (C) Copyright IBM 2017, 2023.
4
#
5
# This code is licensed under the Apache License, Version 2.0. You may
6
# obtain a copy of this license in the LICENSE.txt file in the root directory
7
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
8
#
9
# Any modifications or derivative works of this code must retain this
10
# copyright notice, and modified files need to carry a notice indicating
11
# that they have been altered from the originals.
12

13
# pylint: disable=invalid-name
14
# pylint: disable=missing-param-doc,missing-type-doc,unused-argument
15

16
"""
17
Visualization functions for quantum states.
18
"""
19

20
import math
1✔
21
from typing import List, Union
1✔
22
from functools import reduce
1✔
23
import colorsys
1✔
24

25
import numpy as np
1✔
26
from qiskit import user_config
1✔
27
from qiskit.quantum_info.states.statevector import Statevector
1✔
28
from qiskit.quantum_info.operators.operator import Operator
1✔
29
from qiskit.quantum_info.operators.symplectic import PauliList, SparsePauliOp
1✔
30
from qiskit.quantum_info.states.densitymatrix import DensityMatrix
1✔
31
from qiskit.utils import optionals as _optionals
1✔
32
from qiskit.circuit.tools.pi_check import pi_check
1✔
33

34
from .array import _num_to_latex, array_to_latex
1✔
35
from .utils import matplotlib_close_if_inline
1✔
36
from .exceptions import VisualizationError
1✔
37

38

39
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
40
def plot_state_hinton(state, title="", figsize=None, ax_real=None, ax_imag=None, *, filename=None):
1✔
41
    """Plot a hinton diagram for the density matrix of a quantum state.
42

43
    The hinton diagram represents the values of a matrix using
44
    squares, whose size indicate the magnitude of their corresponding value
45
    and their color, its sign. A white square means the value is positive and
46
    a black one means negative.
47

48
    Args:
49
        state (Statevector or DensityMatrix or ndarray): An N-qubit quantum state.
50
        title (str): a string that represents the plot title
51
        figsize (tuple): Figure size in inches.
52
        filename (str): file path to save image to.
53
        ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
54
            the visualization output. If none is specified a new matplotlib
55
            Figure will be created and used. If this is specified without an
56
            ax_imag only the real component plot will be generated.
57
            Additionally, if specified there will be no returned Figure since
58
            it is redundant.
59
        ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
60
            the visualization output. If none is specified a new matplotlib
61
            Figure will be created and used. If this is specified without an
62
            ax_imag only the real component plot will be generated.
63
            Additionally, if specified there will be no returned Figure since
64
            it is redundant.
65

66
    Returns:
67
        :class:`matplotlib:matplotlib.figure.Figure` :
68
            The matplotlib.Figure of the visualization if
69
            neither ax_real or ax_imag is set.
70

71
    Raises:
72
        MissingOptionalLibraryError: Requires matplotlib.
73
        VisualizationError: Input is not a valid N-qubit state.
74

75
    Examples:
76
        .. plot::
77
           :alt: Output from the previous code.
78
           :include-source:
79

80
            import numpy as np
81
            from qiskit import QuantumCircuit
82
            from qiskit.quantum_info import DensityMatrix
83
            from qiskit.visualization import plot_state_hinton
84

85
            qc = QuantumCircuit(2)
86
            qc.h([0, 1])
87
            qc.cz(0,1)
88
            qc.ry(np.pi/3 , 0)
89
            qc.rx(np.pi/5, 1)
90

91
            state = DensityMatrix(qc)
92
            plot_state_hinton(state, title="New Hinton Plot")
93

94
    """
95
    from matplotlib import pyplot as plt
1✔
96

97
    # Figure data
98
    rho = DensityMatrix(state)
1✔
99
    num = rho.num_qubits
1✔
100
    if num is None:
1✔
101
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
102
    max_weight = 2 ** math.ceil(math.log2(np.abs(rho.data).max()))
1✔
103
    datareal = np.real(rho.data)
1✔
104
    dataimag = np.imag(rho.data)
1✔
105

106
    if figsize is None:
1✔
107
        figsize = (8, 5)
1✔
108
    if not ax_real and not ax_imag:
1✔
109
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
1✔
110
    else:
111
        if ax_real:
×
112
            fig = ax_real.get_figure()
×
113
        else:
114
            fig = ax_imag.get_figure()
×
115
        ax1 = ax_real
×
116
        ax2 = ax_imag
×
117
    # Reversal is to account for Qiskit's endianness.
118
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
1✔
119
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)][::-1]
1✔
120
    ly, lx = datareal.shape
1✔
121
    # Real
122
    if ax1:
1✔
123
        ax1.patch.set_facecolor("gray")
1✔
124
        ax1.set_aspect("equal", "box")
1✔
125
        ax1.xaxis.set_major_locator(plt.NullLocator())
1✔
126
        ax1.yaxis.set_major_locator(plt.NullLocator())
1✔
127

128
        for (x, y), w in np.ndenumerate(datareal):
1✔
129
            # Convert from matrix co-ordinates to plot co-ordinates.
130
            plot_x, plot_y = y, lx - x - 1
1✔
131
            color = "white" if w > 0 else "black"
1✔
132
            size = np.sqrt(np.abs(w) / max_weight)
1✔
133
            rect = plt.Rectangle(
1✔
134
                [0.5 + plot_x - size / 2, 0.5 + plot_y - size / 2],
135
                size,
136
                size,
137
                facecolor=color,
138
                edgecolor=color,
139
            )
140
            ax1.add_patch(rect)
1✔
141

142
        ax1.set_xticks(0.5 + np.arange(lx))
1✔
143
        ax1.set_yticks(0.5 + np.arange(ly))
1✔
144
        ax1.set_xlim([0, lx])
1✔
145
        ax1.set_ylim([0, ly])
1✔
146
        ax1.set_yticklabels(row_names, fontsize=14)
1✔
147
        ax1.set_xticklabels(column_names, fontsize=14, rotation=90)
1✔
148
        ax1.set_title("Re[$\\rho$]", fontsize=14)
1✔
149
    # Imaginary
150
    if ax2:
1✔
151
        ax2.patch.set_facecolor("gray")
1✔
152
        ax2.set_aspect("equal", "box")
1✔
153
        ax2.xaxis.set_major_locator(plt.NullLocator())
1✔
154
        ax2.yaxis.set_major_locator(plt.NullLocator())
1✔
155

156
        for (x, y), w in np.ndenumerate(dataimag):
1✔
157
            # Convert from matrix co-ordinates to plot co-ordinates.
158
            plot_x, plot_y = y, lx - x - 1
1✔
159
            color = "white" if w > 0 else "black"
1✔
160
            size = np.sqrt(np.abs(w) / max_weight)
1✔
161
            rect = plt.Rectangle(
1✔
162
                [0.5 + plot_x - size / 2, 0.5 + plot_y - size / 2],
163
                size,
164
                size,
165
                facecolor=color,
166
                edgecolor=color,
167
            )
168
            ax2.add_patch(rect)
1✔
169

170
        ax2.set_xticks(0.5 + np.arange(lx))
1✔
171
        ax2.set_yticks(0.5 + np.arange(ly))
1✔
172
        ax2.set_xlim([0, lx])
1✔
173
        ax2.set_ylim([0, ly])
1✔
174
        ax2.set_yticklabels(row_names, fontsize=14)
1✔
175
        ax2.set_xticklabels(column_names, fontsize=14, rotation=90)
1✔
176
        ax2.set_title("Im[$\\rho$]", fontsize=14)
1✔
177
    fig.tight_layout()
1✔
178
    if title:
1✔
179
        fig.suptitle(title, fontsize=16)
×
180
    if ax_real is None and ax_imag is None:
1✔
181
        matplotlib_close_if_inline(fig)
1✔
182
    if filename is None:
1✔
183
        return fig
1✔
184
    else:
185
        return fig.savefig(filename)
×
186

187

188
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
189
def plot_bloch_vector(
1✔
190
    bloch, title="", ax=None, figsize=None, coord_type="cartesian", font_size=None
191
):
192
    """Plot the Bloch sphere.
193

194
    Plot a Bloch sphere with the specified coordinates, that can be given in both
195
    cartesian and spherical systems.
196

197
    Args:
198
        bloch (list[double]): array of three elements where [<x>, <y>, <z>] (Cartesian)
199
            or [<r>, <theta>, <phi>] (spherical in radians)
200
            <theta> is inclination angle from +z direction
201
            <phi> is azimuth from +x direction
202
        title (str): a string that represents the plot title
203
        ax (matplotlib.axes.Axes): An Axes to use for rendering the bloch
204
            sphere
205
        figsize (tuple): Figure size in inches. Has no effect is passing ``ax``.
206
        coord_type (str): a string that specifies coordinate type for bloch
207
            (Cartesian or spherical), default is Cartesian
208
        font_size (float): Font size.
209

210
    Returns:
211
        :class:`matplotlib:matplotlib.figure.Figure` : A matplotlib figure instance if ``ax = None``.
212

213
    Raises:
214
        MissingOptionalLibraryError: Requires matplotlib.
215

216
    Examples:
217
        .. plot::
218
           :alt: Output from the previous code.
219
           :include-source:
220

221
           from qiskit.visualization import plot_bloch_vector
222

223
           plot_bloch_vector([0,1,0], title="New Bloch Sphere")
224

225
        .. plot::
226
           :alt: Output from the previous code.
227
           :include-source:
228

229
           import numpy as np
230
           from qiskit.visualization import plot_bloch_vector
231

232
           # You can use spherical coordinates instead of cartesian.
233

234
           plot_bloch_vector([1, np.pi/2, np.pi/3], coord_type='spherical')
235

236
    """
237
    from .bloch import Bloch
1✔
238

239
    if figsize is None:
1✔
240
        figsize = (5, 5)
1✔
241
    B = Bloch(axes=ax, font_size=font_size)
1✔
242
    if coord_type == "spherical":
1✔
243
        r, theta, phi = bloch[0], bloch[1], bloch[2]
×
244
        bloch[0] = r * np.sin(theta) * np.cos(phi)
×
245
        bloch[1] = r * np.sin(theta) * np.sin(phi)
×
246
        bloch[2] = r * np.cos(theta)
×
247
    B.add_vectors(bloch)
1✔
248
    B.render(title=title)
1✔
249
    if ax is None:
1✔
250
        fig = B.fig
×
251
        fig.set_size_inches(figsize[0], figsize[1])
×
252
        matplotlib_close_if_inline(fig)
×
253
        return fig
×
254
    return None
1✔
255

256

257
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
258
def plot_bloch_multivector(
1✔
259
    state,
260
    title="",
261
    figsize=None,
262
    *,
263
    reverse_bits=False,
264
    filename=None,
265
    font_size=None,
266
    title_font_size=None,
267
    title_pad=1,
268
):
269
    r"""Plot a Bloch sphere for each qubit.
270

271
    Each component :math:`(x,y,z)` of the Bloch sphere labeled as 'qubit i' represents the expected
272
    value of the corresponding Pauli operator acting only on that qubit, that is, the expected value
273
    of :math:`I_{N-1} \otimes\dotsb\otimes I_{i+1}\otimes P_i \otimes I_{i-1}\otimes\dotsb\otimes
274
    I_0`, where :math:`N` is the number of qubits, :math:`P\in \{X,Y,Z\}` and :math:`I` is the
275
    identity operator.
276

277
    Args:
278
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
279
        title (str): a string that represents the plot title
280
        figsize (tuple): size of each individual Bloch sphere figure, in inches.
281
        reverse_bits (bool): If True, plots qubits following Qiskit's convention [Default:False].
282
        font_size (float): Font size for the Bloch ball figures.
283
        title_font_size (float): Font size for the title.
284
        title_pad (float): Padding for the title (suptitle `y` position is `0.98`
285
        and the image height will be extended by `1 + title_pad/100`).
286

287
    Returns:
288
        :class:`matplotlib:matplotlib.figure.Figure` :
289
            A matplotlib figure instance.
290

291
    Raises:
292
        MissingOptionalLibraryError: Requires matplotlib.
293
        VisualizationError: if input is not a valid N-qubit state.
294

295
    Examples:
296
        .. plot::
297
           :alt: Output from the previous code.
298
           :include-source:
299

300
            from qiskit import QuantumCircuit
301
            from qiskit.quantum_info import Statevector
302
            from qiskit.visualization import plot_bloch_multivector
303

304
            qc = QuantumCircuit(2)
305
            qc.h(0)
306
            qc.x(1)
307

308
            state = Statevector(qc)
309
            plot_bloch_multivector(state)
310

311
        .. plot::
312
           :alt: Output from the previous code.
313
           :include-source:
314

315
           from qiskit import QuantumCircuit
316
           from qiskit.quantum_info import Statevector
317
           from qiskit.visualization import plot_bloch_multivector
318

319
           qc = QuantumCircuit(2)
320
           qc.h(0)
321
           qc.x(1)
322

323
           # You can reverse the order of the qubits.
324

325
           from qiskit.quantum_info import DensityMatrix
326

327
           qc = QuantumCircuit(2)
328
           qc.h([0, 1])
329
           qc.t(1)
330
           qc.s(0)
331
           qc.cx(0,1)
332

333
           matrix = DensityMatrix(qc)
334
           plot_bloch_multivector(matrix, title='My Bloch Spheres', reverse_bits=True)
335

336
    """
337
    from matplotlib import pyplot as plt
1✔
338

339
    # Data
340
    bloch_data = (
1✔
341
        _bloch_multivector_data(state)[::-1] if reverse_bits else _bloch_multivector_data(state)
342
    )
343
    num = len(bloch_data)
1✔
344
    if figsize is not None:
1✔
345
        width, height = figsize
×
346
        width *= num
×
347
    else:
348
        width, height = plt.figaspect(1 / num)
1✔
349
    if len(title) > 0:
1✔
NEW
350
        height += 1 + title_pad / 100  # additional space for the title
×
351
    default_title_font_size = font_size if font_size is not None else 16
1✔
352
    title_font_size = title_font_size if title_font_size is not None else default_title_font_size
1✔
353
    fig = plt.figure(figsize=(width, height))
1✔
354
    for i in range(num):
1✔
355
        pos = num - 1 - i if reverse_bits else i
1✔
356
        ax = fig.add_subplot(1, num, i + 1, projection="3d")
1✔
357
        plot_bloch_vector(
1✔
358
            bloch_data[i], "qubit " + str(pos), ax=ax, figsize=figsize, font_size=font_size
359
        )
360
    fig.suptitle(title, fontsize=title_font_size, y=0.98)
1✔
361
    matplotlib_close_if_inline(fig)
1✔
362
    if filename is None:
1✔
363
        try:
1✔
364
            fig.tight_layout()
1✔
NEW
365
        except AttributeError:
×
NEW
366
            pass
×
367
        return fig
1✔
368
    else:
369
        return fig.savefig(filename)
×
370

371

372
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
373
def plot_state_city(
1✔
374
    state,
375
    title="",
376
    figsize=None,
377
    color=None,
378
    alpha=1,
379
    ax_real=None,
380
    ax_imag=None,
381
    *,
382
    filename=None,
383
):
384
    """Plot the cityscape of quantum state.
385

386
    Plot two 3d bar graphs (two dimensional) of the real and imaginary
387
    part of the density matrix rho.
388

389
    Args:
390
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
391
        title (str): a string that represents the plot title
392
        figsize (tuple): Figure size in inches.
393
        color (list): A list of len=2 giving colors for real and
394
            imaginary components of matrix elements.
395
        alpha (float): Transparency value for bars
396
        ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
397
            the visualization output. If none is specified a new matplotlib
398
            Figure will be created and used. If this is specified without an
399
            ax_imag only the real component plot will be generated.
400
            Additionally, if specified there will be no returned Figure since
401
            it is redundant.
402
        ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
403
            the visualization output. If none is specified a new matplotlib
404
            Figure will be created and used. If this is specified without an
405
            ax_real only the imaginary component plot will be generated.
406
            Additionally, if specified there will be no returned Figure since
407
            it is redundant.
408

409
    Returns:
410
        :class:`matplotlib:matplotlib.figure.Figure` :
411
            The matplotlib.Figure of the visualization if the
412
            ``ax_real`` and ``ax_imag`` kwargs are not set
413

414
    Raises:
415
        MissingOptionalLibraryError: Requires matplotlib.
416
        ValueError: When 'color' is not a list of len=2.
417
        VisualizationError: if input is not a valid N-qubit state.
418

419
    Examples:
420
        .. plot::
421
           :alt: Output from the previous code.
422
           :include-source:
423

424
           # You can choose different colors for the real and imaginary parts of the density matrix.
425

426
           from qiskit import QuantumCircuit
427
           from qiskit.quantum_info import DensityMatrix
428
           from qiskit.visualization import plot_state_city
429

430
           qc = QuantumCircuit(2)
431
           qc.h(0)
432
           qc.cx(0, 1)
433

434
           state = DensityMatrix(qc)
435
           plot_state_city(state, color=['midnightblue', 'crimson'], title="New State City")
436

437
        .. plot::
438
           :alt: Output from the previous code.
439
           :include-source:
440

441
           # You can make the bars more transparent to better see the ones that are behind
442
           # if they overlap.
443

444
           import numpy as np
445
           from qiskit.quantum_info import Statevector
446
           from qiskit.visualization import plot_state_city
447
           from qiskit import QuantumCircuit
448

449
           qc = QuantumCircuit(2)
450
           qc.h(0)
451
           qc.cx(0, 1)
452

453

454
           qc = QuantumCircuit(2)
455
           qc.h([0, 1])
456
           qc.cz(0,1)
457
           qc.ry(np.pi/3, 0)
458
           qc.rx(np.pi/5, 1)
459

460
           state = Statevector(qc)
461
           plot_state_city(state, alpha=0.6)
462

463
    """
464
    import matplotlib.colors as mcolors
×
465
    from matplotlib import pyplot as plt
×
466
    from mpl_toolkits.mplot3d.art3d import Poly3DCollection
×
467

468
    rho = DensityMatrix(state)
×
469
    num = rho.num_qubits
×
470
    if num is None:
×
471
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
472

473
    # get the real and imag parts of rho
474
    datareal = np.real(rho.data)
×
475
    dataimag = np.imag(rho.data)
×
476

477
    # get the labels
478
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
×
479
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
×
480

481
    ly, lx = datareal.shape[:2]
×
482
    xpos = np.arange(0, lx, 1)  # Set up a mesh of positions
×
483
    ypos = np.arange(0, ly, 1)
×
484
    xpos, ypos = np.meshgrid(xpos + 0.25, ypos + 0.25)
×
485

486
    xpos = xpos.flatten()
×
487
    ypos = ypos.flatten()
×
488
    zpos = np.zeros(lx * ly)
×
489

490
    dx = 0.5 * np.ones_like(zpos)  # width of bars
×
491
    dy = dx.copy()
×
492
    dzr = datareal.flatten()
×
493
    dzi = dataimag.flatten()
×
494

495
    if color is None:
×
496
        real_color, imag_color = "#648fff", "#648fff"
×
497
    else:
498
        if len(color) != 2:
×
499
            raise ValueError("'color' must be a list of len=2.")
×
500
        real_color = "#648fff" if color[0] is None else color[0]
×
501
        imag_color = "#648fff" if color[1] is None else color[1]
×
502
    if ax_real is None and ax_imag is None:
×
503
        # set default figure size
504
        if figsize is None:
×
505
            figsize = (16, 8)
×
506

507
        fig = plt.figure(figsize=figsize, facecolor="w")
×
508
        ax1 = fig.add_subplot(1, 2, 1, projection="3d", computed_zorder=False)
×
509
        ax2 = fig.add_subplot(1, 2, 2, projection="3d", computed_zorder=False)
×
510

511
    elif ax_real is not None:
×
512
        fig = ax_real.get_figure()
×
513
        ax1 = ax_real
×
514
        ax2 = ax_imag
×
515
    else:
516
        fig = ax_imag.get_figure()
×
517
        ax1 = None
×
518
        ax2 = ax_imag
×
519

520
    fig.tight_layout()
×
521

522
    max_dzr = np.max(dzr)
×
523
    max_dzi = np.max(dzi)
×
524

525
    # Figure scaling variables since fig.tight_layout won't work
526
    fig_width, fig_height = fig.get_size_inches()
×
527
    max_plot_size = min(fig_width / 2.25, fig_height)
×
528
    max_font_size = int(3 * max_plot_size)
×
529
    max_zoom = 10 / (10 + np.sqrt(max_plot_size))
×
530

531
    for ax, dz, col, zlabel in (
×
532
        (ax1, dzr, real_color, "Real"),
533
        (ax2, dzi, imag_color, "Imaginary"),
534
    ):
535

536
        if ax is None:
×
537
            continue
×
538

539
        max_dz = np.max(dz)
×
540
        min_dz = np.min(dz)
×
541

542
        if isinstance(col, str) and col.startswith("#"):
×
543
            col = mcolors.to_rgba_array(col)
×
544

545
        dzn = dz < 0
×
546
        if np.any(dzn):
×
547
            fc = generate_facecolors(
×
548
                xpos[dzn], ypos[dzn], zpos[dzn], dx[dzn], dy[dzn], dz[dzn], col
549
            )
550
            negative_bars = ax.bar3d(
×
551
                xpos[dzn],
552
                ypos[dzn],
553
                zpos[dzn],
554
                dx[dzn],
555
                dy[dzn],
556
                dz[dzn],
557
                alpha=alpha,
558
                zorder=0.625,
559
            )
560
            negative_bars.set_facecolor(fc)
×
561

562
        if min_dz < 0 < max_dz:
×
563
            xlim, ylim = [0, lx], [0, ly]
×
564
            verts = [list(zip(xlim + xlim[::-1], np.repeat(ylim, 2), [0] * 4))]
×
565
            plane = Poly3DCollection(verts, alpha=0.25, facecolor="k", linewidths=1)
×
566
            plane.set_zorder(0.75)
×
567
            ax.add_collection3d(plane)
×
568

569
        dzp = dz >= 0
×
570
        if np.any(dzp):
×
571
            fc = generate_facecolors(
×
572
                xpos[dzp], ypos[dzp], zpos[dzp], dx[dzp], dy[dzp], dz[dzp], col
573
            )
574
            positive_bars = ax.bar3d(
×
575
                xpos[dzp],
576
                ypos[dzp],
577
                zpos[dzp],
578
                dx[dzp],
579
                dy[dzp],
580
                dz[dzp],
581
                alpha=alpha,
582
                zorder=0.875,
583
            )
584
            positive_bars.set_facecolor(fc)
×
585

586
        ax.set_title(f"{zlabel} Amplitude (ρ)", fontsize=max_font_size)
×
587

588
        ax.set_xticks(np.arange(0.5, lx + 0.5, 1))
×
589
        ax.set_yticks(np.arange(0.5, ly + 0.5, 1))
×
590
        if max_dz != min_dz:
×
591
            ax.axes.set_zlim3d(min_dz, max(max_dzr + 1e-9, max_dzi))
×
592
        else:
593
            if min_dz == 0:
×
594
                ax.axes.set_zlim3d(min_dz, max(max_dzr + 1e-9, max_dzi))
×
595
            else:
596
                ax.axes.set_zlim3d(auto=True)
×
597
        ax.get_autoscalez_on()
×
598

599
        ax.xaxis.set_ticklabels(
×
600
            row_names, fontsize=max_font_size, rotation=45, ha="right", va="top"
601
        )
602
        ax.yaxis.set_ticklabels(
×
603
            column_names, fontsize=max_font_size, rotation=-22.5, ha="left", va="center"
604
        )
605

606
        for tick in ax.zaxis.get_major_ticks():
×
607
            tick.label1.set_fontsize(max_font_size)
×
608
            tick.label1.set_horizontalalignment("left")
×
609
            tick.label1.set_verticalalignment("bottom")
×
610

611
        ax.set_box_aspect(aspect=(4, 4, 4), zoom=max_zoom)
×
612
        ax.set_xmargin(0)
×
613
        ax.set_ymargin(0)
×
614

615
    fig.suptitle(title, fontsize=max_font_size * 1.25)
×
616
    fig.subplots_adjust(top=0.9, bottom=0, left=0, right=1, hspace=0, wspace=0)
×
617
    if ax_real is None and ax_imag is None:
×
618
        matplotlib_close_if_inline(fig)
×
619
    if filename is None:
×
620
        return fig
×
621
    else:
622
        return fig.savefig(filename)
×
623

624

625
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
626
def plot_state_paulivec(state, title="", figsize=None, color=None, ax=None, *, filename=None):
1✔
627
    r"""Plot the Pauli-vector representation of a quantum state as bar graph.
628

629
    The Pauli-vector of a density matrix :math:`\rho` is defined by the expectation of each
630
    possible tensor product of single-qubit Pauli operators (including the identity), that is
631

632
    .. math ::
633

634
        \rho = \frac{1}{2^n} \sum_{\sigma \in \{I, X, Y, Z\}^{\otimes n}}
635
               \mathrm{Tr}(\sigma \rho) \sigma.
636

637
    This function plots the coefficients :math:`\mathrm{Tr}(\sigma\rho)` as bar graph.
638

639
    Args:
640
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
641
        title (str): a string that represents the plot title
642
        figsize (tuple): Figure size in inches.
643
        color (list or str): Color of the coefficient value bars.
644
        ax (matplotlib.axes.Axes): An optional Axes object to be used for
645
            the visualization output. If none is specified a new matplotlib
646
            Figure will be created and used. Additionally, if specified there
647
            will be no returned Figure since it is redundant.
648

649
    Returns:
650
         :class:`matplotlib:matplotlib.figure.Figure` :
651
            The matplotlib.Figure of the visualization if the
652
            ``ax`` kwarg is not set
653

654
    Raises:
655
        MissingOptionalLibraryError: Requires matplotlib.
656
        VisualizationError: if input is not a valid N-qubit state.
657

658
    Examples:
659
        .. plot::
660
           :alt: Output from the previous code.
661
           :include-source:
662

663
           # You can set a color for all the bars.
664

665
           from qiskit import QuantumCircuit
666
           from qiskit.quantum_info import Statevector
667
           from qiskit.visualization import plot_state_paulivec
668

669
           qc = QuantumCircuit(2)
670
           qc.h(0)
671
           qc.cx(0, 1)
672

673
           state = Statevector(qc)
674
           plot_state_paulivec(state, color='midnightblue', title="New PauliVec plot")
675

676
        .. plot::
677
           :alt: Output from the previous code.
678
           :include-source:
679

680
           # If you introduce a list with less colors than bars, the color of the bars will
681
           # alternate following the sequence from the list.
682

683
           import numpy as np
684
           from qiskit.quantum_info import DensityMatrix
685
           from qiskit import QuantumCircuit
686
           from qiskit.visualization import plot_state_paulivec
687

688
           qc = QuantumCircuit(2)
689
           qc.h(0)
690
           qc.cx(0, 1)
691

692
           qc = QuantumCircuit(2)
693
           qc.h([0, 1])
694
           qc.cz(0, 1)
695
           qc.ry(np.pi/3, 0)
696
           qc.rx(np.pi/5, 1)
697

698
           matrix = DensityMatrix(qc)
699
           plot_state_paulivec(matrix, color=['crimson', 'midnightblue', 'seagreen'])
700
    """
701
    from matplotlib import pyplot as plt
×
702

703
    labels, values = _paulivec_data(state)
×
704
    numelem = len(values)
×
705

706
    if figsize is None:
×
707
        figsize = (7, 5)
×
708
    if color is None:
×
709
        color = "#648fff"
×
710

711
    ind = np.arange(numelem)  # the x locations for the groups
×
712
    width = 0.5  # the width of the bars
×
713
    if ax is None:
×
714
        return_fig = True
×
715
        fig, ax = plt.subplots(figsize=figsize)
×
716
    else:
717
        return_fig = False
×
718
        fig = ax.get_figure()
×
719
    ax.grid(zorder=0, linewidth=1, linestyle="--")
×
720
    ax.bar(ind, values, width, color=color, zorder=2)
×
721
    ax.axhline(linewidth=1, color="k")
×
722
    # add some text for labels, title, and axes ticks
723
    ax.set_ylabel("Coefficients", fontsize=14)
×
724
    ax.set_xticks(ind)
×
725
    ax.set_yticks([-1, -0.5, 0, 0.5, 1])
×
726
    ax.set_xticklabels(labels, fontsize=14, rotation=70)
×
727
    ax.set_xlabel("Pauli", fontsize=14)
×
728
    ax.set_ylim([-1, 1])
×
729
    ax.set_facecolor("#eeeeee")
×
730
    for tick in ax.xaxis.get_major_ticks() + ax.yaxis.get_major_ticks():
×
731
        tick.label1.set_fontsize(14)
×
732
    ax.set_title(title, fontsize=16)
×
733
    if return_fig:
×
734
        matplotlib_close_if_inline(fig)
×
735
    if filename is None:
×
NEW
736
        try:
×
NEW
737
            fig.tight_layout()
×
NEW
738
        except AttributeError:
×
NEW
739
            pass
×
UNCOV
740
        return fig
×
741
    else:
742
        return fig.savefig(filename)
×
743

744

745
def n_choose_k(n, k):
1✔
746
    """Return the number of combinations for n choose k.
747

748
    Args:
749
        n (int): the total number of options .
750
        k (int): The number of elements.
751

752
    Returns:
753
        int: returns the binomial coefficient
754
    """
755
    if n == 0:
1✔
756
        return 0
1✔
757
    return reduce(lambda x, y: x * y[0] / y[1], zip(range(n - k + 1, n + 1), range(1, k + 1)), 1)
1✔
758

759

760
def lex_index(n, k, lst):
1✔
761
    """Return  the lex index of a combination..
762

763
    Args:
764
        n (int): the total number of options .
765
        k (int): The number of elements.
766
        lst (list): list
767

768
    Returns:
769
        int: returns int index for lex order
770

771
    Raises:
772
        VisualizationError: if length of list is not equal to k
773
    """
774
    if len(lst) != k:
1✔
775
        raise VisualizationError("list should have length k")
×
776
    comb = [n - 1 - x for x in lst]
1✔
777
    dualm = sum(n_choose_k(comb[k - 1 - i], i + 1) for i in range(k))
1✔
778
    return int(dualm)
1✔
779

780

781
def bit_string_index(s):
1✔
782
    """Return the index of a string of 0s and 1s."""
783
    n = len(s)
1✔
784
    k = s.count("1")
1✔
785
    if s.count("0") != n - k:
1✔
786
        raise VisualizationError("s must be a string of 0 and 1")
×
787
    ones = [pos for pos, char in enumerate(s) if char == "1"]
1✔
788
    return lex_index(n, k, ones)
1✔
789

790

791
def phase_to_rgb(complex_number):
1✔
792
    """Map a phase of a complexnumber to a color in (r,g,b).
793

794
    complex_number is phase is first mapped to angle in the range
795
    [0, 2pi] and then to the HSL color wheel
796
    """
797
    angles = (np.angle(complex_number) + (np.pi * 5 / 4)) % (np.pi * 2)
1✔
798
    rgb = colorsys.hls_to_rgb(angles / (np.pi * 2), 0.5, 0.5)
1✔
799
    return rgb
1✔
800

801

802
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
803
@_optionals.HAS_SEABORN.require_in_call
1✔
804
def plot_state_qsphere(
1✔
805
    state,
806
    figsize=None,
807
    ax=None,
808
    show_state_labels=True,
809
    show_state_phases=False,
810
    use_degrees=False,
811
    *,
812
    filename=None,
813
):
814
    """Plot the qsphere representation of a quantum state.
815
    Here, the size of the points is proportional to the probability
816
    of the corresponding term in the state and the color represents
817
    the phase.
818

819
    Args:
820
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
821
        figsize (tuple): Figure size in inches.
822
        ax (matplotlib.axes.Axes): An optional Axes object to be used for
823
            the visualization output. If none is specified a new matplotlib
824
            Figure will be created and used. Additionally, if specified there
825
            will be no returned Figure since it is redundant.
826
        show_state_labels (bool): An optional boolean indicating whether to
827
            show labels for each basis state.
828
        show_state_phases (bool): An optional boolean indicating whether to
829
            show the phase for each basis state.
830
        use_degrees (bool): An optional boolean indicating whether to use
831
            radians or degrees for the phase values in the plot.
832

833
    Returns:
834
        :class:`matplotlib:matplotlib.figure.Figure` :
835
            A matplotlib figure instance if the ``ax`` kwarg is not set
836

837
    Raises:
838
        MissingOptionalLibraryError: Requires matplotlib.
839
        VisualizationError: Input is not a valid N-qubit state.
840

841
        QiskitError: Input statevector does not have valid dimensions.
842

843
    Examples:
844
        .. plot::
845
           :alt: Output from the previous code.
846
           :include-source:
847

848
           from qiskit import QuantumCircuit
849
           from qiskit.quantum_info import Statevector
850
           from qiskit.visualization import plot_state_qsphere
851

852
           qc = QuantumCircuit(2)
853
           qc.h(0)
854
           qc.cx(0, 1)
855

856
           state = Statevector(qc)
857
           plot_state_qsphere(state)
858

859
        .. plot::
860
           :alt: Output from the previous code.
861
           :include-source:
862

863
           # You can show the phase of each state and use
864
           # degrees instead of radians
865

866
           from qiskit.quantum_info import DensityMatrix
867
           import numpy as np
868
           from qiskit import QuantumCircuit
869
           from qiskit.visualization import plot_state_qsphere
870

871
           qc = QuantumCircuit(2)
872
           qc.h([0, 1])
873
           qc.cz(0,1)
874
           qc.ry(np.pi/3, 0)
875
           qc.rx(np.pi/5, 1)
876
           qc.z(1)
877

878
           matrix = DensityMatrix(qc)
879
           plot_state_qsphere(matrix,
880
                show_state_phases = True, use_degrees = True)
881
    """
882
    from matplotlib import gridspec
1✔
883
    from matplotlib import pyplot as plt
1✔
884
    from matplotlib.patches import Circle
1✔
885
    import seaborn as sns
1✔
886
    from scipy import linalg
1✔
887
    from .bloch import Arrow3D
1✔
888

889
    rho = DensityMatrix(state)
1✔
890
    num = rho.num_qubits
1✔
891
    if num is None:
1✔
892
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
893
    # get the eigenvectors and eigenvalues
894
    eigvals, eigvecs = linalg.eigh(rho.data)
1✔
895

896
    if figsize is None:
1✔
897
        figsize = (7, 7)
1✔
898

899
    if ax is None:
1✔
900
        return_fig = True
1✔
901
        fig = plt.figure(figsize=figsize)
1✔
902
    else:
903
        return_fig = False
×
904
        fig = ax.get_figure()
×
905

906
    gs = gridspec.GridSpec(nrows=3, ncols=3)
1✔
907

908
    ax = fig.add_subplot(gs[0:3, 0:3], projection="3d")
1✔
909
    ax.axes.set_xlim3d(-1.0, 1.0)
1✔
910
    ax.axes.set_ylim3d(-1.0, 1.0)
1✔
911
    ax.axes.set_zlim3d(-1.0, 1.0)
1✔
912
    ax.axes.grid(False)
1✔
913
    ax.view_init(elev=5, azim=275)
1✔
914

915
    # Force aspect ratio
916
    # MPL 3.2 or previous do not have set_box_aspect
917
    if hasattr(ax.axes, "set_box_aspect"):
1✔
918
        ax.axes.set_box_aspect((1, 1, 1))
1✔
919

920
    # start the plotting
921
    # Plot semi-transparent sphere
922
    u = np.linspace(0, 2 * np.pi, 25)
1✔
923
    v = np.linspace(0, np.pi, 25)
1✔
924
    x = np.outer(np.cos(u), np.sin(v))
1✔
925
    y = np.outer(np.sin(u), np.sin(v))
1✔
926
    z = np.outer(np.ones(np.size(u)), np.cos(v))
1✔
927
    ax.plot_surface(
1✔
928
        x, y, z, rstride=1, cstride=1, color=plt.rcParams["grid.color"], alpha=0.2, linewidth=0
929
    )
930

931
    # Get rid of the panes
932
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
933
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
934
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
935

936
    # Get rid of the spines
937
    ax.xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
938
    ax.yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
939
    ax.zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
940

941
    # Get rid of the ticks
942
    ax.set_xticks([])
1✔
943
    ax.set_yticks([])
1✔
944
    ax.set_zticks([])
1✔
945

946
    # traversing the eigvals/vecs backward as sorted low->high
947
    for idx in range(eigvals.shape[0] - 1, -1, -1):
1✔
948
        if eigvals[idx] > 0.001:
1✔
949
            # get the max eigenvalue
950
            state = eigvecs[:, idx]
1✔
951
            loc = np.absolute(state).argmax()
1✔
952
            # remove the global phase from max element
953
            angles = (np.angle(state[loc]) + 2 * np.pi) % (2 * np.pi)
1✔
954
            angleset = np.exp(-1j * angles)
1✔
955
            state = angleset * state
1✔
956

957
            d = num
1✔
958
            for i in range(2**num):
1✔
959
                # get x,y,z points
960
                element = bin(i)[2:].zfill(num)
1✔
961
                weight = element.count("1")
1✔
962
                zvalue = -2 * weight / d + 1
1✔
963
                number_of_divisions = n_choose_k(d, weight)
1✔
964
                weight_order = bit_string_index(element)
1✔
965
                angle = (float(weight) / d) * (np.pi * 2) + (
1✔
966
                    weight_order * 2 * (np.pi / number_of_divisions)
967
                )
968

969
                if (weight > d / 2) or (
1✔
970
                    (weight == d / 2) and (weight_order >= number_of_divisions / 2)
971
                ):
972
                    angle = np.pi - angle - (2 * np.pi / number_of_divisions)
1✔
973

974
                xvalue = np.sqrt(1 - zvalue**2) * np.cos(angle)
1✔
975
                yvalue = np.sqrt(1 - zvalue**2) * np.sin(angle)
1✔
976

977
                # get prob and angle - prob will be shade and angle color
978
                prob = np.real(np.dot(state[i], state[i].conj()))
1✔
979
                prob = min(prob, 1)  # See https://github.com/Qiskit/qiskit-terra/issues/4666
1✔
980
                colorstate = phase_to_rgb(state[i])
1✔
981

982
                alfa = 1
1✔
983
                if yvalue >= 0.1:
1✔
984
                    alfa = 1.0 - yvalue
1✔
985

986
                if not np.isclose(prob, 0) and show_state_labels:
1✔
987
                    rprime = 1.3
1✔
988
                    angle_theta = np.arctan2(np.sqrt(1 - zvalue**2), zvalue)
1✔
989
                    xvalue_text = rprime * np.sin(angle_theta) * np.cos(angle)
1✔
990
                    yvalue_text = rprime * np.sin(angle_theta) * np.sin(angle)
1✔
991
                    zvalue_text = rprime * np.cos(angle_theta)
1✔
992
                    element_text = "$\\vert" + element + "\\rangle$"
1✔
993
                    if show_state_phases:
1✔
994
                        element_angle = (np.angle(state[i]) + (np.pi * 4)) % (np.pi * 2)
×
995
                        if use_degrees:
×
996
                            element_text += f"\n${element_angle * 180 / np.pi:.1f}^\\circ$"
×
997
                        else:
998
                            element_angle = pi_check(element_angle, ndigits=3).replace("pi", "\\pi")
×
999
                            element_text += f"\n${element_angle}$"
×
1000
                    ax.text(
1✔
1001
                        xvalue_text,
1002
                        yvalue_text,
1003
                        zvalue_text,
1004
                        element_text,
1005
                        ha="center",
1006
                        va="center",
1007
                        size=12,
1008
                    )
1009

1010
                ax.plot(
1✔
1011
                    [xvalue],
1012
                    [yvalue],
1013
                    [zvalue],
1014
                    markerfacecolor=colorstate,
1015
                    markeredgecolor=colorstate,
1016
                    marker="o",
1017
                    markersize=np.sqrt(prob) * 30,
1018
                    alpha=alfa,
1019
                )
1020

1021
                a = Arrow3D(
1✔
1022
                    [0, xvalue],
1023
                    [0, yvalue],
1024
                    [0, zvalue],
1025
                    mutation_scale=20,
1026
                    alpha=prob,
1027
                    arrowstyle="-",
1028
                    color=colorstate,
1029
                    lw=2,
1030
                )
1031
                ax.add_artist(a)
1✔
1032

1033
            # add weight lines
1034
            for weight in range(d + 1):
1✔
1035
                theta = np.linspace(-2 * np.pi, 2 * np.pi, 100)
1✔
1036
                z = -2 * weight / d + 1
1✔
1037
                r = np.sqrt(1 - z**2)
1✔
1038
                x = r * np.cos(theta)
1✔
1039
                y = r * np.sin(theta)
1✔
1040
                ax.plot(x, y, z, color=(0.5, 0.5, 0.5), lw=1, ls=":", alpha=0.5)
1✔
1041

1042
            # add center point
1043
            ax.plot(
1✔
1044
                [0],
1045
                [0],
1046
                [0],
1047
                markerfacecolor=(0.5, 0.5, 0.5),
1048
                markeredgecolor=(0.5, 0.5, 0.5),
1049
                marker="o",
1050
                markersize=3,
1051
                alpha=1,
1052
            )
1053
        else:
1054
            break
1✔
1055

1056
    n = 64
1✔
1057
    theta = np.ones(n)
1✔
1058
    colors = sns.hls_palette(n)
1✔
1059

1060
    ax2 = fig.add_subplot(gs[2:, 2:])
1✔
1061
    ax2.pie(theta, colors=colors[5 * n // 8 :] + colors[: 5 * n // 8], radius=0.75)
1✔
1062
    ax2.add_artist(Circle((0, 0), 0.5, color="white", zorder=1))
1✔
1063
    offset = 0.95  # since radius of sphere is one.
1✔
1064

1065
    if use_degrees:
1✔
1066
        labels = ["Phase\n(Deg)", "0", "90", "180   ", "270"]
×
1067
    else:
1068
        labels = ["Phase", "$0$", "$\\pi/2$", "$\\pi$", "$3\\pi/2$"]
1✔
1069

1070
    ax2.text(0, 0, labels[0], horizontalalignment="center", verticalalignment="center", fontsize=14)
1✔
1071
    ax2.text(
1✔
1072
        offset, 0, labels[1], horizontalalignment="center", verticalalignment="center", fontsize=14
1073
    )
1074
    ax2.text(
1✔
1075
        0, offset, labels[2], horizontalalignment="center", verticalalignment="center", fontsize=14
1076
    )
1077
    ax2.text(
1✔
1078
        -offset, 0, labels[3], horizontalalignment="center", verticalalignment="center", fontsize=14
1079
    )
1080
    ax2.text(
1✔
1081
        0, -offset, labels[4], horizontalalignment="center", verticalalignment="center", fontsize=14
1082
    )
1083

1084
    if return_fig:
1✔
1085
        matplotlib_close_if_inline(fig)
1✔
1086
    if filename is None:
1✔
1087
        return fig
1✔
1088
    else:
1089
        return fig.savefig(filename)
×
1090

1091

1092
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
1093
def generate_facecolors(x, y, z, dx, dy, dz, color):
1✔
1094
    """Generates shaded facecolors for shaded bars.
1095

1096
    This is here to work around a Matplotlib bug
1097
    where alpha does not work in Bar3D.
1098

1099
    Args:
1100
        x (array_like): The x- coordinates of the anchor point of the bars.
1101
        y (array_like): The y- coordinates of the anchor point of the bars.
1102
        z (array_like): The z- coordinates of the anchor point of the bars.
1103
        dx (array_like): Width of bars.
1104
        dy (array_like): Depth of bars.
1105
        dz (array_like): Height of bars.
1106
        color (array_like): sequence of valid color specifications, optional
1107
    Returns:
1108
        list: Shaded colors for bars.
1109
    Raises:
1110
        MissingOptionalLibraryError: If matplotlib is not installed
1111
    """
1112
    import matplotlib.colors as mcolors
×
1113

1114
    cuboid = np.array(
×
1115
        [
1116
            # -z
1117
            (
1118
                (0, 0, 0),
1119
                (0, 1, 0),
1120
                (1, 1, 0),
1121
                (1, 0, 0),
1122
            ),
1123
            # +z
1124
            (
1125
                (0, 0, 1),
1126
                (1, 0, 1),
1127
                (1, 1, 1),
1128
                (0, 1, 1),
1129
            ),
1130
            # -y
1131
            (
1132
                (0, 0, 0),
1133
                (1, 0, 0),
1134
                (1, 0, 1),
1135
                (0, 0, 1),
1136
            ),
1137
            # +y
1138
            (
1139
                (0, 1, 0),
1140
                (0, 1, 1),
1141
                (1, 1, 1),
1142
                (1, 1, 0),
1143
            ),
1144
            # -x
1145
            (
1146
                (0, 0, 0),
1147
                (0, 0, 1),
1148
                (0, 1, 1),
1149
                (0, 1, 0),
1150
            ),
1151
            # +x
1152
            (
1153
                (1, 0, 0),
1154
                (1, 1, 0),
1155
                (1, 1, 1),
1156
                (1, 0, 1),
1157
            ),
1158
        ]
1159
    )
1160

1161
    # indexed by [bar, face, vertex, coord]
1162
    polys = np.empty(x.shape + cuboid.shape)
×
1163
    # handle each coordinate separately
1164
    for i, p, dp in [(0, x, dx), (1, y, dy), (2, z, dz)]:
×
1165
        p = p[..., np.newaxis, np.newaxis]
×
1166
        dp = dp[..., np.newaxis, np.newaxis]
×
1167
        polys[..., i] = p + dp * cuboid[..., i]
×
1168

1169
    # collapse the first two axes
1170
    polys = polys.reshape((-1,) + polys.shape[2:])
×
1171

1172
    facecolors = []
×
1173
    if len(color) == len(x):
×
1174
        # bar colors specified, need to expand to number of faces
1175
        for c in color:
×
1176
            facecolors.extend([c] * 6)
×
1177
    else:
1178
        # a single color specified, or face colors specified explicitly
1179
        facecolors = list(mcolors.to_rgba_array(color))
×
1180
        if len(facecolors) < len(x):
×
1181
            facecolors *= 6 * len(x)
×
1182

1183
    normals = _generate_normals(polys)
×
1184
    return _shade_colors(facecolors, normals)
×
1185

1186

1187
def _generate_normals(polygons):
1✔
1188
    """Takes a list of polygons and return an array of their normals.
1189

1190
    Normals point towards the viewer for a face with its vertices in
1191
    counterclockwise order, following the right hand rule.
1192
    Uses three points equally spaced around the polygon.
1193
    This normal of course might not make sense for polygons with more than
1194
    three points not lying in a plane, but it's a plausible and fast
1195
    approximation.
1196

1197
    Args:
1198
        polygons (list): list of (M_i, 3) array_like, or (..., M, 3) array_like
1199
            A sequence of polygons to compute normals for, which can have
1200
            varying numbers of vertices. If the polygons all have the same
1201
            number of vertices and array is passed, then the operation will
1202
            be vectorized.
1203
    Returns:
1204
        normals: (..., 3) array_like
1205
            A normal vector estimated for the polygon.
1206
    """
1207
    if isinstance(polygons, np.ndarray):
×
1208
        # optimization: polygons all have the same number of points, so can
1209
        # vectorize
1210
        n = polygons.shape[-2]
×
1211
        i1, i2, i3 = 0, n // 3, 2 * n // 3
×
1212
        v1 = polygons[..., i1, :] - polygons[..., i2, :]
×
1213
        v2 = polygons[..., i2, :] - polygons[..., i3, :]
×
1214
    else:
1215
        # The subtraction doesn't vectorize because polygons is jagged.
1216
        v1 = np.empty((len(polygons), 3))
×
1217
        v2 = np.empty((len(polygons), 3))
×
1218
        for poly_i, ps in enumerate(polygons):
×
1219
            n = len(ps)
×
1220
            i1, i2, i3 = 0, n // 3, 2 * n // 3
×
1221
            v1[poly_i, :] = ps[i1, :] - ps[i2, :]
×
1222
            v2[poly_i, :] = ps[i2, :] - ps[i3, :]
×
1223

1224
    return np.cross(v1, v2)
×
1225

1226

1227
def _shade_colors(color, normals, lightsource=None):
1✔
1228
    """
1229
    Shade *color* using normal vectors given by *normals*.
1230
    *color* can also be an array of the same length as *normals*.
1231
    """
1232
    from matplotlib.colors import Normalize, LightSource
×
1233
    import matplotlib.colors as mcolors
×
1234

1235
    if lightsource is None:
×
1236
        # chosen for backwards-compatibility
1237
        lightsource = LightSource(azdeg=225, altdeg=19.4712)
×
1238

1239
    def mod(v):
×
1240
        return np.sqrt(v[0] ** 2 + v[1] ** 2 + v[2] ** 2)
×
1241

1242
    shade = np.array(
×
1243
        [np.dot(n / mod(n), lightsource.direction) if mod(n) else np.nan for n in normals]
1244
    )
1245
    mask = ~np.isnan(shade)
×
1246

1247
    if mask.any():
×
1248
        norm = Normalize(min(shade[mask]), max(shade[mask]))
×
1249
        shade[~mask] = min(shade[mask])
×
1250
        color = mcolors.to_rgba_array(color)
×
1251
        # shape of color should be (M, 4) (where M is number of faces)
1252
        # shape of shade should be (M,)
1253
        # colors should have final shape of (M, 4)
1254
        alpha = color[:, 3]
×
1255
        colors = (0.5 + norm(shade)[:, np.newaxis] * 0.5) * color
×
1256
        colors[:, 3] = alpha
×
1257
    else:
1258
        colors = np.asanyarray(color).copy()
×
1259

1260
    return colors
×
1261

1262

1263
def state_to_latex(
1✔
1264
    state: Union[Statevector, DensityMatrix], dims: bool = None, convention: str = "ket", **args
1265
) -> str:
1266
    """Return a Latex representation of a state. Wrapper function
1267
    for `qiskit.visualization.array_to_latex` for convention 'vector'.
1268
    Adds dims if necessary.
1269
    Intended for use within `state_drawer`.
1270

1271
    Args:
1272
        state: State to be drawn
1273
        dims (bool): Whether to display the state's `dims`
1274
        convention (str): Either 'vector' or 'ket'. For 'ket' plot the state in the ket-notation.
1275
                Otherwise plot as a vector
1276
        **args: Arguments to be passed directly to `array_to_latex` for convention 'ket'
1277

1278
    Returns:
1279
        Latex representation of the state
1280
        MissingOptionalLibrary: If SymPy isn't installed and ``'latex'`` or
1281
            ``'latex_source'`` is selected for ``output``.
1282

1283
    """
1284
    if dims is None:  # show dims if state is not only qubits
1✔
1285
        if set(state.dims()) == {2}:
1✔
1286
            dims = False
1✔
1287
        else:
1288
            dims = True
×
1289

1290
    prefix = ""
1✔
1291
    suffix = ""
1✔
1292
    if dims:
1✔
1293
        prefix = "\\begin{align}\n"
×
1294
        dims_str = state._op_shape.dims_l()
×
1295
        suffix = f"\\\\\n\\text{{dims={dims_str}}}\n\\end{{align}}"
×
1296

1297
    operator_shape = state._op_shape
1✔
1298
    # we only use the ket convetion for qubit statevectors
1299
    # this means the operator shape should hve no input dimensions and all output dimensions equal to 2
1300
    is_qubit_statevector = len(operator_shape.dims_r()) == 0 and set(operator_shape.dims_l()) == {2}
1✔
1301
    if convention == "ket" and is_qubit_statevector:
1✔
1302
        latex_str = _state_to_latex_ket(state._data, **args)
1✔
1303
    else:
1304
        latex_str = array_to_latex(state._data, source=True, **args)
1✔
1305
    return prefix + latex_str + suffix
1✔
1306

1307

1308
def _numbers_to_latex_terms(numbers: List[complex], decimals: int = 10) -> List[str]:
1✔
1309
    """Convert a list of numbers to latex formatted terms
1310

1311
    The first non-zero term is treated differently. For this term a leading + is suppressed.
1312

1313
    Args:
1314
        numbers: List of numbers to format
1315
        decimals: Number of decimal places to round to (default: 10).
1316
    Returns:
1317
        List of formatted terms
1318
    """
1319
    first_term = True
1✔
1320
    terms = []
1✔
1321
    for number in numbers:
1✔
1322
        term = _num_to_latex(number, decimals=decimals, first_term=first_term, coefficient=True)
1✔
1323
        terms.append(term)
1✔
1324
        first_term = False
1✔
1325
    return terms
1✔
1326

1327

1328
def _state_to_latex_ket(
1✔
1329
    data: List[complex], max_size: int = 12, prefix: str = "", decimals: int = 10
1330
) -> str:
1331
    """Convert state vector to latex representation
1332

1333
    Args:
1334
        data: State vector
1335
        max_size: Maximum number of non-zero terms in the expression. If the number of
1336
                 non-zero terms is larger than the max_size, then the representation is truncated.
1337
        prefix: Latex string to be prepended to the latex, intended for labels.
1338
        decimals: Number of decimal places to round to (default: 10).
1339

1340
    Returns:
1341
        String with LaTeX representation of the state vector
1342
    """
1343
    num = int(math.log2(len(data)))
1✔
1344

1345
    def ket_name(i):
1✔
1346
        return bin(i)[2:].zfill(num)
1✔
1347

1348
    data = np.around(data, decimals)
1✔
1349
    nonzero_indices = np.where(data != 0)[0].tolist()
1✔
1350
    if len(nonzero_indices) > max_size:
1✔
1351
        nonzero_indices = (
1✔
1352
            nonzero_indices[: max_size // 2] + [0] + nonzero_indices[-max_size // 2 + 1 :]
1353
        )
1354
        latex_terms = _numbers_to_latex_terms(data[nonzero_indices], decimals)
1✔
1355
        nonzero_indices[max_size // 2] = None
1✔
1356
    else:
1357
        latex_terms = _numbers_to_latex_terms(data[nonzero_indices], decimals)
1✔
1358

1359
    latex_str = ""
1✔
1360
    for idx, ket_idx in enumerate(nonzero_indices):
1✔
1361
        if ket_idx is None:
1✔
1362
            latex_str += r" + \ldots "
1✔
1363
        else:
1364
            term = latex_terms[idx]
1✔
1365
            ket = ket_name(ket_idx)
1✔
1366
            latex_str += f"{term} |{ket}\\rangle"
1✔
1367
    return prefix + latex_str
1✔
1368

1369

1370
class TextMatrix:
1✔
1371
    """Text representation of an array, with `__str__` method so it
1372
    displays nicely in Jupyter notebooks"""
1373

1374
    def __init__(self, state, max_size=8, dims=None, prefix="", suffix=""):
1✔
1375
        self.state = state
1✔
1376
        self.max_size = max_size
1✔
1377
        if dims is None:  # show dims if state is not only qubits
1✔
1378
            if (isinstance(state, (Statevector, DensityMatrix)) and set(state.dims()) == {2}) or (
1✔
1379
                isinstance(state, Operator)
1380
                and len(state.input_dims()) == len(state.output_dims())
1381
                and set(state.input_dims()) == set(state.output_dims()) == {2}
1382
            ):
1383
                dims = False
1✔
1384
            else:
1385
                dims = True
×
1386
        self.dims = dims
1✔
1387
        self.prefix = prefix
1✔
1388
        self.suffix = suffix
1✔
1389
        if isinstance(max_size, int):
1✔
1390
            self.max_size = max_size
1✔
1391
        elif isinstance(state, DensityMatrix):
×
1392
            # density matrices are square, so threshold for
1393
            # summarization is shortest side squared
1394
            self.max_size = min(max_size) ** 2
×
1395
        else:
1396
            self.max_size = max_size[0]
×
1397

1398
    def __str__(self):
1✔
1399
        threshold = self.max_size
×
1400
        data = np.array2string(
×
1401
            self.state._data, prefix=self.prefix, threshold=threshold, separator=","
1402
        )
1403
        dimstr = ""
×
1404
        if self.dims:
×
1405
            data += ",\n"
×
1406
            dimstr += " " * len(self.prefix)
×
1407
            if isinstance(self.state, (Statevector, DensityMatrix)):
×
1408
                dimstr += f"dims={self.state._op_shape.dims_l()}"
×
1409
            else:
1410
                dimstr += f"input_dims={self.state.input_dims()}, "
×
1411
                dimstr += f"output_dims={self.state.output_dims()}"
×
1412

1413
        return self.prefix + data + dimstr + self.suffix
×
1414

1415
    def __repr__(self):
1416
        return self.__str__()
1417

1418

1419
def state_drawer(state, output=None, **drawer_args):
1✔
1420
    """Returns a visualization of the state.
1421

1422
    **repr**: ASCII TextMatrix of the state's ``_repr_``.
1423

1424
    **text**: ASCII TextMatrix that can be printed in the console.
1425

1426
    **latex**: An IPython Latex object for displaying in Jupyter Notebooks.
1427

1428
    **latex_source**: Raw, uncompiled ASCII source to generate array using LaTeX.
1429

1430
    **qsphere**: Matplotlib figure, rendering of statevector using `plot_state_qsphere()`.
1431

1432
    **hinton**: Matplotlib figure, rendering of statevector using `plot_state_hinton()`.
1433

1434
    **bloch**: Matplotlib figure, rendering of statevector using `plot_bloch_multivector()`.
1435

1436
    **city**: Matplotlib figure, rendering of statevector using `plot_state_city()`.
1437

1438
    **paulivec**: Matplotlib figure, rendering of statevector using `plot_state_paulivec()`.
1439

1440
    Args:
1441
        output (str): Select the output method to use for drawing the
1442
            circuit. Valid choices are ``text``, ``latex``, ``latex_source``,
1443
            ``qsphere``, ``hinton``, ``bloch``, ``city`` or ``paulivec``.
1444
            Default is `'text`'.
1445
        drawer_args: Arguments to be passed to the relevant drawer. For
1446
            'latex' and 'latex_source' see ``array_to_latex``
1447

1448
    Returns:
1449
        :class:`matplotlib.figure` or :class:`str` or
1450
        :class:`TextMatrix` or :class:`IPython.display.Latex`:
1451
        Drawing of the state.
1452

1453
    Raises:
1454
        MissingOptionalLibraryError: when `output` is `latex` and IPython is not installed.
1455
            or if SymPy isn't installed and ``'latex'`` or ``'latex_source'`` is selected for
1456
            ``output``.
1457

1458
        ValueError: when `output` is not a valid selection.
1459
    """
1460
    config = user_config.get_config()
1✔
1461
    # Get default 'output' from config file else use 'repr'
1462
    default_output = "repr"
1✔
1463
    if output is None:
1✔
1464
        if config:
×
1465
            default_output = config.get("state_drawer", "repr")
×
1466
        output = default_output
×
1467
    output = output.lower()
1✔
1468

1469
    # Choose drawing backend:
1470
    drawers = {
1✔
1471
        "text": TextMatrix,
1472
        "latex_source": state_to_latex,
1473
        "qsphere": plot_state_qsphere,
1474
        "hinton": plot_state_hinton,
1475
        "bloch": plot_bloch_multivector,
1476
        "city": plot_state_city,
1477
        "paulivec": plot_state_paulivec,
1478
    }
1479
    if output == "latex":
1✔
1480
        _optionals.HAS_IPYTHON.require_now("state_drawer")
1✔
1481
        from IPython.display import Latex
1✔
1482

1483
        draw_func = drawers["latex_source"]
1✔
1484
        return Latex(f"$${draw_func(state, **drawer_args)}$$")
1✔
1485

1486
    if output == "repr":
1✔
1487
        return state.__repr__()
1✔
1488

1489
    try:
1✔
1490
        draw_func = drawers[output]
1✔
1491
        return draw_func(state, **drawer_args)
1✔
1492
    except KeyError as err:
×
1493
        raise ValueError(
×
1494
            f"""'{output}' is not a valid option for drawing {type(state).__name__}
1495
             objects. Please choose from:
1496
            'text', 'latex', 'latex_source', 'qsphere', 'hinton',
1497
            'bloch', 'city' or 'paulivec'."""
1498
        ) from err
1499

1500

1501
def _bloch_multivector_data(state):
1✔
1502
    """Return list of Bloch vectors for each qubit
1503

1504
    Args:
1505
        state (DensityMatrix or Statevector): an N-qubit state.
1506

1507
    Returns:
1508
        list: list of Bloch vectors (x, y, z) for each qubit.
1509

1510
    Raises:
1511
        VisualizationError: if input is not an N-qubit state.
1512
    """
1513
    rho = DensityMatrix(state)
1✔
1514
    num = rho.num_qubits
1✔
1515
    if num is None:
1✔
1516
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
1517
    pauli_singles = PauliList(["X", "Y", "Z"])
1✔
1518
    bloch_data = []
1✔
1519
    for i in range(num):
1✔
1520
        if num > 1:
1✔
1521
            paulis = PauliList.from_symplectic(
1✔
1522
                np.zeros((3, (num - 1)), dtype=bool), np.zeros((3, (num - 1)), dtype=bool)
1523
            ).insert(i, pauli_singles, qubit=True)
1524
        else:
1525
            paulis = pauli_singles
×
1526
        bloch_state = [np.real(np.trace(np.dot(mat, rho.data))) for mat in paulis.matrix_iter()]
1✔
1527
        bloch_data.append(bloch_state)
1✔
1528
    return bloch_data
1✔
1529

1530

1531
def _paulivec_data(state):
1✔
1532
    """Return paulivec data for plotting.
1533

1534
    Args:
1535
        state (DensityMatrix or Statevector): an N-qubit state.
1536

1537
    Returns:
1538
        tuple: (labels, values) for Pauli vector.
1539

1540
    Raises:
1541
        VisualizationError: if input is not an N-qubit state.
1542
    """
1543
    rho = SparsePauliOp.from_operator(DensityMatrix(state))
1✔
1544
    if rho.num_qubits is None:
1✔
1545
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
1546
    return rho.paulis.to_labels(), np.real(rho.coeffs * 2**rho.num_qubits)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc