• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyiron / structuretoolkit / 5603956606

pending completion
5603956606

Pull #39

github-actions

web-flow
Merge 109267b08 into 4a76a3735
Pull Request #39: Drop Python 3.8

2154 of 2479 relevant lines covered (86.89%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

16.9
/structuretoolkit/visualize.py
1
# coding: utf-8
2
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
3
# Distributed under the terms of "New BSD License", see the LICENSE file.
4

5
import warnings
1✔
6

7
import numpy as np
1✔
8
from scipy.interpolate import interp1d
1✔
9

10
__author__ = "Joerg Neugebauer, Sudarsan Surendralal"
1✔
11
__copyright__ = (
1✔
12
    "Copyright 2021, Max-Planck-Institut für Eisenforschung GmbH - "
13
    "Computational Materials Design (CM) Department"
14
)
15
__version__ = "1.0"
1✔
16
__maintainer__ = "Sudarsan Surendralal"
1✔
17
__email__ = "surendralal@mpie.de"
1✔
18
__status__ = "production"
1✔
19
__date__ = "Sep 1, 2017"
1✔
20

21

22
def plot3d(
1✔
23
    structure,
24
    mode="NGLview",
25
    show_cell=True,
26
    show_axes=True,
27
    camera="orthographic",
28
    spacefill=True,
29
    particle_size=1.0,
30
    select_atoms=None,
31
    background="white",
32
    color_scheme=None,
33
    colors=None,
34
    scalar_field=None,
35
    scalar_start=None,
36
    scalar_end=None,
37
    scalar_cmap=None,
38
    vector_field=None,
39
    vector_color=None,
40
    magnetic_moments=False,
41
    view_plane=np.array([0, 0, 1]),
42
    distance_from_camera=1.0,
43
    opacity=1.0,
44
):
45
    """
46
    Plot3d relies on NGLView or plotly to visualize atomic structures. Here, we construct a string in the "protein database"
47

48
    The final widget is returned. If it is assigned to a variable, the visualization is suppressed until that
49
    variable is evaluated, and in the meantime more NGL operations can be applied to it to modify the visualization.
50

51
    Args:
52
        mode (str): `NGLView`, `plotly` or `ase`
53
        show_cell (bool): Whether or not to show the frame. (Default is True.)
54
        show_axes (bool): Whether or not to show xyz axes. (Default is True.)
55
        camera (str): 'perspective' or 'orthographic'. (Default is 'perspective'.)
56
        spacefill (bool): Whether to use a space-filling or ball-and-stick representation. (Default is True, use
57
            space-filling atoms.)
58
        particle_size (float): Size of the particles. (Default is 1.)
59
        select_atoms (numpy.ndarray): Indices of atoms to show, either as integers or a boolean array mask.
60
            (Default is None, show all atoms.)
61
        background (str): Background color. (Default is 'white'.)
62
        color_scheme (str): NGLView color scheme to use. (Default is None, color by element.)
63
        colors (numpy.ndarray): A per-atom array of HTML color names or hex color codes to use for atomic colors.
64
            (Default is None, use coloring scheme.)
65
        scalar_field (numpy.ndarray): Color each atom according to the array value (Default is None, use coloring
66
            scheme.)
67
        scalar_start (float): The scalar value to be mapped onto the low end of the color map (lower values are
68
            clipped). (Default is None, use the minimum value in `scalar_field`.)
69
        scalar_end (float): The scalar value to be mapped onto the high end of the color map (higher values are
70
            clipped). (Default is None, use the maximum value in `scalar_field`.)
71
        scalar_cmap (matplotlib.cm): The colormap to use. (Default is None, giving a blue-red divergent map.)
72
        vector_field (numpy.ndarray): Add vectors (3 values) originating at each atom. (Default is None, no
73
            vectors.)
74
        vector_color (numpy.ndarray): Colors for the vectors (only available with vector_field). (Default is None,
75
            vectors are colored by their direction.)
76
        magnetic_moments (bool): Plot magnetic moments as 'scalar_field' or 'vector_field'.
77
        view_plane (numpy.ndarray): A Nx3-array (N = 1,2,3); the first 3d-component of the array specifies
78
            which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes), the
79
            second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the third
80
            component (if specified) is the vertical component, which is ignored and calculated internally. The
81
            orthonormality of the orientation is internally ensured, and therefore is not required in the function
82
            call. (Default is np.array([0, 0, 1]), which is view normal to the x-y plane.)
83
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
84
            (Default is 14, which also seems to be the NGLView default value.)
85

86
        Possible NGLView color schemes:
87
          " ", "picking", "random", "uniform", "atomindex", "residueindex",
88
          "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
89
          "hydrophobicity", "value", "volume", "occupancy"
90

91
    Returns:
92
        (nglview.NGLWidget): The NGLView widget itself, which can be operated on further or viewed as-is.
93

94
    Warnings:
95
        * Many features only work with space-filling atoms (e.g. coloring by a scalar field).
96
        * The colour interpretation of some hex codes is weird, e.g. 'green'.
97
    """
98
    if mode == "NGLview":
×
99
        return _plot3d(
×
100
            structure=structure,
101
            show_cell=show_cell,
102
            show_axes=show_axes,
103
            camera=camera,
104
            spacefill=spacefill,
105
            particle_size=particle_size,
106
            select_atoms=select_atoms,
107
            background=background,
108
            color_scheme=color_scheme,
109
            colors=colors,
110
            scalar_field=scalar_field,
111
            scalar_start=scalar_start,
112
            scalar_end=scalar_end,
113
            scalar_cmap=scalar_cmap,
114
            vector_field=vector_field,
115
            vector_color=vector_color,
116
            magnetic_moments=magnetic_moments,
117
            view_plane=view_plane,
118
            distance_from_camera=distance_from_camera,
119
        )
120
    elif mode == "plotly":
×
121
        return _plot3d_plotly(
×
122
            structure=structure,
123
            camera=camera,
124
            particle_size=particle_size,
125
            select_atoms=select_atoms,
126
            scalar_field=scalar_field,
127
            view_plane=view_plane,
128
            distance_from_camera=distance_from_camera,
129
            opacity=opacity,
130
        )
131
    elif mode == "ase":
×
132
        return _plot3d_ase(
×
133
            structure=structure,
134
            show_cell=show_cell,
135
            show_axes=show_axes,
136
            camera=camera,
137
            spacefill=spacefill,
138
            particle_size=particle_size,
139
            background=background,
140
            color_scheme=color_scheme,
141
        )
142
    else:
143
        raise ValueError("plot method not recognized")
×
144

145

146
def _plot3d_plotly(
1✔
147
    structure,
148
    scalar_field=None,
149
    select_atoms=None,
150
    particle_size=1.0,
151
    camera="orthographic",
152
    view_plane=np.array([1, 1, 1]),
153
    distance_from_camera=1,
154
    opacity=1,
155
):
156
    """
157
    Make a 3D plot of the atomic structure.
158

159
    Args:
160
        camera (str): 'perspective' or 'orthographic'. (Default is 'perspective'.)
161
        particle_size (float): Size of the particles. (Default is 1.)
162
        scalar_field (numpy.ndarray): Color each atom according to the array value (Default is None, use coloring
163
            scheme.)
164
        view_plane (numpy.ndarray): A Nx3-array (N = 1,2,3); the first 3d-component of the array specifies
165
            which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes), the
166
            second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the third
167
            component (if specified) is the vertical component, which is ignored and calculated internally. The
168
            orthonormality of the orientation is internally ensured, and therefore is not required in the function
169
            call. (Default is np.array([0, 0, 1]), which is view normal to the x-y plane.)
170
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
171
            (Default is 14, which also seems to be the NGLView default value.)
172
        opacity (float): opacity
173

174
    Returns:
175
        (plotly.express): The NGLView widget itself, which can be operated on further or viewed as-is.
176

177
    """
178
    try:
×
179
        import plotly.express as px
×
180
    except ModuleNotFoundError:
×
181
        raise ModuleNotFoundError("plotly not installed - use plot3d instead")
×
182
    if select_atoms is None:
×
183
        select_atoms = np.arange(len(structure))
×
184
    elements = structure.get_chemical_symbols()
×
185
    atomic_numbers = structure.get_atomic_numbers()
×
186
    if scalar_field is None:
×
187
        scalar_field = elements
×
188
    fig = px.scatter_3d(
×
189
        x=structure.positions[select_atoms, 0],
190
        y=structure.positions[select_atoms, 1],
191
        z=structure.positions[select_atoms, 2],
192
        color=scalar_field,
193
        opacity=opacity,
194
        size=_atomic_number_to_radius(
195
            atomic_numbers,
196
            scale=particle_size / (0.1 * structure.get_volume() ** (1 / 3)),
197
        ),
198
    )
199
    fig.layout.scene.camera.projection.type = camera
×
200
    rot = _get_orientation(view_plane).T
×
201
    rot[0, :] *= distance_from_camera * 1.25
×
202
    angle = dict(
×
203
        up=dict(x=rot[2, 0], y=rot[2, 1], z=rot[2, 2]),
204
        eye=dict(x=rot[0, 0], y=rot[0, 1], z=rot[0, 2]),
205
    )
206
    fig.update_layout(scene_camera=angle)
×
207
    fig.update_traces(marker=dict(line=dict(width=0.1, color="DarkSlateGrey")))
×
208
    return fig
×
209

210

211
def _plot3d(
1✔
212
    structure,
213
    show_cell=True,
214
    show_axes=True,
215
    camera="orthographic",
216
    spacefill=True,
217
    particle_size=1.0,
218
    select_atoms=None,
219
    background="white",
220
    color_scheme=None,
221
    colors=None,
222
    scalar_field=None,
223
    scalar_start=None,
224
    scalar_end=None,
225
    scalar_cmap=None,
226
    vector_field=None,
227
    vector_color=None,
228
    magnetic_moments=False,
229
    view_plane=np.array([0, 0, 1]),
230
    distance_from_camera=1.0,
231
):
232
    """
233
    Plot3d relies on NGLView to visualize atomic structures. Here, we construct a string in the "protein database"
234
    ("pdb") format, then turn it into an NGLView "structure". PDB is a white-space sensitive format, so the
235
    string snippets are carefully formatted.
236

237
    The final widget is returned. If it is assigned to a variable, the visualization is suppressed until that
238
    variable is evaluated, and in the meantime more NGL operations can be applied to it to modify the visualization.
239

240
    Args:
241
        show_cell (bool): Whether or not to show the frame. (Default is True.)
242
        show_axes (bool): Whether or not to show xyz axes. (Default is True.)
243
        camera (str): 'perspective' or 'orthographic'. (Default is 'perspective'.)
244
        spacefill (bool): Whether to use a space-filling or ball-and-stick representation. (Default is True, use
245
            space-filling atoms.)
246
        particle_size (float): Size of the particles. (Default is 1.)
247
        select_atoms (numpy.ndarray): Indices of atoms to show, either as integers or a boolean array mask.
248
            (Default is None, show all atoms.)
249
        background (str): Background color. (Default is 'white'.)
250
        color_scheme (str): NGLView color scheme to use. (Default is None, color by element.)
251
        colors (numpy.ndarray): A per-atom array of HTML color names or hex color codes to use for atomic colors.
252
            (Default is None, use coloring scheme.)
253
        scalar_field (numpy.ndarray): Color each atom according to the array value (Default is None, use coloring
254
            scheme.)
255
        scalar_start (float): The scalar value to be mapped onto the low end of the color map (lower values are
256
            clipped). (Default is None, use the minimum value in `scalar_field`.)
257
        scalar_end (float): The scalar value to be mapped onto the high end of the color map (higher values are
258
            clipped). (Default is None, use the maximum value in `scalar_field`.)
259
        scalar_cmap (matplotlib.cm): The colormap to use. (Default is None, giving a blue-red divergent map.)
260
        vector_field (numpy.ndarray): Add vectors (3 values) originating at each atom. (Default is None, no
261
            vectors.)
262
        vector_color (numpy.ndarray): Colors for the vectors (only available with vector_field). (Default is None,
263
            vectors are colored by their direction.)
264
        magnetic_moments (bool): Plot magnetic moments as 'scalar_field' or 'vector_field'.
265
        view_plane (numpy.ndarray): A Nx3-array (N = 1,2,3); the first 3d-component of the array specifies
266
            which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes), the
267
            second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the third
268
            component (if specified) is the vertical component, which is ignored and calculated internally. The
269
            orthonormality of the orientation is internally ensured, and therefore is not required in the function
270
            call. (Default is np.array([0, 0, 1]), which is view normal to the x-y plane.)
271
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
272
            (Default is 14, which also seems to be the NGLView default value.)
273

274
        Possible NGLView color schemes:
275
          " ", "picking", "random", "uniform", "atomindex", "residueindex",
276
          "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
277
          "hydrophobicity", "value", "volume", "occupancy"
278

279
    Returns:
280
        (nglview.NGLWidget): The NGLView widget itself, which can be operated on further or viewed as-is.
281

282
    Warnings:
283
        * Many features only work with space-filling atoms (e.g. coloring by a scalar field).
284
        * The colour interpretation of some hex codes is weird, e.g. 'green'.
285
    """
286
    try:  # If the graphical packages are not available, the GUI will not work.
×
287
        import nglview
×
288
    except ImportError:
×
289
        raise ImportError(
×
290
            "The package nglview needs to be installed for the plot3d() function!"
291
        )
292

293
    if (
×
294
        magnetic_moments is True
295
        and np.sum(np.abs(structure.get_initial_magnetic_moments())) > 0
296
    ):
297
        if len(structure.get_initial_magnetic_moments().shape) == 1:
×
298
            scalar_field = structure.get_initial_magnetic_moments()
×
299
        else:
300
            vector_field = structure.get_initial_magnetic_moments()
×
301

302
    elements = structure.get_chemical_symbols()
×
303
    atomic_numbers = structure.get_atomic_numbers()
×
304
    positions = structure.positions
×
305

306
    # If `select_atoms` was given, visualize only a subset of the `parent_basis`
307
    if select_atoms is not None:
×
308
        select_atoms = np.array(select_atoms, dtype=int)
×
309
        elements = elements[select_atoms]
×
310
        atomic_numbers = atomic_numbers[select_atoms]
×
311
        positions = positions[select_atoms]
×
312
        if colors is not None:
×
313
            colors = np.array(colors)
×
314
            colors = colors[select_atoms]
×
315
        if scalar_field is not None:
×
316
            scalar_field = np.array(scalar_field)
×
317
            scalar_field = scalar_field[select_atoms]
×
318
        if vector_field is not None:
×
319
            vector_field = np.array(vector_field)
×
320
            vector_field = vector_field[select_atoms]
×
321
        if vector_color is not None:
×
322
            vector_color = np.array(vector_color)
×
323
            vector_color = vector_color[select_atoms]
×
324

325
    # Write the nglview protein-database-formatted string
326
    struct = nglview.TextStructure(
×
327
        _ngl_write_structure(elements, positions, structure.cell)
328
    )
329

330
    # Parse the string into the displayable widget
331
    view = nglview.NGLWidget(struct)
×
332

333
    if spacefill:
×
334
        # Color by scheme
335
        if color_scheme is not None:
×
336
            if colors is not None:
×
337
                warnings.warn("`color_scheme` is overriding `colors`")
×
338
            if scalar_field is not None:
×
339
                warnings.warn("`color_scheme` is overriding `scalar_field`")
×
340
            view = _add_colorscheme_spacefill(
×
341
                view, elements, atomic_numbers, particle_size, color_scheme
342
            )
343
        # Color by per-atom colors
344
        elif colors is not None:
×
345
            if scalar_field is not None:
×
346
                warnings.warn("`colors` is overriding `scalar_field`")
×
347
            view = _add_custom_color_spacefill(
×
348
                view, atomic_numbers, particle_size, colors
349
            )
350
        # Color by per-atom scalars
351
        elif scalar_field is not None:  # Color by per-atom scalars
×
352
            colors = _scalars_to_hex_colors(
×
353
                scalar_field, scalar_start, scalar_end, scalar_cmap
354
            )
355
            view = _add_custom_color_spacefill(
×
356
                view, atomic_numbers, particle_size, colors
357
            )
358
        # Color by element
359
        else:
360
            view = _add_colorscheme_spacefill(
×
361
                view, elements, atomic_numbers, particle_size
362
            )
363
        view.remove_ball_and_stick()
×
364
    else:
365
        view.add_ball_and_stick()
×
366

367
    if show_cell:
×
368
        if structure.cell is not None:
×
369
            if all(np.max(structure.cell, axis=0) > 1e-2):
×
370
                view.add_unitcell()
×
371

372
    if vector_color is None and vector_field is not None:
×
373
        vector_color = (
×
374
            0.5
375
            * np.array(vector_field)
376
            / np.linalg.norm(vector_field, axis=-1)[:, np.newaxis]
377
            + 0.5
378
        )
379
    elif (
×
380
        vector_field is not None and vector_field is not None
381
    ):  # WARNING: There must be a bug here...
382
        try:
×
383
            if vector_color.shape != np.ones((len(structure), 3)).shape:
×
384
                vector_color = np.outer(
×
385
                    np.ones(len(structure)),
386
                    vector_color / np.linalg.norm(vector_color),
387
                )
388
        except AttributeError:
×
389
            vector_color = np.ones((len(structure), 3)) * vector_color
×
390

391
    if vector_field is not None:
×
392
        for arr, pos, col in zip(vector_field, positions, vector_color):
×
393
            view.shape.add_arrow(list(pos), list(pos + arr), list(col), 0.2)
×
394

395
    if show_axes:  # Add axes
×
396
        axes_origin = -np.ones(3)
×
397
        arrow_radius = 0.1
×
398
        text_size = 1
×
399
        text_color = [0, 0, 0]
×
400
        arrow_names = ["x", "y", "z"]
×
401

402
        for n in [0, 1, 2]:
×
403
            start = list(axes_origin)
×
404
            shift = np.zeros(3)
×
405
            shift[n] = 1
×
406
            end = list(start + shift)
×
407
            color = list(shift)
×
408
            # We cast as list to avoid JSON warnings
409
            view.shape.add_arrow(start, end, color, arrow_radius)
×
410
            view.shape.add_text(end, text_color, text_size, arrow_names[n])
×
411

412
    if camera != "perspective" and camera != "orthographic":
×
413
        warnings.warn(
×
414
            "Only perspective or orthographic is (likely to be) permitted for camera"
415
        )
416

417
    view.camera = camera
×
418
    view.background = background
×
419

420
    orientation = _get_flattened_orientation(
×
421
        view_plane=view_plane, distance_from_camera=distance_from_camera * 14
422
    )
423
    view.control.orient(orientation)
×
424

425
    return view
×
426

427

428
def _plot3d_ase(
1✔
429
    structure,
430
    spacefill=True,
431
    show_cell=True,
432
    camera="perspective",
433
    particle_size=0.5,
434
    background="white",
435
    color_scheme="element",
436
    show_axes=True,
437
):
438
    """
439
    Possible color schemes:
440
      " ", "picking", "random", "uniform", "atomindex", "residueindex",
441
      "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
442
      "hydrophobicity", "value", "volume", "occupancy"
443
    Returns:
444
    """
445
    try:  # If the graphical packages are not available, the GUI will not work.
×
446
        import nglview
×
447
    except ImportError:
×
448
        raise ImportError(
×
449
            "The package nglview needs to be installed for the plot3d() function!"
450
        )
451
    # Always visualize the parent basis
452
    view = nglview.show_ase(structure)
×
453
    if spacefill:
×
454
        view.add_spacefill(
×
455
            radius_type="vdw", color_scheme=color_scheme, radius=particle_size
456
        )
457
        # view.add_spacefill(radius=1.0)
458
        view.remove_ball_and_stick()
×
459
    else:
460
        view.add_ball_and_stick()
×
461
    if show_cell:
×
462
        if structure.cell is not None:
×
463
            if all(np.max(structure.cell, axis=0) > 1e-2):
×
464
                view.add_unitcell()
×
465
    if show_axes:
×
466
        view.shape.add_arrow([-2, -2, -2], [2, -2, -2], [1, 0, 0], 0.5)
×
467
        view.shape.add_arrow([-2, -2, -2], [-2, 2, -2], [0, 1, 0], 0.5)
×
468
        view.shape.add_arrow([-2, -2, -2], [-2, -2, 2], [0, 0, 1], 0.5)
×
469
    if camera != "perspective" and camera != "orthographic":
×
470
        print("Only perspective or orthographic is permitted")
×
471
        return None
×
472
    view.camera = camera
×
473
    view.background = background
×
474
    return view
×
475

476

477
def _ngl_write_cell(a1, a2, a3, f1=90, f2=90, f3=90):
1✔
478
    """
479
    Writes a PDB-formatted line to represent the simulation cell.
480

481
    Args:
482
        a1, a2, a3 (float): Lengths of the cell vectors.
483
        f1, f2, f3 (float): Angles between the cell vectors (which angles exactly?) (in degrees).
484

485
    Returns:
486
        (str): The line defining the cell in PDB format.
487
    """
488
    return "CRYST1 {:8.3f} {:8.3f} {:8.3f} {:6.2f} {:6.2f} {:6.2f} P 1\n".format(
×
489
        a1, a2, a3, f1, f2, f3
490
    )
491

492

493
def _ngl_write_atom(
1✔
494
    num,
495
    species,
496
    x,
497
    y,
498
    z,
499
    group=None,
500
    num2=None,
501
    occupancy=1.0,
502
    temperature_factor=0.0,
503
):
504
    """
505
    Writes a PDB-formatted line to represent an atom.
506

507
    Args:
508
        num (int): Atomic index.
509
        species (str): Elemental species.
510
        x, y, z (float): Cartesian coordinates of the atom.
511
        group (str): A...group name? (Default is None, repeat elemental species.)
512
        num2 (int): An "alternate" index. (Don't ask me...) (Default is None, repeat first number.)
513
        occupancy (float): PDB occupancy parameter. (Default is 1.)
514
        temperature_factor (float): PDB temperature factor parameter. (Default is 0.
515

516
    Returns:
517
        (str): The line defining an atom in PDB format
518

519
    Warnings:
520
        * The [PDB docs](https://www.cgl.ucsf.edu/chimera/docs/UsersGuide/tutorials/pdbintro.html) indicate that
521
            the xyz coordinates might need to be in some sort of orthogonal basis. If you have weird behaviour,
522
            this might be a good place to investigate.
523
    """
524
    if group is None:
×
525
        group = species
×
526
    if num2 is None:
×
527
        num2 = num
×
528
    return "ATOM {:>6} {:>4} {:>4} {:>5} {:10.3f} {:7.3f} {:7.3f} {:5.2f} {:5.2f} {:>11} \n".format(
×
529
        num, species, group, num2, x, y, z, occupancy, temperature_factor, species
530
    )
531

532

533
def _ngl_write_structure(elements, positions, cell):
1✔
534
    """
535
    Turns structure information into a NGLView-readable protein-database-formatted string.
536

537
    Args:
538
        elements (numpy.ndarray/list): Element symbol for each atom.
539
        positions (numpy.ndarray/list): Vector of Cartesian atom positions.
540
        cell (numpy.ndarray/list): Simulation cell Bravais matrix.
541

542
    Returns:
543
        (str): The PDB-formatted representation of the structure.
544
    """
545
    from ase.geometry import cell_to_cellpar, cellpar_to_cell
×
546

547
    if cell is None or any(np.max(cell, axis=0) < 1e-2):
×
548
        # Define a dummy cell if it doesn't exist (eg. for clusters)
549
        max_pos = np.max(positions, axis=0) - np.min(positions, axis=0)
×
550
        max_pos[np.abs(max_pos) < 1e-2] = 10
×
551
        cell = np.eye(3) * max_pos
×
552
    cellpar = cell_to_cellpar(cell)
×
553
    exportedcell = cellpar_to_cell(cellpar)
×
554
    rotation = np.linalg.solve(cell, exportedcell)
×
555

556
    pdb_str = _ngl_write_cell(*cellpar)
×
557
    pdb_str += "MODEL     1\n"
×
558

559
    if rotation is not None:
×
560
        positions = np.array(positions).dot(rotation)
×
561

562
    for i, p in enumerate(positions):
×
563
        pdb_str += _ngl_write_atom(i, elements[i], *p)
×
564

565
    pdb_str += "ENDMDL \n"
×
566
    return pdb_str
×
567

568

569
def _atomic_number_to_radius(atomic_number, shift=0.2, slope=0.1, scale=1.0):
1✔
570
    """
571
    Give the atomic radius for plotting, which scales like the root of the atomic number.
572

573
    Args:
574
        atomic_number (int/float): The atomic number.
575
        shift (float): A constant addition to the radius. (Default is 0.2.)
576
        slope (float): A multiplier for the root of the atomic number. (Default is 0.1)
577
        scale (float): How much to rescale the whole thing by.
578

579
    Returns:
580
        (float): The radius. (Not physical, just for visualization!)
581
    """
582
    return (shift + slope * np.sqrt(atomic_number)) * scale
×
583

584

585
def _add_colorscheme_spacefill(
1✔
586
    view, elements, atomic_numbers, particle_size, scheme="element"
587
):
588
    """
589
    Set NGLView spacefill parameters according to a color-scheme.
590

591
    Args:
592
        view (NGLWidget): The widget to work on.
593
        elements (numpy.ndarray/list): Elemental symbols.
594
        atomic_numbers (numpy.ndarray/list): Integer atomic numbers for determining atomic size.
595
        particle_size (float): A scale factor for the atomic size.
596
        scheme (str): The scheme to use. (Default is "element".)
597

598
        Possible NGLView color schemes:
599
          " ", "picking", "random", "uniform", "atomindex", "residueindex",
600
          "chainindex", "modelindex", "sstruc", "element", "resname", "bfactor",
601
          "hydrophobicity", "value", "volume", "occupancy"
602

603
    Returns:
604
        (nglview.NGLWidget): The modified widget.
605
    """
606
    for elem, num in set(list(zip(elements, atomic_numbers))):
×
607
        view.add_spacefill(
×
608
            selection="#" + elem,
609
            radius_type="vdw",
610
            radius=_atomic_number_to_radius(num, scale=particle_size),
611
            color_scheme=scheme,
612
        )
613
    return view
×
614

615

616
def _add_custom_color_spacefill(view, atomic_numbers, particle_size, colors):
1✔
617
    """
618
    Set NGLView spacefill parameters according to per-atom colors.
619

620
    Args:
621
        view (NGLWidget): The widget to work on.
622
        atomic_numbers (numpy.ndarray/list): Integer atomic numbers for determining atomic size.
623
        particle_size (float): A scale factor for the atomic size.
624
        colors (numpy.ndarray/list): A per-atom list of HTML or hex color codes.
625

626
    Returns:
627
        (nglview.NGLWidget): The modified widget.
628
    """
629
    for n, num in enumerate(atomic_numbers):
×
630
        view.add_spacefill(
×
631
            selection=[n],
632
            radius_type="vdw",
633
            radius=_atomic_number_to_radius(num, scale=particle_size),
634
            color=colors[n],
635
        )
636
    return view
×
637

638

639
def _scalars_to_hex_colors(scalar_field, start=None, end=None, cmap=None):
1✔
640
    """
641
    Convert scalar values to hex codes using a colormap.
642

643
    Args:
644
        scalar_field (numpy.ndarray/list): Scalars to convert.
645
        start (float): Scalar value to map to the bottom of the colormap (values below are clipped). (Default is
646
            None, use the minimal scalar value.)
647
        end (float): Scalar value to map to the top of the colormap (values above are clipped).  (Default is
648
            None, use the maximal scalar value.)
649
        cmap (matplotlib.cm): The colormap to use. (Default is None, which gives a blue-red divergent map.)
650

651
    Returns:
652
        (list): The corresponding hex codes for each scalar value passed in.
653
    """
654
    from matplotlib.colors import rgb2hex
×
655

656
    if start is None:
×
657
        start = np.amin(scalar_field)
×
658
    if end is None:
×
659
        end = np.amax(scalar_field)
×
660
    interp = interp1d([start, end], [0, 1])
×
661
    remapped_field = interp(np.clip(scalar_field, start, end))  # Map field onto [0,1]
×
662

663
    if cmap is None:
×
664
        try:
×
665
            from seaborn import diverging_palette
×
666
        except ImportError:
×
667
            print(
×
668
                "The package seaborn needs to be installed for the plot3d() function!"
669
            )
670
        cmap = diverging_palette(245, 15, as_cmap=True)  # A nice blue-red palette
×
671

672
    return [
×
673
        rgb2hex(cmap(scalar)[:3]) for scalar in remapped_field
674
    ]  # The slice gets RGB but leaves alpha
675

676

677
def _get_orientation(view_plane):
1✔
678
    """
679
    A helper method to plot3d, which generates a rotation matrix from the input `view_plane`, and returns a
680
    flattened list of len = 16. This flattened list becomes the input argument to `view.contol.orient`.
681

682
    Args:
683
        view_plane (numpy.ndarray/list): A Nx3-array/list (N = 1,2,3); the first 3d-component of the array
684
            specifies which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes),
685
            the second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the
686
            third component (if specified) is the vertical component, which is ignored and calculated internally.
687
            The orthonormality of the orientation is internally ensured, and therefore is not required in the
688
            function call.
689

690
    Returns:
691
        (list): orientation tensor
692
    """
693
    if len(np.array(view_plane).flatten()) % 3 != 0:
1✔
694
        raise ValueError(
×
695
            "The shape of view plane should be (N, 3), where N = 1, 2 or 3. Refer docs for more info."
696
        )
697
    view_plane = np.array(view_plane).reshape(-1, 3)
1✔
698
    rotation_matrix = np.roll(np.eye(3), -1, axis=0)
1✔
699
    rotation_matrix[: len(view_plane)] = view_plane
1✔
700
    rotation_matrix /= np.linalg.norm(rotation_matrix, axis=-1)[:, np.newaxis]
1✔
701
    rotation_matrix[1] -= (
1✔
702
        np.dot(rotation_matrix[0], rotation_matrix[1]) * rotation_matrix[0]
703
    )  # Gran-Schmidt
704
    rotation_matrix[2] = np.cross(
1✔
705
        rotation_matrix[0], rotation_matrix[1]
706
    )  # Specify third axis
707
    if np.isclose(np.linalg.det(rotation_matrix), 0):
1✔
708
        return np.eye(
×
709
            3
710
        )  # view_plane = [0,0,1] is the default view of NGLview, so we do not modify it
711
    return np.roll(
1✔
712
        rotation_matrix / np.linalg.norm(rotation_matrix, axis=-1)[:, np.newaxis],
713
        2,
714
        axis=0,
715
    ).T
716

717

718
def _get_flattened_orientation(view_plane, distance_from_camera):
1✔
719
    """
720
    A helper method to plot3d, which generates a rotation matrix from the input `view_plane`, and returns a
721
    flattened list of len = 16. This flattened list becomes the input argument to `view.contol.orient`.
722

723
    Args:
724
        view_plane (numpy.ndarray/list): A Nx3-array/list (N = 1,2,3); the first 3d-component of the array
725
            specifies which plane of the system to view (for example, [1, 0, 0], [1, 1, 0] or the [1, 1, 1] planes),
726
            the second 3d-component (if specified, otherwise [1, 0, 0]) gives the horizontal direction, and the
727
            third component (if specified) is the vertical component, which is ignored and calculated internally.
728
            The orthonormality of the orientation is internally ensured, and therefore is not required in the
729
            function call.
730
        distance_from_camera (float): Distance of the camera from the structure. Higher = farther away.
731

732
    Returns:
733
        (list): Flattened list of len = 16, which is the input argument to `view.contol.orient`
734
    """
735
    if distance_from_camera <= 0:
1✔
736
        raise ValueError("´distance_from_camera´ must be a positive float!")
×
737
    flattened_orientation = np.eye(4)
1✔
738
    flattened_orientation[:3, :3] = _get_orientation(view_plane)
1✔
739

740
    return (distance_from_camera * flattened_orientation).ravel().tolist()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc