• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyiron / structuretoolkit / 13067053796

31 Jan 2025 05:33AM UTC coverage: 82.903%. Remained the same
13067053796

Pull #313

github

web-flow
Merge 15c0437c3 into 07bdbb42e
Pull Request #313: Bump plotly from 5.24.1 to 6.0.0

1542 of 1860 relevant lines covered (82.9%)

0.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

18.33
/structuretoolkit/visualize.py
1
# coding: utf-8
2
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
3
# Distributed under the terms of "New BSD License", see the LICENSE file.
4

5
import warnings
1✔
6
from typing import Any, Optional
1✔
7

8
import numpy as np
1✔
9
from ase.atoms import Atoms
1✔
10
from scipy.interpolate import interp1d
1✔
11

12
from structuretoolkit.common.helper import get_cell
1✔
13

14
__author__ = "Joerg Neugebauer, Sudarsan Surendralal"
1✔
15
__copyright__ = (
1✔
16
    "Copyright 2021, Max-Planck-Institut für Eisenforschung GmbH - "
17
    "Computational Materials Design (CM) Department"
18
)
19
__version__ = "1.0"
1✔
20
__maintainer__ = "Sudarsan Surendralal"
1✔
21
__email__ = "surendralal@mpie.de"
1✔
22
__status__ = "production"
1✔
23
__date__ = "Sep 1, 2017"
1✔
24

25

26
def plot3d(
1✔
27
    structure: Atoms,
28
    mode: str = "NGLview",
29
    show_cell: bool = True,
30
    show_axes: bool = True,
31
    camera: str = "orthographic",
32
    spacefill: bool = True,
33
    particle_size: float = 1.0,
34
    select_atoms: Optional[np.ndarray] = None,
35
    background: str = "white",
36
    color_scheme: Optional[str] = None,
37
    colors: Optional[np.ndarray] = None,
38
    scalar_field: Optional[np.ndarray] = None,
39
    scalar_start: Optional[float] = None,
40
    scalar_end: Optional[float] = None,
41
    scalar_cmap: Optional[Any] = None,
42
    vector_field: Optional[np.ndarray] = None,
43
    vector_color: Optional[np.ndarray] = None,
44
    magnetic_moments: bool = False,
45
    view_plane: np.ndarray = np.array([0, 0, 1]),
46
    distance_from_camera: float = 1.0,
47
    opacity: float = 1.0,
48
    height: Optional[float] = None,
49
):
50
    """
51
    Plot3d relies on NGLView or plotly to visualize atomic structures. Here, we construct a string in the "protein database"
52

53
    The final widget is returned. If it is assigned to a variable, the visualization is suppressed until that
54
    variable is evaluated, and in the meantime more NGL operations can be applied to it to modify the visualization.
55

56
    Args:
57
        mode (str): `NGLView`, `plotly` or `ase`
58
        show_cell (bool): Whether or not to show the frame. (Default is True.)
59
        show_axes (bool): Whether or not to show xyz axes. (Default is True.)
60
        camera (str): 'perspective' or 'orthographic'. (Default is 'perspective'.)
61
        spacefill (bool): Whether to use a space-filling or ball-and-stick representation. (Default is True, use
62
            space-filling atoms.)
63
        particle_size (float): Size of the particles. (Default is 1.)
64
        select_atoms (numpy.ndarray): Indices of atoms to show, either as integers or a boolean array mask.
65
            (Default is None, show all atoms.)
66
        background (str): Background color. (Default is 'white'.)
67
        color_scheme (str): NGLView color scheme to use. (Default is None, color by element.)
68
        colors (numpy.ndarray): A per-atom array of HTML color names or hex color codes to use for atomic colors.
69
            (Default is None, use coloring scheme.)
70
        scalar_field (numpy.ndarray): Color each atom according to the array value (Default is None, use coloring
71
            scheme.)
72
        scalar_start (float): The scalar value to be mapped onto the low end of the color map (lower values are
73
            clipped). (Default is None, use the minimum value in `scalar_field`.)
74
        scalar_end (float): The scalar value to be mapped onto the high end of the color map (higher values are
75
            clipped). (Default is None, use the maximum value in `scalar_field`.)
76
        scalar_cmap (matplotlib.cm): The colormap to use. (Default is None, giving a blue-red divergent map.)
77
        vector_field (numpy.ndarray): Add vectors (3 values) originating at each atom. (Default is None, no
78
            vectors.)
79
        vector_color (numpy.ndarray): Colors for the vectors (only available with vector_field). (Default is None,
80
            vectors are colored by their direction.)
81
        magnetic_moments (bool): Plot magnetic moments as 'scalar_field' or 'vector_field'.
82
        view_plane (numpy.ndarray): A Nx3-array (N = 1,2,3); the first 3d-component of the array specifies
83
            which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes), the
84
            second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the third
85
            component (if specified) is the vertical component, which is ignored and calculated internally. The
86
            orthonormality of the orientation is internally ensured, and therefore is not required in the function
87
            call. (Default is np.array([0, 0, 1]), which is view normal to the x-y plane.)
88
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
89
            (Default is 14, which also seems to be the NGLView default value.)
90
        height (int/float/None): height of the plot area in pixel (only
91
            available in plotly) Default: 600
92

93
        Possible NGLView color schemes:
94
          " ", "picking", "random", "uniform", "atomindex", "residueindex",
95
          "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
96
          "hydrophobicity", "value", "volume", "occupancy"
97

98
    Returns:
99
        (nglview.NGLWidget): The NGLView widget itself, which can be operated on further or viewed as-is.
100

101
    Warnings:
102
        * Many features only work with space-filling atoms (e.g. coloring by a scalar field).
103
        * The colour interpretation of some hex codes is weird, e.g. 'green'.
104
    """
105
    if mode == "NGLview":
×
106
        if height is not None:
×
107
            warnings.warn("`height` is not implemented in NGLview", SyntaxWarning)
×
108
        return _plot3d(
×
109
            structure=structure,
110
            show_cell=show_cell,
111
            show_axes=show_axes,
112
            camera=camera,
113
            spacefill=spacefill,
114
            particle_size=particle_size,
115
            select_atoms=select_atoms,
116
            background=background,
117
            color_scheme=color_scheme,
118
            colors=colors,
119
            scalar_field=scalar_field,
120
            scalar_start=scalar_start,
121
            scalar_end=scalar_end,
122
            scalar_cmap=scalar_cmap,
123
            vector_field=vector_field,
124
            vector_color=vector_color,
125
            magnetic_moments=magnetic_moments,
126
            view_plane=view_plane,
127
            distance_from_camera=distance_from_camera,
128
        )
129
    elif mode == "plotly":
×
130
        return _plot3d_plotly(
×
131
            structure=structure,
132
            show_cell=show_cell,
133
            camera=camera,
134
            particle_size=particle_size,
135
            select_atoms=select_atoms,
136
            scalar_field=scalar_field,
137
            view_plane=view_plane,
138
            distance_from_camera=distance_from_camera,
139
            opacity=opacity,
140
            height=height,
141
        )
142
    elif mode == "ase":
×
143
        if height is not None:
×
144
            warnings.warn("`height` is not implemented in ase", SyntaxWarning)
×
145
        return _plot3d_ase(
×
146
            structure=structure,
147
            show_cell=show_cell,
148
            show_axes=show_axes,
149
            camera=camera,
150
            spacefill=spacefill,
151
            particle_size=particle_size,
152
            background=background,
153
            color_scheme=color_scheme,
154
        )
155
    else:
156
        raise ValueError("plot method not recognized")
×
157

158

159
def _get_box_skeleton(cell: np.ndarray) -> np.ndarray:
1✔
160
    """
161
    Generate the skeleton of a box defined by the unit cell.
162

163
    Args:
164
        cell (np.ndarray): The unit cell of the structure.
165

166
    Returns:
167
        np.ndarray: The skeleton of the box defined by the unit cell.
168
    """
169
    lines_dz = np.stack(np.meshgrid(*3 * [[0, 1]], indexing="ij"), axis=-1)
1✔
170
    # eight corners of a unit cube, paired as four z-axis lines
171

172
    all_lines = np.reshape(
1✔
173
        [np.roll(lines_dz, i, axis=-1) for i in range(3)], (-1, 2, 3)
174
    )
175
    # All 12 two-point lines on the unit square
176
    return all_lines @ cell
1✔
177

178

179
def _draw_box_plotly(fig: Any, structure: Atoms, px: Any, go: Any) -> Any:
1✔
180
    """
181
    Draw the box skeleton of the atomic structure using Plotly.
182

183
    Args:
184
        fig (go.Figure): The Plotly figure object.
185
        structure (Atoms): The atomic structure.
186
        px (Any): The Plotly express module.
187
        go (Any): The Plotly graph objects module.
188

189
    Returns:
190
        go.Figure: The updated Plotly figure object.
191
    """
192
    cell = get_cell(structure)
×
193
    data = fig.data
×
194
    for lines in _get_box_skeleton(cell):
×
195
        fig = px.line_3d(**{xx: vv for xx, vv in zip(["x", "y", "z"], lines.T)})
×
196
        fig.update_traces(line_color="#000000")
×
197
        data = fig.data + data
×
198
    return go.Figure(data=data)
×
199

200

201
def _plot3d_plotly(
1✔
202
    structure: Atoms,
203
    show_cell: bool = True,
204
    scalar_field: Optional[np.ndarray] = None,
205
    select_atoms: Optional[np.ndarray] = None,
206
    particle_size: float = 1.0,
207
    camera: str = "orthographic",
208
    view_plane: np.ndarray = np.array([1, 1, 1]),
209
    distance_from_camera: float = 1.0,
210
    opacity: float = 1.0,
211
    height: Optional[float] = None,
212
):
213
    """
214
    Make a 3D plot of the atomic structure.
215

216
    Args:
217
        camera (str): 'perspective' or 'orthographic'. (Default is 'perspective'.)
218
        particle_size (float): Size of the particles. (Default is 1.)
219
        scalar_field (numpy.ndarray): Color each atom according to the array value (Default is None, use coloring
220
            scheme.)
221
        view_plane (numpy.ndarray): A Nx3-array (N = 1,2,3); the first 3d-component of the array specifies
222
            which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes), the
223
            second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the third
224
            component (if specified) is the vertical component, which is ignored and calculated internally. The
225
            orthonormality of the orientation is internally ensured, and therefore is not required in the function
226
            call. (Default is np.array([0, 0, 1]), which is view normal to the x-y plane.)
227
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
228
            (Default is 14, which also seems to be the NGLView default value.)
229
        opacity (float): opacity
230
        height (int/float/None): height of the plot area in pixel. Default: 600
231

232
    Returns:
233
        (plotly.express): The NGLView widget itself, which can be operated on further or viewed as-is.
234

235
    """
236
    try:
×
237
        import plotly.express as px
×
238
        import plotly.graph_objects as go
×
239
    except ModuleNotFoundError:
×
240
        raise ModuleNotFoundError("plotly not installed - use plot3d instead")
×
241
    if select_atoms is None:
×
242
        select_atoms = np.arange(len(structure))
×
243
    elements = structure.get_chemical_symbols()
×
244
    atomic_numbers = structure.get_atomic_numbers()
×
245
    if scalar_field is None:
×
246
        scalar_field = elements
×
247
    fig = px.scatter_3d(
×
248
        x=structure.positions[select_atoms, 0],
249
        y=structure.positions[select_atoms, 1],
250
        z=structure.positions[select_atoms, 2],
251
        color=scalar_field,
252
        opacity=opacity,
253
        size=_atomic_number_to_radius(
254
            atomic_numbers,
255
            scale=particle_size / (0.1 * structure.get_volume() ** (1 / 3)),
256
        ),
257
    )
258
    if show_cell:
×
259
        fig = _draw_box_plotly(fig, structure, px, go)
×
260
    fig.layout.scene.camera.projection.type = camera
×
261
    rot = _get_orientation(view_plane).T
×
262
    rot[0, :] *= distance_from_camera * 1.25
×
263
    angle = dict(
×
264
        up=dict(x=rot[2, 0], y=rot[2, 1], z=rot[2, 2]),
265
        eye=dict(x=rot[0, 0], y=rot[0, 1], z=rot[0, 2]),
266
    )
267
    fig.update_layout(scene_camera=angle)
×
268
    fig.update_traces(marker=dict(line=dict(width=0.1, color="DarkSlateGrey")))
×
269
    fig.update_scenes(aspectmode="data")
×
270
    if height is None:
×
271
        height = 600
×
272
    fig.update_layout(autosize=True, height=height)
×
273
    fig.update_layout(legend={"itemsizing": "constant"})
×
274
    return fig
×
275

276

277
def _plot3d(
1✔
278
    structure: Atoms,
279
    show_cell: bool = True,
280
    show_axes: bool = True,
281
    camera: str = "orthographic",
282
    spacefill: bool = True,
283
    particle_size: float = 1.0,
284
    select_atoms: Optional[np.ndarray] = None,
285
    background: str = "white",
286
    color_scheme: Optional[str] = None,
287
    colors: Optional[np.ndarray] = None,
288
    scalar_field: Optional[np.ndarray] = None,
289
    scalar_start: Optional[float] = None,
290
    scalar_end: Optional[float] = None,
291
    scalar_cmap: Optional[Any] = None,
292
    vector_field: Optional[np.ndarray] = None,
293
    vector_color: Optional[np.ndarray] = None,
294
    magnetic_moments: bool = False,
295
    view_plane: np.ndarray = np.array([0, 0, 1]),
296
    distance_from_camera: float = 1.0,
297
):
298
    """
299
    Plot3d relies on NGLView to visualize atomic structures. Here, we construct a string in the "protein database"
300
    ("pdb") format, then turn it into an NGLView "structure". PDB is a white-space sensitive format, so the
301
    string snippets are carefully formatted.
302

303
    The final widget is returned. If it is assigned to a variable, the visualization is suppressed until that
304
    variable is evaluated, and in the meantime more NGL operations can be applied to it to modify the visualization.
305

306
    Args:
307
        show_cell (bool): Whether or not to show the frame. (Default is True.)
308
        show_axes (bool): Whether or not to show xyz axes. (Default is True.)
309
        camera (str): 'perspective' or 'orthographic'. (Default is 'perspective'.)
310
        spacefill (bool): Whether to use a space-filling or ball-and-stick representation. (Default is True, use
311
            space-filling atoms.)
312
        particle_size (float): Size of the particles. (Default is 1.)
313
        select_atoms (numpy.ndarray): Indices of atoms to show, either as integers or a boolean array mask.
314
            (Default is None, show all atoms.)
315
        background (str): Background color. (Default is 'white'.)
316
        color_scheme (str): NGLView color scheme to use. (Default is None, color by element.)
317
        colors (numpy.ndarray): A per-atom array of HTML color names or hex color codes to use for atomic colors.
318
            (Default is None, use coloring scheme.)
319
        scalar_field (numpy.ndarray): Color each atom according to the array value (Default is None, use coloring
320
            scheme.)
321
        scalar_start (float): The scalar value to be mapped onto the low end of the color map (lower values are
322
            clipped). (Default is None, use the minimum value in `scalar_field`.)
323
        scalar_end (float): The scalar value to be mapped onto the high end of the color map (higher values are
324
            clipped). (Default is None, use the maximum value in `scalar_field`.)
325
        scalar_cmap (matplotlib.cm): The colormap to use. (Default is None, giving a blue-red divergent map.)
326
        vector_field (numpy.ndarray): Add vectors (3 values) originating at each atom. (Default is None, no
327
            vectors.)
328
        vector_color (numpy.ndarray): Colors for the vectors (only available with vector_field). (Default is None,
329
            vectors are colored by their direction.)
330
        magnetic_moments (bool): Plot magnetic moments as 'scalar_field' or 'vector_field'.
331
        view_plane (numpy.ndarray): A Nx3-array (N = 1,2,3); the first 3d-component of the array specifies
332
            which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes), the
333
            second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the third
334
            component (if specified) is the vertical component, which is ignored and calculated internally. The
335
            orthonormality of the orientation is internally ensured, and therefore is not required in the function
336
            call. (Default is np.array([0, 0, 1]), which is view normal to the x-y plane.)
337
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
338
            (Default is 14, which also seems to be the NGLView default value.)
339

340
        Possible NGLView color schemes:
341
          " ", "picking", "random", "uniform", "atomindex", "residueindex",
342
          "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
343
          "hydrophobicity", "value", "volume", "occupancy"
344

345
    Returns:
346
        (nglview.NGLWidget): The NGLView widget itself, which can be operated on further or viewed as-is.
347

348
    Warnings:
349
        * Many features only work with space-filling atoms (e.g. coloring by a scalar field).
350
        * The colour interpretation of some hex codes is weird, e.g. 'green'.
351
    """
352
    try:  # If the graphical packages are not available, the GUI will not work.
×
353
        import nglview
×
354
    except ImportError:
×
355
        raise ImportError(
×
356
            "The package nglview needs to be installed for the plot3d() function!"
357
        )
358

359
    if (
×
360
        magnetic_moments is True
361
        and np.sum(np.abs(structure.get_initial_magnetic_moments())) > 0
362
    ):
363
        if len(structure.get_initial_magnetic_moments().shape) == 1:
×
364
            scalar_field = structure.get_initial_magnetic_moments()
×
365
        else:
366
            vector_field = structure.get_initial_magnetic_moments()
×
367

368
    elements = structure.get_chemical_symbols()
×
369
    atomic_numbers = structure.get_atomic_numbers()
×
370
    positions = structure.positions
×
371

372
    # If `select_atoms` was given, visualize only a subset of the `parent_basis`
373
    if select_atoms is not None:
×
374
        select_atoms = np.array(select_atoms, dtype=int)
×
375
        elements = np.array(elements)[select_atoms]
×
376
        atomic_numbers = atomic_numbers[select_atoms]
×
377
        positions = positions[select_atoms]
×
378
        if colors is not None:
×
379
            colors = np.array(colors)
×
380
            colors = colors[select_atoms]
×
381
        if scalar_field is not None:
×
382
            scalar_field = np.array(scalar_field)
×
383
            scalar_field = scalar_field[select_atoms]
×
384
        if vector_field is not None:
×
385
            vector_field = np.array(vector_field)
×
386
            vector_field = vector_field[select_atoms]
×
387
        if vector_color is not None:
×
388
            vector_color = np.array(vector_color)
×
389
            vector_color = vector_color[select_atoms]
×
390

391
    # Write the nglview protein-database-formatted string
392
    struct = nglview.TextStructure(
×
393
        _ngl_write_structure(elements, positions, structure.cell)
394
    )
395

396
    # Parse the string into the displayable widget
397
    view = nglview.NGLWidget(struct)
×
398

399
    if spacefill:
×
400
        # Color by scheme
401
        if color_scheme is not None:
×
402
            if colors is not None:
×
403
                warnings.warn("`color_scheme` is overriding `colors`")
×
404
            if scalar_field is not None:
×
405
                warnings.warn("`color_scheme` is overriding `scalar_field`")
×
406
            view = _add_colorscheme_spacefill(
×
407
                view, elements, atomic_numbers, particle_size, color_scheme
408
            )
409
        # Color by per-atom colors
410
        elif colors is not None:
×
411
            if scalar_field is not None:
×
412
                warnings.warn("`colors` is overriding `scalar_field`")
×
413
            view = _add_custom_color_spacefill(
×
414
                view, atomic_numbers, particle_size, colors
415
            )
416
        # Color by per-atom scalars
417
        elif scalar_field is not None:  # Color by per-atom scalars
×
418
            colors = _scalars_to_hex_colors(
×
419
                scalar_field, scalar_start, scalar_end, scalar_cmap
420
            )
421
            view = _add_custom_color_spacefill(
×
422
                view, atomic_numbers, particle_size, colors
423
            )
424
        # Color by element
425
        else:
426
            view = _add_colorscheme_spacefill(
×
427
                view, elements, atomic_numbers, particle_size
428
            )
429
        view.remove_ball_and_stick()
×
430
    else:
431
        view.add_ball_and_stick()
×
432

433
    if show_cell:
×
434
        if structure.cell is not None:
×
435
            if all(np.max(structure.cell, axis=0) > 1e-2):
×
436
                view.add_unitcell()
×
437

438
    if vector_color is None and vector_field is not None:
×
439
        vector_color = (
×
440
            0.5
441
            * np.array(vector_field)
442
            / np.linalg.norm(vector_field, axis=-1)[:, np.newaxis]
443
            + 0.5
444
        )
445
    elif (
×
446
        vector_field is not None and vector_field is not None
447
    ):  # WARNING: There must be a bug here...
448
        try:
×
449
            if vector_color.shape != np.ones((len(structure), 3)).shape:
×
450
                vector_color = np.outer(
×
451
                    np.ones(len(structure)),
452
                    vector_color / np.linalg.norm(vector_color),
453
                )
454
        except AttributeError:
×
455
            vector_color = np.ones((len(structure), 3)) * vector_color
×
456

457
    if vector_field is not None:
×
458
        for arr, pos, col in zip(vector_field, positions, vector_color):
×
459
            view.shape.add_arrow(list(pos), list(pos + arr), list(col), 0.2)
×
460

461
    if show_axes:  # Add axes
×
462
        axes_origin = -np.ones(3)
×
463
        arrow_radius = 0.1
×
464
        text_size = 1
×
465
        text_color = [0, 0, 0]
×
466
        arrow_names = ["x", "y", "z"]
×
467

468
        for n in [0, 1, 2]:
×
469
            start = list(axes_origin)
×
470
            shift = np.zeros(3)
×
471
            shift[n] = 1
×
472
            end = list(start + shift)
×
473
            color = list(shift)
×
474
            # We cast as list to avoid JSON warnings
475
            view.shape.add_arrow(start, end, color, arrow_radius)
×
476
            view.shape.add_text(end, text_color, text_size, arrow_names[n])
×
477

478
    if camera != "perspective" and camera != "orthographic":
×
479
        warnings.warn(
×
480
            "Only perspective or orthographic is (likely to be) permitted for camera"
481
        )
482

483
    view.camera = camera
×
484
    view.background = background
×
485

486
    orientation = _get_flattened_orientation(
×
487
        view_plane=view_plane, distance_from_camera=distance_from_camera * 14
488
    )
489
    view.control.orient(orientation)
×
490

491
    return view
×
492

493

494
def _plot3d_ase(
1✔
495
    structure: Atoms,
496
    spacefill: bool = True,
497
    show_cell: bool = True,
498
    camera: str = "perspective",
499
    particle_size: float = 0.5,
500
    background: str = "white",
501
    color_scheme: str = "element",
502
    show_axes: bool = True,
503
):
504
    """
505
    Possible color schemes:
506
      " ", "picking", "random", "uniform", "atomindex", "residueindex",
507
      "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
508
      "hydrophobicity", "value", "volume", "occupancy"
509
    Returns:
510
    """
511
    try:  # If the graphical packages are not available, the GUI will not work.
×
512
        import nglview
×
513
    except ImportError:
×
514
        raise ImportError(
×
515
            "The package nglview needs to be installed for the plot3d() function!"
516
        )
517
    # Always visualize the parent basis
518
    view = nglview.show_ase(structure)
×
519
    if spacefill:
×
520
        view.add_spacefill(
×
521
            radius_type="vdw", color_scheme=color_scheme, radius=particle_size
522
        )
523
        # view.add_spacefill(radius=1.0)
524
        view.remove_ball_and_stick()
×
525
    else:
526
        view.add_ball_and_stick()
×
527
    if show_cell:
×
528
        if structure.cell is not None:
×
529
            if all(np.max(structure.cell, axis=0) > 1e-2):
×
530
                view.add_unitcell()
×
531
    if show_axes:
×
532
        view.shape.add_arrow([-2, -2, -2], [2, -2, -2], [1, 0, 0], 0.5)
×
533
        view.shape.add_arrow([-2, -2, -2], [-2, 2, -2], [0, 1, 0], 0.5)
×
534
        view.shape.add_arrow([-2, -2, -2], [-2, -2, 2], [0, 0, 1], 0.5)
×
535
    if camera != "perspective" and camera != "orthographic":
×
536
        print("Only perspective or orthographic is permitted")
×
537
        return None
×
538
    view.camera = camera
×
539
    view.background = background
×
540
    return view
×
541

542

543
def _ngl_write_cell(
1✔
544
    a1: float,
545
    a2: float,
546
    a3: float,
547
    f1: float = 90.0,
548
    f2: float = 90.0,
549
    f3: float = 90.0,
550
):
551
    """
552
    Writes a PDB-formatted line to represent the simulation cell.
553

554
    Args:
555
        a1, a2, a3 (float): Lengths of the cell vectors.
556
        f1, f2, f3 (float): Angles between the cell vectors (which angles exactly?) (in degrees).
557

558
    Returns:
559
        (str): The line defining the cell in PDB format.
560
    """
561
    return "CRYST1 {:8.3f} {:8.3f} {:8.3f} {:6.2f} {:6.2f} {:6.2f} P 1\n".format(
×
562
        a1, a2, a3, f1, f2, f3
563
    )
564

565

566
def _ngl_write_atom(
1✔
567
    num: int,
568
    species: str,
569
    x: float,
570
    y: float,
571
    z: float,
572
    group: Optional[str] = None,
573
    num2: Optional[int] = None,
574
    occupancy: float = 1.0,
575
    temperature_factor: float = 0.0,
576
) -> str:
577
    """
578
    Writes a PDB-formatted line to represent an atom.
579

580
    Args:
581
        num (int): Atomic index.
582
        species (str): Elemental species.
583
        x, y, z (float): Cartesian coordinates of the atom.
584
        group (str): A...group name? (Default is None, repeat elemental species.)
585
        num2 (int): An "alternate" index. (Don't ask me...) (Default is None, repeat first number.)
586
        occupancy (float): PDB occupancy parameter. (Default is 1.)
587
        temperature_factor (float): PDB temperature factor parameter. (Default is 0.
588

589
    Returns:
590
        (str): The line defining an atom in PDB format
591

592
    Warnings:
593
        * The [PDB docs](https://www.cgl.ucsf.edu/chimera/docs/UsersGuide/tutorials/pdbintro.html) indicate that
594
            the xyz coordinates might need to be in some sort of orthogonal basis. If you have weird behaviour,
595
            this might be a good place to investigate.
596
    """
597
    if group is None:
×
598
        group = species
×
599
    if num2 is None:
×
600
        num2 = num
×
601
    return "ATOM {:>6} {:>4} {:>4} {:>5} {:10.3f} {:7.3f} {:7.3f} {:5.2f} {:5.2f} {:>11} \n".format(
×
602
        num, species, group, num2, x, y, z, occupancy, temperature_factor, species
603
    )
604

605

606
def _ngl_write_structure(
1✔
607
    elements: np.ndarray, positions: np.ndarray, cell: np.ndarray
608
) -> str:
609
    """
610
    Turns structure information into a NGLView-readable protein-database-formatted string.
611

612
    Args:
613
        elements (numpy.ndarray/list): Element symbol for each atom.
614
        positions (numpy.ndarray/list): Vector of Cartesian atom positions.
615
        cell (numpy.ndarray/list): Simulation cell Bravais matrix.
616

617
    Returns:
618
        (str): The PDB-formatted representation of the structure.
619
    """
620
    from ase.geometry import cell_to_cellpar, cellpar_to_cell
×
621

622
    if cell is None or any(np.max(cell, axis=0) < 1e-2):
×
623
        # Define a dummy cell if it doesn't exist (eg. for clusters)
624
        max_pos = np.max(positions, axis=0) - np.min(positions, axis=0)
×
625
        max_pos[np.abs(max_pos) < 1e-2] = 10
×
626
        cell = np.eye(3) * max_pos
×
627
    cellpar = cell_to_cellpar(cell)
×
628
    exportedcell = cellpar_to_cell(cellpar)
×
629
    rotation = np.linalg.solve(cell, exportedcell)
×
630

631
    pdb_str = _ngl_write_cell(*cellpar)
×
632
    pdb_str += "MODEL     1\n"
×
633

634
    if rotation is not None:
×
635
        positions = np.array(positions).dot(rotation)
×
636

637
    for i, p in enumerate(positions):
×
638
        pdb_str += _ngl_write_atom(i, elements[i], *p)
×
639

640
    pdb_str += "ENDMDL \n"
×
641
    return pdb_str
×
642

643

644
def _atomic_number_to_radius(
1✔
645
    atomic_number: int, shift: float = 0.2, slope: float = 0.1, scale: float = 1.0
646
) -> float:
647
    """
648
    Give the atomic radius for plotting, which scales like the root of the atomic number.
649

650
    Args:
651
        atomic_number (int/float): The atomic number.
652
        shift (float): A constant addition to the radius. (Default is 0.2.)
653
        slope (float): A multiplier for the root of the atomic number. (Default is 0.1)
654
        scale (float): How much to rescale the whole thing by.
655

656
    Returns:
657
        (float): The radius. (Not physical, just for visualization!)
658
    """
659
    return (shift + slope * np.sqrt(atomic_number)) * scale
×
660

661

662
def _add_colorscheme_spacefill(
1✔
663
    view,
664
    elements: np.ndarray,
665
    atomic_numbers: np.ndarray,
666
    particle_size: float,
667
    scheme: str = "element",
668
):
669
    """
670
    Set NGLView spacefill parameters according to a color-scheme.
671

672
    Args:
673
        view (NGLWidget): The widget to work on.
674
        elements (numpy.ndarray/list): Elemental symbols.
675
        atomic_numbers (numpy.ndarray/list): Integer atomic numbers for determining atomic size.
676
        particle_size (float): A scale factor for the atomic size.
677
        scheme (str): The scheme to use. (Default is "element".)
678

679
        Possible NGLView color schemes:
680
          " ", "picking", "random", "uniform", "atomindex", "residueindex",
681
          "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
682
          "hydrophobicity", "value", "volume", "occupancy"
683

684
    Returns:
685
        (nglview.NGLWidget): The modified widget.
686
    """
687
    for elem, num in set(list(zip(elements, atomic_numbers))):
×
688
        view.add_spacefill(
×
689
            selection="#" + elem,
690
            radius_type="vdw",
691
            radius=_atomic_number_to_radius(num, scale=particle_size),
692
            color_scheme=scheme,
693
        )
694
    return view
×
695

696

697
def _add_custom_color_spacefill(
1✔
698
    view, atomic_numbers: np.ndarray, particle_size: float, colors: np.ndarray
699
):
700
    """
701
    Set NGLView spacefill parameters according to per-atom colors.
702

703
    Args:
704
        view (NGLWidget): The widget to work on.
705
        atomic_numbers (numpy.ndarray/list): Integer atomic numbers for determining atomic size.
706
        particle_size (float): A scale factor for the atomic size.
707
        colors (numpy.ndarray/list): A per-atom list of HTML or hex color codes.
708

709
    Returns:
710
        (nglview.NGLWidget): The modified widget.
711
    """
712
    for n, num in enumerate(atomic_numbers):
×
713
        view.add_spacefill(
×
714
            selection=[n],
715
            radius_type="vdw",
716
            radius=_atomic_number_to_radius(num, scale=particle_size),
717
            color=colors[n],
718
        )
719
    return view
×
720

721

722
def _scalars_to_hex_colors(
1✔
723
    scalar_field: np.ndarray,
724
    start: Optional[float] = None,
725
    end: Optional[float] = None,
726
    cmap=None,
727
):
728
    """
729
    Convert scalar values to hex codes using a colormap.
730

731
    Args:
732
        scalar_field (numpy.ndarray/list): Scalars to convert.
733
        start (float): Scalar value to map to the bottom of the colormap (values below are clipped). (Default is
734
            None, use the minimal scalar value.)
735
        end (float): Scalar value to map to the top of the colormap (values above are clipped).  (Default is
736
            None, use the maximal scalar value.)
737
        cmap (matplotlib.cm): The colormap to use. (Default is None, which gives a blue-red divergent map.)
738

739
    Returns:
740
        (list): The corresponding hex codes for each scalar value passed in.
741
    """
742
    from matplotlib.colors import rgb2hex
×
743

744
    if start is None:
×
745
        start = np.amin(scalar_field)
×
746
    if end is None:
×
747
        end = np.amax(scalar_field)
×
748
    interp = interp1d([start, end], [0, 1])
×
749
    remapped_field = interp(np.clip(scalar_field, start, end))  # Map field onto [0,1]
×
750

751
    if cmap is None:
×
752
        try:
×
753
            from seaborn import diverging_palette
×
754
        except ImportError:
×
755
            print(
×
756
                "The package seaborn needs to be installed for the plot3d() function!"
757
            )
758
        cmap = diverging_palette(245, 15, as_cmap=True)  # A nice blue-red palette
×
759

760
    return [
×
761
        rgb2hex(cmap(scalar)[:3]) for scalar in remapped_field
762
    ]  # The slice gets RGB but leaves alpha
763

764

765
def _get_orientation(view_plane: np.ndarray) -> np.ndarray:
1✔
766
    """
767
    A helper method to plot3d, which generates a rotation matrix from the input `view_plane`, and returns a
768
    flattened list of len = 16. This flattened list becomes the input argument to `view.contol.orient`.
769

770
    Args:
771
        view_plane (numpy.ndarray/list): A Nx3-array/list (N = 1,2,3); the first 3d-component of the array
772
            specifies which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes),
773
            the second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the
774
            third component (if specified) is the vertical component, which is ignored and calculated internally.
775
            The orthonormality of the orientation is internally ensured, and therefore is not required in the
776
            function call.
777

778
    Returns:
779
        (list): orientation tensor
780
    """
781
    if len(np.array(view_plane).flatten()) % 3 != 0:
1✔
782
        raise ValueError(
×
783
            "The shape of view plane should be (N, 3), where N = 1, 2 or 3. Refer docs for more info."
784
        )
785
    view_plane = np.array(view_plane).reshape(-1, 3)
1✔
786
    rotation_matrix = np.roll(np.eye(3), -1, axis=0)
1✔
787
    rotation_matrix[: len(view_plane)] = view_plane
1✔
788
    rotation_matrix /= np.linalg.norm(rotation_matrix, axis=-1)[:, np.newaxis]
1✔
789
    rotation_matrix[1] -= (
1✔
790
        np.dot(rotation_matrix[0], rotation_matrix[1]) * rotation_matrix[0]
791
    )  # Gran-Schmidt
792
    rotation_matrix[2] = np.cross(
1✔
793
        rotation_matrix[0], rotation_matrix[1]
794
    )  # Specify third axis
795
    if np.isclose(np.linalg.det(rotation_matrix), 0):
1✔
796
        return np.eye(
×
797
            3
798
        )  # view_plane = [0,0,1] is the default view of NGLview, so we do not modify it
799
    return np.roll(
1✔
800
        rotation_matrix / np.linalg.norm(rotation_matrix, axis=-1)[:, np.newaxis],
801
        2,
802
        axis=0,
803
    ).T
804

805

806
def _get_flattened_orientation(
1✔
807
    view_plane: np.ndarray, distance_from_camera: float
808
) -> list:
809
    """
810
    A helper method to plot3d, which generates a rotation matrix from the input `view_plane`, and returns a
811
    flattened list of len = 16. This flattened list becomes the input argument to `view.contol.orient`.
812

813
    Args:
814
        view_plane (numpy.ndarray/list): A Nx3-array/list (N = 1,2,3); the first 3d-component of the array
815
            specifies which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes),
816
            the second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the
817
            third component (if specified) is the vertical component, which is ignored and calculated internally.
818
            The orthonormality of the orientation is internally ensured, and therefore is not required in the
819
            function call.
820
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
821

822
    Returns:
823
        (list): Flattened list of len = 16, which is the input argument to `view.contol.orient`
824
    """
825
    if distance_from_camera <= 0:
1✔
826
        raise ValueError("´distance_from_camera´ must be a positive float!")
×
827
    flattened_orientation = np.eye(4)
1✔
828
    flattened_orientation[:3, :3] = _get_orientation(view_plane)
1✔
829

830
    return (distance_from_camera * flattened_orientation).ravel().tolist()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc