• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

materialsproject / pymatgen / 4075885785

pending completion
4075885785

push

github

Shyue Ping Ong
Merge branch 'master' of github.com:materialsproject/pymatgen

96 of 96 new or added lines in 27 files covered. (100.0%)

81013 of 102710 relevant lines covered (78.88%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

61.19
/pymatgen/io/abinit/variable.py
1
"""Support for Abinit input variables."""
2
from __future__ import annotations
1✔
3

4
import collections
1✔
5
import collections.abc
1✔
6
import string
1✔
7

8
import numpy as np
1✔
9

10
__all__ = [
1✔
11
    "InputVariable",
12
]
13

14
_SPECIAL_DATASET_INDICES = (":", "+", "?")
1✔
15

16
_DATASET_INDICES = "".join(list(string.digits) + list(_SPECIAL_DATASET_INDICES))
1✔
17

18
_INTERNAL_DATASET_INDICES = ("__s", "__i", "__a")
1✔
19

20
_SPECIAL_CONVERSION = zip(_INTERNAL_DATASET_INDICES, _SPECIAL_DATASET_INDICES)
1✔
21

22
_UNITS = {
1✔
23
    "bohr": 1.0,
24
    "angstrom": 1.8897261328856432,
25
    "hartree": 1.0,
26
    "Ha": 1.0,
27
    "eV": 0.03674932539796232,
28
}
29

30

31
class InputVariable:
1✔
32
    """
33
    An Abinit input variable.
34
    """
35

36
    def __init__(self, name, value, units="", valperline=3):
1✔
37
        """
38
        Args:
39
            name: Name of the variable.
40
            value: Value of the variable.
41
            units: String specifying one of the units supported by Abinit. Default: atomic units.
42
            valperline: Number of items printed per line.
43
        """
44
        self._name = name
1✔
45
        self.value = value
1✔
46
        self._units = units
1✔
47

48
        # Maximum number of values per line.
49
        self.valperline = valperline
1✔
50
        if name in ["bdgw"]:
1✔
51
            self.valperline = 2
×
52

53
        if is_iter(self.value) and isinstance(self.value[-1], str) and self.value[-1] in _UNITS:
1✔
54
            self.value = list(self.value)
×
55
            self._units = self.value.pop(-1)
×
56

57
    def get_value(self):
1✔
58
        """Return the value."""
59
        if self.units:
×
60
            return list(self.value) + [self.units]
×
61
        return self.value
×
62

63
    @property
1✔
64
    def name(self):
1✔
65
        """Name of the variable."""
66
        return self._name
1✔
67

68
    @property
1✔
69
    def basename(self):
1✔
70
        """Return the name trimmed of any dataset index."""
71
        basename = self.name
×
72
        return basename.rstrip(_DATASET_INDICES)
×
73

74
    @property
1✔
75
    def dataset(self):
1✔
76
        """Return the dataset index in string form."""
77
        return self.name.split(self.basename)[-1]
×
78

79
    @property
1✔
80
    def units(self):
1✔
81
        """Return the units."""
82
        return self._units
1✔
83

84
    def __str__(self):
1✔
85
        """Declaration of the variable in the input file."""
86
        value = self.value
1✔
87
        if value is None or not str(value):
1✔
88
            return ""
1✔
89

90
        var = self.name
1✔
91
        line = " " + var
1✔
92

93
        # By default, do not impose a number of decimal points
94
        floatdecimal = 0
1✔
95

96
        # For some inputs, enforce number of decimal points...
97
        if any(inp in var for inp in ("xred", "xcart", "rprim", "qpt", "kpt")):
1✔
98
            floatdecimal = 16
1✔
99

100
        # ...but not for those
101
        if any(inp in var for inp in ("ngkpt", "kptrlatt", "ngqpt", "ng2qpt")):
1✔
102
            floatdecimal = 0
1✔
103

104
        if isinstance(value, np.ndarray):
1✔
105
            n = 1
1✔
106
            for i in np.shape(value):
1✔
107
                n *= i
1✔
108
            value = np.reshape(value, n)
1✔
109
            value = list(value)
1✔
110

111
        # values in lists
112
        if isinstance(value, (list, tuple)):
1✔
113
            # Reshape a list of lists into a single list
114
            if all(isinstance(v, (list, tuple)) for v in value):
1✔
115
                line += self.format_list2d(value, floatdecimal)
×
116

117
            else:
118
                line += self.format_list(value, floatdecimal)
1✔
119

120
        # scalar values
121
        else:
122
            line += " " + str(value)
1✔
123

124
        # Add units
125
        if self.units:
1✔
126
            line += " " + self.units
×
127

128
        return line
1✔
129

130
    @staticmethod
1✔
131
    def format_scalar(val, floatdecimal=0):
1✔
132
        """
133
        Format a single numerical value into a string
134
        with the appropriate number of decimal.
135
        """
136
        sval = str(val)
1✔
137
        if sval.lstrip("-").lstrip("+").isdigit() and floatdecimal == 0:
1✔
138
            return sval
1✔
139

140
        try:
1✔
141
            fval = float(val)
1✔
142
        except Exception:
×
143
            return sval
×
144

145
        if fval == 0 or (1e-3 < abs(fval) < 1e4):
1✔
146
            form = "f"
1✔
147
            addlen = 5
1✔
148
        else:
149
            form = "e"
×
150
            addlen = 8
×
151

152
        ndec = max(len(str(fval - int(fval))) - 2, floatdecimal)
1✔
153
        ndec = min(ndec, 10)
1✔
154

155
        sval = f"{fval:>{ndec + addlen}.{ndec}{form}}"
1✔
156

157
        sval = sval.replace("e", "d")
1✔
158

159
        return sval
1✔
160

161
    @staticmethod
1✔
162
    def format_list2d(values, floatdecimal=0):
1✔
163
        """Format a list of lists."""
164
        lvals = flatten(values)
×
165

166
        # Determine the representation
167
        if all(isinstance(v, int) for v in lvals):
×
168
            type_all = int
×
169
        else:
170
            try:
×
171
                for v in lvals:
×
172
                    float(v)
×
173
                type_all = float
×
174
            except Exception:
×
175
                type_all = str
×
176

177
        # Determine the format
178
        width = max(len(str(s)) for s in lvals)
×
179
        if type_all == int:
×
180
            formatspec = f">{width}d"
×
181
        elif type_all == str:
×
182
            formatspec = f">{width}"
×
183
        else:
184
            # Number of decimal
185
            maxdec = max(len(str(f - int(f))) - 2 for f in lvals)
×
186
            ndec = min(max(maxdec, floatdecimal), 10)
×
187

188
            if all(f == 0 or (abs(f) > 1e-3 and abs(f) < 1e4) for f in lvals):
×
189
                formatspec = f">{ndec + 5}.{ndec}f"
×
190
            else:
191
                formatspec = f">{ndec + 8}.{ndec}e"
×
192

193
        line = "\n"
×
194
        for L in values:
×
195
            for val in L:
×
196
                line += f" {val:{{formatspec}}}"
×
197
            line += "\n"
×
198

199
        return line.rstrip("\n")
×
200

201
    def format_list(self, values, floatdecimal=0):
1✔
202
        """
203
        Format a list of values into a string.
204
        The result might be spread among several lines.
205
        """
206
        line = ""
1✔
207

208
        # Format the line declaring the value
209
        for i, val in enumerate(values):
1✔
210
            line += " " + self.format_scalar(val, floatdecimal)
1✔
211
            if self.valperline is not None and (i + 1) % self.valperline == 0:
1✔
212
                line += "\n"
1✔
213

214
        # Add a carriage return in case of several lines
215
        if "\n" in line.rstrip("\n"):
1✔
216
            line = "\n" + line
1✔
217

218
        return line.rstrip("\n")
1✔
219

220

221
def is_iter(obj) -> bool:
1✔
222
    """Return True if the argument is list-like."""
223
    return hasattr(obj, "__iter__")
1✔
224

225

226
def flatten(iterable):
1✔
227
    """Make an iterable flat, i.e. a 1d iterable object."""
228
    iterator = iter(iterable)
×
229
    array, stack = collections.deque(), collections.deque()
×
230
    while True:
231
        try:
×
232
            value = next(iterator)
×
233
        except StopIteration:
×
234
            if not stack:
×
235
                return tuple(array)
×
236
            iterator = stack.pop()
×
237
        else:
238
            if not isinstance(value, str) and isinstance(value, collections.abc.Iterable):
×
239
                stack.append(iterator)
×
240
                iterator = iter(value)
×
241
            else:
242
                array.append(value)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc