• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

maurergroup / dfttoolkit / 20237393685

15 Dec 2025 03:18PM UTC coverage: 32.439% (-0.02%) from 32.455%
20237393685

Pull #132

github

web-flow
Merge cdf500a1d into c73822af7
Pull Request #132: Removed automatic write from parameter functions

1366 of 4211 relevant lines covered (32.44%)

0.32 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

63.82
dfttoolkit/parameters.py
1
from typing import Any
1✔
2
from warnings import warn
1✔
3

4
import numpy as np
1✔
5
import numpy.typing as npt
1✔
6

7
from .base import Parser
1✔
8
from .utils.file_utils import MultiDict
1✔
9
from .utils.periodic_table import PeriodicTable
1✔
10

11

12
class Parameters(Parser):
1✔
13
    """
14
    Handle files that control parameters for electronic structure calculations.
15

16
    If contributing a new parser, please subclass this class, add the new supported file
17
    type to _supported_files and match statement in this class' `__init__()`, and call
18
    the `super().__init__()` method, include the new file type as a kwarg in the
19
    `super().__init__()`. Optionally include the `self.lines` line
20
    in examples.
21

22
    ...
23

24
    Attributes
25
    ----------
26
    _supported_files : dict
27
        List of supported file types.
28
    """
29

30
    def __init__(self, **kwargs: str):
1✔
31
        # Parse file information and perform checks
32
        super().__init__(self._supported_files, **kwargs)
1✔
33

34
        self._check_binary(False)
1✔
35

36
    @property
1✔
37
    def _supported_files(self) -> dict:
1✔
38
        # FHI-aims, ...
39
        return {"control_in": ".in", "cube": ".cube"}
1✔
40

41
    def __repr__(self) -> str:
1✔
42
        return f"{self.__class__.__name__}({self._format}={self._name})"
×
43

44
    def __init_subclass__(cls, **kwargs: str):
1✔
45
        # Override the parent's __init_subclass__ without calling it
46
        pass
1✔
47

48
    def write(self) -> None:
1✔
49
        """Write the parameters file to disk."""
50
        with open(self.path, "w") as f:
1✔
51
            f.writelines(self.lines)
1✔
52

53

54
class AimsControl(Parameters):
1✔
55
    """
56
    FHI-aims control file parser.
57

58
    ...
59

60
    Attributes
61
    ----------
62
    path: str
63
        path to the aims.out file
64
    lines: List[str]
65
        contents of the aims.out file
66

67
    Examples
68
    --------
69
    >>> ac = AimsControl(control_in="./control.in")
70
    """
71

72
    def __init__(self, control_in: str = "control.in"):
1✔
73
        super().__init__(control_in=control_in)
1✔
74

75
    # Use normal methods instead of properties for these methods as we want to specify
76
    # the setter method using kwargs instead of assigning the value as a dictionary.
77
    # Then, for consistency, keep get_keywords as a normal function.
78
    def get_keywords(self) -> MultiDict:
1✔
79
        """
80
        Get the keywords from the control.in file.
81

82
        Returns
83
        -------
84
        MultiDict
85
            Keywords in the control.in file.
86
        """
87
        keywords = MultiDict()
1✔
88

89
        for line in self.lines:
1✔
90
            # Stop at third keyword delimiter if ASE wrote the file
91
            spl = line.split()
1✔
92
            if len(spl) > 0 and spl[-1] == "(ASE)":
1✔
93
                n_delims = 0
1✔
94
                if line == "#" + ("=" * 79):
1✔
95
                    n_delims += 1
×
96
                    if n_delims == 3:
×
97
                        # Reached end of keywords
98
                        break
×
99

100
            elif "#" * 80 in line.strip():
1✔
101
                # Reached the basis set definitions
102
                break
1✔
103

104
            if len(spl) > 0 and line[0] != "#":
1✔
105
                keywords[spl[0]] = " ".join(spl[1:])
1✔
106

107
        return keywords
1✔
108

109
    def get_species(self) -> list[str]:
1✔
110
        """
111
        Get the species from a control.in file.
112

113
        Returns
114
        -------
115
        List[str]
116
            A list of the species in the control.in file.
117
        """
118
        species = []
1✔
119
        for line in self.lines:
1✔
120
            spl = line.split()
1✔
121
            if len(spl) > 0 and spl[0] == "species":
1✔
122
                species.append(line.split()[1])
1✔
123

124
        return species
1✔
125

126
    def get_default_basis_funcs(
1✔
127
        self, elements: list[str] | None = None
128
    ) -> dict[str, str]:
129
        """
130
        Get the basis functions.
131

132
        Parameters
133
        ----------
134
        elements : List[str], optional, default=None
135
            The elements to parse the basis functions for as chemical symbols.
136

137
        Returns
138
        -------
139
        Dict[str, str]
140
            A dictionary of the basis functions for the specified elements.
141
        """
142
        # Check that the given elements are valid
143
        if elements is not None and not set(elements).issubset(
1✔
144
            set(PeriodicTable.element_symbols())
145
        ):
146
            raise ValueError("Invalid element(s) given")
×
147

148
        # Warn if the requested elements aren't found in control.in
149
        if elements is not None and not set(elements).issubset(self.get_species()):
1✔
150
            warn("Could not find all requested elements in control.in", stacklevel=2)
×
151

152
        basis_funcs = {}
1✔
153

154
        for i, line_1 in enumerate(self.lines):
1✔
155
            spl_1 = line_1.split()
1✔
156
            if "species" in spl_1[0]:
1✔
157
                species = spl_1[1]
1✔
158

159
                if elements is not None and species not in elements:
1✔
160
                    continue
×
161

162
                for line_2 in self.lines[i + 1 :]:
1✔
163
                    spl = line_2.split()
1✔
164
                    if "species" in spl[0]:
1✔
165
                        break
1✔
166

167
                    if "#" in spl[0]:
1✔
168
                        continue
1✔
169

170
                    if "hydro" in line_2:
1✔
171
                        if species in basis_funcs:
1✔
172
                            basis_funcs[species].append(line_2.strip())
1✔
173
                        else:
174
                            basis_funcs[species] = [line_2.strip()]
1✔
175

176
        return basis_funcs
1✔
177

178
    def add_keywords(self, *args: tuple[str, Any]) -> None:
1✔
179
        """
180
        Add keywords to the AimsControl instance.
181

182
        Note that files written by ASE or in a format where the keywords are at the top
183
        of the file followed by the basis sets are the only formats that are supported
184
        by this function. The keywords need to be added in a Tuple format rather than as
185
        **kwargs because we need to be able to add multiple of the same keyword.
186

187
        Parameters
188
        ----------
189
        *args : Tuple[str, Any]
190
            Keywords to be added to the control.in file.
191
        """
192
        # Get the location of the start of the basis sets
193
        basis_set_start = False
1✔
194

195
        # if ASE wrote the file, use the 'add' point as the end of keywords delimiter
196
        # otherwise, use the start of the basis sets as 'add' point
197
        for i, line_1 in enumerate(self.lines):
1✔
198
            if line_1.strip() == "#" * 80:
1✔
199
                if self.lines[2].split()[-1] == "(ASE)":
1✔
200
                    for j, line_2 in enumerate(reversed(self.lines[:i])):
1✔
201
                        if line_2.strip() == "#" + ("=" * 79):
1✔
202
                            basis_set_start = i - j - 1
1✔
203
                            break
1✔
204
                    break
1✔
205

206
                # not ASE
207
                basis_set_start = i
1✔
208
                break
1✔
209

210
        # Check to make sure basis sets were found
211
        if not basis_set_start:
1✔
212
            raise IndexError("Could not detect basis sets in control.in")
×
213

214
        # Add the new keywords above the basis sets
215
        for arg in reversed(args):
1✔
216
            self.lines.insert(basis_set_start, f"{arg[0]:<34} {arg[1]}\n")
1✔
217

218
    def add_cube_cell(self, cell_matrix: npt.NDArray, resolution: int = 100) -> None:
1✔
219
        """
220
        Add cube output settings to cover the unit cell specified in `cell_matrix`.
221

222
        Since the default behaviour of FHI-AIMS for generating cube files for periodic
223
        structures with vacuum gives confusing results, this function ensures the cube
224
        output quantity is calculated for the full unit cell.
225

226
        Parameters
227
        ----------
228
        cell_matrix : NDArray
229
            2D array defining the unit cell.
230

231
        resolution : int
232
            Number of cube voxels to use for the shortest side of the unit cell.
233
        """
234
        if not self.check_periodic():  # Fail for non-periodic structures
1✔
235
            raise TypeError("add_cube_cell doesn't support non-periodic structures")
1✔
236

237
        shortest_side = min(np.sum(cell_matrix, axis=1))
1✔
238
        resolution = shortest_side / 100.0
1✔
239

240
        cube_x = (
1✔
241
            2 * int(np.ceil(0.5 * np.linalg.norm(cell_matrix[0, :]) / resolution)) + 1
242
        )  # Number of cubes along x axis
243
        x_vector = cell_matrix[0, :] / np.linalg.norm(cell_matrix[0, :]) * resolution
1✔
244
        cube_y = (
1✔
245
            2 * int(np.ceil(0.5 * np.linalg.norm(cell_matrix[1, :]) / resolution)) + 1
246
        )
247
        y_vector = cell_matrix[1, :] / np.linalg.norm(cell_matrix[1, :]) * resolution
1✔
248
        cube_z = (
1✔
249
            2 * int(np.ceil(0.5 * np.linalg.norm(cell_matrix[2, :]) / resolution)) + 1
250
        )
251
        z_vector = cell_matrix[2, :] / np.linalg.norm(cell_matrix[2, :]) * resolution
1✔
252
        self.add_keywords(  # Add cube options to control.in
1✔
253
            (
254
                "cube",
255
                "origin {} {} {}\n".format(
256
                    *(np.transpose(cell_matrix @ [0.5, 0.5, 0.5]))
257
                )
258
                + "cube edge {} {} {} {}\n".format(cube_x, *x_vector)
259
                + "cube edge {} {} {} {}\n".format(cube_y, *y_vector)
260
                + "cube edge {} {} {} {}\n".format(cube_z, *z_vector),
261
            )
262
        )
263

264
    def remove_keywords(self, *args: str) -> None:
1✔
265
        """
266
        Remove keywords from the control.in file.
267

268
        Note that this will not remove keywords that are commented with a '#'.
269

270
        Parameters
271
        ----------
272
        *args : str
273
            Keywords to be removed from the control.in file.
274
        """
275
        for keyword in args:
1✔
276
            for i, line in enumerate(self.lines):
1✔
277
                spl = line.split()
1✔
278
                if len(spl) > 0 and spl[0] != "#" and keyword == spl[0]:
1✔
279
                    self.lines.pop(i)
1✔
280

281
    def check_periodic(self) -> bool:
1✔
282
        """Check if the system is periodic."""
283
        return "k_grid" in self.get_keywords()
1✔
284

285

286
class CubeParameters(Parameters):
1✔
287
    """
288
    Cube file settings that can be used to generate a control file.
289

290
    Attributes
291
    ----------
292
    type : str
293
        type of cube file; all that comes after output cube
294

295
    Parameters
296
    ----------
297
    cube: str
298
        path to the cube file
299
    text: str | None
300
        text to parse
301

302
    Functions
303
    -------------------
304
        parse(text): parses textlines
305

306
        getText(): returns cubefile specifications-string for ControlFile class
307
    """
308

309
    def __init__(self, cube: str = "cube.cube", text: str | None = None):
1✔
310
        super().__init__(cube=cube)
×
311

312
        self._check_binary(False)
×
313

314
        # Set attrs here rather than `File.__post_init__()` as `Cube.__init__()` uses
315
        # ASE to parse the data from a cube file, so it's definied in `Cube.__init__()`
316
        # so `File.__post_init__()` doesn't add these attributes if a cube file
317
        # extension is detected.
318
        with open(self.path) as f:
×
319
            self.lines = f.readlines()
×
320
            self.data = b""
×
321
            self._binary = False
×
322

323
        self._type = ""
×
324

325
        # parsers for specific cube keywords:
326
        # keyword: string_to_number, number_to_string
327
        self._parsing_functions = {
×
328
            "spinstate": [
329
                lambda x: int(x[0]),
330
                str,
331
            ],
332
            "kpoint": [lambda x: int(x[0]), str],
333
            "divisor": [lambda x: int(x[0]), str],
334
            "spinmask": [
335
                lambda x: [int(k) for k in x],
336
                lambda x: "  ".join([str(k) for k in x]),
337
            ],
338
            "origin": [
339
                lambda x: [float(k) for k in x],
340
                lambda x: "  ".join([f"{k: 15.10f}" for k in x]),
341
            ],
342
            "edge": [
343
                lambda x: [int(x[0])] + [float(k) for k in x[1:]],
344
                lambda x: str(int(x[0]))
345
                + "  "
346
                + "  ".join([f"{k: 15.10f}" for k in x[1:]]),
347
            ],
348
        }
349

350
        self._settings = MultiDict()
×
351

352
        if text is not None:
×
353
            self.parse(text)
×
354

355
    def __repr__(self):
1✔
356
        text = "CubeSettings object with content:\n"
×
357
        text += self.get_text()
×
358
        return text
×
359

360
    @property
1✔
361
    def type(self) -> str:
1✔
362
        """Everything that comes after output cube as a single string."""
363
        return self._type
×
364

365
    @type.setter
1✔
366
    def type(self, value: str) -> None:
1✔
367
        """Set the type of the cube file."""
368
        self._type = value
×
369

370
    @property
1✔
371
    def parsing_functions(self) -> dict[str, list[int | str]]:
1✔
372
        """Parsing functions for specific cube keywords."""
373
        return self._parsing_functions
×
374

375
    @property
1✔
376
    def settings(self) -> MultiDict:
1✔
377
        """Settings for the cube file."""
378
        return self._settings
×
379

380
    @property
1✔
381
    def origin(self) -> npt.NDArray[np.float64]:
1✔
382
        """Origin of the cube file."""
383
        raise NotImplementedError(
×
384
            "Decide if this property should return the "
385
            "dictionary value or the first component as a numpy array"
386
        )
387

388
        return self.setting["origin"]
389
        return np.array(self.settings["origin"][0])
390

391
    @origin.setter
1✔
392
    def origin(self, origin: npt.NDArray[np.float64]) -> None:
1✔
393
        self.settings["origin"] = [[origin[0], origin[1], origin[2]]]
×
394

395
    @property
1✔
396
    def edges(self) -> npt.NDArray[np.float64]:
1✔
397
        """Set the edge vectors."""
398
        return np.array(self.settings["edge"])
×
399

400
    @edges.setter
1✔
401
    def edges(self, edges: tuple[list[int], list[float]]) -> None:
1✔
402
        """
403
        TODO.
404

405
        Parameters
406
        ----------
407
        edges : tuple[list[int], list[float]]
408
            TODO
409
        """
410
        raise NotImplementedError("Type annotations need to be fixed")
×
411

412
        # self.settings["edge"] = []
413
        # for i, d in enumerate(edges[0]):
414
        #     self.settings["edge"].append([d, *list(edges[1][i, :])])
415

416
    @property
1✔
417
    def grid_vectors(self) -> float:
1✔
418
        raise NotImplementedError("See edges.setter")
×
419

420
        # edges = self.edges
421
        # return edges[:, 1:]
422

423
    @property
1✔
424
    def divisions(self) -> float:
1✔
425
        raise NotImplementedError("See edges.setter")
×
426

427
        # edges = self.edges
428
        # return edges[:, 0]
429

430
    @divisions.setter
1✔
431
    def divisions(self, divs: npt.NDArray[np.float64]) -> None:
1✔
432
        if len(divs) != 3:
×
433
            raise ValueError("Requires divisions for all three lattice vectors")
×
434

435
        for i in range(3):
×
436
            self.settings["edge"][i][0] = divs[i]
×
437

438
    def parse(self, text: str) -> None:
1✔
439
        """
440
        TODO.
441

442
        Parameters
443
        ----------
444
        str
445
            TODO
446
        """
447
        cubelines = []
×
448
        for line in text:
×
449
            strip = line.strip()
×
450
            # parse only lines that start with cube and are not comments
451
            if not strip.startswith("#"):
×
452
                if strip.startswith("cube"):
×
453
                    cubelines.append(strip)
×
454
                elif strip.startswith("output"):
×
455
                    self.type = " ".join(strip.split()[2:])
×
456

457
        # parse cubelines to self.settings
458
        for line in cubelines:
×
459
            nc_lines = line.split("#")[0]  # remove comments
×
460
            splitline = nc_lines.split()
×
461
            keyword = splitline[1]  # parse keyword
×
462
            values = splitline[2:]  # parse all values
×
463

464
            # check if parsing function exists
465
            if keyword in self.parsing_functions:
×
466
                value = self.parsing_functions[keyword]
×
467

468
            # reconvert to single string otherwise
469
            else:
470
                value = " ".join(values)
×
471

472
            # save all values as list, append to list if key already exists
473
            if keyword in self.settings:
×
474
                self.settings[keyword].append(value)
×
475
            else:
476
                self.settings[keyword] = [value]
×
477

478
    def has_vertical_unit_cell(self) -> bool:
1✔
479
        conditions = [
×
480
            self.settings["edge"][0][3] == 0.0,
481
            self.settings["edge"][1][3] == 0.0,
482
            self.settings["edge"][2][1] == 0.0,
483
            self.settings["edge"][2][1] == 0.0,
484
        ]
485
        return False not in conditions
×
486

487
    def set_z_slice(self, z_bottom: float, z_top: float) -> None:
1✔
488
        """
489
        Crops the cubefile to only include the space between z_bottom and z_top.
490

491
        The cubefile could go slightly beyond z_bottom and z_top in order to preserve
492
        the distance between grid points.
493

494
        Parameters
495
        ----------
496
        z_bottom: float
497
            TODO
498
        z_top: float
499
            TODO
500
        """
501
        if z_top < z_bottom:
×
502
            raise ValueError("Ensure that `z_bottom` is smaller than `z_top`")
×
503

504
        if not self.has_vertical_unit_cell():
×
505
            raise ValueError(
×
506
                "This function is only supported for systems where the "
507
                "cell is parallel to the z-axis"
508
            )
509

510
        diff = z_top - z_bottom
×
511
        average = z_bottom + diff / 2
×
512

513
        # set origin Z
514
        self.settings["origin"][0][2] = average
×
515

516
        # set edge, approximating for excess
517
        z_size = self.settings["edge"][2][0] * self.settings["edge"][2][3]
×
518
        fraction_of_z_size = z_size / diff
×
519
        new_z = self.settings["edge"][2][0] / fraction_of_z_size
×
520

521
        if new_z % 1 != 0:
×
522
            new_z = int(new_z) + 1.0
×
523

524
        self.settings["edge"][2][0] = new_z
×
525

526
    def set_grid_by_box_dimensions(
1✔
527
        self,
528
        x_limits: tuple[float, float],
529
        y_limits: tuple[float, float],
530
        z_limits: tuple[float, float],
531
        spacing: float | tuple[float, float, float],
532
    ) -> None:
533
        """
534
        Set origin and edge as a cuboid box.
535

536
        The ranging is within the given limits, with voxel size specified by spacing.
537

538
        Parameters
539
        ----------
540
        x_limits: tuple[float, float]
541
            min and max of...TODO
542
        y_limits: tuple[float, float]
543
            min and max of...TODO
544
        z_limits: tuple[float, float]
545
            min and max of...TODO
546
        spacing: float | tuple[float, float, float]
547
            TODO
548
        """
549
        raise NotImplementedError("Origin parameter needs to be fixed")
×
550

551
        # self.origin = [0, 0, 0]
552
        # self.settings["edge"] = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
553

554
        # # set one dimension at a time
555
        # for i, lim in enumerate([x_limits, y_limits, z_limits]):
556
        #     if lim[0] >= lim[1]:
557
        #         raise ValueError("Ensure the minimum is given first")
558

559
        #     diff = lim[1] - lim[0]
560

561
        #     # set origin
562
        #     center = lim[0] + (diff / 2)
563
        #     self.settings["origin"][0][i] = center
564

565
        #     # set edges
566
        #     space = spacing[i] if isinstance(spacing, list) else spacing
567

568
        #     # size of voxel
569
        #     self.settings["edge"][i][i + 1] = space
570

571
        #     # number of voxels
572
        #     n_voxels = int(diff / space) + 1
573
        #     self.settings["edge"][i][0] = n_voxels
574

575
    def get_text(self) -> str:
1✔
576
        """
577
        TODO.
578

579
        Returns
580
        -------
581
        TODO
582
        """
583
        raise NotImplementedError("Fix self.parsing_functions type")
×
584

585
        # text = ""
586
        # if len(self.type) > 0:
587
        #     text += "output cube " + self.type + "\n"
588
        # else:
589
        #     warn("No cube type specified", stacklevel=2)
590
        #     text += "output cube" + "CUBETYPE" + "\n"
591

592
        # for key, values in self.settings.items():
593
        #     for v in values:
594
        #         text += "cube " + key + " "
595
        #         if key in self.parsing_functions:
596
        #             text += self.parsing_functions[key][1](v) + "\n"
597
        #         else:
598
        #             print(v)
599
        #             text += v + "\n"
600

601
        # return text
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc