• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyiron / structuretoolkit / 13152197077

05 Feb 2025 07:18AM UTC coverage: 83.109% (+0.2%) from 82.903%
13152197077

push

github

web-flow
extend ruff linter (#315)

* extend ruff linter

* fix broken comparison

33 of 60 new or added lines in 16 files covered. (55.0%)

4 existing lines in 1 file now uncovered.

1545 of 1859 relevant lines covered (83.11%)

0.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

18.64
/structuretoolkit/visualize.py
1
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
2
# Distributed under the terms of "New BSD License", see the LICENSE file.
3

4
import warnings
1✔
5
from typing import Any, Optional
1✔
6

7
import numpy as np
1✔
8
from ase.atoms import Atoms
1✔
9
from scipy.interpolate import interp1d
1✔
10

11
from structuretoolkit.common.helper import get_cell
1✔
12

13
__author__ = "Joerg Neugebauer, Sudarsan Surendralal"
1✔
14
__copyright__ = (
1✔
15
    "Copyright 2021, Max-Planck-Institut für Eisenforschung GmbH - "
16
    "Computational Materials Design (CM) Department"
17
)
18
__version__ = "1.0"
1✔
19
__maintainer__ = "Sudarsan Surendralal"
1✔
20
__email__ = "surendralal@mpie.de"
1✔
21
__status__ = "production"
1✔
22
__date__ = "Sep 1, 2017"
1✔
23

24

25
def plot3d(
1✔
26
    structure: Atoms,
27
    mode: str = "NGLview",
28
    show_cell: bool = True,
29
    show_axes: bool = True,
30
    camera: str = "orthographic",
31
    spacefill: bool = True,
32
    particle_size: float = 1.0,
33
    select_atoms: Optional[np.ndarray] = None,
34
    background: str = "white",
35
    color_scheme: Optional[str] = None,
36
    colors: Optional[np.ndarray] = None,
37
    scalar_field: Optional[np.ndarray] = None,
38
    scalar_start: Optional[float] = None,
39
    scalar_end: Optional[float] = None,
40
    scalar_cmap: Optional[Any] = None,
41
    vector_field: Optional[np.ndarray] = None,
42
    vector_color: Optional[np.ndarray] = None,
43
    magnetic_moments: bool = False,
44
    view_plane: np.ndarray = np.array([0, 0, 1]),
45
    distance_from_camera: float = 1.0,
46
    opacity: float = 1.0,
47
    height: Optional[float] = None,
48
):
49
    """
50
    Plot3d relies on NGLView or plotly to visualize atomic structures. Here, we construct a string in the "protein database"
51

52
    The final widget is returned. If it is assigned to a variable, the visualization is suppressed until that
53
    variable is evaluated, and in the meantime more NGL operations can be applied to it to modify the visualization.
54

55
    Args:
56
        mode (str): `NGLView`, `plotly` or `ase`
57
        show_cell (bool): Whether or not to show the frame. (Default is True.)
58
        show_axes (bool): Whether or not to show xyz axes. (Default is True.)
59
        camera (str): 'perspective' or 'orthographic'. (Default is 'perspective'.)
60
        spacefill (bool): Whether to use a space-filling or ball-and-stick representation. (Default is True, use
61
            space-filling atoms.)
62
        particle_size (float): Size of the particles. (Default is 1.)
63
        select_atoms (numpy.ndarray): Indices of atoms to show, either as integers or a boolean array mask.
64
            (Default is None, show all atoms.)
65
        background (str): Background color. (Default is 'white'.)
66
        color_scheme (str): NGLView color scheme to use. (Default is None, color by element.)
67
        colors (numpy.ndarray): A per-atom array of HTML color names or hex color codes to use for atomic colors.
68
            (Default is None, use coloring scheme.)
69
        scalar_field (numpy.ndarray): Color each atom according to the array value (Default is None, use coloring
70
            scheme.)
71
        scalar_start (float): The scalar value to be mapped onto the low end of the color map (lower values are
72
            clipped). (Default is None, use the minimum value in `scalar_field`.)
73
        scalar_end (float): The scalar value to be mapped onto the high end of the color map (higher values are
74
            clipped). (Default is None, use the maximum value in `scalar_field`.)
75
        scalar_cmap (matplotlib.cm): The colormap to use. (Default is None, giving a blue-red divergent map.)
76
        vector_field (numpy.ndarray): Add vectors (3 values) originating at each atom. (Default is None, no
77
            vectors.)
78
        vector_color (numpy.ndarray): Colors for the vectors (only available with vector_field). (Default is None,
79
            vectors are colored by their direction.)
80
        magnetic_moments (bool): Plot magnetic moments as 'scalar_field' or 'vector_field'.
81
        view_plane (numpy.ndarray): A Nx3-array (N = 1,2,3); the first 3d-component of the array specifies
82
            which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes), the
83
            second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the third
84
            component (if specified) is the vertical component, which is ignored and calculated internally. The
85
            orthonormality of the orientation is internally ensured, and therefore is not required in the function
86
            call. (Default is np.array([0, 0, 1]), which is view normal to the x-y plane.)
87
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
88
            (Default is 14, which also seems to be the NGLView default value.)
89
        height (int/float/None): height of the plot area in pixel (only
90
            available in plotly) Default: 600
91

92
        Possible NGLView color schemes:
93
          " ", "picking", "random", "uniform", "atomindex", "residueindex",
94
          "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
95
          "hydrophobicity", "value", "volume", "occupancy"
96

97
    Returns:
98
        (nglview.NGLWidget): The NGLView widget itself, which can be operated on further or viewed as-is.
99

100
    Warnings:
101
        * Many features only work with space-filling atoms (e.g. coloring by a scalar field).
102
        * The colour interpretation of some hex codes is weird, e.g. 'green'.
103
    """
104
    if mode == "NGLview":
×
105
        if height is not None:
×
NEW
106
            warnings.warn(
×
107
                "`height` is not implemented in NGLview", SyntaxWarning, stacklevel=2
108
            )
UNCOV
109
        return _plot3d(
×
110
            structure=structure,
111
            show_cell=show_cell,
112
            show_axes=show_axes,
113
            camera=camera,
114
            spacefill=spacefill,
115
            particle_size=particle_size,
116
            select_atoms=select_atoms,
117
            background=background,
118
            color_scheme=color_scheme,
119
            colors=colors,
120
            scalar_field=scalar_field,
121
            scalar_start=scalar_start,
122
            scalar_end=scalar_end,
123
            scalar_cmap=scalar_cmap,
124
            vector_field=vector_field,
125
            vector_color=vector_color,
126
            magnetic_moments=magnetic_moments,
127
            view_plane=view_plane,
128
            distance_from_camera=distance_from_camera,
129
        )
130
    elif mode == "plotly":
×
131
        return _plot3d_plotly(
×
132
            structure=structure,
133
            show_cell=show_cell,
134
            camera=camera,
135
            particle_size=particle_size,
136
            select_atoms=select_atoms,
137
            scalar_field=scalar_field,
138
            view_plane=view_plane,
139
            distance_from_camera=distance_from_camera,
140
            opacity=opacity,
141
            height=height,
142
        )
143
    elif mode == "ase":
×
144
        if height is not None:
×
NEW
145
            warnings.warn(
×
146
                "`height` is not implemented in ase", SyntaxWarning, stacklevel=2
147
            )
UNCOV
148
        return _plot3d_ase(
×
149
            structure=structure,
150
            show_cell=show_cell,
151
            show_axes=show_axes,
152
            camera=camera,
153
            spacefill=spacefill,
154
            particle_size=particle_size,
155
            background=background,
156
            color_scheme=color_scheme,
157
        )
158
    else:
159
        raise ValueError("plot method not recognized")
×
160

161

162
def _get_box_skeleton(cell: np.ndarray) -> np.ndarray:
1✔
163
    """
164
    Generate the skeleton of a box defined by the unit cell.
165

166
    Args:
167
        cell (np.ndarray): The unit cell of the structure.
168

169
    Returns:
170
        np.ndarray: The skeleton of the box defined by the unit cell.
171
    """
172
    lines_dz = np.stack(np.meshgrid(*3 * [[0, 1]], indexing="ij"), axis=-1)
1✔
173
    # eight corners of a unit cube, paired as four z-axis lines
174

175
    all_lines = np.reshape(
1✔
176
        [np.roll(lines_dz, i, axis=-1) for i in range(3)], (-1, 2, 3)
177
    )
178
    # All 12 two-point lines on the unit square
179
    return all_lines @ cell
1✔
180

181

182
def _draw_box_plotly(fig: Any, structure: Atoms, px: Any, go: Any) -> Any:
1✔
183
    """
184
    Draw the box skeleton of the atomic structure using Plotly.
185

186
    Args:
187
        fig (go.Figure): The Plotly figure object.
188
        structure (Atoms): The atomic structure.
189
        px (Any): The Plotly express module.
190
        go (Any): The Plotly graph objects module.
191

192
    Returns:
193
        go.Figure: The updated Plotly figure object.
194
    """
195
    cell = get_cell(structure)
×
196
    data = fig.data
×
197
    for lines in _get_box_skeleton(cell):
×
NEW
198
        fig = px.line_3d(**dict(zip(["x", "y", "z"], lines.T)))
×
199
        fig.update_traces(line_color="#000000")
×
200
        data = fig.data + data
×
201
    return go.Figure(data=data)
×
202

203

204
def _plot3d_plotly(
1✔
205
    structure: Atoms,
206
    show_cell: bool = True,
207
    scalar_field: Optional[np.ndarray] = None,
208
    select_atoms: Optional[np.ndarray] = None,
209
    particle_size: float = 1.0,
210
    camera: str = "orthographic",
211
    view_plane: np.ndarray = np.array([1, 1, 1]),
212
    distance_from_camera: float = 1.0,
213
    opacity: float = 1.0,
214
    height: Optional[float] = None,
215
):
216
    """
217
    Make a 3D plot of the atomic structure.
218

219
    Args:
220
        camera (str): 'perspective' or 'orthographic'. (Default is 'perspective'.)
221
        particle_size (float): Size of the particles. (Default is 1.)
222
        scalar_field (numpy.ndarray): Color each atom according to the array value (Default is None, use coloring
223
            scheme.)
224
        view_plane (numpy.ndarray): A Nx3-array (N = 1,2,3); the first 3d-component of the array specifies
225
            which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes), the
226
            second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the third
227
            component (if specified) is the vertical component, which is ignored and calculated internally. The
228
            orthonormality of the orientation is internally ensured, and therefore is not required in the function
229
            call. (Default is np.array([0, 0, 1]), which is view normal to the x-y plane.)
230
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
231
            (Default is 14, which also seems to be the NGLView default value.)
232
        opacity (float): opacity
233
        height (int/float/None): height of the plot area in pixel. Default: 600
234

235
    Returns:
236
        (plotly.express): The NGLView widget itself, which can be operated on further or viewed as-is.
237

238
    """
239
    try:
×
240
        import plotly.express as px
×
241
        import plotly.graph_objects as go
×
242
    except ModuleNotFoundError:
×
243
        raise ModuleNotFoundError("plotly not installed - use plot3d instead")
×
244
    if select_atoms is None:
×
245
        select_atoms = np.arange(len(structure))
×
246
    elements = structure.get_chemical_symbols()
×
247
    atomic_numbers = structure.get_atomic_numbers()
×
248
    if scalar_field is None:
×
249
        scalar_field = elements
×
250
    fig = px.scatter_3d(
×
251
        x=structure.positions[select_atoms, 0],
252
        y=structure.positions[select_atoms, 1],
253
        z=structure.positions[select_atoms, 2],
254
        color=scalar_field,
255
        opacity=opacity,
256
        size=_atomic_number_to_radius(
257
            atomic_numbers,
258
            scale=particle_size / (0.1 * structure.get_volume() ** (1 / 3)),
259
        ),
260
    )
261
    if show_cell:
×
262
        fig = _draw_box_plotly(fig, structure, px, go)
×
263
    fig.layout.scene.camera.projection.type = camera
×
264
    rot = _get_orientation(view_plane).T
×
265
    rot[0, :] *= distance_from_camera * 1.25
×
NEW
266
    angle = {
×
267
        "up": {"x": rot[2, 0], "y": rot[2, 1], "z": rot[2, 2]},
268
        "eye": {"x": rot[0, 0], "y": rot[0, 1], "z": rot[0, 2]},
269
    }
270
    fig.update_layout(scene_camera=angle)
×
NEW
271
    fig.update_traces(marker={"line": {"width": 0.1, "color": "DarkSlateGrey"}})
×
272
    fig.update_scenes(aspectmode="data")
×
273
    if height is None:
×
274
        height = 600
×
275
    fig.update_layout(autosize=True, height=height)
×
276
    fig.update_layout(legend={"itemsizing": "constant"})
×
277
    return fig
×
278

279

280
def _plot3d(
1✔
281
    structure: Atoms,
282
    show_cell: bool = True,
283
    show_axes: bool = True,
284
    camera: str = "orthographic",
285
    spacefill: bool = True,
286
    particle_size: float = 1.0,
287
    select_atoms: Optional[np.ndarray] = None,
288
    background: str = "white",
289
    color_scheme: Optional[str] = None,
290
    colors: Optional[np.ndarray] = None,
291
    scalar_field: Optional[np.ndarray] = None,
292
    scalar_start: Optional[float] = None,
293
    scalar_end: Optional[float] = None,
294
    scalar_cmap: Optional[Any] = None,
295
    vector_field: Optional[np.ndarray] = None,
296
    vector_color: Optional[np.ndarray] = None,
297
    magnetic_moments: bool = False,
298
    view_plane: np.ndarray = np.array([0, 0, 1]),
299
    distance_from_camera: float = 1.0,
300
):
301
    """
302
    Plot3d relies on NGLView to visualize atomic structures. Here, we construct a string in the "protein database"
303
    ("pdb") format, then turn it into an NGLView "structure". PDB is a white-space sensitive format, so the
304
    string snippets are carefully formatted.
305

306
    The final widget is returned. If it is assigned to a variable, the visualization is suppressed until that
307
    variable is evaluated, and in the meantime more NGL operations can be applied to it to modify the visualization.
308

309
    Args:
310
        show_cell (bool): Whether or not to show the frame. (Default is True.)
311
        show_axes (bool): Whether or not to show xyz axes. (Default is True.)
312
        camera (str): 'perspective' or 'orthographic'. (Default is 'perspective'.)
313
        spacefill (bool): Whether to use a space-filling or ball-and-stick representation. (Default is True, use
314
            space-filling atoms.)
315
        particle_size (float): Size of the particles. (Default is 1.)
316
        select_atoms (numpy.ndarray): Indices of atoms to show, either as integers or a boolean array mask.
317
            (Default is None, show all atoms.)
318
        background (str): Background color. (Default is 'white'.)
319
        color_scheme (str): NGLView color scheme to use. (Default is None, color by element.)
320
        colors (numpy.ndarray): A per-atom array of HTML color names or hex color codes to use for atomic colors.
321
            (Default is None, use coloring scheme.)
322
        scalar_field (numpy.ndarray): Color each atom according to the array value (Default is None, use coloring
323
            scheme.)
324
        scalar_start (float): The scalar value to be mapped onto the low end of the color map (lower values are
325
            clipped). (Default is None, use the minimum value in `scalar_field`.)
326
        scalar_end (float): The scalar value to be mapped onto the high end of the color map (higher values are
327
            clipped). (Default is None, use the maximum value in `scalar_field`.)
328
        scalar_cmap (matplotlib.cm): The colormap to use. (Default is None, giving a blue-red divergent map.)
329
        vector_field (numpy.ndarray): Add vectors (3 values) originating at each atom. (Default is None, no
330
            vectors.)
331
        vector_color (numpy.ndarray): Colors for the vectors (only available with vector_field). (Default is None,
332
            vectors are colored by their direction.)
333
        magnetic_moments (bool): Plot magnetic moments as 'scalar_field' or 'vector_field'.
334
        view_plane (numpy.ndarray): A Nx3-array (N = 1,2,3); the first 3d-component of the array specifies
335
            which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes), the
336
            second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the third
337
            component (if specified) is the vertical component, which is ignored and calculated internally. The
338
            orthonormality of the orientation is internally ensured, and therefore is not required in the function
339
            call. (Default is np.array([0, 0, 1]), which is view normal to the x-y plane.)
340
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
341
            (Default is 14, which also seems to be the NGLView default value.)
342

343
        Possible NGLView color schemes:
344
          " ", "picking", "random", "uniform", "atomindex", "residueindex",
345
          "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
346
          "hydrophobicity", "value", "volume", "occupancy"
347

348
    Returns:
349
        (nglview.NGLWidget): The NGLView widget itself, which can be operated on further or viewed as-is.
350

351
    Warnings:
352
        * Many features only work with space-filling atoms (e.g. coloring by a scalar field).
353
        * The colour interpretation of some hex codes is weird, e.g. 'green'.
354
    """
355
    try:  # If the graphical packages are not available, the GUI will not work.
×
356
        import nglview
×
357
    except ImportError:
×
358
        raise ImportError(
×
359
            "The package nglview needs to be installed for the plot3d() function!"
360
        )
361

362
    if (
×
363
        magnetic_moments is True
364
        and np.sum(np.abs(structure.get_initial_magnetic_moments())) > 0
365
    ):
366
        if len(structure.get_initial_magnetic_moments().shape) == 1:
×
367
            scalar_field = structure.get_initial_magnetic_moments()
×
368
        else:
369
            vector_field = structure.get_initial_magnetic_moments()
×
370

371
    elements = structure.get_chemical_symbols()
×
372
    atomic_numbers = structure.get_atomic_numbers()
×
373
    positions = structure.positions
×
374

375
    # If `select_atoms` was given, visualize only a subset of the `parent_basis`
376
    if select_atoms is not None:
×
377
        select_atoms = np.array(select_atoms, dtype=int)
×
378
        elements = np.array(elements)[select_atoms]
×
379
        atomic_numbers = atomic_numbers[select_atoms]
×
380
        positions = positions[select_atoms]
×
381
        if colors is not None:
×
382
            colors = np.array(colors)
×
383
            colors = colors[select_atoms]
×
384
        if scalar_field is not None:
×
385
            scalar_field = np.array(scalar_field)
×
386
            scalar_field = scalar_field[select_atoms]
×
387
        if vector_field is not None:
×
388
            vector_field = np.array(vector_field)
×
389
            vector_field = vector_field[select_atoms]
×
390
        if vector_color is not None:
×
391
            vector_color = np.array(vector_color)
×
392
            vector_color = vector_color[select_atoms]
×
393

394
    # Write the nglview protein-database-formatted string
395
    struct = nglview.TextStructure(
×
396
        _ngl_write_structure(elements, positions, structure.cell)
397
    )
398

399
    # Parse the string into the displayable widget
400
    view = nglview.NGLWidget(struct)
×
401

402
    if spacefill:
×
403
        # Color by scheme
404
        if color_scheme is not None:
×
405
            if colors is not None:
×
NEW
406
                warnings.warn("`color_scheme` is overriding `colors`", stacklevel=2)
×
407
            if scalar_field is not None:
×
NEW
408
                warnings.warn(
×
409
                    "`color_scheme` is overriding `scalar_field`", stacklevel=2
410
                )
UNCOV
411
            view = _add_colorscheme_spacefill(
×
412
                view, elements, atomic_numbers, particle_size, color_scheme
413
            )
414
        # Color by per-atom colors
415
        elif colors is not None:
×
416
            if scalar_field is not None:
×
NEW
417
                warnings.warn("`colors` is overriding `scalar_field`", stacklevel=2)
×
418
            view = _add_custom_color_spacefill(
×
419
                view, atomic_numbers, particle_size, colors
420
            )
421
        # Color by per-atom scalars
422
        elif scalar_field is not None:  # Color by per-atom scalars
×
423
            colors = _scalars_to_hex_colors(
×
424
                scalar_field, scalar_start, scalar_end, scalar_cmap
425
            )
426
            view = _add_custom_color_spacefill(
×
427
                view, atomic_numbers, particle_size, colors
428
            )
429
        # Color by element
430
        else:
431
            view = _add_colorscheme_spacefill(
×
432
                view, elements, atomic_numbers, particle_size
433
            )
434
        view.remove_ball_and_stick()
×
435
    else:
436
        view.add_ball_and_stick()
×
437

NEW
438
    if (
×
439
        show_cell
440
        and structure.cell is not None
441
        and all(np.max(structure.cell, axis=0) > 1e-2)
442
    ):
NEW
443
        view.add_unitcell()
×
444

445
    if vector_color is None and vector_field is not None:
×
446
        vector_color = (
×
447
            0.5
448
            * np.array(vector_field)
449
            / np.linalg.norm(vector_field, axis=-1)[:, np.newaxis]
450
            + 0.5
451
        )
452
    elif (
×
453
        vector_field is not None and vector_field is not None
454
    ):  # WARNING: There must be a bug here...
455
        try:
×
456
            if vector_color.shape != np.ones((len(structure), 3)).shape:
×
457
                vector_color = np.outer(
×
458
                    np.ones(len(structure)),
459
                    vector_color / np.linalg.norm(vector_color),
460
                )
461
        except AttributeError:
×
462
            vector_color = np.ones((len(structure), 3)) * vector_color
×
463

464
    if vector_field is not None:
×
465
        for arr, pos, col in zip(vector_field, positions, vector_color):
×
466
            view.shape.add_arrow(list(pos), list(pos + arr), list(col), 0.2)
×
467

468
    if show_axes:  # Add axes
×
469
        axes_origin = -np.ones(3)
×
470
        arrow_radius = 0.1
×
471
        text_size = 1
×
472
        text_color = [0, 0, 0]
×
473
        arrow_names = ["x", "y", "z"]
×
474

475
        for n in [0, 1, 2]:
×
476
            start = list(axes_origin)
×
477
            shift = np.zeros(3)
×
478
            shift[n] = 1
×
479
            end = list(start + shift)
×
480
            color = list(shift)
×
481
            # We cast as list to avoid JSON warnings
482
            view.shape.add_arrow(start, end, color, arrow_radius)
×
483
            view.shape.add_text(end, text_color, text_size, arrow_names[n])
×
484

NEW
485
    if camera not in ("perspective", "orthographic"):
×
486
        warnings.warn(
×
487
            "Only perspective or orthographic is (likely to be) permitted for camera",
488
            stacklevel=2,
489
        )
490

491
    view.camera = camera
×
492
    view.background = background
×
493

494
    orientation = _get_flattened_orientation(
×
495
        view_plane=view_plane, distance_from_camera=distance_from_camera * 14
496
    )
497
    view.control.orient(orientation)
×
498

499
    return view
×
500

501

502
def _plot3d_ase(
1✔
503
    structure: Atoms,
504
    spacefill: bool = True,
505
    show_cell: bool = True,
506
    camera: str = "perspective",
507
    particle_size: float = 0.5,
508
    background: str = "white",
509
    color_scheme: str = "element",
510
    show_axes: bool = True,
511
):
512
    """
513
    Possible color schemes:
514
      " ", "picking", "random", "uniform", "atomindex", "residueindex",
515
      "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
516
      "hydrophobicity", "value", "volume", "occupancy"
517
    Returns:
518
    """
519
    try:  # If the graphical packages are not available, the GUI will not work.
×
520
        import nglview
×
521
    except ImportError:
×
522
        raise ImportError(
×
523
            "The package nglview needs to be installed for the plot3d() function!"
524
        )
525
    # Always visualize the parent basis
526
    view = nglview.show_ase(structure)
×
527
    if spacefill:
×
528
        view.add_spacefill(
×
529
            radius_type="vdw", color_scheme=color_scheme, radius=particle_size
530
        )
UNCOV
531
        view.remove_ball_and_stick()
×
532
    else:
533
        view.add_ball_and_stick()
×
NEW
534
    if (
×
535
        show_cell
536
        and structure.cell is not None
537
        and all(np.max(structure.cell, axis=0) > 1e-2)
538
    ):
NEW
539
        view.add_unitcell()
×
540
    if show_axes:
×
541
        view.shape.add_arrow([-2, -2, -2], [2, -2, -2], [1, 0, 0], 0.5)
×
542
        view.shape.add_arrow([-2, -2, -2], [-2, 2, -2], [0, 1, 0], 0.5)
×
543
        view.shape.add_arrow([-2, -2, -2], [-2, -2, 2], [0, 0, 1], 0.5)
×
NEW
544
    if camera not in ("perspective", "orthographic"):
×
545
        print("Only perspective or orthographic is permitted")
×
546
        return None
×
547
    view.camera = camera
×
548
    view.background = background
×
549
    return view
×
550

551

552
def _ngl_write_cell(
1✔
553
    a1: float,
554
    a2: float,
555
    a3: float,
556
    f1: float = 90.0,
557
    f2: float = 90.0,
558
    f3: float = 90.0,
559
):
560
    """
561
    Writes a PDB-formatted line to represent the simulation cell.
562

563
    Args:
564
        a1, a2, a3 (float): Lengths of the cell vectors.
565
        f1, f2, f3 (float): Angles between the cell vectors (which angles exactly?) (in degrees).
566

567
    Returns:
568
        (str): The line defining the cell in PDB format.
569
    """
NEW
570
    return f"CRYST1 {a1:8.3f} {a2:8.3f} {a3:8.3f} {f1:6.2f} {f2:6.2f} {f3:6.2f} P 1\n"
×
571

572

573
def _ngl_write_atom(
1✔
574
    num: int,
575
    species: str,
576
    x: float,
577
    y: float,
578
    z: float,
579
    group: Optional[str] = None,
580
    num2: Optional[int] = None,
581
    occupancy: float = 1.0,
582
    temperature_factor: float = 0.0,
583
) -> str:
584
    """
585
    Writes a PDB-formatted line to represent an atom.
586

587
    Args:
588
        num (int): Atomic index.
589
        species (str): Elemental species.
590
        x, y, z (float): Cartesian coordinates of the atom.
591
        group (str): A...group name? (Default is None, repeat elemental species.)
592
        num2 (int): An "alternate" index. (Don't ask me...) (Default is None, repeat first number.)
593
        occupancy (float): PDB occupancy parameter. (Default is 1.)
594
        temperature_factor (float): PDB temperature factor parameter. (Default is 0.
595

596
    Returns:
597
        (str): The line defining an atom in PDB format
598

599
    Warnings:
600
        * The [PDB docs](https://www.cgl.ucsf.edu/chimera/docs/UsersGuide/tutorials/pdbintro.html) indicate that
601
            the xyz coordinates might need to be in some sort of orthogonal basis. If you have weird behaviour,
602
            this might be a good place to investigate.
603
    """
604
    if group is None:
×
605
        group = species
×
606
    if num2 is None:
×
607
        num2 = num
×
NEW
608
    return f"ATOM {num:>6} {species:>4} {group:>4} {num2:>5} {x:10.3f} {y:7.3f} {z:7.3f} {occupancy:5.2f} {temperature_factor:5.2f} {species:>11} \n"
×
609

610

611
def _ngl_write_structure(
1✔
612
    elements: np.ndarray, positions: np.ndarray, cell: np.ndarray
613
) -> str:
614
    """
615
    Turns structure information into a NGLView-readable protein-database-formatted string.
616

617
    Args:
618
        elements (numpy.ndarray/list): Element symbol for each atom.
619
        positions (numpy.ndarray/list): Vector of Cartesian atom positions.
620
        cell (numpy.ndarray/list): Simulation cell Bravais matrix.
621

622
    Returns:
623
        (str): The PDB-formatted representation of the structure.
624
    """
625
    from ase.geometry import cell_to_cellpar, cellpar_to_cell
×
626

627
    if cell is None or any(np.max(cell, axis=0) < 1e-2):
×
628
        # Define a dummy cell if it doesn't exist (eg. for clusters)
629
        max_pos = np.max(positions, axis=0) - np.min(positions, axis=0)
×
630
        max_pos[np.abs(max_pos) < 1e-2] = 10
×
631
        cell = np.eye(3) * max_pos
×
632
    cellpar = cell_to_cellpar(cell)
×
633
    exportedcell = cellpar_to_cell(cellpar)
×
634
    rotation = np.linalg.solve(cell, exportedcell)
×
635

636
    pdb_str = _ngl_write_cell(*cellpar)
×
637
    pdb_str += "MODEL     1\n"
×
638

639
    if rotation is not None:
×
640
        positions = np.array(positions).dot(rotation)
×
641

642
    for i, p in enumerate(positions):
×
643
        pdb_str += _ngl_write_atom(i, elements[i], *p)
×
644

645
    pdb_str += "ENDMDL \n"
×
646
    return pdb_str
×
647

648

649
def _atomic_number_to_radius(
1✔
650
    atomic_number: int, shift: float = 0.2, slope: float = 0.1, scale: float = 1.0
651
) -> float:
652
    """
653
    Give the atomic radius for plotting, which scales like the root of the atomic number.
654

655
    Args:
656
        atomic_number (int/float): The atomic number.
657
        shift (float): A constant addition to the radius. (Default is 0.2.)
658
        slope (float): A multiplier for the root of the atomic number. (Default is 0.1)
659
        scale (float): How much to rescale the whole thing by.
660

661
    Returns:
662
        (float): The radius. (Not physical, just for visualization!)
663
    """
664
    return (shift + slope * np.sqrt(atomic_number)) * scale
×
665

666

667
def _add_colorscheme_spacefill(
1✔
668
    view,
669
    elements: np.ndarray,
670
    atomic_numbers: np.ndarray,
671
    particle_size: float,
672
    scheme: str = "element",
673
):
674
    """
675
    Set NGLView spacefill parameters according to a color-scheme.
676

677
    Args:
678
        view (NGLWidget): The widget to work on.
679
        elements (numpy.ndarray/list): Elemental symbols.
680
        atomic_numbers (numpy.ndarray/list): Integer atomic numbers for determining atomic size.
681
        particle_size (float): A scale factor for the atomic size.
682
        scheme (str): The scheme to use. (Default is "element".)
683

684
        Possible NGLView color schemes:
685
          " ", "picking", "random", "uniform", "atomindex", "residueindex",
686
          "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
687
          "hydrophobicity", "value", "volume", "occupancy"
688

689
    Returns:
690
        (nglview.NGLWidget): The modified widget.
691
    """
NEW
692
    for elem, num in set(zip(elements, atomic_numbers)):
×
693
        view.add_spacefill(
×
694
            selection="#" + elem,
695
            radius_type="vdw",
696
            radius=_atomic_number_to_radius(num, scale=particle_size),
697
            color_scheme=scheme,
698
        )
699
    return view
×
700

701

702
def _add_custom_color_spacefill(
1✔
703
    view, atomic_numbers: np.ndarray, particle_size: float, colors: np.ndarray
704
):
705
    """
706
    Set NGLView spacefill parameters according to per-atom colors.
707

708
    Args:
709
        view (NGLWidget): The widget to work on.
710
        atomic_numbers (numpy.ndarray/list): Integer atomic numbers for determining atomic size.
711
        particle_size (float): A scale factor for the atomic size.
712
        colors (numpy.ndarray/list): A per-atom list of HTML or hex color codes.
713

714
    Returns:
715
        (nglview.NGLWidget): The modified widget.
716
    """
717
    for n, num in enumerate(atomic_numbers):
×
718
        view.add_spacefill(
×
719
            selection=[n],
720
            radius_type="vdw",
721
            radius=_atomic_number_to_radius(num, scale=particle_size),
722
            color=colors[n],
723
        )
724
    return view
×
725

726

727
def _scalars_to_hex_colors(
1✔
728
    scalar_field: np.ndarray,
729
    start: Optional[float] = None,
730
    end: Optional[float] = None,
731
    cmap=None,
732
):
733
    """
734
    Convert scalar values to hex codes using a colormap.
735

736
    Args:
737
        scalar_field (numpy.ndarray/list): Scalars to convert.
738
        start (float): Scalar value to map to the bottom of the colormap (values below are clipped). (Default is
739
            None, use the minimal scalar value.)
740
        end (float): Scalar value to map to the top of the colormap (values above are clipped).  (Default is
741
            None, use the maximal scalar value.)
742
        cmap (matplotlib.cm): The colormap to use. (Default is None, which gives a blue-red divergent map.)
743

744
    Returns:
745
        (list): The corresponding hex codes for each scalar value passed in.
746
    """
747
    from matplotlib.colors import rgb2hex
×
748

749
    if start is None:
×
750
        start = np.amin(scalar_field)
×
751
    if end is None:
×
752
        end = np.amax(scalar_field)
×
753
    interp = interp1d([start, end], [0, 1])
×
754
    remapped_field = interp(np.clip(scalar_field, start, end))  # Map field onto [0,1]
×
755

756
    if cmap is None:
×
757
        try:
×
758
            from seaborn import diverging_palette
×
759
        except ImportError:
×
760
            print(
×
761
                "The package seaborn needs to be installed for the plot3d() function!"
762
            )
763
        cmap = diverging_palette(245, 15, as_cmap=True)  # A nice blue-red palette
×
764

765
    return [
×
766
        rgb2hex(cmap(scalar)[:3]) for scalar in remapped_field
767
    ]  # The slice gets RGB but leaves alpha
768

769

770
def _get_orientation(view_plane: np.ndarray) -> np.ndarray:
1✔
771
    """
772
    A helper method to plot3d, which generates a rotation matrix from the input `view_plane`, and returns a
773
    flattened list of len = 16. This flattened list becomes the input argument to `view.contol.orient`.
774

775
    Args:
776
        view_plane (numpy.ndarray/list): A Nx3-array/list (N = 1,2,3); the first 3d-component of the array
777
            specifies which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes),
778
            the second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the
779
            third component (if specified) is the vertical component, which is ignored and calculated internally.
780
            The orthonormality of the orientation is internally ensured, and therefore is not required in the
781
            function call.
782

783
    Returns:
784
        (list): orientation tensor
785
    """
786
    if len(np.array(view_plane).flatten()) % 3 != 0:
1✔
787
        raise ValueError(
×
788
            "The shape of view plane should be (N, 3), where N = 1, 2 or 3. Refer docs for more info."
789
        )
790
    view_plane = np.array(view_plane).reshape(-1, 3)
1✔
791
    rotation_matrix = np.roll(np.eye(3), -1, axis=0)
1✔
792
    rotation_matrix[: len(view_plane)] = view_plane
1✔
793
    rotation_matrix /= np.linalg.norm(rotation_matrix, axis=-1)[:, np.newaxis]
1✔
794
    rotation_matrix[1] -= (
1✔
795
        np.dot(rotation_matrix[0], rotation_matrix[1]) * rotation_matrix[0]
796
    )  # Gran-Schmidt
797
    rotation_matrix[2] = np.cross(
1✔
798
        rotation_matrix[0], rotation_matrix[1]
799
    )  # Specify third axis
800
    if np.isclose(np.linalg.det(rotation_matrix), 0):
1✔
801
        return np.eye(
×
802
            3
803
        )  # view_plane = [0,0,1] is the default view of NGLview, so we do not modify it
804
    return np.roll(
1✔
805
        rotation_matrix / np.linalg.norm(rotation_matrix, axis=-1)[:, np.newaxis],
806
        2,
807
        axis=0,
808
    ).T
809

810

811
def _get_flattened_orientation(
1✔
812
    view_plane: np.ndarray, distance_from_camera: float
813
) -> list:
814
    """
815
    A helper method to plot3d, which generates a rotation matrix from the input `view_plane`, and returns a
816
    flattened list of len = 16. This flattened list becomes the input argument to `view.contol.orient`.
817

818
    Args:
819
        view_plane (numpy.ndarray/list): A Nx3-array/list (N = 1,2,3); the first 3d-component of the array
820
            specifies which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes),
821
            the second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the
822
            third component (if specified) is the vertical component, which is ignored and calculated internally.
823
            The orthonormality of the orientation is internally ensured, and therefore is not required in the
824
            function call.
825
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
826

827
    Returns:
828
        (list): Flattened list of len = 16, which is the input argument to `view.contol.orient`
829
    """
830
    if distance_from_camera <= 0:
1✔
831
        raise ValueError("´distance_from_camera´ must be a positive float!")
×
832
    flattened_orientation = np.eye(4)
1✔
833
    flattened_orientation[:3, :3] = _get_orientation(view_plane)
1✔
834

835
    return (distance_from_camera * flattened_orientation).ravel().tolist()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc