• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Qiskit / qiskit / 22309958047

23 Feb 2026 02:17PM UTC coverage: 87.87% (+0.03%) from 87.843%
22309958047

push

github

web-flow
Pivot to using ruff for all linting (#15603)

* Pivot to using ruff for all linting

This commit switches us to fully using ruff for all the linting in CI
and locally. The primary motivation for this change is to improve
productivity because ruff is signficantly faster. Pylint is incredibly
slow in general, but compared to ruff especially so. For example, on my
laptop ruff takes 0.04 seconds to run on the qiskit/ subdirectory (after
clearing the cache, with the cache populated it takes 0.025 sec) of the
source tree while running pylint on the same path took 70 sec. This leads
to people skipping lint locally and causes churn in CI becaus.

We had started to experimenting with ruff in the past and
used it for a some small set of rules but were still using pylint for
the bulk of the linting in the repo. The concern at the time was a loss
of lint coverage or a lot of code churn caused by migrating to a new tool.
Specifically pylint does more type inference and checking that ruff
doesn't. However since we started the experiment one major change in
qiskit is how much work is happening in rust now vs Python. At this
point any loss in lint coverage is unlikely to cause a significant
problem in practice and we'll make real productivity gains by making
this change.

* Remove out of date comment

* Update makefile

* Enable more rules

* Revert lambda autofixes

* Fix new rules

* Add bandit rules

* Enable Ruff native rules

* Remove pylint disable comments

* Enable docstring rules

This commit adds the docstring rules on the repo. This involves a few
more changes than previous commits because there are a lot of formatting
consistency rules that needed to be auto-applied. The checking also
found several instances where there was missing documentation that
should have been included.

* Add categories from old ruff config

* Fix from rebase

* Add flake8-raise rule

* Add flake8-pie rules

* Add implicit namespace rules

* Fix deprecation decorator error on impo... (continued)

624 of 677 new or added lines in 198 files covered. (92.17%)

16 existing lines in 9 files now uncovered.

100211 of 114044 relevant lines covered (87.87%)

1151735.18 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

57.3
/qiskit/visualization/state_visualization.py
1
# This code is part of Qiskit.
2
#
3
# (C) Copyright IBM 2017, 2023.
4
#
5
# This code is licensed under the Apache License, Version 2.0. You may
6
# obtain a copy of this license in the LICENSE.txt file in the root directory
7
# of this source tree or at https://www.apache.org/licenses/LICENSE-2.0.
8
#
9
# Any modifications or derivative works of this code must retain this
10
# copyright notice, and modified files need to carry a notice indicating
11
# that they have been altered from the originals.
12

13

14
"""
15
Visualization functions for quantum states.
16
"""
17

18
import math
1✔
19

20
from functools import reduce
1✔
21
import colorsys
1✔
22

23
import numpy as np
1✔
24
from qiskit import user_config
1✔
25
from qiskit.quantum_info.states.statevector import Statevector
1✔
26
from qiskit.quantum_info.operators.operator import Operator
1✔
27
from qiskit.quantum_info.operators.symplectic import PauliList, SparsePauliOp
1✔
28
from qiskit.quantum_info.states.densitymatrix import DensityMatrix
1✔
29
from qiskit.utils import optionals as _optionals
1✔
30
from qiskit.circuit.tools.pi_check import pi_check
1✔
31

32
from .array import _num_to_latex, array_to_latex
1✔
33
from .utils import matplotlib_close_if_inline
1✔
34
from .exceptions import VisualizationError
1✔
35

36

37
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
38
def plot_state_hinton(state, title="", figsize=None, ax_real=None, ax_imag=None, *, filename=None):
1✔
39
    """Plot a hinton diagram for the density matrix of a quantum state.
40

41
    The hinton diagram represents the values of a matrix using
42
    squares, whose size indicate the magnitude of their corresponding value
43
    and their color, its sign. A white square means the value is positive and
44
    a black one means negative.
45

46
    Args:
47
        state (Statevector or DensityMatrix or ndarray): An N-qubit quantum state.
48
        title (str): a string that represents the plot title
49
        figsize (tuple): Figure size in inches.
50
        filename (str): file path to save image to.
51
        ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
52
            the visualization output. If none is specified a new matplotlib
53
            Figure will be created and used. If this is specified without an
54
            ax_imag only the real component plot will be generated.
55
            Additionally, if specified there will be no returned Figure since
56
            it is redundant.
57
        ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
58
            the visualization output. If none is specified a new matplotlib
59
            Figure will be created and used. If this is specified without an
60
            ax_imag only the real component plot will be generated.
61
            Additionally, if specified there will be no returned Figure since
62
            it is redundant.
63

64
    Returns:
65
        :class:`matplotlib:matplotlib.figure.Figure` :
66
            The matplotlib.Figure of the visualization if
67
            neither ax_real or ax_imag is set.
68

69
    Raises:
70
        MissingOptionalLibraryError: Requires matplotlib.
71
        VisualizationError: Input is not a valid N-qubit state.
72

73
    Examples:
74
        .. plot::
75
           :alt: Output from the previous code.
76
           :include-source:
77

78
            import numpy as np
79
            from qiskit import QuantumCircuit
80
            from qiskit.quantum_info import DensityMatrix
81
            from qiskit.visualization import plot_state_hinton
82

83
            qc = QuantumCircuit(2)
84
            qc.h([0, 1])
85
            qc.cz(0,1)
86
            qc.ry(np.pi/3 , 0)
87
            qc.rx(np.pi/5, 1)
88

89
            state = DensityMatrix(qc)
90
            plot_state_hinton(state, title="New Hinton Plot")
91

92
    """
93
    from matplotlib import pyplot as plt
1✔
94

95
    # Figure data
96
    rho = DensityMatrix(state)
1✔
97
    num = rho.num_qubits
1✔
98
    if num is None:
1✔
99
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
100
    max_weight = 2 ** math.ceil(math.log2(np.abs(rho.data).max()))
1✔
101
    datareal = np.real(rho.data)
1✔
102
    dataimag = np.imag(rho.data)
1✔
103

104
    if figsize is None:
1✔
105
        figsize = (8, 5)
1✔
106
    if not ax_real and not ax_imag:
1✔
107
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
1✔
108
    else:
109
        if ax_real:
×
110
            fig = ax_real.get_figure()
×
111
        else:
112
            fig = ax_imag.get_figure()
×
113
        ax1 = ax_real
×
114
        ax2 = ax_imag
×
115
    # Reversal is to account for Qiskit's endianness.
116
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
1✔
117
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)][::-1]
1✔
118
    ly, lx = datareal.shape
1✔
119
    # Real
120
    if ax1:
1✔
121
        ax1.patch.set_facecolor("gray")
1✔
122
        ax1.set_aspect("equal", "box")
1✔
123
        ax1.xaxis.set_major_locator(plt.NullLocator())
1✔
124
        ax1.yaxis.set_major_locator(plt.NullLocator())
1✔
125

126
        for (x, y), w in np.ndenumerate(datareal):
1✔
127
            # Convert from matrix coordinates to plot coordinates.
128
            plot_x, plot_y = y, lx - x - 1
1✔
129
            color = "white" if w > 0 else "black"
1✔
130
            size = np.sqrt(np.abs(w) / max_weight)
1✔
131
            rect = plt.Rectangle(
1✔
132
                [0.5 + plot_x - size / 2, 0.5 + plot_y - size / 2],
133
                size,
134
                size,
135
                facecolor=color,
136
                edgecolor=color,
137
            )
138
            ax1.add_patch(rect)
1✔
139

140
        ax1.set_xticks(0.5 + np.arange(lx))
1✔
141
        ax1.set_yticks(0.5 + np.arange(ly))
1✔
142
        ax1.set_xlim([0, lx])
1✔
143
        ax1.set_ylim([0, ly])
1✔
144
        ax1.set_yticklabels(row_names, fontsize=14)
1✔
145
        ax1.set_xticklabels(column_names, fontsize=14, rotation=90)
1✔
146
        ax1.set_title("Re[$\\rho$]", fontsize=14)
1✔
147
    # Imaginary
148
    if ax2:
1✔
149
        ax2.patch.set_facecolor("gray")
1✔
150
        ax2.set_aspect("equal", "box")
1✔
151
        ax2.xaxis.set_major_locator(plt.NullLocator())
1✔
152
        ax2.yaxis.set_major_locator(plt.NullLocator())
1✔
153

154
        for (x, y), w in np.ndenumerate(dataimag):
1✔
155
            # Convert from matrix coordinates to plot coordinates.
156
            plot_x, plot_y = y, lx - x - 1
1✔
157
            color = "white" if w > 0 else "black"
1✔
158
            size = np.sqrt(np.abs(w) / max_weight)
1✔
159
            rect = plt.Rectangle(
1✔
160
                [0.5 + plot_x - size / 2, 0.5 + plot_y - size / 2],
161
                size,
162
                size,
163
                facecolor=color,
164
                edgecolor=color,
165
            )
166
            ax2.add_patch(rect)
1✔
167

168
        ax2.set_xticks(0.5 + np.arange(lx))
1✔
169
        ax2.set_yticks(0.5 + np.arange(ly))
1✔
170
        ax2.set_xlim([0, lx])
1✔
171
        ax2.set_ylim([0, ly])
1✔
172
        ax2.set_yticklabels(row_names, fontsize=14)
1✔
173
        ax2.set_xticklabels(column_names, fontsize=14, rotation=90)
1✔
174
        ax2.set_title("Im[$\\rho$]", fontsize=14)
1✔
175
    fig.tight_layout()
1✔
176
    if title:
1✔
177
        fig.suptitle(title, fontsize=16)
×
178
    if ax_real is None and ax_imag is None:
1✔
179
        matplotlib_close_if_inline(fig)
1✔
180
    if filename is None:
1✔
181
        return fig
1✔
182
    else:
183
        return fig.savefig(filename)
×
184

185

186
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
187
def plot_bloch_vector(
1✔
188
    bloch, title="", ax=None, figsize=None, coord_type="cartesian", font_size=None
189
):
190
    """Plot the Bloch sphere.
191

192
    Plot a Bloch sphere with the specified coordinates, that can be given in both
193
    cartesian and spherical systems.
194

195
    Args:
196
        bloch (list[double]): array of three elements where [<x>, <y>, <z>] (Cartesian)
197
            or [<r>, <theta>, <phi>] (spherical in radians)
198
            <theta> is inclination angle from +z direction
199
            <phi> is azimuth from +x direction
200
        title (str): a string that represents the plot title
201
        ax (matplotlib.axes.Axes): An Axes to use for rendering the bloch
202
            sphere
203
        figsize (tuple): Figure size in inches. Has no effect is passing ``ax``.
204
        coord_type (str): a string that specifies coordinate type for bloch
205
            (Cartesian or spherical), default is Cartesian
206
        font_size (float): Font size.
207

208
    Returns:
209
        :class:`matplotlib:matplotlib.figure.Figure` : A matplotlib figure instance if ``ax = None``.
210

211
    Raises:
212
        MissingOptionalLibraryError: Requires matplotlib.
213

214
    Examples:
215
        .. plot::
216
           :alt: Output from the previous code.
217
           :include-source:
218

219
           from qiskit.visualization import plot_bloch_vector
220

221
           plot_bloch_vector([0,1,0], title="New Bloch Sphere")
222

223
        .. plot::
224
           :alt: Output from the previous code.
225
           :include-source:
226

227
           import numpy as np
228
           from qiskit.visualization import plot_bloch_vector
229

230
           # You can use spherical coordinates instead of cartesian.
231

232
           plot_bloch_vector([1, np.pi/2, np.pi/3], coord_type='spherical')
233

234
    """
235
    from .bloch import Bloch
1✔
236

237
    if figsize is None:
1✔
238
        figsize = (5, 5)
1✔
239
    B = Bloch(axes=ax, font_size=font_size)
1✔
240
    if coord_type == "spherical":
1✔
241
        r, theta, phi = bloch[0], bloch[1], bloch[2]
×
242
        bloch[0] = r * np.sin(theta) * np.cos(phi)
×
243
        bloch[1] = r * np.sin(theta) * np.sin(phi)
×
244
        bloch[2] = r * np.cos(theta)
×
245
    B.add_vectors(bloch)
1✔
246
    B.render(title=title)
1✔
247
    if ax is None:
1✔
248
        fig = B.fig
×
249
        fig.set_size_inches(figsize[0], figsize[1])
×
250
        matplotlib_close_if_inline(fig)
×
251
        return fig
×
252
    return None
1✔
253

254

255
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
256
def plot_bloch_multivector(
1✔
257
    state,
258
    title="",
259
    figsize=None,
260
    *,
261
    reverse_bits=False,
262
    filename=None,
263
    font_size=None,
264
    title_font_size=None,
265
    title_pad=1,
266
):
267
    r"""Plot a Bloch sphere for each qubit.
268

269
    Each component :math:`(x,y,z)` of the Bloch sphere labeled as 'qubit i' represents the expected
270
    value of the corresponding Pauli operator acting only on that qubit, that is, the expected value
271
    of :math:`I_{N-1} \otimes\dotsb\otimes I_{i+1}\otimes P_i \otimes I_{i-1}\otimes\dotsb\otimes
272
    I_0`, where :math:`N` is the number of qubits, :math:`P\in \{X,Y,Z\}` and :math:`I` is the
273
    identity operator.
274

275
    Args:
276
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
277
        title (str): a string that represents the plot title
278
        figsize (tuple): size of each individual Bloch sphere figure, in inches.
279
        reverse_bits (bool): If True, plots qubits following Qiskit's convention [Default:False].
280
        font_size (float): Font size for the Bloch ball figures.
281
        title_font_size (float): Font size for the title.
282
        title_pad (float): Padding for the title (suptitle ``y`` position is ``0.98``
283
        and the image height will be extended by ``1 + title_pad/100``).
284

285
    Returns:
286
        :class:`matplotlib:matplotlib.figure.Figure` :
287
            A matplotlib figure instance.
288

289
    Raises:
290
        MissingOptionalLibraryError: Requires matplotlib.
291
        VisualizationError: if input is not a valid N-qubit state.
292

293
    Examples:
294
        .. plot::
295
           :alt: Output from the previous code.
296
           :include-source:
297

298
            from qiskit import QuantumCircuit
299
            from qiskit.quantum_info import Statevector
300
            from qiskit.visualization import plot_bloch_multivector
301

302
            qc = QuantumCircuit(2)
303
            qc.h(0)
304
            qc.x(1)
305

306
            state = Statevector(qc)
307
            plot_bloch_multivector(state)
308

309
        .. plot::
310
           :alt: Output from the previous code.
311
           :include-source:
312

313
           from qiskit import QuantumCircuit
314
           from qiskit.quantum_info import Statevector
315
           from qiskit.visualization import plot_bloch_multivector
316

317
           qc = QuantumCircuit(2)
318
           qc.h(0)
319
           qc.x(1)
320

321
           # You can reverse the order of the qubits.
322

323
           from qiskit.quantum_info import DensityMatrix
324

325
           qc = QuantumCircuit(2)
326
           qc.h([0, 1])
327
           qc.t(1)
328
           qc.s(0)
329
           qc.cx(0,1)
330

331
           matrix = DensityMatrix(qc)
332
           plot_bloch_multivector(matrix, title='My Bloch Spheres', reverse_bits=True)
333

334
    """
335
    from matplotlib import pyplot as plt
1✔
336

337
    # Data
338
    bloch_data = (
1✔
339
        _bloch_multivector_data(state)[::-1] if reverse_bits else _bloch_multivector_data(state)
340
    )
341
    num = len(bloch_data)
1✔
342
    if figsize is not None:
1✔
343
        width, height = figsize
×
344
        width *= num
×
345
    else:
346
        width, height = plt.figaspect(1 / num)
1✔
347
    if len(title) > 0:
1✔
348
        height += 1 + title_pad / 100  # additional space for the title
×
349
    default_title_font_size = font_size if font_size is not None else 16
1✔
350
    title_font_size = title_font_size if title_font_size is not None else default_title_font_size
1✔
351
    fig = plt.figure(figsize=(width, height))
1✔
352
    for i in range(num):
1✔
353
        pos = num - 1 - i if reverse_bits else i
1✔
354
        ax = fig.add_subplot(1, num, i + 1, projection="3d")
1✔
355
        plot_bloch_vector(
1✔
356
            bloch_data[i], "qubit " + str(pos), ax=ax, figsize=figsize, font_size=font_size
357
        )
358
    fig.suptitle(title, fontsize=title_font_size, y=0.98)
1✔
359
    matplotlib_close_if_inline(fig)
1✔
360
    if filename is None:
1✔
361
        try:
1✔
362
            fig.tight_layout()
1✔
363
        except AttributeError:
×
364
            pass
×
365
        return fig
1✔
366
    else:
367
        return fig.savefig(filename)
×
368

369

370
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
371
def plot_state_city(
1✔
372
    state,
373
    title="",
374
    figsize=None,
375
    color=None,
376
    alpha=1,
377
    ax_real=None,
378
    ax_imag=None,
379
    *,
380
    filename=None,
381
):
382
    """Plot the cityscape of quantum state.
383

384
    Plot two 3d bar graphs (two dimensional) of the real and imaginary
385
    part of the density matrix rho.
386

387
    Args:
388
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
389
        title (str): a string that represents the plot title
390
        figsize (tuple): Figure size in inches.
391
        color (list): A list of len=2 giving colors for real and
392
            imaginary components of matrix elements.
393
        alpha (float): Transparency value for bars
394
        ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
395
            the visualization output. If none is specified a new matplotlib
396
            Figure will be created and used. If this is specified without an
397
            ax_imag only the real component plot will be generated.
398
            Additionally, if specified there will be no returned Figure since
399
            it is redundant.
400
        ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
401
            the visualization output. If none is specified a new matplotlib
402
            Figure will be created and used. If this is specified without an
403
            ax_real only the imaginary component plot will be generated.
404
            Additionally, if specified there will be no returned Figure since
405
            it is redundant.
406

407
    Returns:
408
        :class:`matplotlib:matplotlib.figure.Figure` :
409
            The matplotlib.Figure of the visualization if the
410
            ``ax_real`` and ``ax_imag`` kwargs are not set
411

412
    Raises:
413
        MissingOptionalLibraryError: Requires matplotlib.
414
        ValueError: When 'color' is not a list of len=2.
415
        VisualizationError: if input is not a valid N-qubit state.
416

417
    Examples:
418
        .. plot::
419
           :alt: Output from the previous code.
420
           :include-source:
421

422
           # You can choose different colors for the real and imaginary parts of the density matrix.
423

424
           from qiskit import QuantumCircuit
425
           from qiskit.quantum_info import DensityMatrix
426
           from qiskit.visualization import plot_state_city
427

428
           qc = QuantumCircuit(2)
429
           qc.h(0)
430
           qc.cx(0, 1)
431

432
           state = DensityMatrix(qc)
433
           plot_state_city(state, color=['midnightblue', 'crimson'], title="New State City")
434

435
        .. plot::
436
           :alt: Output from the previous code.
437
           :include-source:
438

439
           # You can make the bars more transparent to better see the ones that are behind
440
           # if they overlap.
441

442
           import numpy as np
443
           from qiskit.quantum_info import Statevector
444
           from qiskit.visualization import plot_state_city
445
           from qiskit import QuantumCircuit
446

447
           qc = QuantumCircuit(2)
448
           qc.h(0)
449
           qc.cx(0, 1)
450

451

452
           qc = QuantumCircuit(2)
453
           qc.h([0, 1])
454
           qc.cz(0,1)
455
           qc.ry(np.pi/3, 0)
456
           qc.rx(np.pi/5, 1)
457

458
           state = Statevector(qc)
459
           plot_state_city(state, alpha=0.6)
460

461
    """
462
    import matplotlib.colors as mcolors
×
463
    from matplotlib import pyplot as plt
×
464
    from mpl_toolkits.mplot3d.art3d import Poly3DCollection
×
465

466
    rho = DensityMatrix(state)
×
467
    num = rho.num_qubits
×
468
    if num is None:
×
469
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
470

471
    # get the real and imag parts of rho
472
    datareal = np.real(rho.data)
×
473
    dataimag = np.imag(rho.data)
×
474

475
    # get the labels
476
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
×
477
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
×
478

479
    ly, lx = datareal.shape[:2]
×
480
    xpos = np.arange(0, lx, 1)  # Set up a mesh of positions
×
481
    ypos = np.arange(0, ly, 1)
×
482
    xpos, ypos = np.meshgrid(xpos + 0.25, ypos + 0.25)
×
483

484
    xpos = xpos.flatten()
×
485
    ypos = ypos.flatten()
×
486
    zpos = np.zeros(lx * ly)
×
487

488
    dx = 0.5 * np.ones_like(zpos)  # width of bars
×
489
    dy = dx.copy()
×
490
    dzr = datareal.flatten()
×
491
    dzi = dataimag.flatten()
×
492

493
    if color is None:
×
494
        real_color, imag_color = "#648fff", "#648fff"
×
495
    else:
496
        if len(color) != 2:
×
497
            raise ValueError("'color' must be a list of len=2.")
×
498
        real_color = "#648fff" if color[0] is None else color[0]
×
499
        imag_color = "#648fff" if color[1] is None else color[1]
×
500
    if ax_real is None and ax_imag is None:
×
501
        # set default figure size
502
        if figsize is None:
×
503
            figsize = (16, 8)
×
504

505
        fig = plt.figure(figsize=figsize, facecolor="w")
×
506
        ax1 = fig.add_subplot(1, 2, 1, projection="3d", computed_zorder=False)
×
507
        ax2 = fig.add_subplot(1, 2, 2, projection="3d", computed_zorder=False)
×
508

509
    elif ax_real is not None:
×
510
        fig = ax_real.get_figure()
×
511
        ax1 = ax_real
×
512
        ax2 = ax_imag
×
513
    else:
514
        fig = ax_imag.get_figure()
×
515
        ax1 = None
×
516
        ax2 = ax_imag
×
517

518
    fig.tight_layout()
×
519

520
    max_dzr = np.max(dzr)
×
521
    max_dzi = np.max(dzi)
×
522

523
    # Figure scaling variables since fig.tight_layout won't work
524
    fig_width, fig_height = fig.get_size_inches()
×
525
    max_plot_size = min(fig_width / 2.25, fig_height)
×
526
    max_font_size = int(3 * max_plot_size)
×
527
    max_zoom = 10 / (10 + np.sqrt(max_plot_size))
×
528

529
    for ax, dz, col, zlabel in (
×
530
        (ax1, dzr, real_color, "Real"),
531
        (ax2, dzi, imag_color, "Imaginary"),
532
    ):
533

534
        if ax is None:
×
535
            continue
×
536

537
        max_dz = np.max(dz)
×
538
        min_dz = np.min(dz)
×
539

540
        if isinstance(col, str) and col.startswith("#"):
×
541
            col = mcolors.to_rgba_array(col)
×
542

543
        dzn = dz < 0
×
544
        if np.any(dzn):
×
545
            fc = generate_facecolors(
×
546
                xpos[dzn], ypos[dzn], zpos[dzn], dx[dzn], dy[dzn], dz[dzn], col
547
            )
548
            negative_bars = ax.bar3d(
×
549
                xpos[dzn],
550
                ypos[dzn],
551
                zpos[dzn],
552
                dx[dzn],
553
                dy[dzn],
554
                dz[dzn],
555
                alpha=alpha,
556
                zorder=0.625,
557
            )
558
            negative_bars.set_facecolor(fc)
×
559

560
        if min_dz < 0 < max_dz:
×
561
            xlim, ylim = [0, lx], [0, ly]
×
562
            verts = [list(zip(xlim + xlim[::-1], np.repeat(ylim, 2), [0] * 4))]
×
563
            plane = Poly3DCollection(verts, alpha=0.25, facecolor="k", linewidths=1)
×
564
            plane.set_zorder(0.75)
×
565
            ax.add_collection3d(plane)
×
566

567
        dzp = dz >= 0
×
568
        if np.any(dzp):
×
569
            fc = generate_facecolors(
×
570
                xpos[dzp], ypos[dzp], zpos[dzp], dx[dzp], dy[dzp], dz[dzp], col
571
            )
572
            positive_bars = ax.bar3d(
×
573
                xpos[dzp],
574
                ypos[dzp],
575
                zpos[dzp],
576
                dx[dzp],
577
                dy[dzp],
578
                dz[dzp],
579
                alpha=alpha,
580
                zorder=0.875,
581
            )
582
            positive_bars.set_facecolor(fc)
×
583

584
        ax.set_title(f"{zlabel} Amplitude (ρ)", fontsize=max_font_size)
×
585

586
        ax.set_xticks(np.arange(0.5, lx + 0.5, 1))
×
587
        ax.set_yticks(np.arange(0.5, ly + 0.5, 1))
×
588
        if max_dz != min_dz:
×
589
            ax.axes.set_zlim3d(min_dz, max(max_dzr + 1e-9, max_dzi))
×
NEW
590
        elif min_dz == 0:
×
NEW
591
            ax.axes.set_zlim3d(min_dz, max(max_dzr + 1e-9, max_dzi))
×
592
        else:
NEW
593
            ax.axes.set_zlim3d(auto=True)
×
594
        ax.get_autoscalez_on()
×
595

596
        ax.xaxis.set_ticklabels(
×
597
            row_names, fontsize=max_font_size, rotation=45, ha="right", va="top"
598
        )
599
        ax.yaxis.set_ticklabels(
×
600
            column_names, fontsize=max_font_size, rotation=-22.5, ha="left", va="center"
601
        )
602

603
        for tick in ax.zaxis.get_major_ticks():
×
604
            tick.label1.set_fontsize(max_font_size)
×
605
            tick.label1.set_horizontalalignment("left")
×
606
            tick.label1.set_verticalalignment("bottom")
×
607

608
        ax.set_box_aspect(aspect=(4, 4, 4), zoom=max_zoom)
×
609
        ax.set_xmargin(0)
×
610
        ax.set_ymargin(0)
×
611

612
    fig.suptitle(title, fontsize=max_font_size * 1.25)
×
613
    fig.subplots_adjust(top=0.9, bottom=0, left=0, right=1, hspace=0, wspace=0)
×
614
    if ax_real is None and ax_imag is None:
×
615
        matplotlib_close_if_inline(fig)
×
616
    if filename is None:
×
617
        return fig
×
618
    else:
619
        return fig.savefig(filename)
×
620

621

622
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
623
def plot_state_paulivec(state, title="", figsize=None, color=None, ax=None, *, filename=None):
1✔
624
    r"""Plot the Pauli-vector representation of a quantum state as bar graph.
625

626
    The Pauli-vector of a density matrix :math:`\rho` is defined by the expectation of each
627
    possible tensor product of single-qubit Pauli operators (including the identity), that is
628

629
    .. math ::
630

631
        \rho = \frac{1}{2^n} \sum_{\sigma \in \{I, X, Y, Z\}^{\otimes n}}
632
               \mathrm{Tr}(\sigma \rho) \sigma.
633

634
    This function plots the coefficients :math:`\mathrm{Tr}(\sigma\rho)` as bar graph.
635

636
    Args:
637
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
638
        title (str): a string that represents the plot title
639
        figsize (tuple): Figure size in inches.
640
        color (list or str): Color of the coefficient value bars.
641
        ax (matplotlib.axes.Axes): An optional Axes object to be used for
642
            the visualization output. If none is specified a new matplotlib
643
            Figure will be created and used. Additionally, if specified there
644
            will be no returned Figure since it is redundant.
645

646
    Returns:
647
         :class:`matplotlib:matplotlib.figure.Figure` :
648
            The matplotlib.Figure of the visualization if the
649
            ``ax`` kwarg is not set
650

651
    Raises:
652
        MissingOptionalLibraryError: Requires matplotlib.
653
        VisualizationError: if input is not a valid N-qubit state.
654

655
    Examples:
656
        .. plot::
657
           :alt: Output from the previous code.
658
           :include-source:
659

660
           # You can set a color for all the bars.
661

662
           from qiskit import QuantumCircuit
663
           from qiskit.quantum_info import Statevector
664
           from qiskit.visualization import plot_state_paulivec
665

666
           qc = QuantumCircuit(2)
667
           qc.h(0)
668
           qc.cx(0, 1)
669

670
           state = Statevector(qc)
671
           plot_state_paulivec(state, color='midnightblue', title="New PauliVec plot")
672

673
        .. plot::
674
           :alt: Output from the previous code.
675
           :include-source:
676

677
           # If you introduce a list with less colors than bars, the color of the bars will
678
           # alternate following the sequence from the list.
679

680
           import numpy as np
681
           from qiskit.quantum_info import DensityMatrix
682
           from qiskit import QuantumCircuit
683
           from qiskit.visualization import plot_state_paulivec
684

685
           qc = QuantumCircuit(2)
686
           qc.h(0)
687
           qc.cx(0, 1)
688

689
           qc = QuantumCircuit(2)
690
           qc.h([0, 1])
691
           qc.cz(0, 1)
692
           qc.ry(np.pi/3, 0)
693
           qc.rx(np.pi/5, 1)
694

695
           matrix = DensityMatrix(qc)
696
           plot_state_paulivec(matrix, color=['crimson', 'midnightblue', 'seagreen'])
697
    """
698
    from matplotlib import pyplot as plt
×
699

700
    labels, values = _paulivec_data(state)
×
701
    numelem = len(values)
×
702

703
    if figsize is None:
×
704
        figsize = (7, 5)
×
705
    if color is None:
×
706
        color = "#648fff"
×
707

708
    ind = np.arange(numelem)  # the x locations for the groups
×
709
    width = 0.5  # the width of the bars
×
710
    if ax is None:
×
711
        return_fig = True
×
712
        fig, ax = plt.subplots(figsize=figsize)
×
713
    else:
714
        return_fig = False
×
715
        fig = ax.get_figure()
×
716
    ax.grid(zorder=0, linewidth=1, linestyle="--")
×
717
    ax.bar(ind, values, width, color=color, zorder=2)
×
718
    ax.axhline(linewidth=1, color="k")
×
719
    # add some text for labels, title, and axes ticks
720
    ax.set_ylabel("Coefficients", fontsize=14)
×
721
    ax.set_xticks(ind)
×
722
    ax.set_yticks([-1, -0.5, 0, 0.5, 1])
×
723
    ax.set_xticklabels(labels, fontsize=14, rotation=70)
×
724
    ax.set_xlabel("Pauli", fontsize=14)
×
725
    ax.set_ylim([-1, 1])
×
726
    ax.set_facecolor("#eeeeee")
×
727
    for tick in ax.xaxis.get_major_ticks() + ax.yaxis.get_major_ticks():
×
728
        tick.label1.set_fontsize(14)
×
729
    ax.set_title(title, fontsize=16)
×
730
    if return_fig:
×
731
        matplotlib_close_if_inline(fig)
×
732
    if filename is None:
×
733
        try:
×
734
            fig.tight_layout()
×
735
        except AttributeError:
×
736
            pass
×
737
        return fig
×
738
    else:
739
        return fig.savefig(filename)
×
740

741

742
def n_choose_k(n, k):
1✔
743
    """Return the number of combinations for n choose k.
744

745
    Args:
746
        n (int): the total number of options .
747
        k (int): The number of elements.
748

749
    Returns:
750
        int: returns the binomial coefficient
751
    """
752
    if n == 0:
1✔
753
        return 0
1✔
754
    return reduce(lambda x, y: x * y[0] / y[1], zip(range(n - k + 1, n + 1), range(1, k + 1)), 1)
1✔
755

756

757
def lex_index(n, k, lst):
1✔
758
    """Return  the lex index of a combination..
759

760
    Args:
761
        n (int): the total number of options .
762
        k (int): The number of elements.
763
        lst (list): list
764

765
    Returns:
766
        int: returns int index for lex order
767

768
    Raises:
769
        VisualizationError: if length of list is not equal to k
770
    """
771
    if len(lst) != k:
1✔
772
        raise VisualizationError("list should have length k")
×
773
    comb = [n - 1 - x for x in lst]
1✔
774
    dualm = sum(n_choose_k(comb[k - 1 - i], i + 1) for i in range(k))
1✔
775
    return int(dualm)
1✔
776

777

778
def bit_string_index(s):
1✔
779
    """Return the index of a string of 0s and 1s."""
780
    n = len(s)
1✔
781
    k = s.count("1")
1✔
782
    if s.count("0") != n - k:
1✔
783
        raise VisualizationError("s must be a string of 0 and 1")
×
784
    ones = [pos for pos, char in enumerate(s) if char == "1"]
1✔
785
    return lex_index(n, k, ones)
1✔
786

787

788
def phase_to_rgb(complex_number):
1✔
789
    """Map a phase of a complexnumber to a color in (r,g,b).
790

791
    complex_number is phase is first mapped to angle in the range
792
    [0, 2pi] and then to the HSL color wheel
793
    """
794
    angles = (np.angle(complex_number) + (np.pi * 5 / 4)) % (np.pi * 2)
1✔
795
    rgb = colorsys.hls_to_rgb(angles / (np.pi * 2), 0.5, 0.5)
1✔
796
    return rgb
1✔
797

798

799
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
800
@_optionals.HAS_SEABORN.require_in_call
1✔
801
def plot_state_qsphere(
1✔
802
    state,
803
    figsize=None,
804
    ax=None,
805
    show_state_labels=True,
806
    show_state_phases=False,
807
    use_degrees=False,
808
    *,
809
    filename=None,
810
):
811
    """Plot the qsphere representation of a quantum state.
812
    Here, the size of the points is proportional to the probability
813
    of the corresponding term in the state and the color represents
814
    the phase.
815

816
    Args:
817
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
818
        figsize (tuple): Figure size in inches.
819
        ax (matplotlib.axes.Axes): An optional Axes object to be used for
820
            the visualization output. If none is specified a new matplotlib
821
            Figure will be created and used. Additionally, if specified there
822
            will be no returned Figure since it is redundant.
823
        show_state_labels (bool): An optional boolean indicating whether to
824
            show labels for each basis state.
825
        show_state_phases (bool): An optional boolean indicating whether to
826
            show the phase for each basis state.
827
        use_degrees (bool): An optional boolean indicating whether to use
828
            radians or degrees for the phase values in the plot.
829

830
    Returns:
831
        :class:`matplotlib:matplotlib.figure.Figure` :
832
            A matplotlib figure instance if the ``ax`` kwarg is not set
833

834
    Raises:
835
        MissingOptionalLibraryError: Requires matplotlib.
836
        VisualizationError: Input is not a valid N-qubit state.
837

838
        QiskitError: Input statevector does not have valid dimensions.
839

840
    Examples:
841
        .. plot::
842
           :alt: Output from the previous code.
843
           :include-source:
844

845
           from qiskit import QuantumCircuit
846
           from qiskit.quantum_info import Statevector
847
           from qiskit.visualization import plot_state_qsphere
848

849
           qc = QuantumCircuit(2)
850
           qc.h(0)
851
           qc.cx(0, 1)
852

853
           state = Statevector(qc)
854
           plot_state_qsphere(state)
855

856
        .. plot::
857
           :alt: Output from the previous code.
858
           :include-source:
859

860
           # You can show the phase of each state and use
861
           # degrees instead of radians
862

863
           from qiskit.quantum_info import DensityMatrix
864
           import numpy as np
865
           from qiskit import QuantumCircuit
866
           from qiskit.visualization import plot_state_qsphere
867

868
           qc = QuantumCircuit(2)
869
           qc.h([0, 1])
870
           qc.cz(0,1)
871
           qc.ry(np.pi/3, 0)
872
           qc.rx(np.pi/5, 1)
873
           qc.z(1)
874

875
           matrix = DensityMatrix(qc)
876
           plot_state_qsphere(matrix,
877
                show_state_phases = True, use_degrees = True)
878
    """
879
    from matplotlib import gridspec
1✔
880
    from matplotlib import pyplot as plt
1✔
881
    from matplotlib.patches import Circle
1✔
882
    import seaborn as sns
1✔
883
    from scipy import linalg
1✔
884
    from .bloch import Arrow3D
1✔
885

886
    rho = DensityMatrix(state)
1✔
887
    num = rho.num_qubits
1✔
888
    if num is None:
1✔
889
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
890
    # get the eigenvectors and eigenvalues
891
    eigvals, eigvecs = linalg.eigh(rho.data)
1✔
892

893
    if figsize is None:
1✔
894
        figsize = (7, 7)
1✔
895

896
    if ax is None:
1✔
897
        return_fig = True
1✔
898
        fig = plt.figure(figsize=figsize)
1✔
899
    else:
900
        return_fig = False
×
901
        fig = ax.get_figure()
×
902

903
    gs = gridspec.GridSpec(nrows=3, ncols=3)
1✔
904

905
    ax = fig.add_subplot(gs[0:3, 0:3], projection="3d")
1✔
906
    ax.axes.set_xlim3d(-1.0, 1.0)
1✔
907
    ax.axes.set_ylim3d(-1.0, 1.0)
1✔
908
    ax.axes.set_zlim3d(-1.0, 1.0)
1✔
909
    ax.axes.grid(False)
1✔
910
    ax.view_init(elev=5, azim=275)
1✔
911

912
    # Force aspect ratio
913
    # MPL 3.2 or previous do not have set_box_aspect
914
    if hasattr(ax.axes, "set_box_aspect"):
1✔
915
        ax.axes.set_box_aspect((1, 1, 1))
1✔
916

917
    # start the plotting
918
    # Plot semi-transparent sphere
919
    u = np.linspace(0, 2 * np.pi, 25)
1✔
920
    v = np.linspace(0, np.pi, 25)
1✔
921
    x = np.outer(np.cos(u), np.sin(v))
1✔
922
    y = np.outer(np.sin(u), np.sin(v))
1✔
923
    z = np.outer(np.ones(np.size(u)), np.cos(v))
1✔
924
    ax.plot_surface(
1✔
925
        x, y, z, rstride=1, cstride=1, color=plt.rcParams["grid.color"], alpha=0.2, linewidth=0
926
    )
927

928
    # Get rid of the panes
929
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
930
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
931
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
932

933
    # Get rid of the spines
934
    ax.xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
935
    ax.yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
936
    ax.zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
937

938
    # Get rid of the ticks
939
    ax.set_xticks([])
1✔
940
    ax.set_yticks([])
1✔
941
    ax.set_zticks([])
1✔
942

943
    # traversing the eigvals/vecs backward as sorted low->high
944
    for idx in range(eigvals.shape[0] - 1, -1, -1):
1✔
945
        if eigvals[idx] > 0.001:
1✔
946
            # get the max eigenvalue
947
            state = eigvecs[:, idx]
1✔
948
            # Rounding to 13 decimals ignores machine epsilon noise (~1e-16)
949
            # from the solver, ensuring 'argmax' finds the true analytical winner.
950
            loc = np.round(np.absolute(state), decimals=13).argmax()
1✔
951
            # remove the global phase from max element
952
            angles = (np.angle(state[loc]) + 2 * np.pi) % (2 * np.pi)
1✔
953
            angleset = np.exp(-1j * angles)
1✔
954
            state = angleset * state
1✔
955

956
            d = num
1✔
957
            for i in range(2**num):
1✔
958
                # get x,y,z points
959
                element = bin(i)[2:].zfill(num)
1✔
960
                weight = element.count("1")
1✔
961
                zvalue = -2 * weight / d + 1
1✔
962
                number_of_divisions = n_choose_k(d, weight)
1✔
963
                weight_order = bit_string_index(element)
1✔
964
                angle = (float(weight) / d) * (np.pi * 2) + (
1✔
965
                    weight_order * 2 * (np.pi / number_of_divisions)
966
                )
967

968
                if (weight > d / 2) or (
1✔
969
                    (weight == d / 2) and (weight_order >= number_of_divisions / 2)
970
                ):
971
                    angle = np.pi - angle - (2 * np.pi / number_of_divisions)
1✔
972

973
                xvalue = np.sqrt(1 - zvalue**2) * np.cos(angle)
1✔
974
                yvalue = np.sqrt(1 - zvalue**2) * np.sin(angle)
1✔
975

976
                # get prob and angle - prob will be shade and angle color
977
                prob = np.real(np.dot(state[i], state[i].conj()))
1✔
978
                prob = min(prob, 1)  # See https://github.com/Qiskit/qiskit-terra/issues/4666
1✔
979
                colorstate = phase_to_rgb(state[i])
1✔
980

981
                alfa = 1
1✔
982
                if yvalue >= 0.1:
1✔
983
                    alfa = 1.0 - yvalue
1✔
984

985
                if not np.isclose(prob, 0) and show_state_labels:
1✔
986
                    rprime = 1.3
1✔
987
                    angle_theta = np.arctan2(np.sqrt(1 - zvalue**2), zvalue)
1✔
988
                    xvalue_text = rprime * np.sin(angle_theta) * np.cos(angle)
1✔
989
                    yvalue_text = rprime * np.sin(angle_theta) * np.sin(angle)
1✔
990
                    zvalue_text = rprime * np.cos(angle_theta)
1✔
991
                    element_text = "$\\vert" + element + "\\rangle$"
1✔
992
                    if show_state_phases:
1✔
993
                        element_angle = (np.angle(state[i]) + (np.pi * 4)) % (np.pi * 2)
×
994
                        if use_degrees:
×
995
                            element_text += f"\n${element_angle * 180 / np.pi:.1f}^\\circ$"
×
996
                        else:
997
                            element_angle = pi_check(element_angle, ndigits=3).replace("pi", "\\pi")
×
998
                            element_text += f"\n${element_angle}$"
×
999
                    ax.text(
1✔
1000
                        xvalue_text,
1001
                        yvalue_text,
1002
                        zvalue_text,
1003
                        element_text,
1004
                        ha="center",
1005
                        va="center",
1006
                        size=12,
1007
                    )
1008

1009
                ax.plot(
1✔
1010
                    [xvalue],
1011
                    [yvalue],
1012
                    [zvalue],
1013
                    markerfacecolor=colorstate,
1014
                    markeredgecolor=colorstate,
1015
                    marker="o",
1016
                    markersize=np.sqrt(prob) * 30,
1017
                    alpha=alfa,
1018
                )
1019

1020
                a = Arrow3D(
1✔
1021
                    [0, xvalue],
1022
                    [0, yvalue],
1023
                    [0, zvalue],
1024
                    mutation_scale=20,
1025
                    alpha=prob,
1026
                    arrowstyle="-",
1027
                    color=colorstate,
1028
                    lw=2,
1029
                )
1030
                ax.add_artist(a)
1✔
1031

1032
            # add weight lines
1033
            for weight in range(d + 1):
1✔
1034
                theta = np.linspace(-2 * np.pi, 2 * np.pi, 100)
1✔
1035
                z = -2 * weight / d + 1
1✔
1036
                r = np.sqrt(1 - z**2)
1✔
1037
                x = r * np.cos(theta)
1✔
1038
                y = r * np.sin(theta)
1✔
1039
                ax.plot(x, y, z, color=(0.5, 0.5, 0.5), lw=1, ls=":", alpha=0.5)
1✔
1040

1041
            # add center point
1042
            ax.plot(
1✔
1043
                [0],
1044
                [0],
1045
                [0],
1046
                markerfacecolor=(0.5, 0.5, 0.5),
1047
                markeredgecolor=(0.5, 0.5, 0.5),
1048
                marker="o",
1049
                markersize=3,
1050
                alpha=1,
1051
            )
1052
        else:
1053
            break
1✔
1054

1055
    n = 64
1✔
1056
    theta = np.ones(n)
1✔
1057
    colors = sns.hls_palette(n)
1✔
1058

1059
    ax2 = fig.add_subplot(gs[2:, 2:])
1✔
1060
    ax2.pie(theta, colors=colors[5 * n // 8 :] + colors[: 5 * n // 8], radius=0.75)
1✔
1061
    ax2.add_artist(Circle((0, 0), 0.5, color="white", zorder=1))
1✔
1062
    offset = 0.95  # since radius of sphere is one.
1✔
1063

1064
    if use_degrees:
1✔
1065
        labels = ["Phase\n(Deg)", "0", "90", "180   ", "270"]
×
1066
    else:
1067
        labels = ["Phase", "$0$", "$\\pi/2$", "$\\pi$", "$3\\pi/2$"]
1✔
1068

1069
    ax2.text(0, 0, labels[0], horizontalalignment="center", verticalalignment="center", fontsize=14)
1✔
1070
    ax2.text(
1✔
1071
        offset, 0, labels[1], horizontalalignment="center", verticalalignment="center", fontsize=14
1072
    )
1073
    ax2.text(
1✔
1074
        0, offset, labels[2], horizontalalignment="center", verticalalignment="center", fontsize=14
1075
    )
1076
    ax2.text(
1✔
1077
        -offset, 0, labels[3], horizontalalignment="center", verticalalignment="center", fontsize=14
1078
    )
1079
    ax2.text(
1✔
1080
        0, -offset, labels[4], horizontalalignment="center", verticalalignment="center", fontsize=14
1081
    )
1082

1083
    if return_fig:
1✔
1084
        matplotlib_close_if_inline(fig)
1✔
1085
    if filename is None:
1✔
1086
        return fig
1✔
1087
    else:
1088
        return fig.savefig(filename)
×
1089

1090

1091
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
1092
def generate_facecolors(x, y, z, dx, dy, dz, color):
1✔
1093
    """Generates shaded facecolors for shaded bars.
1094

1095
    This is here to work around a Matplotlib bug
1096
    where alpha does not work in Bar3D.
1097

1098
    Args:
1099
        x (array_like): The x- coordinates of the anchor point of the bars.
1100
        y (array_like): The y- coordinates of the anchor point of the bars.
1101
        z (array_like): The z- coordinates of the anchor point of the bars.
1102
        dx (array_like): Width of bars.
1103
        dy (array_like): Depth of bars.
1104
        dz (array_like): Height of bars.
1105
        color (array_like): sequence of valid color specifications, optional
1106
    Returns:
1107
        list: Shaded colors for bars.
1108
    Raises:
1109
        MissingOptionalLibraryError: If matplotlib is not installed
1110
    """
1111
    import matplotlib.colors as mcolors
×
1112

1113
    cuboid = np.array(
×
1114
        [
1115
            # -z
1116
            (
1117
                (0, 0, 0),
1118
                (0, 1, 0),
1119
                (1, 1, 0),
1120
                (1, 0, 0),
1121
            ),
1122
            # +z
1123
            (
1124
                (0, 0, 1),
1125
                (1, 0, 1),
1126
                (1, 1, 1),
1127
                (0, 1, 1),
1128
            ),
1129
            # -y
1130
            (
1131
                (0, 0, 0),
1132
                (1, 0, 0),
1133
                (1, 0, 1),
1134
                (0, 0, 1),
1135
            ),
1136
            # +y
1137
            (
1138
                (0, 1, 0),
1139
                (0, 1, 1),
1140
                (1, 1, 1),
1141
                (1, 1, 0),
1142
            ),
1143
            # -x
1144
            (
1145
                (0, 0, 0),
1146
                (0, 0, 1),
1147
                (0, 1, 1),
1148
                (0, 1, 0),
1149
            ),
1150
            # +x
1151
            (
1152
                (1, 0, 0),
1153
                (1, 1, 0),
1154
                (1, 1, 1),
1155
                (1, 0, 1),
1156
            ),
1157
        ]
1158
    )
1159

1160
    # indexed by [bar, face, vertex, coord]
1161
    polys = np.empty(x.shape + cuboid.shape)
×
1162
    # handle each coordinate separately
1163
    for i, p, dp in [(0, x, dx), (1, y, dy), (2, z, dz)]:
×
1164
        p = p[..., np.newaxis, np.newaxis]
×
1165
        dp = dp[..., np.newaxis, np.newaxis]
×
1166
        polys[..., i] = p + dp * cuboid[..., i]
×
1167

1168
    # collapse the first two axes
1169
    polys = polys.reshape((-1,) + polys.shape[2:])
×
1170

1171
    facecolors = []
×
1172
    if len(color) == len(x):
×
1173
        # bar colors specified, need to expand to number of faces
1174
        for c in color:
×
1175
            facecolors.extend([c] * 6)
×
1176
    else:
1177
        # a single color specified, or face colors specified explicitly
1178
        facecolors = list(mcolors.to_rgba_array(color))
×
1179
        if len(facecolors) < len(x):
×
1180
            facecolors *= 6 * len(x)
×
1181

1182
    normals = _generate_normals(polys)
×
1183
    return _shade_colors(facecolors, normals)
×
1184

1185

1186
def _generate_normals(polygons):
1✔
1187
    """Takes a list of polygons and return an array of their normals.
1188

1189
    Normals point towards the viewer for a face with its vertices in
1190
    counterclockwise order, following the right hand rule.
1191
    Uses three points equally spaced around the polygon.
1192
    This normal of course might not make sense for polygons with more than
1193
    three points not lying in a plane, but it's a plausible and fast
1194
    approximation.
1195

1196
    Args:
1197
        polygons (list): list of (M_i, 3) array_like, or (..., M, 3) array_like
1198
            A sequence of polygons to compute normals for, which can have
1199
            varying numbers of vertices. If the polygons all have the same
1200
            number of vertices and array is passed, then the operation will
1201
            be vectorized.
1202
    Returns:
1203
        normals: (..., 3) array_like
1204
            A normal vector estimated for the polygon.
1205
    """
1206
    if isinstance(polygons, np.ndarray):
×
1207
        # optimization: polygons all have the same number of points, so can
1208
        # vectorize
1209
        n = polygons.shape[-2]
×
1210
        i1, i2, i3 = 0, n // 3, 2 * n // 3
×
1211
        v1 = polygons[..., i1, :] - polygons[..., i2, :]
×
1212
        v2 = polygons[..., i2, :] - polygons[..., i3, :]
×
1213
    else:
1214
        # The subtraction doesn't vectorize because polygons is jagged.
1215
        v1 = np.empty((len(polygons), 3))
×
1216
        v2 = np.empty((len(polygons), 3))
×
1217
        for poly_i, ps in enumerate(polygons):
×
1218
            n = len(ps)
×
1219
            i1, i2, i3 = 0, n // 3, 2 * n // 3
×
1220
            v1[poly_i, :] = ps[i1, :] - ps[i2, :]
×
1221
            v2[poly_i, :] = ps[i2, :] - ps[i3, :]
×
1222

1223
    return np.cross(v1, v2)
×
1224

1225

1226
def _shade_colors(color, normals, lightsource=None):
1✔
1227
    """
1228
    Shade *color* using normal vectors given by *normals*.
1229
    *color* can also be an array of the same length as *normals*.
1230
    """
1231
    from matplotlib.colors import Normalize, LightSource
×
1232
    import matplotlib.colors as mcolors
×
1233

1234
    if lightsource is None:
×
1235
        # chosen for backwards-compatibility
1236
        lightsource = LightSource(azdeg=225, altdeg=19.4712)
×
1237

1238
    def mod(v):
×
1239
        return np.sqrt(v[0] ** 2 + v[1] ** 2 + v[2] ** 2)
×
1240

1241
    shade = np.array(
×
1242
        [np.dot(n / mod(n), lightsource.direction) if mod(n) else np.nan for n in normals]
1243
    )
1244
    mask = ~np.isnan(shade)
×
1245

1246
    if mask.any():
×
1247
        norm = Normalize(min(shade[mask]), max(shade[mask]))
×
1248
        shade[~mask] = min(shade[mask])
×
1249
        color = mcolors.to_rgba_array(color)
×
1250
        # shape of color should be (M, 4) (where M is number of faces)
1251
        # shape of shade should be (M,)
1252
        # colors should have final shape of (M, 4)
1253
        alpha = color[:, 3]
×
1254
        colors = (0.5 + norm(shade)[:, np.newaxis] * 0.5) * color
×
1255
        colors[:, 3] = alpha
×
1256
    else:
1257
        colors = np.asanyarray(color).copy()
×
1258

1259
    return colors
×
1260

1261

1262
def state_to_latex(
1✔
1263
    state: Statevector | DensityMatrix, dims: bool | None = None, convention: str = "ket", **args
1264
) -> str:
1265
    """Return a Latex representation of a state. Wrapper function
1266
    for `qiskit.visualization.array_to_latex` for convention 'vector'.
1267
    Adds dims if necessary.
1268
    Intended for use within `state_drawer`.
1269

1270
    Args:
1271
        state: State to be drawn
1272
        dims (bool): Whether to display the state's `dims`
1273
        convention (str): Either 'vector' or 'ket'. For 'ket' plot the state in the ket-notation.
1274
                Otherwise plot as a vector
1275
        **args: Arguments to be passed directly to `array_to_latex` for convention 'ket'
1276

1277
    Returns:
1278
        Latex representation of the state
1279
        MissingOptionalLibrary: If SymPy isn't installed and ``'latex'`` or
1280
            ``'latex_source'`` is selected for ``output``.
1281

1282
    """
1283
    if dims is None:  # show dims if state is not only qubits
1✔
1284
        if set(state.dims()) == {2}:
1✔
1285
            dims = False
1✔
1286
        else:
1287
            dims = True
×
1288

1289
    prefix = ""
1✔
1290
    suffix = ""
1✔
1291
    if dims:
1✔
1292
        prefix = "\\begin{align}\n"
×
1293
        dims_str = state._op_shape.dims_l()
×
1294
        suffix = f"\\\\\n\\text{{dims={dims_str}}}\n\\end{{align}}"
×
1295

1296
    operator_shape = state._op_shape
1✔
1297
    # we only use the ket convention for qubit statevectors
1298
    # this means the operator shape should have no input dimensions and all output dimensions equal to 2
1299
    is_qubit_statevector = len(operator_shape.dims_r()) == 0 and set(operator_shape.dims_l()) == {2}
1✔
1300
    if convention == "ket" and is_qubit_statevector:
1✔
1301
        latex_str = _state_to_latex_ket(state._data, **args)
1✔
1302
    else:
1303
        latex_str = array_to_latex(state._data, source=True, **args)
1✔
1304
    return prefix + latex_str + suffix
1✔
1305

1306

1307
def _numbers_to_latex_terms(numbers: list[complex], decimals: int = 10) -> list[str]:
1✔
1308
    """Convert a list of numbers to latex formatted terms
1309

1310
    The first non-zero term is treated differently. For this term a leading + is suppressed.
1311

1312
    Args:
1313
        numbers: List of numbers to format
1314
        decimals: Number of decimal places to round to (default: 10).
1315
    Returns:
1316
        List of formatted terms
1317
    """
1318
    first_term = True
1✔
1319
    terms = []
1✔
1320
    for number in numbers:
1✔
1321
        term = _num_to_latex(number, decimals=decimals, first_term=first_term, coefficient=True)
1✔
1322
        terms.append(term)
1✔
1323
        first_term = False
1✔
1324
    return terms
1✔
1325

1326

1327
def _state_to_latex_ket(
1✔
1328
    data: list[complex], max_size: int = 12, prefix: str = "", decimals: int = 10
1329
) -> str:
1330
    """Convert state vector to latex representation
1331

1332
    Args:
1333
        data: State vector
1334
        max_size: Maximum number of non-zero terms in the expression. If the number of
1335
                 non-zero terms is larger than the max_size, then the representation is truncated.
1336
        prefix: Latex string to be prepended to the latex, intended for labels.
1337
        decimals: Number of decimal places to round to (default: 10).
1338

1339
    Returns:
1340
        String with LaTeX representation of the state vector
1341
    """
1342
    num = int(math.log2(len(data)))
1✔
1343

1344
    def ket_name(i):
1✔
1345
        return bin(i)[2:].zfill(num)
1✔
1346

1347
    data = np.around(data, decimals)
1✔
1348
    nonzero_indices = np.where(data != 0)[0].tolist()
1✔
1349
    if len(nonzero_indices) > max_size:
1✔
1350
        nonzero_indices = (
1✔
1351
            nonzero_indices[: max_size // 2] + [0] + nonzero_indices[-max_size // 2 + 1 :]
1352
        )
1353
        latex_terms = _numbers_to_latex_terms(data[nonzero_indices], decimals)
1✔
1354
        nonzero_indices[max_size // 2] = None
1✔
1355
    else:
1356
        latex_terms = _numbers_to_latex_terms(data[nonzero_indices], decimals)
1✔
1357

1358
    latex_str = ""
1✔
1359
    for idx, ket_idx in enumerate(nonzero_indices):
1✔
1360
        if ket_idx is None:
1✔
1361
            latex_str += r" + \ldots "
1✔
1362
        else:
1363
            term = latex_terms[idx]
1✔
1364
            ket = ket_name(ket_idx)
1✔
1365
            latex_str += f"{term} |{ket}\\rangle"
1✔
1366
    return prefix + latex_str
1✔
1367

1368

1369
class TextMatrix:
1✔
1370
    """Text representation of an array, with `__str__` method so it
1371
    displays nicely in Jupyter notebooks"""
1372

1373
    def __init__(self, state, max_size=8, dims=None, prefix="", suffix=""):
1✔
1374
        self.state = state
1✔
1375
        self.max_size = max_size
1✔
1376
        if dims is None:  # show dims if state is not only qubits
1✔
1377
            if (isinstance(state, (Statevector, DensityMatrix)) and set(state.dims()) == {2}) or (
1✔
1378
                isinstance(state, Operator)
1379
                and len(state.input_dims()) == len(state.output_dims())
1380
                and set(state.input_dims()) == set(state.output_dims()) == {2}
1381
            ):
1382
                dims = False
1✔
1383
            else:
1384
                dims = True
×
1385
        self.dims = dims
1✔
1386
        self.prefix = prefix
1✔
1387
        self.suffix = suffix
1✔
1388
        if isinstance(max_size, int):
1✔
1389
            self.max_size = max_size
1✔
1390
        elif isinstance(state, DensityMatrix):
×
1391
            # density matrices are square, so threshold for
1392
            # summarization is shortest side squared
1393
            self.max_size = min(max_size) ** 2
×
1394
        else:
1395
            self.max_size = max_size[0]
×
1396

1397
    def __str__(self):
1✔
1398
        threshold = self.max_size
×
1399
        data = np.array2string(
×
1400
            self.state._data, prefix=self.prefix, threshold=threshold, separator=","
1401
        )
1402
        dimstr = ""
×
1403
        if self.dims:
×
1404
            data += ",\n"
×
1405
            dimstr += " " * len(self.prefix)
×
1406
            if isinstance(self.state, (Statevector, DensityMatrix)):
×
1407
                dimstr += f"dims={self.state._op_shape.dims_l()}"
×
1408
            else:
1409
                dimstr += f"input_dims={self.state.input_dims()}, "
×
1410
                dimstr += f"output_dims={self.state.output_dims()}"
×
1411

1412
        return self.prefix + data + dimstr + self.suffix
×
1413

1414
    def __repr__(self):
1415
        return self.__str__()
1416

1417

1418
def state_drawer(state, output=None, **drawer_args):
1✔
1419
    """Returns a visualization of the state.
1420

1421
    **repr**: ASCII TextMatrix of the state's ``_repr_``.
1422

1423
    **text**: ASCII TextMatrix that can be printed in the console.
1424

1425
    **latex**: An IPython Latex object for displaying in Jupyter Notebooks.
1426

1427
    **latex_source**: Raw, uncompiled ASCII source to generate array using LaTeX.
1428

1429
    **qsphere**: Matplotlib figure, rendering of statevector using `plot_state_qsphere()`.
1430

1431
    **hinton**: Matplotlib figure, rendering of statevector using `plot_state_hinton()`.
1432

1433
    **bloch**: Matplotlib figure, rendering of statevector using `plot_bloch_multivector()`.
1434

1435
    **city**: Matplotlib figure, rendering of statevector using `plot_state_city()`.
1436

1437
    **paulivec**: Matplotlib figure, rendering of statevector using `plot_state_paulivec()`.
1438

1439
    Args:
1440
        output (str): Select the output method to use for drawing the
1441
            circuit. Valid choices are ``text``, ``latex``, ``latex_source``,
1442
            ``qsphere``, ``hinton``, ``bloch``, ``city`` or ``paulivec``.
1443
            Default is `'text`'.
1444
        drawer_args: Arguments to be passed to the relevant drawer. For
1445
            'latex' and 'latex_source' see ``array_to_latex``
1446

1447
    Returns:
1448
        :class:`matplotlib.figure` or :class:`str` or
1449
        :class:`TextMatrix` or :class:`IPython.display.Latex`:
1450
        Drawing of the state.
1451

1452
    Raises:
1453
        MissingOptionalLibraryError: when `output` is `latex` and IPython is not installed.
1454
            or if SymPy isn't installed and ``'latex'`` or ``'latex_source'`` is selected for
1455
            ``output``.
1456

1457
        ValueError: when `output` is not a valid selection.
1458
    """
1459
    config = user_config.get_config()
1✔
1460
    # Get default 'output' from config file else use 'repr'
1461
    default_output = "repr"
1✔
1462
    if output is None:
1✔
1463
        if config:
×
1464
            default_output = config.get("state_drawer", "repr")
×
1465
        output = default_output
×
1466
    output = output.lower()
1✔
1467

1468
    # Choose drawing backend:
1469
    drawers = {
1✔
1470
        "text": TextMatrix,
1471
        "latex_source": state_to_latex,
1472
        "qsphere": plot_state_qsphere,
1473
        "hinton": plot_state_hinton,
1474
        "bloch": plot_bloch_multivector,
1475
        "city": plot_state_city,
1476
        "paulivec": plot_state_paulivec,
1477
    }
1478
    if output == "latex":
1✔
1479
        _optionals.HAS_IPYTHON.require_now("state_drawer")
1✔
1480
        from IPython.display import Latex
1✔
1481

1482
        draw_func = drawers["latex_source"]
1✔
1483
        return Latex(f"$${draw_func(state, **drawer_args)}$$")
1✔
1484

1485
    if output == "repr":
1✔
1486
        return state.__repr__()
1✔
1487

1488
    try:
1✔
1489
        draw_func = drawers[output]
1✔
1490
        return draw_func(state, **drawer_args)
1✔
1491
    except KeyError as err:
×
1492
        raise ValueError(
×
1493
            f"""'{output}' is not a valid option for drawing {type(state).__name__}
1494
             objects. Please choose from:
1495
            'text', 'latex', 'latex_source', 'qsphere', 'hinton',
1496
            'bloch', 'city' or 'paulivec'."""
1497
        ) from err
1498

1499

1500
def _bloch_multivector_data(state):
1✔
1501
    """Return list of Bloch vectors for each qubit
1502

1503
    Args:
1504
        state (DensityMatrix or Statevector): an N-qubit state.
1505

1506
    Returns:
1507
        list: list of Bloch vectors (x, y, z) for each qubit.
1508

1509
    Raises:
1510
        VisualizationError: if input is not an N-qubit state.
1511
    """
1512
    rho = DensityMatrix(state)
1✔
1513
    num = rho.num_qubits
1✔
1514
    if num is None:
1✔
1515
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
1516
    pauli_singles = PauliList(["X", "Y", "Z"])
1✔
1517
    bloch_data = []
1✔
1518
    for i in range(num):
1✔
1519
        if num > 1:
1✔
1520
            paulis = PauliList.from_symplectic(
1✔
1521
                np.zeros((3, (num - 1)), dtype=bool), np.zeros((3, (num - 1)), dtype=bool)
1522
            ).insert(i, pauli_singles, qubit=True)
1523
        else:
1524
            paulis = pauli_singles
×
1525
        bloch_state = [np.real(np.trace(np.dot(mat, rho.data))) for mat in paulis.matrix_iter()]
1✔
1526
        bloch_data.append(bloch_state)
1✔
1527
    return bloch_data
1✔
1528

1529

1530
def _paulivec_data(state):
1✔
1531
    """Return paulivec data for plotting.
1532

1533
    Args:
1534
        state (DensityMatrix or Statevector): an N-qubit state.
1535

1536
    Returns:
1537
        tuple: (labels, values) for Pauli vector.
1538

1539
    Raises:
1540
        VisualizationError: if input is not an N-qubit state.
1541
    """
1542
    rho = SparsePauliOp.from_operator(DensityMatrix(state))
1✔
1543
    if rho.num_qubits is None:
1✔
1544
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
1545
    return rho.paulis.to_labels(), np.real(rho.coeffs * 2**rho.num_qubits)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc