• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Qiskit / qiskit / 25626761124

10 May 2026 10:49AM UTC coverage: 87.667% (+0.2%) from 87.467%
25626761124

Pull #15974

github

web-flow
Merge 3036ed9bb into c25216340
Pull Request #15974: [WIP] Extend litinski transformation for all Pauli rotations

592 of 608 new or added lines in 4 files covered. (97.37%)

1404 existing lines in 39 files now uncovered.

107556 of 122687 relevant lines covered (87.67%)

956445.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

57.5
/qiskit/visualization/state_visualization.py
1
# This code is part of Qiskit.
2
#
3
# (C) Copyright IBM 2017, 2023.
4
#
5
# This code is licensed under the Apache License, Version 2.0. You may
6
# obtain a copy of this license in the LICENSE.txt file in the root directory
7
# of this source tree or at https://www.apache.org/licenses/LICENSE-2.0.
8
#
9
# Any modifications or derivative works of this code must retain this
10
# copyright notice, and modified files need to carry a notice indicating
11
# that they have been altered from the originals.
12

13

14
"""
15
Visualization functions for quantum states.
16
"""
17

18
import math
1✔
19

20
from functools import reduce
1✔
21
import colorsys
1✔
22

23
import numpy as np
1✔
24
from qiskit import user_config
1✔
25
from qiskit.quantum_info.states.statevector import Statevector
1✔
26
from qiskit.quantum_info.operators.operator import Operator
1✔
27
from qiskit.quantum_info.operators.symplectic import PauliList, SparsePauliOp
1✔
28
from qiskit.quantum_info.states.densitymatrix import DensityMatrix
1✔
29
from qiskit.utils import optionals as _optionals
1✔
30
from qiskit.circuit.tools.pi_check import pi_check
1✔
31

32
from .array import _num_to_latex, array_to_latex
1✔
33
from .utils import matplotlib_close_if_inline
1✔
34
from .exceptions import VisualizationError
1✔
35

36

37
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
38
def plot_state_hinton(state, title="", figsize=None, ax_real=None, ax_imag=None, *, filename=None):
1✔
39
    """Plot a hinton diagram for the density matrix of a quantum state.
40

41
    The hinton diagram represents the values of a matrix using
42
    squares, whose size indicate the magnitude of their corresponding value
43
    and their color, its sign. A white square means the value is positive and
44
    a black one means negative.
45

46
    Args:
47
        state (Statevector or DensityMatrix or ndarray): An N-qubit quantum state.
48
        title (str): a string that represents the plot title
49
        figsize (tuple): Figure size in inches.
50
        filename (str | None): The optional file path to save image to. If not specified
51
            no file is created for the visualization. If this is set the return
52
            from this function will be ``None``.
53
        ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
54
            the visualization output. If none is specified a new matplotlib
55
            Figure will be created and used. If this is specified without an
56
            ax_imag only the real component plot will be generated.
57
            Additionally, if specified there will be no returned Figure since
58
            it is redundant.
59
        ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
60
            the visualization output. If none is specified a new matplotlib
61
            Figure will be created and used. If this is specified without an
62
            ax_imag only the real component plot will be generated.
63
            Additionally, if specified there will be no returned Figure since
64
            it is redundant.
65

66
    Returns:
67
        :class:`matplotlib:matplotlib.figure.Figure` :
68
            The matplotlib.Figure of the visualization if
69
            neither ax_real or ax_imag is set.
70

71
    Raises:
72
        MissingOptionalLibraryError: Requires matplotlib.
73
        VisualizationError: Input is not a valid N-qubit state.
74

75
    Examples:
76
        .. plot::
77
           :alt: Output from the previous code.
78
           :include-source:
79

80
            import numpy as np
81
            from qiskit import QuantumCircuit
82
            from qiskit.quantum_info import DensityMatrix
83
            from qiskit.visualization import plot_state_hinton
84

85
            qc = QuantumCircuit(2)
86
            qc.h([0, 1])
87
            qc.cz(0,1)
88
            qc.ry(np.pi/3 , 0)
89
            qc.rx(np.pi/5, 1)
90

91
            state = DensityMatrix(qc)
92
            plot_state_hinton(state, title="New Hinton Plot")
93

94
    """
95
    from matplotlib import pyplot as plt
1✔
96

97
    # Figure data
98
    rho = DensityMatrix(state)
1✔
99
    num = rho.num_qubits
1✔
100
    if num is None:
1✔
101
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
102
    max_weight = 2 ** math.ceil(math.log2(np.abs(rho.data).max()))
1✔
103
    datareal = np.real(rho.data)
1✔
104
    dataimag = np.imag(rho.data)
1✔
105

106
    if figsize is None:
1✔
107
        figsize = (8, 5)
1✔
108
    if not ax_real and not ax_imag:
1✔
109
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
1✔
110
    else:
111
        if ax_real:
×
112
            fig = ax_real.get_figure()
×
113
        else:
114
            fig = ax_imag.get_figure()
×
115
        ax1 = ax_real
×
116
        ax2 = ax_imag
×
117
    # Reversal is to account for Qiskit's endianness.
118
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
1✔
119
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)][::-1]
1✔
120
    ly, lx = datareal.shape
1✔
121
    # Real
122
    if ax1:
1✔
123
        ax1.patch.set_facecolor("gray")
1✔
124
        ax1.set_aspect("equal", "box")
1✔
125
        ax1.xaxis.set_major_locator(plt.NullLocator())
1✔
126
        ax1.yaxis.set_major_locator(plt.NullLocator())
1✔
127

128
        for (x, y), w in np.ndenumerate(datareal):
1✔
129
            # Convert from matrix coordinates to plot coordinates.
130
            plot_x, plot_y = y, lx - x - 1
1✔
131
            color = "white" if w > 0 else "black"
1✔
132
            size = np.sqrt(np.abs(w) / max_weight)
1✔
133
            rect = plt.Rectangle(
1✔
134
                [0.5 + plot_x - size / 2, 0.5 + plot_y - size / 2],
135
                size,
136
                size,
137
                facecolor=color,
138
                edgecolor=color,
139
            )
140
            ax1.add_patch(rect)
1✔
141

142
        ax1.set_xticks(0.5 + np.arange(lx))
1✔
143
        ax1.set_yticks(0.5 + np.arange(ly))
1✔
144
        ax1.set_xlim([0, lx])
1✔
145
        ax1.set_ylim([0, ly])
1✔
146
        ax1.set_yticklabels(row_names, fontsize=14)
1✔
147
        ax1.set_xticklabels(column_names, fontsize=14, rotation=90)
1✔
148
        ax1.set_title("Re[$\\rho$]", fontsize=14)
1✔
149
    # Imaginary
150
    if ax2:
1✔
151
        ax2.patch.set_facecolor("gray")
1✔
152
        ax2.set_aspect("equal", "box")
1✔
153
        ax2.xaxis.set_major_locator(plt.NullLocator())
1✔
154
        ax2.yaxis.set_major_locator(plt.NullLocator())
1✔
155

156
        for (x, y), w in np.ndenumerate(dataimag):
1✔
157
            # Convert from matrix coordinates to plot coordinates.
158
            plot_x, plot_y = y, lx - x - 1
1✔
159
            color = "white" if w > 0 else "black"
1✔
160
            size = np.sqrt(np.abs(w) / max_weight)
1✔
161
            rect = plt.Rectangle(
1✔
162
                [0.5 + plot_x - size / 2, 0.5 + plot_y - size / 2],
163
                size,
164
                size,
165
                facecolor=color,
166
                edgecolor=color,
167
            )
168
            ax2.add_patch(rect)
1✔
169

170
        ax2.set_xticks(0.5 + np.arange(lx))
1✔
171
        ax2.set_yticks(0.5 + np.arange(ly))
1✔
172
        ax2.set_xlim([0, lx])
1✔
173
        ax2.set_ylim([0, ly])
1✔
174
        ax2.set_yticklabels(row_names, fontsize=14)
1✔
175
        ax2.set_xticklabels(column_names, fontsize=14, rotation=90)
1✔
176
        ax2.set_title("Im[$\\rho$]", fontsize=14)
1✔
177
    fig.tight_layout()
1✔
178
    if title:
1✔
179
        fig.suptitle(title, fontsize=16)
×
180
    if ax_real is None and ax_imag is None:
1✔
181
        matplotlib_close_if_inline(fig)
1✔
182
    if filename is None:
1✔
183
        return fig
1✔
184
    else:
185
        return fig.savefig(filename)
×
186

187

188
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
189
def plot_bloch_vector(
1✔
190
    bloch, title="", ax=None, figsize=None, coord_type="cartesian", font_size=None
191
):
192
    """Plot the Bloch sphere.
193

194
    Plot a Bloch sphere with the specified coordinates, that can be given in both
195
    cartesian and spherical systems.
196

197
    Args:
198
        bloch (tuple[float, float, float]): tuple of three elements where (<x>, <y>, <z>) (Cartesian)
199
            or (<r>, <theta>, <phi>) (spherical in radians)
200
            <theta> is inclination angle from +z direction
201
            <phi> is azimuth from +x direction
202
        title (str): a string that represents the plot title
203
        ax (matplotlib.axes.Axes): An Axes to use for rendering the bloch
204
            sphere
205
        figsize (tuple): Figure size in inches. Has no effect if passing ``ax``.
206
        coord_type (Literal["cartesian", "spherical"]): Either ``"cartesian"`` or ``"spherical"``
207
            depending on whether the input is given in Cartesian or spherical coordinates.
208
        font_size (float): Font size.
209

210
    Returns:
211
        :class:`matplotlib:matplotlib.figure.Figure` : A matplotlib figure instance if ``ax = None``.
212

213
    Raises:
214
        MissingOptionalLibraryError: Requires matplotlib.
215

216
    Examples:
217
        .. plot::
218
           :alt: Output from the previous code.
219
           :include-source:
220

221
           from qiskit.visualization import plot_bloch_vector
222

223
           plot_bloch_vector([0,1,0], title="New Bloch Sphere")
224

225
        .. plot::
226
           :alt: Output from the previous code.
227
           :include-source:
228

229
           import numpy as np
230
           from qiskit.visualization import plot_bloch_vector
231

232
           # You can use spherical coordinates instead of cartesian.
233

234
           plot_bloch_vector([1, np.pi/2, np.pi/3], coord_type='spherical')
235

236
    """
237
    from .bloch import Bloch
1✔
238

239
    if figsize is None:
1✔
240
        figsize = (5, 5)
1✔
241
    B = Bloch(axes=ax, font_size=font_size)
1✔
242
    if coord_type == "spherical":
1✔
243
        r, theta, phi = bloch
×
244
        bloch = (
×
245
            r * math.sin(theta) * math.cos(phi),
246
            r * math.sin(theta) * math.sin(phi),
247
            r * math.cos(theta),
248
        )
249
    B.add_vectors(bloch)
1✔
250
    B.render(title=title)
1✔
251
    if ax is None:
1✔
252
        fig = B.fig
×
253
        fig.set_size_inches(figsize[0], figsize[1])
×
UNCOV
254
        matplotlib_close_if_inline(fig)
×
UNCOV
255
        return fig
×
256
    return None
1✔
257

258

259
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
260
def plot_bloch_multivector(
1✔
261
    state,
262
    title="",
263
    figsize=None,
264
    *,
265
    reverse_bits=False,
266
    filename=None,
267
    font_size=None,
268
    title_font_size=None,
269
    title_pad=1,
270
):
271
    r"""Plot a Bloch sphere for each qubit.
272

273
    Each component :math:`(x,y,z)` of the Bloch sphere labeled as 'qubit i' represents the expected
274
    value of the corresponding Pauli operator acting only on that qubit, that is, the expected value
275
    of :math:`I_{N-1} \otimes\dotsb\otimes I_{i+1}\otimes P_i \otimes I_{i-1}\otimes\dotsb\otimes
276
    I_0`, where :math:`N` is the number of qubits, :math:`P\in \{X,Y,Z\}` and :math:`I` is the
277
    identity operator.
278

279
    Args:
280
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
281
        title (str): a string that represents the plot title
282
        figsize (tuple): size of each individual Bloch sphere figure, in inches.
283
        reverse_bits (bool): If True, plots qubits following Qiskit's convention [Default:False].
284
        font_size (float): Font size for the Bloch ball figures.
285
        title_font_size (float): Font size for the title.
286
        title_pad (float): Padding for the title (suptitle ``y`` position is ``0.98``
287
        and the image height will be extended by ``1 + title_pad/100``).
288
        filename (str | None): The optional file path to save image to. If not specified
289
            no file is created for the visualization. If this is set the return
290
            from this function will be ``None``.
291

292
    Returns:
293
        :class:`matplotlib:matplotlib.figure.Figure` :
294
            A matplotlib figure instance.
295

296
    Raises:
297
        MissingOptionalLibraryError: Requires matplotlib.
298
        VisualizationError: if input is not a valid N-qubit state.
299

300
    Examples:
301
        .. plot::
302
           :alt: Output from the previous code.
303
           :include-source:
304

305
            from qiskit import QuantumCircuit
306
            from qiskit.quantum_info import Statevector
307
            from qiskit.visualization import plot_bloch_multivector
308

309
            qc = QuantumCircuit(2)
310
            qc.h(0)
311
            qc.x(1)
312

313
            state = Statevector(qc)
314
            plot_bloch_multivector(state)
315

316
        .. plot::
317
           :alt: Output from the previous code.
318
           :include-source:
319

320
           from qiskit import QuantumCircuit
321
           from qiskit.quantum_info import Statevector
322
           from qiskit.visualization import plot_bloch_multivector
323

324
           qc = QuantumCircuit(2)
325
           qc.h(0)
326
           qc.x(1)
327

328
           # You can reverse the order of the qubits.
329

330
           from qiskit.quantum_info import DensityMatrix
331

332
           qc = QuantumCircuit(2)
333
           qc.h([0, 1])
334
           qc.t(1)
335
           qc.s(0)
336
           qc.cx(0,1)
337

338
           matrix = DensityMatrix(qc)
339
           plot_bloch_multivector(matrix, title='My Bloch Spheres', reverse_bits=True)
340

341
    """
342
    from matplotlib import pyplot as plt
1✔
343

344
    # Data
345
    bloch_data = (
1✔
346
        _bloch_multivector_data(state)[::-1] if reverse_bits else _bloch_multivector_data(state)
347
    )
348
    num = len(bloch_data)
1✔
349
    if figsize is not None:
1✔
UNCOV
350
        width, height = figsize
×
UNCOV
351
        width *= num
×
352
    else:
353
        width, height = plt.figaspect(1 / num)
1✔
354
    if len(title) > 0:
1✔
UNCOV
355
        height += 1 + title_pad / 100  # additional space for the title
×
356
    default_title_font_size = font_size if font_size is not None else 16
1✔
357
    title_font_size = title_font_size if title_font_size is not None else default_title_font_size
1✔
358
    fig = plt.figure(figsize=(width, height))
1✔
359
    for i in range(num):
1✔
360
        pos = num - 1 - i if reverse_bits else i
1✔
361
        ax = fig.add_subplot(1, num, i + 1, projection="3d")
1✔
362
        plot_bloch_vector(
1✔
363
            bloch_data[i], "qubit " + str(pos), ax=ax, figsize=figsize, font_size=font_size
364
        )
365
    fig.suptitle(title, fontsize=title_font_size, y=0.98)
1✔
366
    matplotlib_close_if_inline(fig)
1✔
367
    if filename is None:
1✔
368
        try:
1✔
369
            fig.tight_layout()
1✔
UNCOV
370
        except AttributeError:
×
UNCOV
371
            pass
×
372
        return fig
1✔
373
    else:
UNCOV
374
        return fig.savefig(filename)
×
375

376

377
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
378
def plot_state_city(
1✔
379
    state,
380
    title="",
381
    figsize=None,
382
    color=None,
383
    alpha=1,
384
    ax_real=None,
385
    ax_imag=None,
386
    *,
387
    filename=None,
388
):
389
    """Plot the cityscape of quantum state.
390

391
    Plot two 3d bar graphs (two dimensional) of the real and imaginary
392
    part of the density matrix rho.
393

394
    Args:
395
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
396
        title (str): a string that represents the plot title
397
        figsize (tuple): Figure size in inches.
398
        color (list): A list of len=2 giving colors for real and
399
            imaginary components of matrix elements.
400
        alpha (float): Transparency value for bars
401
        ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
402
            the visualization output. If none is specified a new matplotlib
403
            Figure will be created and used. If this is specified without an
404
            ax_imag only the real component plot will be generated.
405
            Additionally, if specified there will be no returned Figure since
406
            it is redundant.
407
        ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
408
            the visualization output. If none is specified a new matplotlib
409
            Figure will be created and used. If this is specified without an
410
            ax_real only the imaginary component plot will be generated.
411
            Additionally, if specified there will be no returned Figure since
412
            it is redundant.
413
        filename (str | None): The optional file path to save image to. If not specified
414
            no file is created for the visualization. If this is set the return
415
            from this function will be ``None``.
416

417
    Returns:
418
        :class:`matplotlib:matplotlib.figure.Figure` :
419
            The matplotlib.Figure of the visualization if the
420
            ``ax_real`` and ``ax_imag`` kwargs are not set
421

422
    Raises:
423
        MissingOptionalLibraryError: Requires matplotlib.
424
        ValueError: When 'color' is not a list of len=2.
425
        VisualizationError: if input is not a valid N-qubit state.
426

427
    Examples:
428
        .. plot::
429
           :alt: Output from the previous code.
430
           :include-source:
431

432
           # You can choose different colors for the real and imaginary parts of the density matrix.
433

434
           from qiskit import QuantumCircuit
435
           from qiskit.quantum_info import DensityMatrix
436
           from qiskit.visualization import plot_state_city
437

438
           qc = QuantumCircuit(2)
439
           qc.h(0)
440
           qc.cx(0, 1)
441

442
           state = DensityMatrix(qc)
443
           plot_state_city(state, color=['midnightblue', 'crimson'], title="New State City")
444

445
        .. plot::
446
           :alt: Output from the previous code.
447
           :include-source:
448

449
           # You can make the bars more transparent to better see the ones that are behind
450
           # if they overlap.
451

452
           import numpy as np
453
           from qiskit.quantum_info import Statevector
454
           from qiskit.visualization import plot_state_city
455
           from qiskit import QuantumCircuit
456

457
           qc = QuantumCircuit(2)
458
           qc.h(0)
459
           qc.cx(0, 1)
460

461

462
           qc = QuantumCircuit(2)
463
           qc.h([0, 1])
464
           qc.cz(0,1)
465
           qc.ry(np.pi/3, 0)
466
           qc.rx(np.pi/5, 1)
467

468
           state = Statevector(qc)
469
           plot_state_city(state, alpha=0.6)
470

471
    """
472
    import matplotlib.colors as mcolors
×
UNCOV
473
    from matplotlib import pyplot as plt
×
474
    from mpl_toolkits.mplot3d.art3d import Poly3DCollection
×
475

476
    rho = DensityMatrix(state)
×
477
    num = rho.num_qubits
×
UNCOV
478
    if num is None:
×
UNCOV
479
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
480

481
    # get the real and imag parts of rho
UNCOV
482
    datareal = np.real(rho.data)
×
UNCOV
483
    dataimag = np.imag(rho.data)
×
484

485
    # get the labels
UNCOV
486
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
×
487
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
×
488

489
    ly, lx = datareal.shape[:2]
×
490
    xpos = np.arange(0, lx, 1)  # Set up a mesh of positions
×
UNCOV
491
    ypos = np.arange(0, ly, 1)
×
492
    xpos, ypos = np.meshgrid(xpos + 0.25, ypos + 0.25)
×
493

494
    xpos = xpos.flatten()
×
UNCOV
495
    ypos = ypos.flatten()
×
496
    zpos = np.zeros(lx * ly)
×
497

498
    dx = 0.5 * np.ones_like(zpos)  # width of bars
×
499
    dy = dx.copy()
×
UNCOV
500
    dzr = datareal.flatten()
×
501
    dzi = dataimag.flatten()
×
502

UNCOV
503
    if color is None:
×
504
        real_color, imag_color = "#648fff", "#648fff"
×
505
    else:
506
        if len(color) != 2:
×
507
            raise ValueError("'color' must be a list of len=2.")
×
508
        real_color = "#648fff" if color[0] is None else color[0]
×
UNCOV
509
        imag_color = "#648fff" if color[1] is None else color[1]
×
510
    if ax_real is None and ax_imag is None:
×
511
        # set default figure size
UNCOV
512
        if figsize is None:
×
513
            figsize = (16, 8)
×
514

515
        fig = plt.figure(figsize=figsize, facecolor="w")
×
UNCOV
516
        ax1 = fig.add_subplot(1, 2, 1, projection="3d", computed_zorder=False)
×
517
        ax2 = fig.add_subplot(1, 2, 2, projection="3d", computed_zorder=False)
×
518

519
    elif ax_real is not None:
×
520
        fig = ax_real.get_figure()
×
UNCOV
521
        ax1 = ax_real
×
522
        ax2 = ax_imag
×
523
    else:
524
        fig = ax_imag.get_figure()
×
UNCOV
525
        ax1 = None
×
526
        ax2 = ax_imag
×
527

528
    fig.tight_layout()
×
529

UNCOV
530
    max_dzr = np.max(dzr)
×
UNCOV
531
    max_dzi = np.max(dzi)
×
532

533
    # Figure scaling variables since fig.tight_layout won't work
534
    fig_width, fig_height = fig.get_size_inches()
×
535
    max_plot_size = min(fig_width / 2.25, fig_height)
×
UNCOV
536
    max_font_size = int(3 * max_plot_size)
×
537
    max_zoom = 10 / (10 + np.sqrt(max_plot_size))
×
538

UNCOV
539
    for ax, dz, col, zlabel in (
×
540
        (ax1, dzr, real_color, "Real"),
541
        (ax2, dzi, imag_color, "Imaginary"),
542
    ):
543

UNCOV
544
        if ax is None:
×
545
            continue
×
546

UNCOV
547
        max_dz = np.max(dz)
×
548
        min_dz = np.min(dz)
×
549

UNCOV
550
        if isinstance(col, str) and col.startswith("#"):
×
551
            col = mcolors.to_rgba_array(col)
×
552

553
        dzn = dz < 0
×
UNCOV
554
        if np.any(dzn):
×
UNCOV
555
            fc = generate_facecolors(
×
556
                xpos[dzn], ypos[dzn], zpos[dzn], dx[dzn], dy[dzn], dz[dzn], col
557
            )
UNCOV
558
            negative_bars = ax.bar3d(
×
559
                xpos[dzn],
560
                ypos[dzn],
561
                zpos[dzn],
562
                dx[dzn],
563
                dy[dzn],
564
                dz[dzn],
565
                alpha=alpha,
566
                zorder=0.625,
567
            )
568
            negative_bars.set_facecolor(fc)
×
569

570
        if min_dz < 0 < max_dz:
×
571
            xlim, ylim = [0, lx], [0, ly]
×
572
            verts = [list(zip(xlim + xlim[::-1], np.repeat(ylim, 2), [0] * 4))]
×
573
            plane = Poly3DCollection(verts, alpha=0.25, facecolor="k", linewidths=1)
×
UNCOV
574
            plane.set_zorder(0.75)
×
575
            ax.add_collection3d(plane)
×
576

577
        dzp = dz >= 0
×
UNCOV
578
        if np.any(dzp):
×
UNCOV
579
            fc = generate_facecolors(
×
580
                xpos[dzp], ypos[dzp], zpos[dzp], dx[dzp], dy[dzp], dz[dzp], col
581
            )
UNCOV
582
            positive_bars = ax.bar3d(
×
583
                xpos[dzp],
584
                ypos[dzp],
585
                zpos[dzp],
586
                dx[dzp],
587
                dy[dzp],
588
                dz[dzp],
589
                alpha=alpha,
590
                zorder=0.875,
591
            )
592
            positive_bars.set_facecolor(fc)
×
593

594
        ax.set_title(f"{zlabel} Amplitude (ρ)", fontsize=max_font_size)
×
595

596
        ax.set_xticks(np.arange(0.5, lx + 0.5, 1))
×
597
        ax.set_yticks(np.arange(0.5, ly + 0.5, 1))
×
598
        if max_dz != min_dz:
×
599
            ax.axes.set_zlim3d(min_dz, max(max_dzr + 1e-9, max_dzi))
×
UNCOV
600
        elif min_dz == 0:
×
601
            ax.axes.set_zlim3d(min_dz, max(max_dzr + 1e-9, max_dzi))
×
602
        else:
UNCOV
603
            ax.axes.set_zlim3d(auto=True)
×
604
        ax.get_autoscalez_on()
×
605

UNCOV
606
        ax.xaxis.set_ticklabels(
×
607
            row_names, fontsize=max_font_size, rotation=45, ha="right", va="top"
608
        )
UNCOV
609
        ax.yaxis.set_ticklabels(
×
610
            column_names, fontsize=max_font_size, rotation=-22.5, ha="left", va="center"
611
        )
612

613
        for tick in ax.zaxis.get_major_ticks():
×
614
            tick.label1.set_fontsize(max_font_size)
×
UNCOV
615
            tick.label1.set_horizontalalignment("left")
×
616
            tick.label1.set_verticalalignment("bottom")
×
617

618
        ax.set_box_aspect(aspect=(4, 4, 4), zoom=max_zoom)
×
UNCOV
619
        ax.set_xmargin(0)
×
620
        ax.set_ymargin(0)
×
621

622
    fig.suptitle(title, fontsize=max_font_size * 1.25)
×
623
    fig.subplots_adjust(top=0.9, bottom=0, left=0, right=1, hspace=0, wspace=0)
×
624
    if ax_real is None and ax_imag is None:
×
625
        matplotlib_close_if_inline(fig)
×
UNCOV
626
    if filename is None:
×
627
        return fig
×
628
    else:
UNCOV
629
        return fig.savefig(filename)
×
630

631

632
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
633
def plot_state_paulivec(state, title="", figsize=None, color=None, ax=None, *, filename=None):
1✔
634
    r"""Plot the Pauli-vector representation of a quantum state as bar graph.
635

636
    The Pauli-vector of a density matrix :math:`\rho` is defined by the expectation of each
637
    possible tensor product of single-qubit Pauli operators (including the identity), that is
638

639
    .. math ::
640

641
        \rho = \frac{1}{2^n} \sum_{\sigma \in \{I, X, Y, Z\}^{\otimes n}}
642
               \mathrm{Tr}(\sigma \rho) \sigma.
643

644
    This function plots the coefficients :math:`\mathrm{Tr}(\sigma\rho)` as bar graph.
645

646
    Args:
647
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
648
        title (str): a string that represents the plot title
649
        figsize (tuple): Figure size in inches.
650
        color (list or str): Color of the coefficient value bars.
651
        ax (matplotlib.axes.Axes): An optional Axes object to be used for
652
            the visualization output. If none is specified a new matplotlib
653
            Figure will be created and used. Additionally, if specified there
654
            will be no returned Figure since it is redundant.
655
        filename (str | None): The optional file path to save image to. If not specified
656
            no file is created for the visualization. If this is set the return
657
            from this function will be ``None``.
658

659
    Returns:
660
         :class:`matplotlib:matplotlib.figure.Figure` :
661
            The matplotlib.Figure of the visualization if the
662
            ``ax`` kwarg is not set
663

664
    Raises:
665
        MissingOptionalLibraryError: Requires matplotlib.
666
        VisualizationError: if input is not a valid N-qubit state.
667

668
    Examples:
669
        .. plot::
670
           :alt: Output from the previous code.
671
           :include-source:
672

673
           # You can set a color for all the bars.
674

675
           from qiskit import QuantumCircuit
676
           from qiskit.quantum_info import Statevector
677
           from qiskit.visualization import plot_state_paulivec
678

679
           qc = QuantumCircuit(2)
680
           qc.h(0)
681
           qc.cx(0, 1)
682

683
           state = Statevector(qc)
684
           plot_state_paulivec(state, color='midnightblue', title="New PauliVec plot")
685

686
        .. plot::
687
           :alt: Output from the previous code.
688
           :include-source:
689

690
           # If you introduce a list with less colors than bars, the color of the bars will
691
           # alternate following the sequence from the list.
692

693
           import numpy as np
694
           from qiskit.quantum_info import DensityMatrix
695
           from qiskit import QuantumCircuit
696
           from qiskit.visualization import plot_state_paulivec
697

698
           qc = QuantumCircuit(2)
699
           qc.h(0)
700
           qc.cx(0, 1)
701

702
           qc = QuantumCircuit(2)
703
           qc.h([0, 1])
704
           qc.cz(0, 1)
705
           qc.ry(np.pi/3, 0)
706
           qc.rx(np.pi/5, 1)
707

708
           matrix = DensityMatrix(qc)
709
           plot_state_paulivec(matrix, color=['crimson', 'midnightblue', 'seagreen'])
710
    """
711
    from matplotlib import pyplot as plt
×
712

UNCOV
713
    labels, values = _paulivec_data(state)
×
714
    numelem = len(values)
×
715

716
    if figsize is None:
×
717
        figsize = (7, 5)
×
UNCOV
718
    if color is None:
×
719
        color = "#648fff"
×
720

721
    ind = np.arange(numelem)  # the x locations for the groups
×
722
    width = 0.5  # the width of the bars
×
723
    if ax is None:
×
UNCOV
724
        return_fig = True
×
725
        fig, ax = plt.subplots(figsize=figsize)
×
726
    else:
727
        return_fig = False
×
728
        fig = ax.get_figure()
×
729
    ax.grid(zorder=0, linewidth=1, linestyle="--")
×
UNCOV
730
    ax.bar(ind, values, width, color=color, zorder=2)
×
731
    ax.axhline(linewidth=1, color="k")
×
732
    # add some text for labels, title, and axes ticks
733
    ax.set_ylabel("Coefficients", fontsize=14)
×
734
    ax.set_xticks(ind)
×
735
    ax.set_yticks([-1, -0.5, 0, 0.5, 1])
×
736
    ax.set_xticklabels(labels, fontsize=14, rotation=70)
×
737
    ax.set_xlabel("Pauli", fontsize=14)
×
738
    ax.set_ylim([-1, 1])
×
739
    ax.set_facecolor("#eeeeee")
×
740
    for tick in ax.xaxis.get_major_ticks() + ax.yaxis.get_major_ticks():
×
741
        tick.label1.set_fontsize(14)
×
742
    ax.set_title(title, fontsize=16)
×
743
    if return_fig:
×
744
        matplotlib_close_if_inline(fig)
×
745
    if filename is None:
×
746
        try:
×
747
            fig.tight_layout()
×
748
        except AttributeError:
×
UNCOV
749
            pass
×
750
        return fig
×
751
    else:
UNCOV
752
        return fig.savefig(filename)
×
753

754

755
def n_choose_k(n, k):
1✔
756
    """Return the number of combinations for n choose k.
757

758
    Args:
759
        n (int): the total number of options .
760
        k (int): The number of elements.
761

762
    Returns:
763
        int: returns the binomial coefficient
764
    """
765
    if n == 0:
1✔
766
        return 0
1✔
767
    return reduce(lambda x, y: x * y[0] / y[1], zip(range(n - k + 1, n + 1), range(1, k + 1)), 1)
1✔
768

769

770
def lex_index(n, k, lst):
1✔
771
    """Return  the lex index of a combination..
772

773
    Args:
774
        n (int): the total number of options .
775
        k (int): The number of elements.
776
        lst (list): list
777

778
    Returns:
779
        int: returns int index for lex order
780

781
    Raises:
782
        VisualizationError: if length of list is not equal to k
783
    """
784
    if len(lst) != k:
1✔
UNCOV
785
        raise VisualizationError("list should have length k")
×
786
    comb = [n - 1 - x for x in lst]
1✔
787
    dualm = sum(n_choose_k(comb[k - 1 - i], i + 1) for i in range(k))
1✔
788
    return int(dualm)
1✔
789

790

791
def bit_string_index(s):
1✔
792
    """Return the index of a string of 0s and 1s."""
793
    n = len(s)
1✔
794
    k = s.count("1")
1✔
795
    if s.count("0") != n - k:
1✔
UNCOV
796
        raise VisualizationError("s must be a string of 0 and 1")
×
797
    ones = [pos for pos, char in enumerate(s) if char == "1"]
1✔
798
    return lex_index(n, k, ones)
1✔
799

800

801
def phase_to_rgb(complex_number):
1✔
802
    """Map a phase of a complex number to a color in (r,g,b).
803

804
    complex_number is phase is first mapped to angle in the range
805
    [0, 2pi] and then to the HSL color wheel
806
    """
807
    angles = (np.angle(complex_number) + (np.pi * 5 / 4)) % (np.pi * 2)
1✔
808
    rgb = colorsys.hls_to_rgb(angles / (np.pi * 2), 0.5, 0.5)
1✔
809
    return rgb
1✔
810

811

812
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
813
@_optionals.HAS_SEABORN.require_in_call
1✔
814
def plot_state_qsphere(
1✔
815
    state,
816
    figsize=None,
817
    ax=None,
818
    show_state_labels=True,
819
    show_state_phases=False,
820
    use_degrees=False,
821
    *,
822
    filename=None,
823
):
824
    """Plot the qsphere representation of a quantum state.
825
    Here, the size of the points is proportional to the probability
826
    of the corresponding term in the state and the color represents
827
    the phase.
828

829
    Args:
830
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
831
        figsize (tuple): Figure size in inches.
832
        ax (matplotlib.axes.Axes): An optional Axes object to be used for
833
            the visualization output. If none is specified a new matplotlib
834
            Figure will be created and used. Additionally, if specified there
835
            will be no returned Figure since it is redundant.
836
        show_state_labels (bool): An optional boolean indicating whether to
837
            show labels for each basis state.
838
        show_state_phases (bool): An optional boolean indicating whether to
839
            show the phase for each basis state.
840
        use_degrees (bool): An optional boolean indicating whether to use
841
            radians or degrees for the phase values in the plot.
842
        filename (str | None): The optional file path to save image to. If not specified
843
            no file is created for the visualization. If this is set the return
844
            from this function will be ``None``.
845

846

847
    Returns:
848
        :class:`matplotlib:matplotlib.figure.Figure` :
849
            A matplotlib figure instance if the ``ax`` kwarg is not set
850

851
    Raises:
852
        MissingOptionalLibraryError: Requires matplotlib.
853
        VisualizationError: Input is not a valid N-qubit state.
854

855
        QiskitError: Input statevector does not have valid dimensions.
856

857
    Examples:
858
        .. plot::
859
           :alt: Output from the previous code.
860
           :include-source:
861

862
           from qiskit import QuantumCircuit
863
           from qiskit.quantum_info import Statevector
864
           from qiskit.visualization import plot_state_qsphere
865

866
           qc = QuantumCircuit(2)
867
           qc.h(0)
868
           qc.cx(0, 1)
869

870
           state = Statevector(qc)
871
           plot_state_qsphere(state)
872

873
        .. plot::
874
           :alt: Output from the previous code.
875
           :include-source:
876

877
           # You can show the phase of each state and use
878
           # degrees instead of radians
879

880
           from qiskit.quantum_info import DensityMatrix
881
           import numpy as np
882
           from qiskit import QuantumCircuit
883
           from qiskit.visualization import plot_state_qsphere
884

885
           qc = QuantumCircuit(2)
886
           qc.h([0, 1])
887
           qc.cz(0,1)
888
           qc.ry(np.pi/3, 0)
889
           qc.rx(np.pi/5, 1)
890
           qc.z(1)
891

892
           matrix = DensityMatrix(qc)
893
           plot_state_qsphere(matrix,
894
                show_state_phases = True, use_degrees = True)
895
    """
896
    from matplotlib import gridspec
1✔
897
    from matplotlib import pyplot as plt
1✔
898
    from matplotlib.patches import Circle
1✔
899
    import seaborn as sns
1✔
900
    from scipy import linalg
1✔
901
    from .bloch import Arrow3D
1✔
902

903
    rho = DensityMatrix(state)
1✔
904
    num = rho.num_qubits
1✔
905
    if num is None:
1✔
UNCOV
906
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
907
    # get the eigenvectors and eigenvalues
908
    eigvals, eigvecs = linalg.eigh(rho.data)
1✔
909

910
    if figsize is None:
1✔
911
        figsize = (7, 7)
1✔
912

913
    if ax is None:
1✔
914
        return_fig = True
1✔
915
        fig = plt.figure(figsize=figsize)
1✔
916
    else:
UNCOV
917
        return_fig = False
×
UNCOV
918
        fig = ax.get_figure()
×
919

920
    gs = gridspec.GridSpec(nrows=3, ncols=3)
1✔
921

922
    ax = fig.add_subplot(gs[0:3, 0:3], projection="3d")
1✔
923
    ax.axes.set_xlim3d(-1.0, 1.0)
1✔
924
    ax.axes.set_ylim3d(-1.0, 1.0)
1✔
925
    ax.axes.set_zlim3d(-1.0, 1.0)
1✔
926
    ax.axes.grid(False)
1✔
927
    ax.view_init(elev=5, azim=275)
1✔
928

929
    # Force aspect ratio
930
    # MPL 3.2 or previous do not have set_box_aspect
931
    if hasattr(ax.axes, "set_box_aspect"):
1✔
932
        ax.axes.set_box_aspect((1, 1, 1))
1✔
933

934
    # start the plotting
935
    # Plot semi-transparent sphere
936
    u = np.linspace(0, 2 * np.pi, 25)
1✔
937
    v = np.linspace(0, np.pi, 25)
1✔
938
    x = np.outer(np.cos(u), np.sin(v))
1✔
939
    y = np.outer(np.sin(u), np.sin(v))
1✔
940
    z = np.outer(np.ones(np.size(u)), np.cos(v))
1✔
941
    ax.plot_surface(
1✔
942
        x, y, z, rstride=1, cstride=1, color=plt.rcParams["grid.color"], alpha=0.2, linewidth=0
943
    )
944

945
    # Get rid of the panes
946
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
947
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
948
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
949

950
    # Get rid of the spines
951
    ax.xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
952
    ax.yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
953
    ax.zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
954

955
    # Get rid of the ticks
956
    ax.set_xticks([])
1✔
957
    ax.set_yticks([])
1✔
958
    ax.set_zticks([])
1✔
959

960
    # traversing the eigvals/vecs backward as sorted low->high
961
    for idx in range(eigvals.shape[0] - 1, -1, -1):
1✔
962
        if eigvals[idx] > 0.001:
1✔
963
            # get the max eigenvalue
964
            state = eigvecs[:, idx]
1✔
965
            # Rounding to 13 decimals ignores machine epsilon noise (~1e-16)
966
            # from the solver, ensuring 'argmax' finds the true analytical winner.
967
            loc = np.round(np.absolute(state), decimals=13).argmax()
1✔
968
            # remove the global phase from max element
969
            angles = (np.angle(state[loc]) + 2 * np.pi) % (2 * np.pi)
1✔
970
            angleset = np.exp(-1j * angles)
1✔
971
            state = angleset * state
1✔
972

973
            d = num
1✔
974
            for i in range(2**num):
1✔
975
                # get x,y,z points
976
                element = bin(i)[2:].zfill(num)
1✔
977
                weight = element.count("1")
1✔
978
                zvalue = -2 * weight / d + 1
1✔
979
                number_of_divisions = n_choose_k(d, weight)
1✔
980
                weight_order = bit_string_index(element)
1✔
981
                angle = (float(weight) / d) * (np.pi * 2) + (
1✔
982
                    weight_order * 2 * (np.pi / number_of_divisions)
983
                )
984

985
                if (weight > d / 2) or (
1✔
986
                    (weight == d / 2) and (weight_order >= number_of_divisions / 2)
987
                ):
988
                    angle = np.pi - angle - (2 * np.pi / number_of_divisions)
1✔
989

990
                xvalue = np.sqrt(1 - zvalue**2) * np.cos(angle)
1✔
991
                yvalue = np.sqrt(1 - zvalue**2) * np.sin(angle)
1✔
992

993
                # get prob and angle - prob will be shade and angle color
994
                prob = np.real(np.dot(state[i], state[i].conj()))
1✔
995
                prob = min(prob, 1)  # See https://github.com/Qiskit/qiskit-terra/issues/4666
1✔
996
                colorstate = phase_to_rgb(state[i])
1✔
997

998
                alfa = 1
1✔
999
                if yvalue >= 0.1:
1✔
1000
                    alfa = 1.0 - yvalue
1✔
1001

1002
                if not np.isclose(prob, 0) and show_state_labels:
1✔
1003
                    rprime = 1.3
1✔
1004
                    angle_theta = np.arctan2(np.sqrt(1 - zvalue**2), zvalue)
1✔
1005
                    xvalue_text = rprime * np.sin(angle_theta) * np.cos(angle)
1✔
1006
                    yvalue_text = rprime * np.sin(angle_theta) * np.sin(angle)
1✔
1007
                    zvalue_text = rprime * np.cos(angle_theta)
1✔
1008
                    element_text = "$\\vert" + element + "\\rangle$"
1✔
1009
                    if show_state_phases:
1✔
1010
                        element_angle = (np.angle(state[i]) + (np.pi * 4)) % (np.pi * 2)
×
UNCOV
1011
                        if use_degrees:
×
1012
                            element_text += f"\n${element_angle * 180 / np.pi:.1f}^\\circ$"
×
1013
                        else:
UNCOV
1014
                            element_angle = pi_check(element_angle, ndigits=3).replace("pi", "\\pi")
×
UNCOV
1015
                            element_text += f"\n${element_angle}$"
×
1016
                    ax.text(
1✔
1017
                        xvalue_text,
1018
                        yvalue_text,
1019
                        zvalue_text,
1020
                        element_text,
1021
                        ha="center",
1022
                        va="center",
1023
                        size=12,
1024
                    )
1025

1026
                ax.plot(
1✔
1027
                    [xvalue],
1028
                    [yvalue],
1029
                    [zvalue],
1030
                    markerfacecolor=colorstate,
1031
                    markeredgecolor=colorstate,
1032
                    marker="o",
1033
                    markersize=np.sqrt(prob) * 30,
1034
                    alpha=alfa,
1035
                )
1036

1037
                a = Arrow3D(
1✔
1038
                    [0, xvalue],
1039
                    [0, yvalue],
1040
                    [0, zvalue],
1041
                    mutation_scale=20,
1042
                    alpha=prob,
1043
                    arrowstyle="-",
1044
                    color=colorstate,
1045
                    lw=2,
1046
                )
1047
                ax.add_artist(a)
1✔
1048

1049
            # add weight lines
1050
            for weight in range(d + 1):
1✔
1051
                theta = np.linspace(-2 * np.pi, 2 * np.pi, 100)
1✔
1052
                z = -2 * weight / d + 1
1✔
1053
                r = np.sqrt(1 - z**2)
1✔
1054
                x = r * np.cos(theta)
1✔
1055
                y = r * np.sin(theta)
1✔
1056
                ax.plot(x, y, z, color=(0.5, 0.5, 0.5), lw=1, ls=":", alpha=0.5)
1✔
1057

1058
            # add center point
1059
            ax.plot(
1✔
1060
                [0],
1061
                [0],
1062
                [0],
1063
                markerfacecolor=(0.5, 0.5, 0.5),
1064
                markeredgecolor=(0.5, 0.5, 0.5),
1065
                marker="o",
1066
                markersize=3,
1067
                alpha=1,
1068
            )
1069
        else:
1070
            break
1✔
1071

1072
    n = 64
1✔
1073
    theta = np.ones(n)
1✔
1074
    colors = sns.hls_palette(n)
1✔
1075

1076
    ax2 = fig.add_subplot(gs[2:, 2:])
1✔
1077
    ax2.pie(theta, colors=colors[5 * n // 8 :] + colors[: 5 * n // 8], radius=0.75)
1✔
1078
    ax2.add_artist(Circle((0, 0), 0.5, color="white", zorder=1))
1✔
1079
    offset = 0.95  # since radius of sphere is one.
1✔
1080

1081
    if use_degrees:
1✔
UNCOV
1082
        labels = ["Phase\n(Deg)", "0", "90", "180   ", "270"]
×
1083
    else:
1084
        labels = ["Phase", "$0$", "$\\pi/2$", "$\\pi$", "$3\\pi/2$"]
1✔
1085

1086
    ax2.text(0, 0, labels[0], horizontalalignment="center", verticalalignment="center", fontsize=14)
1✔
1087
    ax2.text(
1✔
1088
        offset, 0, labels[1], horizontalalignment="center", verticalalignment="center", fontsize=14
1089
    )
1090
    ax2.text(
1✔
1091
        0, offset, labels[2], horizontalalignment="center", verticalalignment="center", fontsize=14
1092
    )
1093
    ax2.text(
1✔
1094
        -offset, 0, labels[3], horizontalalignment="center", verticalalignment="center", fontsize=14
1095
    )
1096
    ax2.text(
1✔
1097
        0, -offset, labels[4], horizontalalignment="center", verticalalignment="center", fontsize=14
1098
    )
1099

1100
    if return_fig:
1✔
1101
        matplotlib_close_if_inline(fig)
1✔
1102
    if filename is None:
1✔
1103
        return fig
1✔
1104
    else:
UNCOV
1105
        return fig.savefig(filename)
×
1106

1107

1108
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
1109
def generate_facecolors(x, y, z, dx, dy, dz, color):
1✔
1110
    """Generates shaded facecolors for shaded bars.
1111

1112
    This is here to work around a Matplotlib bug
1113
    where alpha does not work in Bar3D.
1114

1115
    Args:
1116
        x (array_like): The x- coordinates of the anchor point of the bars.
1117
        y (array_like): The y- coordinates of the anchor point of the bars.
1118
        z (array_like): The z- coordinates of the anchor point of the bars.
1119
        dx (array_like): Width of bars.
1120
        dy (array_like): Depth of bars.
1121
        dz (array_like): Height of bars.
1122
        color (array_like): sequence of valid color specifications, optional
1123
    Returns:
1124
        list: Shaded colors for bars.
1125
    Raises:
1126
        MissingOptionalLibraryError: If matplotlib is not installed
1127
    """
1128
    import matplotlib.colors as mcolors
×
1129

UNCOV
1130
    cuboid = np.array(
×
1131
        [
1132
            # -z
1133
            (
1134
                (0, 0, 0),
1135
                (0, 1, 0),
1136
                (1, 1, 0),
1137
                (1, 0, 0),
1138
            ),
1139
            # +z
1140
            (
1141
                (0, 0, 1),
1142
                (1, 0, 1),
1143
                (1, 1, 1),
1144
                (0, 1, 1),
1145
            ),
1146
            # -y
1147
            (
1148
                (0, 0, 0),
1149
                (1, 0, 0),
1150
                (1, 0, 1),
1151
                (0, 0, 1),
1152
            ),
1153
            # +y
1154
            (
1155
                (0, 1, 0),
1156
                (0, 1, 1),
1157
                (1, 1, 1),
1158
                (1, 1, 0),
1159
            ),
1160
            # -x
1161
            (
1162
                (0, 0, 0),
1163
                (0, 0, 1),
1164
                (0, 1, 1),
1165
                (0, 1, 0),
1166
            ),
1167
            # +x
1168
            (
1169
                (1, 0, 0),
1170
                (1, 1, 0),
1171
                (1, 1, 1),
1172
                (1, 0, 1),
1173
            ),
1174
        ]
1175
    )
1176

1177
    # indexed by [bar, face, vertex, coord]
1178
    polys = np.empty(x.shape + cuboid.shape)
×
1179
    # handle each coordinate separately
1180
    for i, p, dp in [(0, x, dx), (1, y, dy), (2, z, dz)]:
×
1181
        p = p[..., np.newaxis, np.newaxis]
×
UNCOV
1182
        dp = dp[..., np.newaxis, np.newaxis]
×
UNCOV
1183
        polys[..., i] = p + dp * cuboid[..., i]
×
1184

1185
    # collapse the first two axes
1186
    polys = polys.reshape((-1,) + polys.shape[2:])
×
1187

UNCOV
1188
    facecolors = []
×
1189
    if len(color) == len(x):
×
1190
        # bar colors specified, need to expand to number of faces
UNCOV
1191
        for c in color:
×
UNCOV
1192
            facecolors.extend([c] * 6)
×
1193
    else:
1194
        # a single color specified, or face colors specified explicitly
1195
        facecolors = list(mcolors.to_rgba_array(color))
×
UNCOV
1196
        if len(facecolors) < len(x):
×
1197
            facecolors *= 6 * len(x)
×
1198

UNCOV
1199
    normals = _generate_normals(polys)
×
UNCOV
1200
    return _shade_colors(facecolors, normals)
×
1201

1202

1203
def _generate_normals(polygons):
1✔
1204
    """Takes a list of polygons and return an array of their normals.
1205

1206
    Normals point towards the viewer for a face with its vertices in
1207
    counterclockwise order, following the right hand rule.
1208
    Uses three points equally spaced around the polygon.
1209
    This normal of course might not make sense for polygons with more than
1210
    three points not lying in a plane, but it's a plausible and fast
1211
    approximation.
1212

1213
    Args:
1214
        polygons (list): list of (M_i, 3) array_like, or (..., M, 3) array_like
1215
            A sequence of polygons to compute normals for, which can have
1216
            varying numbers of vertices. If the polygons all have the same
1217
            number of vertices and array is passed, then the operation will
1218
            be vectorized.
1219
    Returns:
1220
        normals: (..., 3) array_like
1221
            A normal vector estimated for the polygon.
1222
    """
UNCOV
1223
    if isinstance(polygons, np.ndarray):
×
1224
        # optimization: polygons all have the same number of points, so can
1225
        # vectorize
1226
        n = polygons.shape[-2]
×
1227
        i1, i2, i3 = 0, n // 3, 2 * n // 3
×
UNCOV
1228
        v1 = polygons[..., i1, :] - polygons[..., i2, :]
×
UNCOV
1229
        v2 = polygons[..., i2, :] - polygons[..., i3, :]
×
1230
    else:
1231
        # The subtraction doesn't vectorize because polygons is jagged.
1232
        v1 = np.empty((len(polygons), 3))
×
1233
        v2 = np.empty((len(polygons), 3))
×
1234
        for poly_i, ps in enumerate(polygons):
×
1235
            n = len(ps)
×
1236
            i1, i2, i3 = 0, n // 3, 2 * n // 3
×
UNCOV
1237
            v1[poly_i, :] = ps[i1, :] - ps[i2, :]
×
1238
            v2[poly_i, :] = ps[i2, :] - ps[i3, :]
×
1239

UNCOV
1240
    return np.cross(v1, v2)
×
1241

1242

1243
def _shade_colors(color, normals, lightsource=None):
1✔
1244
    """
1245
    Shade *color* using normal vectors given by *normals*.
1246
    *color* can also be an array of the same length as *normals*.
1247
    """
UNCOV
1248
    from matplotlib.colors import Normalize, LightSource
×
1249
    import matplotlib.colors as mcolors
×
1250

1251
    if lightsource is None:
×
1252
        # chosen for backwards-compatibility
1253
        lightsource = LightSource(azdeg=225, altdeg=19.4712)
×
1254

UNCOV
1255
    def mod(v):
×
1256
        return np.sqrt(v[0] ** 2 + v[1] ** 2 + v[2] ** 2)
×
1257

UNCOV
1258
    shade = np.array(
×
1259
        [np.dot(n / mod(n), lightsource.direction) if mod(n) else np.nan for n in normals]
1260
    )
1261
    mask = ~np.isnan(shade)
×
1262

1263
    if mask.any():
×
1264
        norm = Normalize(min(shade[mask]), max(shade[mask]))
×
UNCOV
1265
        shade[~mask] = min(shade[mask])
×
UNCOV
1266
        color = mcolors.to_rgba_array(color)
×
1267
        # shape of color should be (M, 4) (where M is number of faces)
1268
        # shape of shade should be (M,)
1269
        # colors should have final shape of (M, 4)
1270
        alpha = color[:, 3]
×
UNCOV
1271
        colors = (0.5 + norm(shade)[:, np.newaxis] * 0.5) * color
×
1272
        colors[:, 3] = alpha
×
1273
    else:
1274
        colors = np.asanyarray(color).copy()
×
1275

UNCOV
1276
    return colors
×
1277

1278

1279
def state_to_latex(
1✔
1280
    state: Statevector | DensityMatrix, dims: bool | None = None, convention: str = "ket", **args
1281
) -> str:
1282
    """Return a Latex representation of a state. Wrapper function
1283
    for `qiskit.visualization.array_to_latex` for convention 'vector'.
1284
    Adds dims if necessary.
1285
    Intended for use within `state_drawer`.
1286

1287
    Args:
1288
        state: State to be drawn
1289
        dims (bool): Whether to display the state's `dims`
1290
        convention (str): Either 'vector' or 'ket'. For 'ket' plot the state in the ket-notation.
1291
                Otherwise plot as a vector
1292
        **args: Arguments to be passed directly to `array_to_latex` for convention 'ket'
1293

1294
    Returns:
1295
        Latex representation of the state
1296
        MissingOptionalLibrary: If SymPy isn't installed and ``'latex'`` or
1297
            ``'latex_source'`` is selected for ``output``.
1298

1299
    """
1300
    if dims is None:  # show dims if state is not only qubits
1✔
1301
        if set(state.dims()) == {2}:
1✔
1302
            dims = False
1✔
1303
        else:
UNCOV
1304
            dims = True
×
1305

1306
    prefix = ""
1✔
1307
    suffix = ""
1✔
1308
    if dims:
1✔
1309
        prefix = "\\begin{align}\n"
×
UNCOV
1310
        dims_str = state._op_shape.dims_l()
×
UNCOV
1311
        suffix = f"\\\\\n\\text{{dims={dims_str}}}\n\\end{{align}}"
×
1312

1313
    operator_shape = state._op_shape
1✔
1314
    # we only use the ket convention for qubit statevectors
1315
    # this means the operator shape should have no input dimensions and all output dimensions equal to 2
1316
    is_qubit_statevector = len(operator_shape.dims_r()) == 0 and set(operator_shape.dims_l()) == {2}
1✔
1317
    if convention == "ket" and is_qubit_statevector:
1✔
1318
        latex_str = _state_to_latex_ket(state._data, **args)
1✔
1319
    else:
1320
        latex_str = array_to_latex(state._data, source=True, **args)
1✔
1321
    return prefix + latex_str + suffix
1✔
1322

1323

1324
def _numbers_to_latex_terms(numbers: list[complex], decimals: int = 10) -> list[str]:
1✔
1325
    """Convert a list of numbers to latex formatted terms
1326

1327
    The first non-zero term is treated differently. For this term a leading + is suppressed.
1328

1329
    Args:
1330
        numbers: List of numbers to format
1331
        decimals: Number of decimal places to round to (default: 10).
1332
    Returns:
1333
        List of formatted terms
1334
    """
1335
    first_term = True
1✔
1336
    terms = []
1✔
1337
    for number in numbers:
1✔
1338
        term = _num_to_latex(number, decimals=decimals, first_term=first_term, coefficient=True)
1✔
1339
        terms.append(term)
1✔
1340
        first_term = False
1✔
1341
    return terms
1✔
1342

1343

1344
def _state_to_latex_ket(
1✔
1345
    data: list[complex], max_size: int = 12, prefix: str = "", decimals: int = 10
1346
) -> str:
1347
    """Convert state vector to latex representation
1348

1349
    Args:
1350
        data: State vector
1351
        max_size: Maximum number of non-zero terms in the expression. If the number of
1352
                 non-zero terms is larger than the max_size, then the representation is truncated.
1353
        prefix: Latex string to be prepended to the latex, intended for labels.
1354
        decimals: Number of decimal places to round to (default: 10).
1355

1356
    Returns:
1357
        String with LaTeX representation of the state vector
1358
    """
1359
    num = int(math.log2(len(data)))
1✔
1360

1361
    def ket_name(i):
1✔
1362
        return bin(i)[2:].zfill(num)
1✔
1363

1364
    data = np.around(data, decimals)
1✔
1365
    nonzero_indices = np.where(data != 0)[0].tolist()
1✔
1366
    if len(nonzero_indices) > max_size:
1✔
1367
        nonzero_indices = (
1✔
1368
            nonzero_indices[: max_size // 2] + [0] + nonzero_indices[-max_size // 2 + 1 :]
1369
        )
1370
        latex_terms = _numbers_to_latex_terms(data[nonzero_indices], decimals)
1✔
1371
        nonzero_indices[max_size // 2] = None
1✔
1372
    else:
1373
        latex_terms = _numbers_to_latex_terms(data[nonzero_indices], decimals)
1✔
1374

1375
    latex_str = ""
1✔
1376
    for idx, ket_idx in enumerate(nonzero_indices):
1✔
1377
        if ket_idx is None:
1✔
1378
            latex_str += r" + \ldots "
1✔
1379
        else:
1380
            term = latex_terms[idx]
1✔
1381
            ket = ket_name(ket_idx)
1✔
1382
            latex_str += f"{term} |{ket}\\rangle"
1✔
1383
    return prefix + latex_str
1✔
1384

1385

1386
class TextMatrix:
1✔
1387
    """Text representation of an array, with `__str__` method so it
1388
    displays nicely in Jupyter notebooks"""
1389

1390
    def __init__(self, state, max_size=8, dims=None, prefix="", suffix=""):
1✔
1391
        self.state = state
1✔
1392
        self.max_size = max_size
1✔
1393
        if dims is None:  # show dims if state is not only qubits
1✔
1394
            if (isinstance(state, (Statevector, DensityMatrix)) and set(state.dims()) == {2}) or (
1✔
1395
                isinstance(state, Operator)
1396
                and len(state.input_dims()) == len(state.output_dims())
1397
                and set(state.input_dims()) == set(state.output_dims()) == {2}
1398
            ):
1399
                dims = False
1✔
1400
            else:
UNCOV
1401
                dims = True
×
1402
        self.dims = dims
1✔
1403
        self.prefix = prefix
1✔
1404
        self.suffix = suffix
1✔
1405
        if isinstance(max_size, int):
1✔
1406
            self.max_size = max_size
1✔
UNCOV
1407
        elif isinstance(state, DensityMatrix):
×
1408
            # density matrices are square, so threshold for
1409
            # summarization is shortest side squared
1410
            self.max_size = min(max_size) ** 2
×
1411
        else:
UNCOV
1412
            self.max_size = max_size[0]
×
1413

1414
    def __str__(self):
1✔
UNCOV
1415
        threshold = self.max_size
×
UNCOV
1416
        data = np.array2string(
×
1417
            self.state._data, prefix=self.prefix, threshold=threshold, separator=","
1418
        )
1419
        dimstr = ""
×
1420
        if self.dims:
×
1421
            data += ",\n"
×
1422
            dimstr += " " * len(self.prefix)
×
UNCOV
1423
            if isinstance(self.state, (Statevector, DensityMatrix)):
×
1424
                dimstr += f"dims={self.state._op_shape.dims_l()}"
×
1425
            else:
UNCOV
1426
                dimstr += f"input_dims={self.state.input_dims()}, "
×
1427
                dimstr += f"output_dims={self.state.output_dims()}"
×
1428

UNCOV
1429
        return self.prefix + data + dimstr + self.suffix
×
1430

1431
    def __repr__(self):
1432
        return self.__str__()
1433

1434

1435
def state_drawer(state, output=None, **drawer_args):
1✔
1436
    """Returns a visualization of the state.
1437

1438
    **repr**: ASCII TextMatrix of the state's ``_repr_``.
1439

1440
    **text**: ASCII TextMatrix that can be printed in the console.
1441

1442
    **latex**: An IPython Latex object for displaying in Jupyter Notebooks.
1443

1444
    **latex_source**: Raw, uncompiled ASCII source to generate array using LaTeX.
1445

1446
    **qsphere**: Matplotlib figure, rendering of statevector using `plot_state_qsphere()`.
1447

1448
    **hinton**: Matplotlib figure, rendering of statevector using `plot_state_hinton()`.
1449

1450
    **bloch**: Matplotlib figure, rendering of statevector using `plot_bloch_multivector()`.
1451

1452
    **city**: Matplotlib figure, rendering of statevector using `plot_state_city()`.
1453

1454
    **paulivec**: Matplotlib figure, rendering of statevector using `plot_state_paulivec()`.
1455

1456
    Args:
1457
        state: State to be drawn
1458
        output (str): Select the output method to use for drawing the
1459
            circuit. Valid choices are ``text``, ``latex``, ``latex_source``,
1460
            ``qsphere``, ``hinton``, ``bloch``, ``city`` or ``paulivec``.
1461
            Default is `'text`'.
1462
        drawer_args: Arguments to be passed to the relevant drawer. For
1463
            'latex' and 'latex_source' see ``array_to_latex``
1464

1465
    Returns:
1466
        :class:`matplotlib.figure` or :class:`str` or
1467
        :class:`TextMatrix` or :class:`IPython.display.Latex`:
1468
        Drawing of the state.
1469

1470
    Raises:
1471
        MissingOptionalLibraryError: when `output` is `latex` and IPython is not installed.
1472
            or if SymPy isn't installed and ``'latex'`` or ``'latex_source'`` is selected for
1473
            ``output``.
1474

1475
        ValueError: when `output` is not a valid selection.
1476
    """
1477
    config = user_config.get_config()
1✔
1478
    # Get default 'output' from config file else use 'repr'
1479
    default_output = "repr"
1✔
1480
    if output is None:
1✔
1481
        if config:
×
UNCOV
1482
            default_output = config.get("state_drawer", "repr")
×
UNCOV
1483
        output = default_output
×
1484
    output = output.lower()
1✔
1485

1486
    # Choose drawing backend:
1487
    drawers = {
1✔
1488
        "text": TextMatrix,
1489
        "latex_source": state_to_latex,
1490
        "qsphere": plot_state_qsphere,
1491
        "hinton": plot_state_hinton,
1492
        "bloch": plot_bloch_multivector,
1493
        "city": plot_state_city,
1494
        "paulivec": plot_state_paulivec,
1495
    }
1496
    if output == "latex":
1✔
1497
        _optionals.HAS_IPYTHON.require_now("state_drawer")
1✔
1498
        from IPython.display import Latex
1✔
1499

1500
        draw_func = drawers["latex_source"]
1✔
1501
        return Latex(f"$${draw_func(state, **drawer_args)}$$")
1✔
1502

1503
    if output == "repr":
1✔
1504
        return state.__repr__()
1✔
1505

1506
    try:
1✔
1507
        draw_func = drawers[output]
1✔
1508
        return draw_func(state, **drawer_args)
1✔
UNCOV
1509
    except KeyError as err:
×
UNCOV
1510
        raise ValueError(
×
1511
            f"""'{output}' is not a valid option for drawing {type(state).__name__}
1512
             objects. Please choose from:
1513
            'text', 'latex', 'latex_source', 'qsphere', 'hinton',
1514
            'bloch', 'city' or 'paulivec'."""
1515
        ) from err
1516

1517

1518
def _bloch_multivector_data(state):
1✔
1519
    """Return list of Bloch vectors for each qubit
1520

1521
    Args:
1522
        state (DensityMatrix or Statevector): an N-qubit state.
1523

1524
    Returns:
1525
        list: list of Bloch vectors (x, y, z) for each qubit.
1526

1527
    Raises:
1528
        VisualizationError: if input is not an N-qubit state.
1529
    """
1530
    rho = DensityMatrix(state)
1✔
1531
    num = rho.num_qubits
1✔
1532
    if num is None:
1✔
UNCOV
1533
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
1534
    pauli_singles = PauliList(["X", "Y", "Z"])
1✔
1535
    bloch_data = []
1✔
1536
    for i in range(num):
1✔
1537
        if num > 1:
1✔
1538
            paulis = PauliList.from_symplectic(
1✔
1539
                np.zeros((3, (num - 1)), dtype=bool), np.zeros((3, (num - 1)), dtype=bool)
1540
            ).insert(i, pauli_singles, qubit=True)
1541
        else:
UNCOV
1542
            paulis = pauli_singles
×
1543
        bloch_state = [np.real(np.trace(np.dot(mat, rho.data))) for mat in paulis.matrix_iter()]
1✔
1544
        bloch_data.append(bloch_state)
1✔
1545
    return bloch_data
1✔
1546

1547

1548
def _paulivec_data(state):
1✔
1549
    """Return paulivec data for plotting.
1550

1551
    Args:
1552
        state (DensityMatrix or Statevector): an N-qubit state.
1553

1554
    Returns:
1555
        tuple: (labels, values) for Pauli vector.
1556

1557
    Raises:
1558
        VisualizationError: if input is not an N-qubit state.
1559
    """
1560
    rho = SparsePauliOp.from_operator(DensityMatrix(state))
1✔
1561
    if rho.num_qubits is None:
1✔
UNCOV
1562
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
1563
    return rho.paulis.to_labels(), np.real(rho.coeffs * 2**rho.num_qubits)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc