• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

materialsproject / pymatgen / 4075885785

pending completion
4075885785

push

github

Shyue Ping Ong
Merge branch 'master' of github.com:materialsproject/pymatgen

96 of 96 new or added lines in 27 files covered. (100.0%)

81013 of 102710 relevant lines covered (78.88%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

57.02
/pymatgen/entries/correction_calculator.py
1
"""
2
This module calculates corrections for the species listed below, fitted to the experimental and computed
3
entries given to the CorrectionCalculator constructor.
4
"""
5

6
from __future__ import annotations
1✔
7

8
import os
1✔
9
import warnings
1✔
10

11
import numpy as np
1✔
12
import plotly.graph_objects as go
1✔
13
from monty.serialization import loadfn
1✔
14
from ruamel import yaml
1✔
15
from scipy.optimize import curve_fit
1✔
16

17
from pymatgen.analysis.reaction_calculator import ComputedReaction
1✔
18
from pymatgen.analysis.structure_analyzer import sulfide_type
1✔
19
from pymatgen.core.composition import Composition
1✔
20
from pymatgen.core.periodic_table import Element
1✔
21

22

23
class CorrectionCalculator:
1✔
24
    """
25
    A CorrectionCalculator contains experimental and computed entries which it uses to compute corrections.
26

27
    It graphs residual errors after applying the computed corrections and creates the MPCompatibility.yaml
28
    file the Correction classes use.
29

30
    Attributes:
31
        species: list of species that corrections are being calculated for
32
        exp_compounds: list of dictionaries which each contain a compound's formula and experimental data
33
        calc_compounds: dictionary of ComputedEntry objects
34
        corrections: list of corrections in same order as species list
35
        corrections_std_error: list of the variances of the corrections in same order as species list
36
        corrections_dict: dictionary of format {'species': (value, uncertainty)} for easier correction lookup
37
    """
38

39
    def __init__(
1✔
40
        self,
41
        species: list[str] | None = None,
42
        max_error: float = 0.1,
43
        allow_unstable: float | bool = 0.1,
44
        exclude_polyanions: list[str] | None = None,
45
    ) -> None:
46
        """
47
        Initializes a CorrectionCalculator.
48

49
        Args:
50
            species: list of species to calculate corrections for
51
            max_error: maximum tolerable relative uncertainty in experimental energy.
52
                Compounds with relative uncertainty greater than this value will be excluded from the fit
53
            allow_unstable: whether unstable entries are to be included in the fit. If True, all compounds will
54
                be included regardless of their energy above hull. If False or a float, compounds with
55
                energy above hull greater than the given value (defaults to 0.1 eV/atom) will be
56
                excluded
57
            exclude_polyanions: a list of polyanions that contain additional sources of error that may negatively
58
                influence the quality of the fitted corrections. Compounds with these polyanions
59
                will be excluded from the fit
60
        """
61
        self.species = species or "oxide peroxide superoxide S F Cl Br I N Se Si Sb Te V Cr Mn Fe Co Ni W Mo H".split()
1✔
62

63
        self.max_error = max_error
1✔
64
        if not allow_unstable:
1✔
65
            self.allow_unstable = 0.1
×
66
        else:
67
            self.allow_unstable = allow_unstable
1✔
68
        self.exclude_polyanions = (
1✔
69
            exclude_polyanions
70
            if exclude_polyanions is not None
71
            else "SO4 SO3 CO3 NO3 NO2 OCl3 ClO3 ClO4 HO ClO SeO3 TiO3 TiO4 WO4 SiO3 SiO4 Si2O5 PO3 PO4 P2O7".split()
72
        )
73

74
        self.corrections: list[float] = []
1✔
75
        self.corrections_std_error: list[float] = []
1✔
76
        self.corrections_dict: dict[str, tuple[float, float]] = {}  # {'species': (value, uncertainty)}
1✔
77

78
        # to help the graph_residual_error_per_species() method differentiate between oxygen containing compounds
79
        if "oxide" in self.species:
1✔
80
            self.oxides: list[str] = []
1✔
81
        if "peroxide" in self.species:
1✔
82
            self.peroxides: list[str] = []
1✔
83
        if "superoxide" in self.species:
1✔
84
            self.superoxides: list[str] = []
1✔
85
        if "S" in self.species:
1✔
86
            self.sulfides: list[str] = []
1✔
87

88
    def compute_from_files(self, exp_gz: str, comp_gz: str):
1✔
89
        """
90
        Args:
91
            exp_gz: name of .json.gz file that contains experimental data
92
                    data in .json.gz file should be a list of dictionary objects with the following keys/values:
93
                    {"formula": chemical formula, "exp energy": formation energy in eV/formula unit,
94
                    "uncertainty": uncertainty in formation energy}
95
            comp_gz: name of .json.gz file that contains computed entries
96
                    data in .json.gz file should be a dictionary of {chemical formula: ComputedEntry}
97
        """
98
        exp_entries = loadfn(exp_gz)
1✔
99
        calc_entries = loadfn(comp_gz)
1✔
100

101
        return self.compute_corrections(exp_entries, calc_entries)
1✔
102

103
    def compute_corrections(self, exp_entries: list, calc_entries: dict) -> dict:
1✔
104
        """
105
        Computes the corrections and fills in correction, corrections_std_error, and corrections_dict.
106

107
        Args:
108
            exp_entries: list of dictionary objects with the following keys/values:
109
                    {"formula": chemical formula, "exp energy": formation energy in eV/formula unit,
110
                    "uncertainty": uncertainty in formation energy}
111
            calc_entries: dictionary of computed entries, of the form {chemical formula: ComputedEntry}
112

113
        Raises:
114
            ValueError: calc_compounds is missing an entry
115
        """
116
        self.exp_compounds = exp_entries
1✔
117
        self.calc_compounds = calc_entries
1✔
118

119
        self.names: list[str] = []
1✔
120
        self.diffs: list[float] = []
1✔
121
        self.coeff_mat: list[list[float]] = []
1✔
122
        self.exp_uncer: list[float] = []
1✔
123

124
        # remove any corrections in calc_compounds
125
        for entry in self.calc_compounds.values():
1✔
126
            entry.correction = 0
1✔
127

128
        for cmpd_info in self.exp_compounds:
1✔
129
            # to get consistent element ordering in formula
130
            name = Composition(cmpd_info["formula"]).reduced_formula
1✔
131

132
            allow = True
1✔
133

134
            compound = self.calc_compounds.get(name, None)
1✔
135
            if not compound:
1✔
136
                warnings.warn(f"Compound {name} is not found in provided computed entries and is excluded from the fit")
1✔
137
                continue
1✔
138

139
            # filter out compounds with large uncertainties
140
            relative_uncertainty = abs(cmpd_info["uncertainty"] / cmpd_info["exp energy"])
1✔
141
            if relative_uncertainty > self.max_error:
1✔
142
                allow = False
1✔
143
                warnings.warn(
1✔
144
                    f"Compound {name} is excluded from the fit due to high experimental "
145
                    f"uncertainty ({relative_uncertainty:.1%})"
146
                )
147

148
            # filter out compounds containing certain polyanions
149
            for anion in self.exclude_polyanions:
1✔
150
                if anion in name or anion in cmpd_info["formula"]:
1✔
151
                    allow = False
1✔
152
                    warnings.warn(f"Compound {name} contains the polyanion {anion} and is excluded from the fit")
1✔
153
                    break
1✔
154

155
            # filter out compounds that are unstable
156
            if isinstance(self.allow_unstable, float):
1✔
157
                try:
1✔
158
                    eah = compound.data["e_above_hull"]
1✔
159
                except KeyError:
×
160
                    raise ValueError("Missing e above hull data")
×
161
                if eah > self.allow_unstable:
1✔
162
                    allow = False
1✔
163
                    warnings.warn(f"Compound {name} is unstable and excluded from the fit (e_above_hull = {eah})")
1✔
164

165
            if allow:
1✔
166
                comp = Composition(name)
1✔
167
                elems = list(comp.as_dict())
1✔
168

169
                reactants = []
1✔
170
                for elem in elems:
1✔
171
                    try:
1✔
172
                        elem_name = Composition(elem).reduced_formula
1✔
173
                        reactants.append(self.calc_compounds[elem_name])
1✔
174
                    except KeyError:
1✔
175
                        raise ValueError("Computed entries missing " + elem)
1✔
176

177
                rxn = ComputedReaction(reactants, [compound])
1✔
178
                rxn.normalize_to(comp)
1✔
179
                energy = rxn.calculated_reaction_energy
1✔
180

181
                coeff = []
1✔
182
                for specie in self.species:
1✔
183
                    if specie == "oxide":
1✔
184
                        if compound.data["oxide_type"] == "oxide":
1✔
185
                            coeff.append(comp["O"])
1✔
186
                            self.oxides.append(name)
1✔
187
                        else:
188
                            coeff.append(0)
1✔
189
                    elif specie == "peroxide":
1✔
190
                        if compound.data["oxide_type"] == "peroxide":
1✔
191
                            coeff.append(comp["O"])
1✔
192
                            self.peroxides.append(name)
1✔
193
                        else:
194
                            coeff.append(0)
1✔
195
                    elif specie == "superoxide":
1✔
196
                        if compound.data["oxide_type"] == "superoxide":
1✔
197
                            coeff.append(comp["O"])
1✔
198
                            self.superoxides.append(name)
1✔
199
                        else:
200
                            coeff.append(0)
1✔
201
                    elif specie == "S":
1✔
202
                        if Element("S") in comp:
1✔
203
                            sf_type = "sulfide"
1✔
204
                            if compound.data.get("sulfide_type"):
1✔
205
                                sf_type = compound.data["sulfide_type"]
×
206
                            elif hasattr(compound, "structure"):
1✔
207
                                sf_type = sulfide_type(compound.structure)
1✔
208
                            if sf_type == "sulfide":
1✔
209
                                coeff.append(comp["S"])
1✔
210
                                self.sulfides.append(name)
1✔
211
                            else:
212
                                coeff.append(0)
1✔
213
                        else:
214
                            coeff.append(0)
1✔
215
                    else:
216
                        try:
1✔
217
                            coeff.append(comp[specie])
1✔
218
                        except ValueError:
×
219
                            raise ValueError(f"We can't detect this specie: {specie}")
×
220

221
                self.names.append(name)
1✔
222
                self.diffs.append((cmpd_info["exp energy"] - energy) / comp.num_atoms)
1✔
223
                self.coeff_mat.append([i / comp.num_atoms for i in coeff])
1✔
224
                self.exp_uncer.append((cmpd_info["uncertainty"]) / comp.num_atoms)
1✔
225

226
        # for any exp entries with no uncertainty value, assign average uncertainty value
227
        sigma = np.array(self.exp_uncer)
1✔
228
        sigma[sigma == 0] = np.nan
1✔
229

230
        with warnings.catch_warnings():
1✔
231
            warnings.simplefilter(
1✔
232
                "ignore", category=RuntimeWarning
233
            )  # numpy raises warning if the entire array is nan values
234
            mean_uncer = np.nanmean(sigma)
1✔
235

236
        sigma = np.where(np.isnan(sigma), mean_uncer, sigma)
1✔
237

238
        if np.isnan(mean_uncer):
1✔
239
            # no uncertainty values for any compounds, don't try to weight
240
            popt, self.pcov = curve_fit(
1✔
241
                lambda x, *m: np.dot(x, m), self.coeff_mat, self.diffs, p0=np.ones(len(self.species))
242
            )
243
        else:
244
            popt, self.pcov = curve_fit(
1✔
245
                lambda x, *m: np.dot(x, m),
246
                self.coeff_mat,
247
                self.diffs,
248
                p0=np.ones(len(self.species)),
249
                sigma=sigma,
250
                absolute_sigma=True,
251
            )
252
        self.corrections = popt.tolist()
1✔
253
        self.corrections_std_error = np.sqrt(np.diag(self.pcov)).tolist()
1✔
254
        for i, v in enumerate(self.species):
1✔
255
            self.corrections_dict[v] = (
1✔
256
                round(self.corrections[i], 3),
257
                round(self.corrections_std_error[i], 4),
258
            )
259

260
        # set ozonide correction to 0 so that this species does not receive a correction
261
        # while other oxide types do
262
        self.corrections_dict["ozonide"] = (0, 0)
1✔
263

264
        return self.corrections_dict
1✔
265

266
    def graph_residual_error(self) -> go.Figure:
1✔
267
        """
268
        Graphs the residual errors for all compounds after applying computed corrections.
269
        """
270
        if len(self.corrections) == 0:
×
271
            raise RuntimeError("Please call compute_corrections or compute_from_files to calculate corrections first")
×
272

273
        abs_errors = [abs(i) for i in self.diffs - np.dot(self.coeff_mat, self.corrections)]
×
274
        labels_graph = self.names.copy()
×
275
        abs_errors, labels_graph = (list(t) for t in zip(*sorted(zip(abs_errors, labels_graph))))  # sort by error
×
276

277
        num = len(abs_errors)
×
278
        fig = go.Figure(
×
279
            data=go.Scatter(
280
                x=np.linspace(1, num, num),
281
                y=abs_errors,
282
                mode="markers",
283
                text=labels_graph,
284
            ),
285
            layout=go.Layout(
286
                title=go.layout.Title(text="Residual Errors"),
287
                yaxis=go.layout.YAxis(title=go.layout.yaxis.Title(text="Residual Error (eV/atom)")),
288
            ),
289
        )
290

291
        print("Residual Error:")
×
292
        print("Median = " + str(np.median(np.array(abs_errors))))
×
293
        print("Mean = " + str(np.mean(np.array(abs_errors))))
×
294
        print("Std Dev = " + str(np.std(np.array(abs_errors))))
×
295
        print("Original Error:")
×
296
        print("Median = " + str(abs(np.median(np.array(self.diffs)))))
×
297
        print("Mean = " + str(abs(np.mean(np.array(self.diffs)))))
×
298
        print("Std Dev = " + str(np.std(np.array(self.diffs))))
×
299

300
        return fig
×
301

302
    def graph_residual_error_per_species(self, specie: str) -> go.Figure:
1✔
303
        """
304
        Graphs the residual errors for each compound that contains specie after applying computed corrections.
305

306
        Args:
307
            specie: the specie/group that residual errors are being plotted for
308

309
        Raises:
310
            ValueError: the specie is not a valid specie that this class fits corrections for
311
        """
312
        if specie not in self.species:
×
313
            raise ValueError("not a valid specie")
×
314

315
        if len(self.corrections) == 0:
×
316
            raise RuntimeError("Please call compute_corrections or compute_from_files to calculate corrections first")
×
317

318
        abs_errors = [abs(i) for i in self.diffs - np.dot(self.coeff_mat, self.corrections)]
×
319
        labels_species = self.names.copy()
×
320
        diffs_cpy = self.diffs.copy()
×
321
        num = len(labels_species)
×
322

323
        if specie in ("oxide", "peroxide", "superoxide", "S"):
×
324
            if specie == "oxide":
×
325
                compounds = self.oxides
×
326
            elif specie == "peroxide":
×
327
                compounds = self.peroxides
×
328
            elif specie == "superoxides":
×
329
                compounds = self.superoxides
×
330
            else:
331
                compounds = self.sulfides
×
332
            for i in range(num):
×
333
                if labels_species[num - i - 1] not in compounds:
×
334
                    del labels_species[num - i - 1]
×
335
                    del abs_errors[num - i - 1]
×
336
                    del diffs_cpy[num - i - 1]
×
337
        else:
338
            for i in range(num):
×
339
                if not Composition(labels_species[num - i - 1])[specie]:
×
340
                    del labels_species[num - i - 1]
×
341
                    del abs_errors[num - i - 1]
×
342
                    del diffs_cpy[num - i - 1]
×
343
        abs_errors, labels_species = (list(t) for t in zip(*sorted(zip(abs_errors, labels_species))))  # sort by error
×
344

345
        num = len(abs_errors)
×
346
        fig = go.Figure(
×
347
            data=go.Scatter(
348
                x=np.linspace(1, num, num),
349
                y=abs_errors,
350
                mode="markers",
351
                text=labels_species,
352
            ),
353
            layout=go.Layout(
354
                title=go.layout.Title(text="Residual Errors for " + specie),
355
                yaxis=go.layout.YAxis(title=go.layout.yaxis.Title(text="Residual Error (eV/atom)")),
356
            ),
357
        )
358

359
        print("Residual Error:")
×
360
        print("Median = " + str(np.median(np.array(abs_errors))))
×
361
        print("Mean = " + str(np.mean(np.array(abs_errors))))
×
362
        print("Std Dev = " + str(np.std(np.array(abs_errors))))
×
363
        print("Original Error:")
×
364
        print("Median = " + str(abs(np.median(np.array(diffs_cpy)))))
×
365
        print("Mean = " + str(abs(np.mean(np.array(diffs_cpy)))))
×
366
        print("Std Dev = " + str(np.std(np.array(diffs_cpy))))
×
367

368
        return fig
×
369

370
    def make_yaml(self, name: str = "MP2020", dir: str | None = None) -> None:
1✔
371
        """
372
        Creates the _name_Compatibility.yaml that stores corrections as well as _name_CompatibilityUncertainties.yaml
373
        for correction uncertainties.
374

375
        Args:
376
            name: str, alternate name for the created .yaml file.
377
                Default: "MP2020"
378
            dir: str, directory in which to save the file. Pass None (default) to
379
                save the file in the current working directory.
380
        """
381
        if len(self.corrections) == 0:
×
382
            raise RuntimeError("Please call compute_corrections or compute_from_files to calculate corrections first")
×
383

384
        # elements with U values
385
        ggau_correction_species = ["V", "Cr", "Mn", "Fe", "Co", "Ni", "W", "Mo"]
×
386

387
        comp_corr: dict[str, float] = {}
×
388
        o: dict[str, float] = {}
×
389
        f: dict[str, float] = {}
×
390

391
        comp_corr_error: dict[str, float] = {}
×
392
        o_error: dict[str, float] = {}
×
393
        f_error: dict[str, float] = {}
×
394

395
        for specie in list(self.species) + ["ozonide"]:
×
396
            if specie in ggau_correction_species:
×
397
                o[specie] = self.corrections_dict[specie][0]
×
398
                f[specie] = self.corrections_dict[specie][0]
×
399

400
                o_error[specie] = self.corrections_dict[specie][1]
×
401
                f_error[specie] = self.corrections_dict[specie][1]
×
402

403
            else:
404
                comp_corr[specie] = self.corrections_dict[specie][0]
×
405
                comp_corr_error[specie] = self.corrections_dict[specie][1]
×
406

407
        outline = """\
×
408
        Name:
409
        Corrections:
410
            GGAUMixingCorrections:
411
                O:
412
                F:
413
            CompositionCorrections:
414
        Uncertainties:
415
            GGAUMixingCorrections:
416
                O:
417
                F:
418
            CompositionCorrections:
419
        """
420
        fn = name + "Compatibility.yaml"
×
421
        if dir:
×
422
            path = os.path.join(dir, fn)
×
423
        else:
424
            path = fn
×
425

426
        yml = yaml.YAML()
×
427
        yml.default_flow_style = False
×
428
        contents = yml.load(outline)
×
429

430
        contents["Name"] = name
×
431

432
        # make CommentedMap so comments can be added
433
        contents["Corrections"]["GGAUMixingCorrections"]["O"] = yaml.comments.CommentedMap(o)
×
434
        contents["Corrections"]["GGAUMixingCorrections"]["F"] = yaml.comments.CommentedMap(f)
×
435
        contents["Corrections"]["CompositionCorrections"] = yaml.comments.CommentedMap(comp_corr)
×
436
        contents["Uncertainties"]["GGAUMixingCorrections"]["O"] = yaml.comments.CommentedMap(o_error)
×
437
        contents["Uncertainties"]["GGAUMixingCorrections"]["F"] = yaml.comments.CommentedMap(f_error)
×
438
        contents["Uncertainties"]["CompositionCorrections"] = yaml.comments.CommentedMap(comp_corr_error)
×
439

440
        contents["Corrections"].yaml_set_start_comment("Energy corrections in eV/atom", indent=2)
×
441
        contents["Corrections"]["GGAUMixingCorrections"].yaml_set_start_comment(
×
442
            "Composition-based corrections applied to transition metal oxides\nand fluorides to "
443
            + 'make GGA and GGA+U energies compatible\nwhen compat_type = "Advanced" (default)',
444
            indent=4,
445
        )
446
        contents["Corrections"]["CompositionCorrections"].yaml_set_start_comment(
×
447
            "Composition-based corrections applied to any compound containing\nthese species as anions",
448
            indent=4,
449
        )
450
        contents["Uncertainties"].yaml_set_start_comment(
×
451
            "Uncertainties corresponding to each energy correction (eV/atom)", indent=2
452
        )
453
        with open(path, "w") as file:
×
454
            yml.dump(contents, file)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc