• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Qiskit / qiskit / 13813956451

12 Mar 2025 02:32PM UTC coverage: 88.112% (+1.0%) from 87.154%
13813956451

Pull #13961

github

web-flow
Merge b15c3408b into abb0cf9db
Pull Request #13961: Add explicit tests for MCX synthesis algorithms

72657 of 82460 relevant lines covered (88.11%)

511547.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

57.87
/qiskit/visualization/state_visualization.py
1
# This code is part of Qiskit.
2
#
3
# (C) Copyright IBM 2017, 2023.
4
#
5
# This code is licensed under the Apache License, Version 2.0. You may
6
# obtain a copy of this license in the LICENSE.txt file in the root directory
7
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
8
#
9
# Any modifications or derivative works of this code must retain this
10
# copyright notice, and modified files need to carry a notice indicating
11
# that they have been altered from the originals.
12

13
# pylint: disable=invalid-name
14
# pylint: disable=missing-param-doc,missing-type-doc,unused-argument
15

16
"""
17
Visualization functions for quantum states.
18
"""
19

20
import math
1✔
21
from typing import List, Union
1✔
22
from functools import reduce
1✔
23
import colorsys
1✔
24

25
import numpy as np
1✔
26
from qiskit import user_config
1✔
27
from qiskit.quantum_info.states.statevector import Statevector
1✔
28
from qiskit.quantum_info.operators.operator import Operator
1✔
29
from qiskit.quantum_info.operators.symplectic import PauliList, SparsePauliOp
1✔
30
from qiskit.quantum_info.states.densitymatrix import DensityMatrix
1✔
31
from qiskit.utils import optionals as _optionals
1✔
32
from qiskit.circuit.tools.pi_check import pi_check
1✔
33

34
from .array import _num_to_latex, array_to_latex
1✔
35
from .utils import matplotlib_close_if_inline
1✔
36
from .exceptions import VisualizationError
1✔
37

38

39
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
40
def plot_state_hinton(state, title="", figsize=None, ax_real=None, ax_imag=None, *, filename=None):
1✔
41
    """Plot a hinton diagram for the density matrix of a quantum state.
42

43
    The hinton diagram represents the values of a matrix using
44
    squares, whose size indicate the magnitude of their corresponding value
45
    and their color, its sign. A white square means the value is positive and
46
    a black one means negative.
47

48
    Args:
49
        state (Statevector or DensityMatrix or ndarray): An N-qubit quantum state.
50
        title (str): a string that represents the plot title
51
        figsize (tuple): Figure size in inches.
52
        filename (str): file path to save image to.
53
        ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
54
            the visualization output. If none is specified a new matplotlib
55
            Figure will be created and used. If this is specified without an
56
            ax_imag only the real component plot will be generated.
57
            Additionally, if specified there will be no returned Figure since
58
            it is redundant.
59
        ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
60
            the visualization output. If none is specified a new matplotlib
61
            Figure will be created and used. If this is specified without an
62
            ax_imag only the real component plot will be generated.
63
            Additionally, if specified there will be no returned Figure since
64
            it is redundant.
65

66
    Returns:
67
        :class:`matplotlib:matplotlib.figure.Figure` :
68
            The matplotlib.Figure of the visualization if
69
            neither ax_real or ax_imag is set.
70

71
    Raises:
72
        MissingOptionalLibraryError: Requires matplotlib.
73
        VisualizationError: Input is not a valid N-qubit state.
74

75
    Examples:
76
        .. plot::
77
           :alt: Output from the previous code.
78
           :include-source:
79

80
            import numpy as np
81
            from qiskit import QuantumCircuit
82
            from qiskit.quantum_info import DensityMatrix
83
            from qiskit.visualization import plot_state_hinton
84

85
            qc = QuantumCircuit(2)
86
            qc.h([0, 1])
87
            qc.cz(0,1)
88
            qc.ry(np.pi/3 , 0)
89
            qc.rx(np.pi/5, 1)
90

91
            state = DensityMatrix(qc)
92
            plot_state_hinton(state, title="New Hinton Plot")
93

94
    """
95
    from matplotlib import pyplot as plt
1✔
96

97
    # Figure data
98
    rho = DensityMatrix(state)
1✔
99
    num = rho.num_qubits
1✔
100
    if num is None:
1✔
101
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
102
    max_weight = 2 ** math.ceil(math.log2(np.abs(rho.data).max()))
1✔
103
    datareal = np.real(rho.data)
1✔
104
    dataimag = np.imag(rho.data)
1✔
105

106
    if figsize is None:
1✔
107
        figsize = (8, 5)
1✔
108
    if not ax_real and not ax_imag:
1✔
109
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
1✔
110
    else:
111
        if ax_real:
×
112
            fig = ax_real.get_figure()
×
113
        else:
114
            fig = ax_imag.get_figure()
×
115
        ax1 = ax_real
×
116
        ax2 = ax_imag
×
117
    # Reversal is to account for Qiskit's endianness.
118
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
1✔
119
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)][::-1]
1✔
120
    ly, lx = datareal.shape
1✔
121
    # Real
122
    if ax1:
1✔
123
        ax1.patch.set_facecolor("gray")
1✔
124
        ax1.set_aspect("equal", "box")
1✔
125
        ax1.xaxis.set_major_locator(plt.NullLocator())
1✔
126
        ax1.yaxis.set_major_locator(plt.NullLocator())
1✔
127

128
        for (x, y), w in np.ndenumerate(datareal):
1✔
129
            # Convert from matrix co-ordinates to plot co-ordinates.
130
            plot_x, plot_y = y, lx - x - 1
1✔
131
            color = "white" if w > 0 else "black"
1✔
132
            size = np.sqrt(np.abs(w) / max_weight)
1✔
133
            rect = plt.Rectangle(
1✔
134
                [0.5 + plot_x - size / 2, 0.5 + plot_y - size / 2],
135
                size,
136
                size,
137
                facecolor=color,
138
                edgecolor=color,
139
            )
140
            ax1.add_patch(rect)
1✔
141

142
        ax1.set_xticks(0.5 + np.arange(lx))
1✔
143
        ax1.set_yticks(0.5 + np.arange(ly))
1✔
144
        ax1.set_xlim([0, lx])
1✔
145
        ax1.set_ylim([0, ly])
1✔
146
        ax1.set_yticklabels(row_names, fontsize=14)
1✔
147
        ax1.set_xticklabels(column_names, fontsize=14, rotation=90)
1✔
148
        ax1.set_title("Re[$\\rho$]", fontsize=14)
1✔
149
    # Imaginary
150
    if ax2:
1✔
151
        ax2.patch.set_facecolor("gray")
1✔
152
        ax2.set_aspect("equal", "box")
1✔
153
        ax2.xaxis.set_major_locator(plt.NullLocator())
1✔
154
        ax2.yaxis.set_major_locator(plt.NullLocator())
1✔
155

156
        for (x, y), w in np.ndenumerate(dataimag):
1✔
157
            # Convert from matrix co-ordinates to plot co-ordinates.
158
            plot_x, plot_y = y, lx - x - 1
1✔
159
            color = "white" if w > 0 else "black"
1✔
160
            size = np.sqrt(np.abs(w) / max_weight)
1✔
161
            rect = plt.Rectangle(
1✔
162
                [0.5 + plot_x - size / 2, 0.5 + plot_y - size / 2],
163
                size,
164
                size,
165
                facecolor=color,
166
                edgecolor=color,
167
            )
168
            ax2.add_patch(rect)
1✔
169

170
        ax2.set_xticks(0.5 + np.arange(lx))
1✔
171
        ax2.set_yticks(0.5 + np.arange(ly))
1✔
172
        ax2.set_xlim([0, lx])
1✔
173
        ax2.set_ylim([0, ly])
1✔
174
        ax2.set_yticklabels(row_names, fontsize=14)
1✔
175
        ax2.set_xticklabels(column_names, fontsize=14, rotation=90)
1✔
176
        ax2.set_title("Im[$\\rho$]", fontsize=14)
1✔
177
    fig.tight_layout()
1✔
178
    if title:
1✔
179
        fig.suptitle(title, fontsize=16)
×
180
    if ax_real is None and ax_imag is None:
1✔
181
        matplotlib_close_if_inline(fig)
1✔
182
    if filename is None:
1✔
183
        return fig
1✔
184
    else:
185
        return fig.savefig(filename)
×
186

187

188
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
189
def plot_bloch_vector(
1✔
190
    bloch, title="", ax=None, figsize=None, coord_type="cartesian", font_size=None
191
):
192
    """Plot the Bloch sphere.
193

194
    Plot a Bloch sphere with the specified coordinates, that can be given in both
195
    cartesian and spherical systems.
196

197
    Args:
198
        bloch (list[double]): array of three elements where [<x>, <y>, <z>] (Cartesian)
199
            or [<r>, <theta>, <phi>] (spherical in radians)
200
            <theta> is inclination angle from +z direction
201
            <phi> is azimuth from +x direction
202
        title (str): a string that represents the plot title
203
        ax (matplotlib.axes.Axes): An Axes to use for rendering the bloch
204
            sphere
205
        figsize (tuple): Figure size in inches. Has no effect is passing ``ax``.
206
        coord_type (str): a string that specifies coordinate type for bloch
207
            (Cartesian or spherical), default is Cartesian
208
        font_size (float): Font size.
209

210
    Returns:
211
        :class:`matplotlib:matplotlib.figure.Figure` : A matplotlib figure instance if ``ax = None``.
212

213
    Raises:
214
        MissingOptionalLibraryError: Requires matplotlib.
215

216
    Examples:
217
        .. plot::
218
           :alt: Output from the previous code.
219
           :include-source:
220

221
           from qiskit.visualization import plot_bloch_vector
222

223
           plot_bloch_vector([0,1,0], title="New Bloch Sphere")
224

225
        .. plot::
226
           :alt: Output from the previous code.
227
           :include-source:
228

229
           import numpy as np
230
           from qiskit.visualization import plot_bloch_vector
231

232
           # You can use spherical coordinates instead of cartesian.
233

234
           plot_bloch_vector([1, np.pi/2, np.pi/3], coord_type='spherical')
235

236
    """
237
    from .bloch import Bloch
1✔
238

239
    if figsize is None:
1✔
240
        figsize = (5, 5)
1✔
241
    B = Bloch(axes=ax, font_size=font_size)
1✔
242
    if coord_type == "spherical":
1✔
243
        r, theta, phi = bloch[0], bloch[1], bloch[2]
×
244
        bloch[0] = r * np.sin(theta) * np.cos(phi)
×
245
        bloch[1] = r * np.sin(theta) * np.sin(phi)
×
246
        bloch[2] = r * np.cos(theta)
×
247
    B.add_vectors(bloch)
1✔
248
    B.render(title=title)
1✔
249
    if ax is None:
1✔
250
        fig = B.fig
×
251
        fig.set_size_inches(figsize[0], figsize[1])
×
252
        matplotlib_close_if_inline(fig)
×
253
        return fig
×
254
    return None
1✔
255

256

257
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
258
def plot_bloch_multivector(
1✔
259
    state,
260
    title="",
261
    figsize=None,
262
    *,
263
    reverse_bits=False,
264
    filename=None,
265
    font_size=None,
266
    title_font_size=None,
267
    title_pad=1,
268
):
269
    r"""Plot a Bloch sphere for each qubit.
270

271
    Each component :math:`(x,y,z)` of the Bloch sphere labeled as 'qubit i' represents the expected
272
    value of the corresponding Pauli operator acting only on that qubit, that is, the expected value
273
    of :math:`I_{N-1} \otimes\dotsb\otimes I_{i+1}\otimes P_i \otimes I_{i-1}\otimes\dotsb\otimes
274
    I_0`, where :math:`N` is the number of qubits, :math:`P\in \{X,Y,Z\}` and :math:`I` is the
275
    identity operator.
276

277
    Args:
278
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
279
        title (str): a string that represents the plot title
280
        figsize (tuple): size of each individual Bloch sphere figure, in inches.
281
        reverse_bits (bool): If True, plots qubits following Qiskit's convention [Default:False].
282
        font_size (float): Font size for the Bloch ball figures.
283
        title_font_size (float): Font size for the title.
284
        title_pad (float): Padding for the title (suptitle `y` position is `y=1+title_pad/100`).
285

286
    Returns:
287
        :class:`matplotlib:matplotlib.figure.Figure` :
288
            A matplotlib figure instance.
289

290
    Raises:
291
        MissingOptionalLibraryError: Requires matplotlib.
292
        VisualizationError: if input is not a valid N-qubit state.
293

294
    Examples:
295
        .. plot::
296
           :alt: Output from the previous code.
297
           :include-source:
298

299
            from qiskit import QuantumCircuit
300
            from qiskit.quantum_info import Statevector
301
            from qiskit.visualization import plot_bloch_multivector
302

303
            qc = QuantumCircuit(2)
304
            qc.h(0)
305
            qc.x(1)
306

307
            state = Statevector(qc)
308
            plot_bloch_multivector(state)
309

310
        .. plot::
311
           :alt: Output from the previous code.
312
           :include-source:
313

314
           from qiskit import QuantumCircuit
315
           from qiskit.quantum_info import Statevector
316
           from qiskit.visualization import plot_bloch_multivector
317

318
           qc = QuantumCircuit(2)
319
           qc.h(0)
320
           qc.x(1)
321

322
           # You can reverse the order of the qubits.
323

324
           from qiskit.quantum_info import DensityMatrix
325

326
           qc = QuantumCircuit(2)
327
           qc.h([0, 1])
328
           qc.t(1)
329
           qc.s(0)
330
           qc.cx(0,1)
331

332
           matrix = DensityMatrix(qc)
333
           plot_bloch_multivector(matrix, title='My Bloch Spheres', reverse_bits=True)
334

335
    """
336
    from matplotlib import pyplot as plt
1✔
337

338
    # Data
339
    bloch_data = (
1✔
340
        _bloch_multivector_data(state)[::-1] if reverse_bits else _bloch_multivector_data(state)
341
    )
342
    num = len(bloch_data)
1✔
343
    if figsize is not None:
1✔
344
        width, height = figsize
×
345
        width *= num
×
346
    else:
347
        width, height = plt.figaspect(1 / num)
1✔
348
    default_title_font_size = font_size if font_size is not None else 16
1✔
349
    title_font_size = title_font_size if title_font_size is not None else default_title_font_size
1✔
350
    fig = plt.figure(figsize=(width, height))
1✔
351
    for i in range(num):
1✔
352
        pos = num - 1 - i if reverse_bits else i
1✔
353
        ax = fig.add_subplot(1, num, i + 1, projection="3d")
1✔
354
        plot_bloch_vector(
1✔
355
            bloch_data[i], "qubit " + str(pos), ax=ax, figsize=figsize, font_size=font_size
356
        )
357
    fig.suptitle(title, fontsize=title_font_size, y=1.0 + title_pad / 100)
1✔
358
    matplotlib_close_if_inline(fig)
1✔
359
    if filename is None:
1✔
360
        return fig
1✔
361
    else:
362
        return fig.savefig(filename)
×
363

364

365
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
366
def plot_state_city(
1✔
367
    state,
368
    title="",
369
    figsize=None,
370
    color=None,
371
    alpha=1,
372
    ax_real=None,
373
    ax_imag=None,
374
    *,
375
    filename=None,
376
):
377
    """Plot the cityscape of quantum state.
378

379
    Plot two 3d bar graphs (two dimensional) of the real and imaginary
380
    part of the density matrix rho.
381

382
    Args:
383
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
384
        title (str): a string that represents the plot title
385
        figsize (tuple): Figure size in inches.
386
        color (list): A list of len=2 giving colors for real and
387
            imaginary components of matrix elements.
388
        alpha (float): Transparency value for bars
389
        ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
390
            the visualization output. If none is specified a new matplotlib
391
            Figure will be created and used. If this is specified without an
392
            ax_imag only the real component plot will be generated.
393
            Additionally, if specified there will be no returned Figure since
394
            it is redundant.
395
        ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
396
            the visualization output. If none is specified a new matplotlib
397
            Figure will be created and used. If this is specified without an
398
            ax_real only the imaginary component plot will be generated.
399
            Additionally, if specified there will be no returned Figure since
400
            it is redundant.
401

402
    Returns:
403
        :class:`matplotlib:matplotlib.figure.Figure` :
404
            The matplotlib.Figure of the visualization if the
405
            ``ax_real`` and ``ax_imag`` kwargs are not set
406

407
    Raises:
408
        MissingOptionalLibraryError: Requires matplotlib.
409
        ValueError: When 'color' is not a list of len=2.
410
        VisualizationError: if input is not a valid N-qubit state.
411

412
    Examples:
413
        .. plot::
414
           :alt: Output from the previous code.
415
           :include-source:
416

417
           # You can choose different colors for the real and imaginary parts of the density matrix.
418

419
           from qiskit import QuantumCircuit
420
           from qiskit.quantum_info import DensityMatrix
421
           from qiskit.visualization import plot_state_city
422

423
           qc = QuantumCircuit(2)
424
           qc.h(0)
425
           qc.cx(0, 1)
426

427
           state = DensityMatrix(qc)
428
           plot_state_city(state, color=['midnightblue', 'crimson'], title="New State City")
429

430
        .. plot::
431
           :alt: Output from the previous code.
432
           :include-source:
433

434
           # You can make the bars more transparent to better see the ones that are behind
435
           # if they overlap.
436

437
           import numpy as np
438
           from qiskit.quantum_info import Statevector
439
           from qiskit.visualization import plot_state_city
440
           from qiskit import QuantumCircuit
441

442
           qc = QuantumCircuit(2)
443
           qc.h(0)
444
           qc.cx(0, 1)
445

446

447
           qc = QuantumCircuit(2)
448
           qc.h([0, 1])
449
           qc.cz(0,1)
450
           qc.ry(np.pi/3, 0)
451
           qc.rx(np.pi/5, 1)
452

453
           state = Statevector(qc)
454
           plot_state_city(state, alpha=0.6)
455

456
    """
457
    import matplotlib.colors as mcolors
×
458
    from matplotlib import pyplot as plt
×
459
    from mpl_toolkits.mplot3d.art3d import Poly3DCollection
×
460

461
    rho = DensityMatrix(state)
×
462
    num = rho.num_qubits
×
463
    if num is None:
×
464
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
465

466
    # get the real and imag parts of rho
467
    datareal = np.real(rho.data)
×
468
    dataimag = np.imag(rho.data)
×
469

470
    # get the labels
471
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
×
472
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
×
473

474
    ly, lx = datareal.shape[:2]
×
475
    xpos = np.arange(0, lx, 1)  # Set up a mesh of positions
×
476
    ypos = np.arange(0, ly, 1)
×
477
    xpos, ypos = np.meshgrid(xpos + 0.25, ypos + 0.25)
×
478

479
    xpos = xpos.flatten()
×
480
    ypos = ypos.flatten()
×
481
    zpos = np.zeros(lx * ly)
×
482

483
    dx = 0.5 * np.ones_like(zpos)  # width of bars
×
484
    dy = dx.copy()
×
485
    dzr = datareal.flatten()
×
486
    dzi = dataimag.flatten()
×
487

488
    if color is None:
×
489
        real_color, imag_color = "#648fff", "#648fff"
×
490
    else:
491
        if len(color) != 2:
×
492
            raise ValueError("'color' must be a list of len=2.")
×
493
        real_color = "#648fff" if color[0] is None else color[0]
×
494
        imag_color = "#648fff" if color[1] is None else color[1]
×
495
    if ax_real is None and ax_imag is None:
×
496
        # set default figure size
497
        if figsize is None:
×
498
            figsize = (16, 8)
×
499

500
        fig = plt.figure(figsize=figsize, facecolor="w")
×
501
        ax1 = fig.add_subplot(1, 2, 1, projection="3d", computed_zorder=False)
×
502
        ax2 = fig.add_subplot(1, 2, 2, projection="3d", computed_zorder=False)
×
503

504
    elif ax_real is not None:
×
505
        fig = ax_real.get_figure()
×
506
        ax1 = ax_real
×
507
        ax2 = ax_imag
×
508
    else:
509
        fig = ax_imag.get_figure()
×
510
        ax1 = None
×
511
        ax2 = ax_imag
×
512

513
    fig.tight_layout()
×
514

515
    max_dzr = np.max(dzr)
×
516
    max_dzi = np.max(dzi)
×
517

518
    # Figure scaling variables since fig.tight_layout won't work
519
    fig_width, fig_height = fig.get_size_inches()
×
520
    max_plot_size = min(fig_width / 2.25, fig_height)
×
521
    max_font_size = int(3 * max_plot_size)
×
522
    max_zoom = 10 / (10 + np.sqrt(max_plot_size))
×
523

524
    for ax, dz, col, zlabel in (
×
525
        (ax1, dzr, real_color, "Real"),
526
        (ax2, dzi, imag_color, "Imaginary"),
527
    ):
528

529
        if ax is None:
×
530
            continue
×
531

532
        max_dz = np.max(dz)
×
533
        min_dz = np.min(dz)
×
534

535
        if isinstance(col, str) and col.startswith("#"):
×
536
            col = mcolors.to_rgba_array(col)
×
537

538
        dzn = dz < 0
×
539
        if np.any(dzn):
×
540
            fc = generate_facecolors(
×
541
                xpos[dzn], ypos[dzn], zpos[dzn], dx[dzn], dy[dzn], dz[dzn], col
542
            )
543
            negative_bars = ax.bar3d(
×
544
                xpos[dzn],
545
                ypos[dzn],
546
                zpos[dzn],
547
                dx[dzn],
548
                dy[dzn],
549
                dz[dzn],
550
                alpha=alpha,
551
                zorder=0.625,
552
            )
553
            negative_bars.set_facecolor(fc)
×
554

555
        if min_dz < 0 < max_dz:
×
556
            xlim, ylim = [0, lx], [0, ly]
×
557
            verts = [list(zip(xlim + xlim[::-1], np.repeat(ylim, 2), [0] * 4))]
×
558
            plane = Poly3DCollection(verts, alpha=0.25, facecolor="k", linewidths=1)
×
559
            plane.set_zorder(0.75)
×
560
            ax.add_collection3d(plane)
×
561

562
        dzp = dz >= 0
×
563
        if np.any(dzp):
×
564
            fc = generate_facecolors(
×
565
                xpos[dzp], ypos[dzp], zpos[dzp], dx[dzp], dy[dzp], dz[dzp], col
566
            )
567
            positive_bars = ax.bar3d(
×
568
                xpos[dzp],
569
                ypos[dzp],
570
                zpos[dzp],
571
                dx[dzp],
572
                dy[dzp],
573
                dz[dzp],
574
                alpha=alpha,
575
                zorder=0.875,
576
            )
577
            positive_bars.set_facecolor(fc)
×
578

579
        ax.set_title(f"{zlabel} Amplitude (ρ)", fontsize=max_font_size)
×
580

581
        ax.set_xticks(np.arange(0.5, lx + 0.5, 1))
×
582
        ax.set_yticks(np.arange(0.5, ly + 0.5, 1))
×
583
        if max_dz != min_dz:
×
584
            ax.axes.set_zlim3d(min_dz, max(max_dzr + 1e-9, max_dzi))
×
585
        else:
586
            if min_dz == 0:
×
587
                ax.axes.set_zlim3d(min_dz, max(max_dzr + 1e-9, max_dzi))
×
588
            else:
589
                ax.axes.set_zlim3d(auto=True)
×
590
        ax.get_autoscalez_on()
×
591

592
        ax.xaxis.set_ticklabels(
×
593
            row_names, fontsize=max_font_size, rotation=45, ha="right", va="top"
594
        )
595
        ax.yaxis.set_ticklabels(
×
596
            column_names, fontsize=max_font_size, rotation=-22.5, ha="left", va="center"
597
        )
598

599
        for tick in ax.zaxis.get_major_ticks():
×
600
            tick.label1.set_fontsize(max_font_size)
×
601
            tick.label1.set_horizontalalignment("left")
×
602
            tick.label1.set_verticalalignment("bottom")
×
603

604
        ax.set_box_aspect(aspect=(4, 4, 4), zoom=max_zoom)
×
605
        ax.set_xmargin(0)
×
606
        ax.set_ymargin(0)
×
607

608
    fig.suptitle(title, fontsize=max_font_size * 1.25)
×
609
    fig.subplots_adjust(top=0.9, bottom=0, left=0, right=1, hspace=0, wspace=0)
×
610
    if ax_real is None and ax_imag is None:
×
611
        matplotlib_close_if_inline(fig)
×
612
    if filename is None:
×
613
        return fig
×
614
    else:
615
        return fig.savefig(filename)
×
616

617

618
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
619
def plot_state_paulivec(state, title="", figsize=None, color=None, ax=None, *, filename=None):
1✔
620
    r"""Plot the Pauli-vector representation of a quantum state as bar graph.
621

622
    The Pauli-vector of a density matrix :math:`\rho` is defined by the expectation of each
623
    possible tensor product of single-qubit Pauli operators (including the identity), that is
624

625
    .. math ::
626

627
        \rho = \frac{1}{2^n} \sum_{\sigma \in \{I, X, Y, Z\}^{\otimes n}}
628
               \mathrm{Tr}(\sigma \rho) \sigma.
629

630
    This function plots the coefficients :math:`\mathrm{Tr}(\sigma\rho)` as bar graph.
631

632
    Args:
633
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
634
        title (str): a string that represents the plot title
635
        figsize (tuple): Figure size in inches.
636
        color (list or str): Color of the coefficient value bars.
637
        ax (matplotlib.axes.Axes): An optional Axes object to be used for
638
            the visualization output. If none is specified a new matplotlib
639
            Figure will be created and used. Additionally, if specified there
640
            will be no returned Figure since it is redundant.
641

642
    Returns:
643
         :class:`matplotlib:matplotlib.figure.Figure` :
644
            The matplotlib.Figure of the visualization if the
645
            ``ax`` kwarg is not set
646

647
    Raises:
648
        MissingOptionalLibraryError: Requires matplotlib.
649
        VisualizationError: if input is not a valid N-qubit state.
650

651
    Examples:
652
        .. plot::
653
           :alt: Output from the previous code.
654
           :include-source:
655

656
           # You can set a color for all the bars.
657

658
           from qiskit import QuantumCircuit
659
           from qiskit.quantum_info import Statevector
660
           from qiskit.visualization import plot_state_paulivec
661

662
           qc = QuantumCircuit(2)
663
           qc.h(0)
664
           qc.cx(0, 1)
665

666
           state = Statevector(qc)
667
           plot_state_paulivec(state, color='midnightblue', title="New PauliVec plot")
668

669
        .. plot::
670
           :alt: Output from the previous code.
671
           :include-source:
672

673
           # If you introduce a list with less colors than bars, the color of the bars will
674
           # alternate following the sequence from the list.
675

676
           import numpy as np
677
           from qiskit.quantum_info import DensityMatrix
678
           from qiskit import QuantumCircuit
679
           from qiskit.visualization import plot_state_paulivec
680

681
           qc = QuantumCircuit(2)
682
           qc.h(0)
683
           qc.cx(0, 1)
684

685
           qc = QuantumCircuit(2)
686
           qc.h([0, 1])
687
           qc.cz(0, 1)
688
           qc.ry(np.pi/3, 0)
689
           qc.rx(np.pi/5, 1)
690

691
           matrix = DensityMatrix(qc)
692
           plot_state_paulivec(matrix, color=['crimson', 'midnightblue', 'seagreen'])
693
    """
694
    from matplotlib import pyplot as plt
×
695

696
    labels, values = _paulivec_data(state)
×
697
    numelem = len(values)
×
698

699
    if figsize is None:
×
700
        figsize = (7, 5)
×
701
    if color is None:
×
702
        color = "#648fff"
×
703

704
    ind = np.arange(numelem)  # the x locations for the groups
×
705
    width = 0.5  # the width of the bars
×
706
    if ax is None:
×
707
        return_fig = True
×
708
        fig, ax = plt.subplots(figsize=figsize)
×
709
    else:
710
        return_fig = False
×
711
        fig = ax.get_figure()
×
712
    ax.grid(zorder=0, linewidth=1, linestyle="--")
×
713
    ax.bar(ind, values, width, color=color, zorder=2)
×
714
    ax.axhline(linewidth=1, color="k")
×
715
    # add some text for labels, title, and axes ticks
716
    ax.set_ylabel("Coefficients", fontsize=14)
×
717
    ax.set_xticks(ind)
×
718
    ax.set_yticks([-1, -0.5, 0, 0.5, 1])
×
719
    ax.set_xticklabels(labels, fontsize=14, rotation=70)
×
720
    ax.set_xlabel("Pauli", fontsize=14)
×
721
    ax.set_ylim([-1, 1])
×
722
    ax.set_facecolor("#eeeeee")
×
723
    for tick in ax.xaxis.get_major_ticks() + ax.yaxis.get_major_ticks():
×
724
        tick.label1.set_fontsize(14)
×
725
    ax.set_title(title, fontsize=16)
×
726
    if return_fig:
×
727
        matplotlib_close_if_inline(fig)
×
728
    if filename is None:
×
729
        return fig
×
730
    else:
731
        return fig.savefig(filename)
×
732

733

734
def n_choose_k(n, k):
1✔
735
    """Return the number of combinations for n choose k.
736

737
    Args:
738
        n (int): the total number of options .
739
        k (int): The number of elements.
740

741
    Returns:
742
        int: returns the binomial coefficient
743
    """
744
    if n == 0:
1✔
745
        return 0
1✔
746
    return reduce(lambda x, y: x * y[0] / y[1], zip(range(n - k + 1, n + 1), range(1, k + 1)), 1)
1✔
747

748

749
def lex_index(n, k, lst):
1✔
750
    """Return  the lex index of a combination..
751

752
    Args:
753
        n (int): the total number of options .
754
        k (int): The number of elements.
755
        lst (list): list
756

757
    Returns:
758
        int: returns int index for lex order
759

760
    Raises:
761
        VisualizationError: if length of list is not equal to k
762
    """
763
    if len(lst) != k:
1✔
764
        raise VisualizationError("list should have length k")
×
765
    comb = [n - 1 - x for x in lst]
1✔
766
    dualm = sum(n_choose_k(comb[k - 1 - i], i + 1) for i in range(k))
1✔
767
    return int(dualm)
1✔
768

769

770
def bit_string_index(s):
1✔
771
    """Return the index of a string of 0s and 1s."""
772
    n = len(s)
1✔
773
    k = s.count("1")
1✔
774
    if s.count("0") != n - k:
1✔
775
        raise VisualizationError("s must be a string of 0 and 1")
×
776
    ones = [pos for pos, char in enumerate(s) if char == "1"]
1✔
777
    return lex_index(n, k, ones)
1✔
778

779

780
def phase_to_rgb(complex_number):
1✔
781
    """Map a phase of a complexnumber to a color in (r,g,b).
782

783
    complex_number is phase is first mapped to angle in the range
784
    [0, 2pi] and then to the HSL color wheel
785
    """
786
    angles = (np.angle(complex_number) + (np.pi * 5 / 4)) % (np.pi * 2)
1✔
787
    rgb = colorsys.hls_to_rgb(angles / (np.pi * 2), 0.5, 0.5)
1✔
788
    return rgb
1✔
789

790

791
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
792
@_optionals.HAS_SEABORN.require_in_call
1✔
793
def plot_state_qsphere(
1✔
794
    state,
795
    figsize=None,
796
    ax=None,
797
    show_state_labels=True,
798
    show_state_phases=False,
799
    use_degrees=False,
800
    *,
801
    filename=None,
802
):
803
    """Plot the qsphere representation of a quantum state.
804
    Here, the size of the points is proportional to the probability
805
    of the corresponding term in the state and the color represents
806
    the phase.
807

808
    Args:
809
        state (Statevector or DensityMatrix or ndarray): an N-qubit quantum state.
810
        figsize (tuple): Figure size in inches.
811
        ax (matplotlib.axes.Axes): An optional Axes object to be used for
812
            the visualization output. If none is specified a new matplotlib
813
            Figure will be created and used. Additionally, if specified there
814
            will be no returned Figure since it is redundant.
815
        show_state_labels (bool): An optional boolean indicating whether to
816
            show labels for each basis state.
817
        show_state_phases (bool): An optional boolean indicating whether to
818
            show the phase for each basis state.
819
        use_degrees (bool): An optional boolean indicating whether to use
820
            radians or degrees for the phase values in the plot.
821

822
    Returns:
823
        :class:`matplotlib:matplotlib.figure.Figure` :
824
            A matplotlib figure instance if the ``ax`` kwarg is not set
825

826
    Raises:
827
        MissingOptionalLibraryError: Requires matplotlib.
828
        VisualizationError: Input is not a valid N-qubit state.
829

830
        QiskitError: Input statevector does not have valid dimensions.
831

832
    Examples:
833
        .. plot::
834
           :alt: Output from the previous code.
835
           :include-source:
836

837
           from qiskit import QuantumCircuit
838
           from qiskit.quantum_info import Statevector
839
           from qiskit.visualization import plot_state_qsphere
840

841
           qc = QuantumCircuit(2)
842
           qc.h(0)
843
           qc.cx(0, 1)
844

845
           state = Statevector(qc)
846
           plot_state_qsphere(state)
847

848
        .. plot::
849
           :alt: Output from the previous code.
850
           :include-source:
851

852
           # You can show the phase of each state and use
853
           # degrees instead of radians
854

855
           from qiskit.quantum_info import DensityMatrix
856
           import numpy as np
857
           from qiskit import QuantumCircuit
858
           from qiskit.visualization import plot_state_qsphere
859

860
           qc = QuantumCircuit(2)
861
           qc.h([0, 1])
862
           qc.cz(0,1)
863
           qc.ry(np.pi/3, 0)
864
           qc.rx(np.pi/5, 1)
865
           qc.z(1)
866

867
           matrix = DensityMatrix(qc)
868
           plot_state_qsphere(matrix,
869
                show_state_phases = True, use_degrees = True)
870
    """
871
    from matplotlib import gridspec
1✔
872
    from matplotlib import pyplot as plt
1✔
873
    from matplotlib.patches import Circle
1✔
874
    import seaborn as sns
1✔
875
    from scipy import linalg
1✔
876
    from .bloch import Arrow3D
1✔
877

878
    rho = DensityMatrix(state)
1✔
879
    num = rho.num_qubits
1✔
880
    if num is None:
1✔
881
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
882
    # get the eigenvectors and eigenvalues
883
    eigvals, eigvecs = linalg.eigh(rho.data)
1✔
884

885
    if figsize is None:
1✔
886
        figsize = (7, 7)
1✔
887

888
    if ax is None:
1✔
889
        return_fig = True
1✔
890
        fig = plt.figure(figsize=figsize)
1✔
891
    else:
892
        return_fig = False
×
893
        fig = ax.get_figure()
×
894

895
    gs = gridspec.GridSpec(nrows=3, ncols=3)
1✔
896

897
    ax = fig.add_subplot(gs[0:3, 0:3], projection="3d")
1✔
898
    ax.axes.set_xlim3d(-1.0, 1.0)
1✔
899
    ax.axes.set_ylim3d(-1.0, 1.0)
1✔
900
    ax.axes.set_zlim3d(-1.0, 1.0)
1✔
901
    ax.axes.grid(False)
1✔
902
    ax.view_init(elev=5, azim=275)
1✔
903

904
    # Force aspect ratio
905
    # MPL 3.2 or previous do not have set_box_aspect
906
    if hasattr(ax.axes, "set_box_aspect"):
1✔
907
        ax.axes.set_box_aspect((1, 1, 1))
1✔
908

909
    # start the plotting
910
    # Plot semi-transparent sphere
911
    u = np.linspace(0, 2 * np.pi, 25)
1✔
912
    v = np.linspace(0, np.pi, 25)
1✔
913
    x = np.outer(np.cos(u), np.sin(v))
1✔
914
    y = np.outer(np.sin(u), np.sin(v))
1✔
915
    z = np.outer(np.ones(np.size(u)), np.cos(v))
1✔
916
    ax.plot_surface(
1✔
917
        x, y, z, rstride=1, cstride=1, color=plt.rcParams["grid.color"], alpha=0.2, linewidth=0
918
    )
919

920
    # Get rid of the panes
921
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
922
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
923
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
1✔
924

925
    # Get rid of the spines
926
    ax.xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
927
    ax.yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
928
    ax.zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
1✔
929

930
    # Get rid of the ticks
931
    ax.set_xticks([])
1✔
932
    ax.set_yticks([])
1✔
933
    ax.set_zticks([])
1✔
934

935
    # traversing the eigvals/vecs backward as sorted low->high
936
    for idx in range(eigvals.shape[0] - 1, -1, -1):
1✔
937
        if eigvals[idx] > 0.001:
1✔
938
            # get the max eigenvalue
939
            state = eigvecs[:, idx]
1✔
940
            loc = np.absolute(state).argmax()
1✔
941
            # remove the global phase from max element
942
            angles = (np.angle(state[loc]) + 2 * np.pi) % (2 * np.pi)
1✔
943
            angleset = np.exp(-1j * angles)
1✔
944
            state = angleset * state
1✔
945

946
            d = num
1✔
947
            for i in range(2**num):
1✔
948
                # get x,y,z points
949
                element = bin(i)[2:].zfill(num)
1✔
950
                weight = element.count("1")
1✔
951
                zvalue = -2 * weight / d + 1
1✔
952
                number_of_divisions = n_choose_k(d, weight)
1✔
953
                weight_order = bit_string_index(element)
1✔
954
                angle = (float(weight) / d) * (np.pi * 2) + (
1✔
955
                    weight_order * 2 * (np.pi / number_of_divisions)
956
                )
957

958
                if (weight > d / 2) or (
1✔
959
                    (weight == d / 2) and (weight_order >= number_of_divisions / 2)
960
                ):
961
                    angle = np.pi - angle - (2 * np.pi / number_of_divisions)
1✔
962

963
                xvalue = np.sqrt(1 - zvalue**2) * np.cos(angle)
1✔
964
                yvalue = np.sqrt(1 - zvalue**2) * np.sin(angle)
1✔
965

966
                # get prob and angle - prob will be shade and angle color
967
                prob = np.real(np.dot(state[i], state[i].conj()))
1✔
968
                prob = min(prob, 1)  # See https://github.com/Qiskit/qiskit-terra/issues/4666
1✔
969
                colorstate = phase_to_rgb(state[i])
1✔
970

971
                alfa = 1
1✔
972
                if yvalue >= 0.1:
1✔
973
                    alfa = 1.0 - yvalue
1✔
974

975
                if not np.isclose(prob, 0) and show_state_labels:
1✔
976
                    rprime = 1.3
1✔
977
                    angle_theta = np.arctan2(np.sqrt(1 - zvalue**2), zvalue)
1✔
978
                    xvalue_text = rprime * np.sin(angle_theta) * np.cos(angle)
1✔
979
                    yvalue_text = rprime * np.sin(angle_theta) * np.sin(angle)
1✔
980
                    zvalue_text = rprime * np.cos(angle_theta)
1✔
981
                    element_text = "$\\vert" + element + "\\rangle$"
1✔
982
                    if show_state_phases:
1✔
983
                        element_angle = (np.angle(state[i]) + (np.pi * 4)) % (np.pi * 2)
×
984
                        if use_degrees:
×
985
                            element_text += f"\n${element_angle * 180 / np.pi:.1f}^\\circ$"
×
986
                        else:
987
                            element_angle = pi_check(element_angle, ndigits=3).replace("pi", "\\pi")
×
988
                            element_text += f"\n${element_angle}$"
×
989
                    ax.text(
1✔
990
                        xvalue_text,
991
                        yvalue_text,
992
                        zvalue_text,
993
                        element_text,
994
                        ha="center",
995
                        va="center",
996
                        size=12,
997
                    )
998

999
                ax.plot(
1✔
1000
                    [xvalue],
1001
                    [yvalue],
1002
                    [zvalue],
1003
                    markerfacecolor=colorstate,
1004
                    markeredgecolor=colorstate,
1005
                    marker="o",
1006
                    markersize=np.sqrt(prob) * 30,
1007
                    alpha=alfa,
1008
                )
1009

1010
                a = Arrow3D(
1✔
1011
                    [0, xvalue],
1012
                    [0, yvalue],
1013
                    [0, zvalue],
1014
                    mutation_scale=20,
1015
                    alpha=prob,
1016
                    arrowstyle="-",
1017
                    color=colorstate,
1018
                    lw=2,
1019
                )
1020
                ax.add_artist(a)
1✔
1021

1022
            # add weight lines
1023
            for weight in range(d + 1):
1✔
1024
                theta = np.linspace(-2 * np.pi, 2 * np.pi, 100)
1✔
1025
                z = -2 * weight / d + 1
1✔
1026
                r = np.sqrt(1 - z**2)
1✔
1027
                x = r * np.cos(theta)
1✔
1028
                y = r * np.sin(theta)
1✔
1029
                ax.plot(x, y, z, color=(0.5, 0.5, 0.5), lw=1, ls=":", alpha=0.5)
1✔
1030

1031
            # add center point
1032
            ax.plot(
1✔
1033
                [0],
1034
                [0],
1035
                [0],
1036
                markerfacecolor=(0.5, 0.5, 0.5),
1037
                markeredgecolor=(0.5, 0.5, 0.5),
1038
                marker="o",
1039
                markersize=3,
1040
                alpha=1,
1041
            )
1042
        else:
1043
            break
1✔
1044

1045
    n = 64
1✔
1046
    theta = np.ones(n)
1✔
1047
    colors = sns.hls_palette(n)
1✔
1048

1049
    ax2 = fig.add_subplot(gs[2:, 2:])
1✔
1050
    ax2.pie(theta, colors=colors[5 * n // 8 :] + colors[: 5 * n // 8], radius=0.75)
1✔
1051
    ax2.add_artist(Circle((0, 0), 0.5, color="white", zorder=1))
1✔
1052
    offset = 0.95  # since radius of sphere is one.
1✔
1053

1054
    if use_degrees:
1✔
1055
        labels = ["Phase\n(Deg)", "0", "90", "180   ", "270"]
×
1056
    else:
1057
        labels = ["Phase", "$0$", "$\\pi/2$", "$\\pi$", "$3\\pi/2$"]
1✔
1058

1059
    ax2.text(0, 0, labels[0], horizontalalignment="center", verticalalignment="center", fontsize=14)
1✔
1060
    ax2.text(
1✔
1061
        offset, 0, labels[1], horizontalalignment="center", verticalalignment="center", fontsize=14
1062
    )
1063
    ax2.text(
1✔
1064
        0, offset, labels[2], horizontalalignment="center", verticalalignment="center", fontsize=14
1065
    )
1066
    ax2.text(
1✔
1067
        -offset, 0, labels[3], horizontalalignment="center", verticalalignment="center", fontsize=14
1068
    )
1069
    ax2.text(
1✔
1070
        0, -offset, labels[4], horizontalalignment="center", verticalalignment="center", fontsize=14
1071
    )
1072

1073
    if return_fig:
1✔
1074
        matplotlib_close_if_inline(fig)
1✔
1075
    if filename is None:
1✔
1076
        return fig
1✔
1077
    else:
1078
        return fig.savefig(filename)
×
1079

1080

1081
@_optionals.HAS_MATPLOTLIB.require_in_call
1✔
1082
def generate_facecolors(x, y, z, dx, dy, dz, color):
1✔
1083
    """Generates shaded facecolors for shaded bars.
1084

1085
    This is here to work around a Matplotlib bug
1086
    where alpha does not work in Bar3D.
1087

1088
    Args:
1089
        x (array_like): The x- coordinates of the anchor point of the bars.
1090
        y (array_like): The y- coordinates of the anchor point of the bars.
1091
        z (array_like): The z- coordinates of the anchor point of the bars.
1092
        dx (array_like): Width of bars.
1093
        dy (array_like): Depth of bars.
1094
        dz (array_like): Height of bars.
1095
        color (array_like): sequence of valid color specifications, optional
1096
    Returns:
1097
        list: Shaded colors for bars.
1098
    Raises:
1099
        MissingOptionalLibraryError: If matplotlib is not installed
1100
    """
1101
    import matplotlib.colors as mcolors
×
1102

1103
    cuboid = np.array(
×
1104
        [
1105
            # -z
1106
            (
1107
                (0, 0, 0),
1108
                (0, 1, 0),
1109
                (1, 1, 0),
1110
                (1, 0, 0),
1111
            ),
1112
            # +z
1113
            (
1114
                (0, 0, 1),
1115
                (1, 0, 1),
1116
                (1, 1, 1),
1117
                (0, 1, 1),
1118
            ),
1119
            # -y
1120
            (
1121
                (0, 0, 0),
1122
                (1, 0, 0),
1123
                (1, 0, 1),
1124
                (0, 0, 1),
1125
            ),
1126
            # +y
1127
            (
1128
                (0, 1, 0),
1129
                (0, 1, 1),
1130
                (1, 1, 1),
1131
                (1, 1, 0),
1132
            ),
1133
            # -x
1134
            (
1135
                (0, 0, 0),
1136
                (0, 0, 1),
1137
                (0, 1, 1),
1138
                (0, 1, 0),
1139
            ),
1140
            # +x
1141
            (
1142
                (1, 0, 0),
1143
                (1, 1, 0),
1144
                (1, 1, 1),
1145
                (1, 0, 1),
1146
            ),
1147
        ]
1148
    )
1149

1150
    # indexed by [bar, face, vertex, coord]
1151
    polys = np.empty(x.shape + cuboid.shape)
×
1152
    # handle each coordinate separately
1153
    for i, p, dp in [(0, x, dx), (1, y, dy), (2, z, dz)]:
×
1154
        p = p[..., np.newaxis, np.newaxis]
×
1155
        dp = dp[..., np.newaxis, np.newaxis]
×
1156
        polys[..., i] = p + dp * cuboid[..., i]
×
1157

1158
    # collapse the first two axes
1159
    polys = polys.reshape((-1,) + polys.shape[2:])
×
1160

1161
    facecolors = []
×
1162
    if len(color) == len(x):
×
1163
        # bar colors specified, need to expand to number of faces
1164
        for c in color:
×
1165
            facecolors.extend([c] * 6)
×
1166
    else:
1167
        # a single color specified, or face colors specified explicitly
1168
        facecolors = list(mcolors.to_rgba_array(color))
×
1169
        if len(facecolors) < len(x):
×
1170
            facecolors *= 6 * len(x)
×
1171

1172
    normals = _generate_normals(polys)
×
1173
    return _shade_colors(facecolors, normals)
×
1174

1175

1176
def _generate_normals(polygons):
1✔
1177
    """Takes a list of polygons and return an array of their normals.
1178

1179
    Normals point towards the viewer for a face with its vertices in
1180
    counterclockwise order, following the right hand rule.
1181
    Uses three points equally spaced around the polygon.
1182
    This normal of course might not make sense for polygons with more than
1183
    three points not lying in a plane, but it's a plausible and fast
1184
    approximation.
1185

1186
    Args:
1187
        polygons (list): list of (M_i, 3) array_like, or (..., M, 3) array_like
1188
            A sequence of polygons to compute normals for, which can have
1189
            varying numbers of vertices. If the polygons all have the same
1190
            number of vertices and array is passed, then the operation will
1191
            be vectorized.
1192
    Returns:
1193
        normals: (..., 3) array_like
1194
            A normal vector estimated for the polygon.
1195
    """
1196
    if isinstance(polygons, np.ndarray):
×
1197
        # optimization: polygons all have the same number of points, so can
1198
        # vectorize
1199
        n = polygons.shape[-2]
×
1200
        i1, i2, i3 = 0, n // 3, 2 * n // 3
×
1201
        v1 = polygons[..., i1, :] - polygons[..., i2, :]
×
1202
        v2 = polygons[..., i2, :] - polygons[..., i3, :]
×
1203
    else:
1204
        # The subtraction doesn't vectorize because polygons is jagged.
1205
        v1 = np.empty((len(polygons), 3))
×
1206
        v2 = np.empty((len(polygons), 3))
×
1207
        for poly_i, ps in enumerate(polygons):
×
1208
            n = len(ps)
×
1209
            i1, i2, i3 = 0, n // 3, 2 * n // 3
×
1210
            v1[poly_i, :] = ps[i1, :] - ps[i2, :]
×
1211
            v2[poly_i, :] = ps[i2, :] - ps[i3, :]
×
1212

1213
    return np.cross(v1, v2)
×
1214

1215

1216
def _shade_colors(color, normals, lightsource=None):
1✔
1217
    """
1218
    Shade *color* using normal vectors given by *normals*.
1219
    *color* can also be an array of the same length as *normals*.
1220
    """
1221
    from matplotlib.colors import Normalize, LightSource
×
1222
    import matplotlib.colors as mcolors
×
1223

1224
    if lightsource is None:
×
1225
        # chosen for backwards-compatibility
1226
        lightsource = LightSource(azdeg=225, altdeg=19.4712)
×
1227

1228
    def mod(v):
×
1229
        return np.sqrt(v[0] ** 2 + v[1] ** 2 + v[2] ** 2)
×
1230

1231
    shade = np.array(
×
1232
        [np.dot(n / mod(n), lightsource.direction) if mod(n) else np.nan for n in normals]
1233
    )
1234
    mask = ~np.isnan(shade)
×
1235

1236
    if mask.any():
×
1237
        norm = Normalize(min(shade[mask]), max(shade[mask]))
×
1238
        shade[~mask] = min(shade[mask])
×
1239
        color = mcolors.to_rgba_array(color)
×
1240
        # shape of color should be (M, 4) (where M is number of faces)
1241
        # shape of shade should be (M,)
1242
        # colors should have final shape of (M, 4)
1243
        alpha = color[:, 3]
×
1244
        colors = (0.5 + norm(shade)[:, np.newaxis] * 0.5) * color
×
1245
        colors[:, 3] = alpha
×
1246
    else:
1247
        colors = np.asanyarray(color).copy()
×
1248

1249
    return colors
×
1250

1251

1252
def state_to_latex(
1✔
1253
    state: Union[Statevector, DensityMatrix], dims: bool = None, convention: str = "ket", **args
1254
) -> str:
1255
    """Return a Latex representation of a state. Wrapper function
1256
    for `qiskit.visualization.array_to_latex` for convention 'vector'.
1257
    Adds dims if necessary.
1258
    Intended for use within `state_drawer`.
1259

1260
    Args:
1261
        state: State to be drawn
1262
        dims (bool): Whether to display the state's `dims`
1263
        convention (str): Either 'vector' or 'ket'. For 'ket' plot the state in the ket-notation.
1264
                Otherwise plot as a vector
1265
        **args: Arguments to be passed directly to `array_to_latex` for convention 'ket'
1266

1267
    Returns:
1268
        Latex representation of the state
1269
        MissingOptionalLibrary: If SymPy isn't installed and ``'latex'`` or
1270
            ``'latex_source'`` is selected for ``output``.
1271

1272
    """
1273
    if dims is None:  # show dims if state is not only qubits
1✔
1274
        if set(state.dims()) == {2}:
1✔
1275
            dims = False
1✔
1276
        else:
1277
            dims = True
×
1278

1279
    prefix = ""
1✔
1280
    suffix = ""
1✔
1281
    if dims:
1✔
1282
        prefix = "\\begin{align}\n"
×
1283
        dims_str = state._op_shape.dims_l()
×
1284
        suffix = f"\\\\\n\\text{{dims={dims_str}}}\n\\end{{align}}"
×
1285

1286
    operator_shape = state._op_shape
1✔
1287
    # we only use the ket convetion for qubit statevectors
1288
    # this means the operator shape should hve no input dimensions and all output dimensions equal to 2
1289
    is_qubit_statevector = len(operator_shape.dims_r()) == 0 and set(operator_shape.dims_l()) == {2}
1✔
1290
    if convention == "ket" and is_qubit_statevector:
1✔
1291
        latex_str = _state_to_latex_ket(state._data, **args)
1✔
1292
    else:
1293
        latex_str = array_to_latex(state._data, source=True, **args)
1✔
1294
    return prefix + latex_str + suffix
1✔
1295

1296

1297
def _numbers_to_latex_terms(numbers: List[complex], decimals: int = 10) -> List[str]:
1✔
1298
    """Convert a list of numbers to latex formatted terms
1299

1300
    The first non-zero term is treated differently. For this term a leading + is suppressed.
1301

1302
    Args:
1303
        numbers: List of numbers to format
1304
        decimals: Number of decimal places to round to (default: 10).
1305
    Returns:
1306
        List of formatted terms
1307
    """
1308
    first_term = True
1✔
1309
    terms = []
1✔
1310
    for number in numbers:
1✔
1311
        term = _num_to_latex(number, decimals=decimals, first_term=first_term, coefficient=True)
1✔
1312
        terms.append(term)
1✔
1313
        first_term = False
1✔
1314
    return terms
1✔
1315

1316

1317
def _state_to_latex_ket(
1✔
1318
    data: List[complex], max_size: int = 12, prefix: str = "", decimals: int = 10
1319
) -> str:
1320
    """Convert state vector to latex representation
1321

1322
    Args:
1323
        data: State vector
1324
        max_size: Maximum number of non-zero terms in the expression. If the number of
1325
                 non-zero terms is larger than the max_size, then the representation is truncated.
1326
        prefix: Latex string to be prepended to the latex, intended for labels.
1327
        decimals: Number of decimal places to round to (default: 10).
1328

1329
    Returns:
1330
        String with LaTeX representation of the state vector
1331
    """
1332
    num = int(math.log2(len(data)))
1✔
1333

1334
    def ket_name(i):
1✔
1335
        return bin(i)[2:].zfill(num)
1✔
1336

1337
    data = np.around(data, decimals)
1✔
1338
    nonzero_indices = np.where(data != 0)[0].tolist()
1✔
1339
    if len(nonzero_indices) > max_size:
1✔
1340
        nonzero_indices = (
1✔
1341
            nonzero_indices[: max_size // 2] + [0] + nonzero_indices[-max_size // 2 + 1 :]
1342
        )
1343
        latex_terms = _numbers_to_latex_terms(data[nonzero_indices], decimals)
1✔
1344
        nonzero_indices[max_size // 2] = None
1✔
1345
    else:
1346
        latex_terms = _numbers_to_latex_terms(data[nonzero_indices], decimals)
1✔
1347

1348
    latex_str = ""
1✔
1349
    for idx, ket_idx in enumerate(nonzero_indices):
1✔
1350
        if ket_idx is None:
1✔
1351
            latex_str += r" + \ldots "
1✔
1352
        else:
1353
            term = latex_terms[idx]
1✔
1354
            ket = ket_name(ket_idx)
1✔
1355
            latex_str += f"{term} |{ket}\\rangle"
1✔
1356
    return prefix + latex_str
1✔
1357

1358

1359
class TextMatrix:
1✔
1360
    """Text representation of an array, with `__str__` method so it
1361
    displays nicely in Jupyter notebooks"""
1362

1363
    def __init__(self, state, max_size=8, dims=None, prefix="", suffix=""):
1✔
1364
        self.state = state
1✔
1365
        self.max_size = max_size
1✔
1366
        if dims is None:  # show dims if state is not only qubits
1✔
1367
            if (isinstance(state, (Statevector, DensityMatrix)) and set(state.dims()) == {2}) or (
1✔
1368
                isinstance(state, Operator)
1369
                and len(state.input_dims()) == len(state.output_dims())
1370
                and set(state.input_dims()) == set(state.output_dims()) == {2}
1371
            ):
1372
                dims = False
1✔
1373
            else:
1374
                dims = True
×
1375
        self.dims = dims
1✔
1376
        self.prefix = prefix
1✔
1377
        self.suffix = suffix
1✔
1378
        if isinstance(max_size, int):
1✔
1379
            self.max_size = max_size
1✔
1380
        elif isinstance(state, DensityMatrix):
×
1381
            # density matrices are square, so threshold for
1382
            # summarization is shortest side squared
1383
            self.max_size = min(max_size) ** 2
×
1384
        else:
1385
            self.max_size = max_size[0]
×
1386

1387
    def __str__(self):
1✔
1388
        threshold = self.max_size
×
1389
        data = np.array2string(
×
1390
            self.state._data, prefix=self.prefix, threshold=threshold, separator=","
1391
        )
1392
        dimstr = ""
×
1393
        if self.dims:
×
1394
            data += ",\n"
×
1395
            dimstr += " " * len(self.prefix)
×
1396
            if isinstance(self.state, (Statevector, DensityMatrix)):
×
1397
                dimstr += f"dims={self.state._op_shape.dims_l()}"
×
1398
            else:
1399
                dimstr += f"input_dims={self.state.input_dims()}, "
×
1400
                dimstr += f"output_dims={self.state.output_dims()}"
×
1401

1402
        return self.prefix + data + dimstr + self.suffix
×
1403

1404
    def __repr__(self):
1405
        return self.__str__()
1406

1407

1408
def state_drawer(state, output=None, **drawer_args):
1✔
1409
    """Returns a visualization of the state.
1410

1411
    **repr**: ASCII TextMatrix of the state's ``_repr_``.
1412

1413
    **text**: ASCII TextMatrix that can be printed in the console.
1414

1415
    **latex**: An IPython Latex object for displaying in Jupyter Notebooks.
1416

1417
    **latex_source**: Raw, uncompiled ASCII source to generate array using LaTeX.
1418

1419
    **qsphere**: Matplotlib figure, rendering of statevector using `plot_state_qsphere()`.
1420

1421
    **hinton**: Matplotlib figure, rendering of statevector using `plot_state_hinton()`.
1422

1423
    **bloch**: Matplotlib figure, rendering of statevector using `plot_bloch_multivector()`.
1424

1425
    **city**: Matplotlib figure, rendering of statevector using `plot_state_city()`.
1426

1427
    **paulivec**: Matplotlib figure, rendering of statevector using `plot_state_paulivec()`.
1428

1429
    Args:
1430
        output (str): Select the output method to use for drawing the
1431
            circuit. Valid choices are ``text``, ``latex``, ``latex_source``,
1432
            ``qsphere``, ``hinton``, ``bloch``, ``city`` or ``paulivec``.
1433
            Default is `'text`'.
1434
        drawer_args: Arguments to be passed to the relevant drawer. For
1435
            'latex' and 'latex_source' see ``array_to_latex``
1436

1437
    Returns:
1438
        :class:`matplotlib.figure` or :class:`str` or
1439
        :class:`TextMatrix` or :class:`IPython.display.Latex`:
1440
        Drawing of the state.
1441

1442
    Raises:
1443
        MissingOptionalLibraryError: when `output` is `latex` and IPython is not installed.
1444
            or if SymPy isn't installed and ``'latex'`` or ``'latex_source'`` is selected for
1445
            ``output``.
1446

1447
        ValueError: when `output` is not a valid selection.
1448
    """
1449
    config = user_config.get_config()
1✔
1450
    # Get default 'output' from config file else use 'repr'
1451
    default_output = "repr"
1✔
1452
    if output is None:
1✔
1453
        if config:
×
1454
            default_output = config.get("state_drawer", "repr")
×
1455
        output = default_output
×
1456
    output = output.lower()
1✔
1457

1458
    # Choose drawing backend:
1459
    drawers = {
1✔
1460
        "text": TextMatrix,
1461
        "latex_source": state_to_latex,
1462
        "qsphere": plot_state_qsphere,
1463
        "hinton": plot_state_hinton,
1464
        "bloch": plot_bloch_multivector,
1465
        "city": plot_state_city,
1466
        "paulivec": plot_state_paulivec,
1467
    }
1468
    if output == "latex":
1✔
1469
        _optionals.HAS_IPYTHON.require_now("state_drawer")
1✔
1470
        from IPython.display import Latex
1✔
1471

1472
        draw_func = drawers["latex_source"]
1✔
1473
        return Latex(f"$${draw_func(state, **drawer_args)}$$")
1✔
1474

1475
    if output == "repr":
1✔
1476
        return state.__repr__()
1✔
1477

1478
    try:
1✔
1479
        draw_func = drawers[output]
1✔
1480
        return draw_func(state, **drawer_args)
1✔
1481
    except KeyError as err:
×
1482
        raise ValueError(
×
1483
            f"""'{output}' is not a valid option for drawing {type(state).__name__}
1484
             objects. Please choose from:
1485
            'text', 'latex', 'latex_source', 'qsphere', 'hinton',
1486
            'bloch', 'city' or 'paulivec'."""
1487
        ) from err
1488

1489

1490
def _bloch_multivector_data(state):
1✔
1491
    """Return list of Bloch vectors for each qubit
1492

1493
    Args:
1494
        state (DensityMatrix or Statevector): an N-qubit state.
1495

1496
    Returns:
1497
        list: list of Bloch vectors (x, y, z) for each qubit.
1498

1499
    Raises:
1500
        VisualizationError: if input is not an N-qubit state.
1501
    """
1502
    rho = DensityMatrix(state)
1✔
1503
    num = rho.num_qubits
1✔
1504
    if num is None:
1✔
1505
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
1506
    pauli_singles = PauliList(["X", "Y", "Z"])
1✔
1507
    bloch_data = []
1✔
1508
    for i in range(num):
1✔
1509
        if num > 1:
1✔
1510
            paulis = PauliList.from_symplectic(
1✔
1511
                np.zeros((3, (num - 1)), dtype=bool), np.zeros((3, (num - 1)), dtype=bool)
1512
            ).insert(i, pauli_singles, qubit=True)
1513
        else:
1514
            paulis = pauli_singles
×
1515
        bloch_state = [np.real(np.trace(np.dot(mat, rho.data))) for mat in paulis.matrix_iter()]
1✔
1516
        bloch_data.append(bloch_state)
1✔
1517
    return bloch_data
1✔
1518

1519

1520
def _paulivec_data(state):
1✔
1521
    """Return paulivec data for plotting.
1522

1523
    Args:
1524
        state (DensityMatrix or Statevector): an N-qubit state.
1525

1526
    Returns:
1527
        tuple: (labels, values) for Pauli vector.
1528

1529
    Raises:
1530
        VisualizationError: if input is not an N-qubit state.
1531
    """
1532
    rho = SparsePauliOp.from_operator(DensityMatrix(state))
1✔
1533
    if rho.num_qubits is None:
1✔
1534
        raise VisualizationError("Input is not a multi-qubit quantum state.")
×
1535
    return rho.paulis.to_labels(), np.real(rho.coeffs * 2**rho.num_qubits)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc