• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyiron / structuretoolkit / 9149077263

19 May 2024 04:34PM UTC coverage: 80.099% (-0.2%) from 80.317%
9149077263

Pull #185

github

web-flow
Merge 9f5c0de58 into 696fd525f
Pull Request #185: Create mesh visualization - just for CI

39 of 53 new or added lines in 5 files covered. (73.58%)

84 existing lines in 1 file now uncovered.

1457 of 1819 relevant lines covered (80.1%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

17.53
/structuretoolkit/visualize.py
1
# coding: utf-8
2
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
3
# Distributed under the terms of "New BSD License", see the LICENSE file.
4

5
from __future__ import annotations
1✔
6
import warnings
1✔
7

8
from ase.atoms import Atoms
1✔
9
import numpy as np
1✔
10
from typing import Optional
1✔
11
from scipy.interpolate import interp1d
1✔
12

13
__author__ = "Joerg Neugebauer, Sudarsan Surendralal"
1✔
14
__copyright__ = (
1✔
15
    "Copyright 2021, Max-Planck-Institut für Eisenforschung GmbH - "
16
    "Computational Materials Design (CM) Department"
17
)
18
__version__ = "1.0"
1✔
19
__maintainer__ = "Sudarsan Surendralal"
1✔
20
__email__ = "surendralal@mpie.de"
1✔
21
__status__ = "production"
1✔
22
__date__ = "Sep 1, 2017"
1✔
23

24

25
def plot3d(
1✔
26
    structure: Atoms,
27
    mode: str = "NGLview",
28
    show_cell: bool = True,
29
    show_axes: bool = True,
30
    camera: str = "orthographic",
31
    spacefill: bool = True,
32
    particle_size: float = 1.0,
33
    select_atoms: Optional[np.ndarray] = None,
34
    background: str = "white",
35
    color_scheme: Optional[str] = None,
36
    colors: Optional[np.ndarray] = None,
37
    scalar_field: Optional[np.ndarray] = None,
38
    scalar_start: Optional[float] = None,
39
    scalar_end: Optional[float] = None,
40
    scalar_cmap: Optional = None,
41
    vector_field: Optional[np.ndarray] = None,
42
    vector_color: Optional[np.ndarray] = None,
43
    magnetic_moments: bool = False,
44
    view_plane: np.ndarray = np.array([0, 0, 1]),
45
    distance_from_camera: float = 1.0,
46
    opacity: float = 1.0,
47
    height: Optional[float] = None,
48
):
49
    """
50
    Plot3d relies on NGLView or plotly to visualize atomic structures. Here, we construct a string in the "protein database"
51

52
    The final widget is returned. If it is assigned to a variable, the visualization is suppressed until that
53
    variable is evaluated, and in the meantime more NGL operations can be applied to it to modify the visualization.
54

55
    Args:
56
        mode (str): `NGLView`, `plotly` or `ase`
57
        show_cell (bool): Whether or not to show the frame. (Default is True.)
58
        show_axes (bool): Whether or not to show xyz axes. (Default is True.)
59
        camera (str): 'perspective' or 'orthographic'. (Default is 'perspective'.)
60
        spacefill (bool): Whether to use a space-filling or ball-and-stick representation. (Default is True, use
61
            space-filling atoms.)
62
        particle_size (float): Size of the particles. (Default is 1.)
63
        select_atoms (numpy.ndarray): Indices of atoms to show, either as integers or a boolean array mask.
64
            (Default is None, show all atoms.)
65
        background (str): Background color. (Default is 'white'.)
66
        color_scheme (str): NGLView color scheme to use. (Default is None, color by element.)
67
        colors (numpy.ndarray): A per-atom array of HTML color names or hex color codes to use for atomic colors.
68
            (Default is None, use coloring scheme.)
69
        scalar_field (numpy.ndarray): Color each atom according to the array value (Default is None, use coloring
70
            scheme.)
71
        scalar_start (float): The scalar value to be mapped onto the low end of the color map (lower values are
72
            clipped). (Default is None, use the minimum value in `scalar_field`.)
73
        scalar_end (float): The scalar value to be mapped onto the high end of the color map (higher values are
74
            clipped). (Default is None, use the maximum value in `scalar_field`.)
75
        scalar_cmap (matplotlib.cm): The colormap to use. (Default is None, giving a blue-red divergent map.)
76
        vector_field (numpy.ndarray): Add vectors (3 values) originating at each atom. (Default is None, no
77
            vectors.)
78
        vector_color (numpy.ndarray): Colors for the vectors (only available with vector_field). (Default is None,
79
            vectors are colored by their direction.)
80
        magnetic_moments (bool): Plot magnetic moments as 'scalar_field' or 'vector_field'.
81
        view_plane (numpy.ndarray): A Nx3-array (N = 1,2,3); the first 3d-component of the array specifies
82
            which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes), the
83
            second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the third
84
            component (if specified) is the vertical component, which is ignored and calculated internally. The
85
            orthonormality of the orientation is internally ensured, and therefore is not required in the function
86
            call. (Default is np.array([0, 0, 1]), which is view normal to the x-y plane.)
87
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
88
            (Default is 14, which also seems to be the NGLView default value.)
89
        height (int/float/None): height of the plot area in pixel (only
90
            available in plotly) Default: 600
91

92
        Possible NGLView color schemes:
93
          " ", "picking", "random", "uniform", "atomindex", "residueindex",
94
          "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
95
          "hydrophobicity", "value", "volume", "occupancy"
96

97
    Returns:
98
        (nglview.NGLWidget): The NGLView widget itself, which can be operated on further or viewed as-is.
99

100
    Warnings:
101
        * Many features only work with space-filling atoms (e.g. coloring by a scalar field).
102
        * The colour interpretation of some hex codes is weird, e.g. 'green'.
103
    """
UNCOV
104
    if mode == "NGLview":
×
UNCOV
105
        if height is not None:
×
UNCOV
106
            warnings.warn("`height` is not implemented in NGLview", SyntaxWarning)
×
107
        return _plot3d(
×
108
            structure=structure,
109
            show_cell=show_cell,
110
            show_axes=show_axes,
111
            camera=camera,
112
            spacefill=spacefill,
113
            particle_size=particle_size,
114
            select_atoms=select_atoms,
115
            background=background,
116
            color_scheme=color_scheme,
117
            colors=colors,
118
            scalar_field=scalar_field,
119
            scalar_start=scalar_start,
120
            scalar_end=scalar_end,
121
            scalar_cmap=scalar_cmap,
122
            vector_field=vector_field,
123
            vector_color=vector_color,
124
            magnetic_moments=magnetic_moments,
125
            view_plane=view_plane,
126
            distance_from_camera=distance_from_camera,
127
        )
UNCOV
128
    elif mode == "plotly":
×
UNCOV
129
        return _plot3d_plotly(
×
130
            structure=structure,
131
            show_cell=show_cell,
132
            camera=camera,
133
            particle_size=particle_size,
134
            select_atoms=select_atoms,
135
            scalar_field=scalar_field,
136
            view_plane=view_plane,
137
            distance_from_camera=distance_from_camera,
138
            opacity=opacity,
139
            height=height,
140
        )
UNCOV
141
    elif mode == "ase":
×
UNCOV
142
        if height is not None:
×
UNCOV
143
            warnings.warn("`height` is not implemented in ase", SyntaxWarning)
×
144
        return _plot3d_ase(
×
145
            structure=structure,
146
            show_cell=show_cell,
147
            show_axes=show_axes,
148
            camera=camera,
149
            spacefill=spacefill,
150
            particle_size=particle_size,
151
            background=background,
152
            color_scheme=color_scheme,
153
        )
154
    else:
UNCOV
155
        raise ValueError("plot method not recognized")
×
156

157

158
def _get_box_skeleton(cell: np.ndarray):
1✔
159
    lines_dz = np.stack(np.meshgrid(*3 * [[0, 1]], indexing="ij"), axis=-1)
1✔
160
    # eight corners of a unit cube, paired as four z-axis lines
161

162
    all_lines = np.reshape(
1✔
163
        [np.roll(lines_dz, i, axis=-1) for i in range(3)], (-1, 2, 3)
164
    )
165
    # All 12 two-point lines on the unit square
166
    return all_lines @ cell
1✔
167

168

169
def _plot3d_plotly(
1✔
170
    structure: Atoms,
171
    show_cell: bool = True,
172
    scalar_field: Optional[np.ndarray] = None,
173
    select_atoms: Optional[np.ndarray] = None,
174
    particle_size: float = 1.0,
175
    camera: str = "orthographic",
176
    view_plane: np.ndarray = np.array([1, 1, 1]),
177
    distance_from_camera: float = 1.0,
178
    opacity: float = 1.0,
179
    height: Optional[float] = None,
180
):
181
    """
182
    Make a 3D plot of the atomic structure.
183

184
    Args:
185
        camera (str): 'perspective' or 'orthographic'. (Default is 'perspective'.)
186
        particle_size (float): Size of the particles. (Default is 1.)
187
        scalar_field (numpy.ndarray): Color each atom according to the array value (Default is None, use coloring
188
            scheme.)
189
        view_plane (numpy.ndarray): A Nx3-array (N = 1,2,3); the first 3d-component of the array specifies
190
            which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes), the
191
            second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the third
192
            component (if specified) is the vertical component, which is ignored and calculated internally. The
193
            orthonormality of the orientation is internally ensured, and therefore is not required in the function
194
            call. (Default is np.array([0, 0, 1]), which is view normal to the x-y plane.)
195
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
196
            (Default is 14, which also seems to be the NGLView default value.)
197
        opacity (float): opacity
198
        height (int/float/None): height of the plot area in pixel. Default: 600
199

200
    Returns:
201
        (plotly.express): The NGLView widget itself, which can be operated on further or viewed as-is.
202

203
    """
UNCOV
204
    try:
×
UNCOV
205
        import plotly.express as px
×
UNCOV
206
        import plotly.graph_objects as go
×
207
    except ModuleNotFoundError:
×
208
        raise ModuleNotFoundError("plotly not installed - use plot3d instead")
×
209
    if select_atoms is None:
×
210
        select_atoms = np.arange(len(structure))
×
211
    elements = structure.get_chemical_symbols()
×
212
    atomic_numbers = structure.get_atomic_numbers()
×
213
    if scalar_field is None:
×
214
        scalar_field = elements
×
215
    fig = px.scatter_3d(
×
216
        x=structure.positions[select_atoms, 0],
217
        y=structure.positions[select_atoms, 1],
218
        z=structure.positions[select_atoms, 2],
219
        color=scalar_field,
220
        opacity=opacity,
221
        size=_atomic_number_to_radius(
222
            atomic_numbers,
223
            scale=particle_size / (0.1 * structure.get_volume() ** (1 / 3)),
224
        ),
225
    )
UNCOV
226
    if show_cell:
×
UNCOV
227
        data = fig.data
×
UNCOV
228
        for lines in _get_box_skeleton(structure.cell):
×
229
            fig = px.line_3d(**{xx: vv for xx, vv in zip(["x", "y", "z"], lines.T)})
×
230
            fig.update_traces(line_color="#000000")
×
231
            data = fig.data + data
×
232
        fig = go.Figure(data=data)
×
233
    fig.layout.scene.camera.projection.type = camera
×
234
    rot = _get_orientation(view_plane).T
×
235
    rot[0, :] *= distance_from_camera * 1.25
×
236
    angle = dict(
×
237
        up=dict(x=rot[2, 0], y=rot[2, 1], z=rot[2, 2]),
238
        eye=dict(x=rot[0, 0], y=rot[0, 1], z=rot[0, 2]),
239
    )
UNCOV
240
    fig.update_layout(scene_camera=angle)
×
UNCOV
241
    fig.update_traces(marker=dict(line=dict(width=0.1, color="DarkSlateGrey")))
×
UNCOV
242
    fig.update_scenes(aspectmode="data")
×
243
    if height is None:
×
244
        height = 600
×
245
    fig.update_layout(autosize=True, height=height)
×
246
    fig.update_layout(legend={"itemsizing": "constant"})
×
247
    return fig
×
248

249

250
def _plot3d(
1✔
251
    structure: Atoms,
252
    show_cell: bool = True,
253
    show_axes: bool = True,
254
    camera: str = "orthographic",
255
    spacefill: bool = True,
256
    particle_size: float = 1.0,
257
    select_atoms: Optional[np.ndarray] = None,
258
    background: str = "white",
259
    color_scheme: Optional[str] = None,
260
    colors: Optional[np.ndarray] = None,
261
    scalar_field: Optional[np.ndarray] = None,
262
    scalar_start: Optional[float] = None,
263
    scalar_end: Optional[float] = None,
264
    scalar_cmap: Optional = None,
265
    vector_field: Optional[np.ndarray] = None,
266
    vector_color: Optional[np.ndarray] = None,
267
    magnetic_moments: bool = False,
268
    view_plane: np.ndarray = np.array([0, 0, 1]),
269
    distance_from_camera: float = 1.0,
270
):
271
    """
272
    Plot3d relies on NGLView to visualize atomic structures. Here, we construct a string in the "protein database"
273
    ("pdb") format, then turn it into an NGLView "structure". PDB is a white-space sensitive format, so the
274
    string snippets are carefully formatted.
275

276
    The final widget is returned. If it is assigned to a variable, the visualization is suppressed until that
277
    variable is evaluated, and in the meantime more NGL operations can be applied to it to modify the visualization.
278

279
    Args:
280
        show_cell (bool): Whether or not to show the frame. (Default is True.)
281
        show_axes (bool): Whether or not to show xyz axes. (Default is True.)
282
        camera (str): 'perspective' or 'orthographic'. (Default is 'perspective'.)
283
        spacefill (bool): Whether to use a space-filling or ball-and-stick representation. (Default is True, use
284
            space-filling atoms.)
285
        particle_size (float): Size of the particles. (Default is 1.)
286
        select_atoms (numpy.ndarray): Indices of atoms to show, either as integers or a boolean array mask.
287
            (Default is None, show all atoms.)
288
        background (str): Background color. (Default is 'white'.)
289
        color_scheme (str): NGLView color scheme to use. (Default is None, color by element.)
290
        colors (numpy.ndarray): A per-atom array of HTML color names or hex color codes to use for atomic colors.
291
            (Default is None, use coloring scheme.)
292
        scalar_field (numpy.ndarray): Color each atom according to the array value (Default is None, use coloring
293
            scheme.)
294
        scalar_start (float): The scalar value to be mapped onto the low end of the color map (lower values are
295
            clipped). (Default is None, use the minimum value in `scalar_field`.)
296
        scalar_end (float): The scalar value to be mapped onto the high end of the color map (higher values are
297
            clipped). (Default is None, use the maximum value in `scalar_field`.)
298
        scalar_cmap (matplotlib.cm): The colormap to use. (Default is None, giving a blue-red divergent map.)
299
        vector_field (numpy.ndarray): Add vectors (3 values) originating at each atom. (Default is None, no
300
            vectors.)
301
        vector_color (numpy.ndarray): Colors for the vectors (only available with vector_field). (Default is None,
302
            vectors are colored by their direction.)
303
        magnetic_moments (bool): Plot magnetic moments as 'scalar_field' or 'vector_field'.
304
        view_plane (numpy.ndarray): A Nx3-array (N = 1,2,3); the first 3d-component of the array specifies
305
            which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes), the
306
            second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the third
307
            component (if specified) is the vertical component, which is ignored and calculated internally. The
308
            orthonormality of the orientation is internally ensured, and therefore is not required in the function
309
            call. (Default is np.array([0, 0, 1]), which is view normal to the x-y plane.)
310
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
311
            (Default is 14, which also seems to be the NGLView default value.)
312

313
        Possible NGLView color schemes:
314
          " ", "picking", "random", "uniform", "atomindex", "residueindex",
315
          "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
316
          "hydrophobicity", "value", "volume", "occupancy"
317

318
    Returns:
319
        (nglview.NGLWidget): The NGLView widget itself, which can be operated on further or viewed as-is.
320

321
    Warnings:
322
        * Many features only work with space-filling atoms (e.g. coloring by a scalar field).
323
        * The colour interpretation of some hex codes is weird, e.g. 'green'.
324
    """
UNCOV
325
    try:  # If the graphical packages are not available, the GUI will not work.
×
UNCOV
326
        import nglview
×
UNCOV
327
    except ImportError:
×
328
        raise ImportError(
×
329
            "The package nglview needs to be installed for the plot3d() function!"
330
        )
331

UNCOV
332
    if (
×
333
        magnetic_moments is True
334
        and np.sum(np.abs(structure.get_initial_magnetic_moments())) > 0
335
    ):
UNCOV
336
        if len(structure.get_initial_magnetic_moments().shape) == 1:
×
UNCOV
337
            scalar_field = structure.get_initial_magnetic_moments()
×
338
        else:
339
            vector_field = structure.get_initial_magnetic_moments()
×
340

UNCOV
341
    elements = structure.get_chemical_symbols()
×
342
    atomic_numbers = structure.get_atomic_numbers()
×
UNCOV
343
    positions = structure.positions
×
344

345
    # If `select_atoms` was given, visualize only a subset of the `parent_basis`
346
    if select_atoms is not None:
×
UNCOV
347
        select_atoms = np.array(select_atoms, dtype=int)
×
UNCOV
348
        elements = np.array(elements)[select_atoms]
×
349
        atomic_numbers = atomic_numbers[select_atoms]
×
350
        positions = positions[select_atoms]
×
351
        if colors is not None:
×
352
            colors = np.array(colors)
×
353
            colors = colors[select_atoms]
×
354
        if scalar_field is not None:
×
355
            scalar_field = np.array(scalar_field)
×
356
            scalar_field = scalar_field[select_atoms]
×
357
        if vector_field is not None:
×
358
            vector_field = np.array(vector_field)
×
359
            vector_field = vector_field[select_atoms]
×
360
        if vector_color is not None:
×
361
            vector_color = np.array(vector_color)
×
362
            vector_color = vector_color[select_atoms]
×
363

364
    # Write the nglview protein-database-formatted string
365
    struct = nglview.TextStructure(
×
366
        _ngl_write_structure(elements, positions, structure.cell)
367
    )
368

369
    # Parse the string into the displayable widget
UNCOV
370
    view = nglview.NGLWidget(struct)
×
371

UNCOV
372
    if spacefill:
×
373
        # Color by scheme
UNCOV
374
        if color_scheme is not None:
×
375
            if colors is not None:
×
UNCOV
376
                warnings.warn("`color_scheme` is overriding `colors`")
×
377
            if scalar_field is not None:
×
378
                warnings.warn("`color_scheme` is overriding `scalar_field`")
×
379
            view = _add_colorscheme_spacefill(
×
380
                view, elements, atomic_numbers, particle_size, color_scheme
381
            )
382
        # Color by per-atom colors
UNCOV
383
        elif colors is not None:
×
UNCOV
384
            if scalar_field is not None:
×
UNCOV
385
                warnings.warn("`colors` is overriding `scalar_field`")
×
386
            view = _add_custom_color_spacefill(
×
387
                view, atomic_numbers, particle_size, colors
388
            )
389
        # Color by per-atom scalars
UNCOV
390
        elif scalar_field is not None:  # Color by per-atom scalars
×
UNCOV
391
            colors = _scalars_to_hex_colors(
×
392
                scalar_field, scalar_start, scalar_end, scalar_cmap
393
            )
394
            view = _add_custom_color_spacefill(
×
395
                view, atomic_numbers, particle_size, colors
396
            )
397
        # Color by element
398
        else:
UNCOV
399
            view = _add_colorscheme_spacefill(
×
400
                view, elements, atomic_numbers, particle_size
401
            )
402
        view.remove_ball_and_stick()
×
403
    else:
UNCOV
404
        view.add_ball_and_stick()
×
405

UNCOV
406
    if show_cell:
×
407
        if structure.cell is not None:
×
UNCOV
408
            if all(np.max(structure.cell, axis=0) > 1e-2):
×
409
                view.add_unitcell()
×
410

411
    if vector_color is None and vector_field is not None:
×
412
        vector_color = (
×
413
            0.5
414
            * np.array(vector_field)
415
            / np.linalg.norm(vector_field, axis=-1)[:, np.newaxis]
416
            + 0.5
417
        )
UNCOV
418
    elif (
×
419
        vector_field is not None and vector_field is not None
420
    ):  # WARNING: There must be a bug here...
421
        try:
×
UNCOV
422
            if vector_color.shape != np.ones((len(structure), 3)).shape:
×
UNCOV
423
                vector_color = np.outer(
×
424
                    np.ones(len(structure)),
425
                    vector_color / np.linalg.norm(vector_color),
426
                )
UNCOV
427
        except AttributeError:
×
UNCOV
428
            vector_color = np.ones((len(structure), 3)) * vector_color
×
429

430
    if vector_field is not None:
×
431
        for arr, pos, col in zip(vector_field, positions, vector_color):
×
UNCOV
432
            view.shape.add_arrow(list(pos), list(pos + arr), list(col), 0.2)
×
433

434
    if show_axes:  # Add axes
×
435
        axes_origin = -np.ones(3)
×
UNCOV
436
        arrow_radius = 0.1
×
437
        text_size = 1
×
438
        text_color = [0, 0, 0]
×
439
        arrow_names = ["x", "y", "z"]
×
440

441
        for n in [0, 1, 2]:
×
442
            start = list(axes_origin)
×
UNCOV
443
            shift = np.zeros(3)
×
444
            shift[n] = 1
×
445
            end = list(start + shift)
×
446
            color = list(shift)
×
447
            # We cast as list to avoid JSON warnings
448
            view.shape.add_arrow(start, end, color, arrow_radius)
×
449
            view.shape.add_text(end, text_color, text_size, arrow_names[n])
×
450

451
    if camera != "perspective" and camera != "orthographic":
×
452
        warnings.warn(
×
453
            "Only perspective or orthographic is (likely to be) permitted for camera"
454
        )
455

UNCOV
456
    view.camera = camera
×
UNCOV
457
    view.background = background
×
458

459
    orientation = _get_flattened_orientation(
×
460
        view_plane=view_plane, distance_from_camera=distance_from_camera * 14
461
    )
462
    view.control.orient(orientation)
×
463

UNCOV
464
    return view
×
465

466

467
def _plot3d_ase(
1✔
468
    structure: Atoms,
469
    spacefill: bool = True,
470
    show_cell: bool = True,
471
    camera: str = "perspective",
472
    particle_size: float = 0.5,
473
    background: str = "white",
474
    color_scheme: str = "element",
475
    show_axes: bool = True,
476
):
477
    """
478
    Possible color schemes:
479
      " ", "picking", "random", "uniform", "atomindex", "residueindex",
480
      "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
481
      "hydrophobicity", "value", "volume", "occupancy"
482
    Returns:
483
    """
UNCOV
484
    try:  # If the graphical packages are not available, the GUI will not work.
×
UNCOV
485
        import nglview
×
UNCOV
486
    except ImportError:
×
487
        raise ImportError(
×
488
            "The package nglview needs to be installed for the plot3d() function!"
489
        )
490
    # Always visualize the parent basis
UNCOV
491
    view = nglview.show_ase(structure)
×
UNCOV
492
    if spacefill:
×
UNCOV
493
        view.add_spacefill(
×
494
            radius_type="vdw", color_scheme=color_scheme, radius=particle_size
495
        )
496
        # view.add_spacefill(radius=1.0)
UNCOV
497
        view.remove_ball_and_stick()
×
498
    else:
UNCOV
499
        view.add_ball_and_stick()
×
500
    if show_cell:
×
UNCOV
501
        if structure.cell is not None:
×
502
            if all(np.max(structure.cell, axis=0) > 1e-2):
×
503
                view.add_unitcell()
×
504
    if show_axes:
×
505
        view.shape.add_arrow([-2, -2, -2], [2, -2, -2], [1, 0, 0], 0.5)
×
506
        view.shape.add_arrow([-2, -2, -2], [-2, 2, -2], [0, 1, 0], 0.5)
×
507
        view.shape.add_arrow([-2, -2, -2], [-2, -2, 2], [0, 0, 1], 0.5)
×
508
    if camera != "perspective" and camera != "orthographic":
×
509
        print("Only perspective or orthographic is permitted")
×
510
        return None
×
511
    view.camera = camera
×
512
    view.background = background
×
513
    return view
×
514

515

516
def _ngl_write_cell(
1✔
517
    a1: float,
518
    a2: float,
519
    a3: float,
520
    f1: float = 90.0,
521
    f2: float = 90.0,
522
    f3: float = 90.0,
523
):
524
    """
525
    Writes a PDB-formatted line to represent the simulation cell.
526

527
    Args:
528
        a1, a2, a3 (float): Lengths of the cell vectors.
529
        f1, f2, f3 (float): Angles between the cell vectors (which angles exactly?) (in degrees).
530

531
    Returns:
532
        (str): The line defining the cell in PDB format.
533
    """
UNCOV
534
    return "CRYST1 {:8.3f} {:8.3f} {:8.3f} {:6.2f} {:6.2f} {:6.2f} P 1\n".format(
×
535
        a1, a2, a3, f1, f2, f3
536
    )
537

538

539
def _ngl_write_atom(
1✔
540
    num: int,
541
    species: str,
542
    x: float,
543
    y: float,
544
    z: float,
545
    group: Optional[str] = None,
546
    num2: Optional[int] = None,
547
    occupancy: float = 1.0,
548
    temperature_factor: float = 0.0,
549
) -> str:
550
    """
551
    Writes a PDB-formatted line to represent an atom.
552

553
    Args:
554
        num (int): Atomic index.
555
        species (str): Elemental species.
556
        x, y, z (float): Cartesian coordinates of the atom.
557
        group (str): A...group name? (Default is None, repeat elemental species.)
558
        num2 (int): An "alternate" index. (Don't ask me...) (Default is None, repeat first number.)
559
        occupancy (float): PDB occupancy parameter. (Default is 1.)
560
        temperature_factor (float): PDB temperature factor parameter. (Default is 0.
561

562
    Returns:
563
        (str): The line defining an atom in PDB format
564

565
    Warnings:
566
        * The [PDB docs](https://www.cgl.ucsf.edu/chimera/docs/UsersGuide/tutorials/pdbintro.html) indicate that
567
            the xyz coordinates might need to be in some sort of orthogonal basis. If you have weird behaviour,
568
            this might be a good place to investigate.
569
    """
UNCOV
570
    if group is None:
×
UNCOV
571
        group = species
×
UNCOV
572
    if num2 is None:
×
573
        num2 = num
×
574
    return "ATOM {:>6} {:>4} {:>4} {:>5} {:10.3f} {:7.3f} {:7.3f} {:5.2f} {:5.2f} {:>11} \n".format(
×
575
        num, species, group, num2, x, y, z, occupancy, temperature_factor, species
576
    )
577

578

579
def _ngl_write_structure(
1✔
580
    elements: np.ndarray, positions: np.ndarray, cell: np.ndarray
581
) -> str:
582
    """
583
    Turns structure information into a NGLView-readable protein-database-formatted string.
584

585
    Args:
586
        elements (numpy.ndarray/list): Element symbol for each atom.
587
        positions (numpy.ndarray/list): Vector of Cartesian atom positions.
588
        cell (numpy.ndarray/list): Simulation cell Bravais matrix.
589

590
    Returns:
591
        (str): The PDB-formatted representation of the structure.
592
    """
UNCOV
593
    from ase.geometry import cell_to_cellpar, cellpar_to_cell
×
594

UNCOV
595
    if cell is None or any(np.max(cell, axis=0) < 1e-2):
×
596
        # Define a dummy cell if it doesn't exist (eg. for clusters)
UNCOV
597
        max_pos = np.max(positions, axis=0) - np.min(positions, axis=0)
×
598
        max_pos[np.abs(max_pos) < 1e-2] = 10
×
UNCOV
599
        cell = np.eye(3) * max_pos
×
600
    cellpar = cell_to_cellpar(cell)
×
601
    exportedcell = cellpar_to_cell(cellpar)
×
602
    rotation = np.linalg.solve(cell, exportedcell)
×
603

604
    pdb_str = _ngl_write_cell(*cellpar)
×
605
    pdb_str += "MODEL     1\n"
×
606

607
    if rotation is not None:
×
608
        positions = np.array(positions).dot(rotation)
×
609

610
    for i, p in enumerate(positions):
×
611
        pdb_str += _ngl_write_atom(i, elements[i], *p)
×
612

613
    pdb_str += "ENDMDL \n"
×
614
    return pdb_str
×
615

616

617
def _atomic_number_to_radius(
1✔
618
    atomic_number: int, shift: float = 0.2, slope: float = 0.1, scale: float = 1.0
619
) -> float:
620
    """
621
    Give the atomic radius for plotting, which scales like the root of the atomic number.
622

623
    Args:
624
        atomic_number (int/float): The atomic number.
625
        shift (float): A constant addition to the radius. (Default is 0.2.)
626
        slope (float): A multiplier for the root of the atomic number. (Default is 0.1)
627
        scale (float): How much to rescale the whole thing by.
628

629
    Returns:
630
        (float): The radius. (Not physical, just for visualization!)
631
    """
UNCOV
632
    return (shift + slope * np.sqrt(atomic_number)) * scale
×
633

634

635
def _add_colorscheme_spacefill(
1✔
636
    view,
637
    elements: np.ndarray,
638
    atomic_numbers: np.ndarray,
639
    particle_size: float,
640
    scheme: str = "element",
641
):
642
    """
643
    Set NGLView spacefill parameters according to a color-scheme.
644

645
    Args:
646
        view (NGLWidget): The widget to work on.
647
        elements (numpy.ndarray/list): Elemental symbols.
648
        atomic_numbers (numpy.ndarray/list): Integer atomic numbers for determining atomic size.
649
        particle_size (float): A scale factor for the atomic size.
650
        scheme (str): The scheme to use. (Default is "element".)
651

652
        Possible NGLView color schemes:
653
          " ", "picking", "random", "uniform", "atomindex", "residueindex",
654
          "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
655
          "hydrophobicity", "value", "volume", "occupancy"
656

657
    Returns:
658
        (nglview.NGLWidget): The modified widget.
659
    """
UNCOV
660
    for elem, num in set(list(zip(elements, atomic_numbers))):
×
UNCOV
661
        view.add_spacefill(
×
662
            selection="#" + elem,
663
            radius_type="vdw",
664
            radius=_atomic_number_to_radius(num, scale=particle_size),
665
            color_scheme=scheme,
666
        )
UNCOV
667
    return view
×
668

669

670
def _add_custom_color_spacefill(
1✔
671
    view, atomic_numbers: np.ndarray, particle_size: float, colors: np.ndarray
672
):
673
    """
674
    Set NGLView spacefill parameters according to per-atom colors.
675

676
    Args:
677
        view (NGLWidget): The widget to work on.
678
        atomic_numbers (numpy.ndarray/list): Integer atomic numbers for determining atomic size.
679
        particle_size (float): A scale factor for the atomic size.
680
        colors (numpy.ndarray/list): A per-atom list of HTML or hex color codes.
681

682
    Returns:
683
        (nglview.NGLWidget): The modified widget.
684
    """
UNCOV
685
    for n, num in enumerate(atomic_numbers):
×
UNCOV
686
        view.add_spacefill(
×
687
            selection=[n],
688
            radius_type="vdw",
689
            radius=_atomic_number_to_radius(num, scale=particle_size),
690
            color=colors[n],
691
        )
UNCOV
692
    return view
×
693

694

695
def _scalars_to_hex_colors(
1✔
696
    scalar_field: np.ndarray,
697
    start: Optional[float] = None,
698
    end: Optional[float] = None,
699
    cmap=None,
700
):
701
    """
702
    Convert scalar values to hex codes using a colormap.
703

704
    Args:
705
        scalar_field (numpy.ndarray/list): Scalars to convert.
706
        start (float): Scalar value to map to the bottom of the colormap (values below are clipped). (Default is
707
            None, use the minimal scalar value.)
708
        end (float): Scalar value to map to the top of the colormap (values above are clipped).  (Default is
709
            None, use the maximal scalar value.)
710
        cmap (matplotlib.cm): The colormap to use. (Default is None, which gives a blue-red divergent map.)
711

712
    Returns:
713
        (list): The corresponding hex codes for each scalar value passed in.
714
    """
UNCOV
715
    from matplotlib.colors import rgb2hex
×
716

UNCOV
717
    if start is None:
×
718
        start = np.amin(scalar_field)
×
UNCOV
719
    if end is None:
×
720
        end = np.amax(scalar_field)
×
721
    interp = interp1d([start, end], [0, 1])
×
722
    remapped_field = interp(np.clip(scalar_field, start, end))  # Map field onto [0,1]
×
723

724
    if cmap is None:
×
725
        try:
×
UNCOV
726
            from seaborn import diverging_palette
×
727
        except ImportError:
×
728
            print(
×
729
                "The package seaborn needs to be installed for the plot3d() function!"
730
            )
731
        cmap = diverging_palette(245, 15, as_cmap=True)  # A nice blue-red palette
×
732

UNCOV
733
    return [
×
734
        rgb2hex(cmap(scalar)[:3]) for scalar in remapped_field
735
    ]  # The slice gets RGB but leaves alpha
736

737

738
def _get_orientation(view_plane: np.ndarray) -> np.ndarray:
1✔
739
    """
740
    A helper method to plot3d, which generates a rotation matrix from the input `view_plane`, and returns a
741
    flattened list of len = 16. This flattened list becomes the input argument to `view.contol.orient`.
742

743
    Args:
744
        view_plane (numpy.ndarray/list): A Nx3-array/list (N = 1,2,3); the first 3d-component of the array
745
            specifies which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes),
746
            the second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the
747
            third component (if specified) is the vertical component, which is ignored and calculated internally.
748
            The orthonormality of the orientation is internally ensured, and therefore is not required in the
749
            function call.
750

751
    Returns:
752
        (list): orientation tensor
753
    """
754
    if len(np.array(view_plane).flatten()) % 3 != 0:
1✔
UNCOV
755
        raise ValueError(
×
756
            "The shape of view plane should be (N, 3), where N = 1, 2 or 3. Refer docs for more info."
757
        )
758
    view_plane = np.array(view_plane).reshape(-1, 3)
1✔
759
    rotation_matrix = np.roll(np.eye(3), -1, axis=0)
1✔
760
    rotation_matrix[: len(view_plane)] = view_plane
1✔
761
    rotation_matrix /= np.linalg.norm(rotation_matrix, axis=-1)[:, np.newaxis]
1✔
762
    rotation_matrix[1] -= (
1✔
763
        np.dot(rotation_matrix[0], rotation_matrix[1]) * rotation_matrix[0]
764
    )  # Gran-Schmidt
765
    rotation_matrix[2] = np.cross(
1✔
766
        rotation_matrix[0], rotation_matrix[1]
767
    )  # Specify third axis
768
    if np.isclose(np.linalg.det(rotation_matrix), 0):
1✔
UNCOV
769
        return np.eye(
×
770
            3
771
        )  # view_plane = [0,0,1] is the default view of NGLview, so we do not modify it
772
    return np.roll(
1✔
773
        rotation_matrix / np.linalg.norm(rotation_matrix, axis=-1)[:, np.newaxis],
774
        2,
775
        axis=0,
776
    ).T
777

778

779
def _get_flattened_orientation(
1✔
780
    view_plane: np.ndarray, distance_from_camera: float
781
) -> list:
782
    """
783
    A helper method to plot3d, which generates a rotation matrix from the input `view_plane`, and returns a
784
    flattened list of len = 16. This flattened list becomes the input argument to `view.contol.orient`.
785

786
    Args:
787
        view_plane (numpy.ndarray/list): A Nx3-array/list (N = 1,2,3); the first 3d-component of the array
788
            specifies which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes),
789
            the second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the
790
            third component (if specified) is the vertical component, which is ignored and calculated internally.
791
            The orthonormality of the orientation is internally ensured, and therefore is not required in the
792
            function call.
793
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
794

795
    Returns:
796
        (list): Flattened list of len = 16, which is the input argument to `view.contol.orient`
797
    """
798
    if distance_from_camera <= 0:
1✔
UNCOV
799
        raise ValueError("´distance_from_camera´ must be a positive float!")
×
800
    flattened_orientation = np.eye(4)
1✔
801
    flattened_orientation[:3, :3] = _get_orientation(view_plane)
1✔
802

803
    return (distance_from_camera * flattened_orientation).ravel().tolist()
1✔
804

805

806
def plot_isosurface(
1✔
807
    mesh,
808
    value,
809
    structure_plot: Optional["plotly.graph_objs._figure.Figure"] = None,
810
    isomin: Optional[float] = None,
811
    isomax: Optional[float] = None,
812
    surface_fill: Optional[float] = None,
813
    opacity: Optional[float] = None,
814
    surface_count: Optional[int] = None,
815
    colorbar_nticks: Optional[int] = None,
816
    caps: Optional[dict] = dict(x_show=False, y_show=False, z_show=False),
817
    colorscale: Optional[str] = None,
818
    height: Optional[float] = 600,
819
    camera: Optional[str] = "orthographic",
820
    **kwargs,
821
):
822
    """
823
    Make a mesh plot
824

825
    Args:
826
        mesh (numpy.ndarray): Mesh grid. Must have a shape of (3, nx, ny, nz).
827
            It can be generated from structuretoolkit.create_mesh
828
        value: (numpy.ndarray): Value to plot. Must have a shape of (nx, ny, nz)
829
        structure_plot (plotly.graph_objs._figure.Figure): Plot of the
830
            structure to overlay. You should basically always use
831
            structuretoolkit.plot3d(structure, mode="plotly")
832
        isomin(float): Min color value
833
        isomax(float): Max color value
834
        surface_fill(float): Polygonal filling of the surface to choose between
835
            0 and 1
836
        opacity(float): Opacity
837
        surface_count(int): Number of isosurfaces, 2 by default, which means
838
            only min and max
839
        colorbar_nticks(int): Colorbar ticks correspond to isosurface values
840
        caps(dict): Whether to set cap on sides or not. Default is False. You
841
            can set: caps=dict(x_show=True, y_show=True, z_show=True)
842
        colorscale(str): Colorscale ("turbo", "gnbu" etc.)
843
        height(float): Height of the figure. 600px by default
844
        camera(str): Camera perspective to choose from "orthographic" and
845
            "perspective". Default is "orthographic"
846
    """
NEW
847
    try:
×
NEW
848
        import plotly.graph_objects as go
×
NEW
849
    except ModuleNotFoundError:
×
NEW
850
        raise ModuleNotFoundError("plotly not installed")
×
NEW
851
    x_mesh = np.reshape(mesh, (3, -1))
×
NEW
852
    data = go.Isosurface(
×
853
        x=x_mesh[0],
854
        y=x_mesh[1],
855
        z=x_mesh[2],
856
        value=np.array(value).flatten(),
857
        isomin=isomin,
858
        isomax=isomax,
859
        surface_fill=surface_fill,
860
        opacity=opacity,
861
        surface_count=surface_count,
862
        colorbar_nticks=colorbar_nticks,
863
        caps=caps,
864
        colorscale=colorscale,
865
        **kwargs,
866
    )
NEW
867
    fig = go.Figure(data=data)
×
NEW
868
    if structure_plot is not None:
×
NEW
869
        fig = go.Figure(data=fig.data + structure_plot.data)
×
NEW
870
    fig.update_scenes(aspectmode="data")
×
NEW
871
    fig.layout.scene.camera.projection.type = camera
×
NEW
872
    fig.update_layout(autosize=True, height=height)
×
NEW
873
    return fig
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc