• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stfc / janus-core / 13675396748

05 Mar 2025 11:55AM UTC coverage: 92.587% (+0.002%) from 92.585%
13675396748

Pull #462

github

web-flow
Merge cfda1fe27 into a38fc0155
Pull Request #462: Add default output directory

34 of 36 new or added lines in 11 files covered. (94.44%)

90 existing lines in 14 files now uncovered.

2598 of 2806 relevant lines covered (92.59%)

2.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.51
/janus_core/cli/utils.py
1
"""Utility functions for CLI."""
2

3
from __future__ import annotations
3✔
4

5
from collections.abc import Sequence
3✔
6
import datetime
3✔
7
import logging
3✔
8
from pathlib import Path
3✔
9
from typing import TYPE_CHECKING, Any
3✔
10

11
from typer_config import conf_callback_factory, yaml_loader
3✔
12
import yaml
3✔
13

14
if TYPE_CHECKING:
3✔
NEW
15
    from ase import Atoms
×
16
    from typer import Context
×
17

18
    from janus_core.cli.types import TyperDict
×
UNCOV
19
    from janus_core.helpers.janus_types import (
×
20
        Architectures,
21
        ASEReadArgs,
22
        Devices,
23
        MaybeSequence,
24
    )
25

26

27
def dict_paths_to_strs(dictionary: dict) -> None:
3✔
28
    """
29
    Recursively iterate over dictionary, converting Path values to strings.
30

31
    Parameters
32
    ----------
33
    dictionary
34
        Dictionary to be converted.
35
    """
36
    for key, value in dictionary.items():
3✔
37
        if isinstance(value, dict):
3✔
38
            dict_paths_to_strs(value)
3✔
39
        elif isinstance(value, Path):
3✔
40
            dictionary[key] = str(value)
3✔
41

42

43
def dict_tuples_to_lists(dictionary: dict) -> None:
3✔
44
    """
45
    Recursively iterate over dictionary, converting tuple values to lists.
46

47
    Parameters
48
    ----------
49
    dictionary
50
        Dictionary to be converted.
51
    """
52
    for key, value in dictionary.items():
3✔
53
        if isinstance(value, dict):
3✔
54
            dict_paths_to_strs(value)
3✔
55
        elif isinstance(value, tuple):
3✔
56
            dictionary[key] = list(value)
3✔
57

58

59
def dict_remove_hyphens(dictionary: dict) -> dict:
3✔
60
    """
61
    Recursively iterate over dictionary, replacing hyphens with underscores in keys.
62

63
    Parameters
64
    ----------
65
    dictionary
66
        Dictionary to be converted.
67

68
    Returns
69
    -------
70
    dict
71
        Dictionary with hyphens in keys replaced with underscores.
72
    """
73
    for key, value in dictionary.items():
3✔
74
        if isinstance(value, dict):
3✔
75
            dictionary[key] = dict_remove_hyphens(value)
3✔
76
    return {k.replace("-", "_"): v for k, v in dictionary.items()}
3✔
77

78

79
def set_read_kwargs_index(read_kwargs: dict[str, Any]) -> None:
3✔
80
    """
81
    Set default read_kwargs["index"] to final image and check its value is an integer.
82

83
    To ensure only a single Atoms object is read, slices such as ":" are forbidden.
84

85
    Parameters
86
    ----------
87
    read_kwargs
88
        Keyword arguments to be passed to ase.io.read. If specified,
89
        read_kwargs["index"] must be an integer, and if not, a default value
90
        of -1 is set.
91
    """
92
    read_kwargs.setdefault("index", -1)
3✔
93
    try:
3✔
94
        int(read_kwargs["index"])
3✔
95
    except ValueError as e:
3✔
96
        raise ValueError("`read_kwargs['index']` must be an integer") from e
3✔
97

98

99
def parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
3✔
100
    """
101
    Convert list of TyperDict objects to list of dictionaries.
102

103
    Parameters
104
    ----------
105
    typer_dicts
106
        List of TyperDict objects to convert.
107

108
    Returns
109
    -------
110
    list[dict]
111
        List of converted dictionaries.
112

113
    Raises
114
    ------
115
    ValueError
116
        If items in list are not converted to dicts.
117
    """
118
    for i, typer_dict in enumerate(typer_dicts):
3✔
119
        typer_dicts[i] = typer_dict.value if typer_dict else {}
3✔
120
        if not isinstance(typer_dicts[i], dict):
3✔
UNCOV
121
            raise ValueError(
×
122
                f"""{typer_dicts[i]} must be passed as a dictionary wrapped in quotes.\
123
 For example, "{{'key': value}}" """
124
            )
125
    return typer_dicts
3✔
126

127

128
def yaml_converter_loader(config_file: str) -> dict[str, Any]:
3✔
129
    """
130
    Load yaml configuration and replace hyphens with underscores.
131

132
    Parameters
133
    ----------
134
    config_file
135
        Yaml configuration file to read.
136

137
    Returns
138
    -------
139
    dict[str, Any]
140
        Dictionary with loaded configuration.
141
    """
142
    if not config_file:
3✔
143
        return {}
3✔
144

145
    config = yaml_loader(config_file)
3✔
146
    # Replace all "-"" with "_" in conf
147
    return dict_remove_hyphens(config)
3✔
148

149

150
yaml_converter_callback = conf_callback_factory(yaml_converter_loader)
3✔
151

152

153
def start_summary(*, command: str, summary: Path, inputs: dict) -> None:
3✔
154
    """
155
    Write initial summary contents.
156

157
    Parameters
158
    ----------
159
    command
160
        Name of CLI command being used.
161
    summary
162
        Path to summary file being saved.
163
    inputs
164
        Inputs to CLI command to save.
165
    """
166
    save_info = {
3✔
167
        "command": f"janus {command}",
168
        "start_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S"),
169
        "inputs": inputs,
170
    }
171
    with open(summary, "w", encoding="utf8") as outfile:
3✔
172
        yaml.dump(save_info, outfile, default_flow_style=False)
3✔
173

174

175
def carbon_summary(*, summary: Path, log: Path) -> None:
3✔
176
    """
177
    Calculate and write carbon tracking summary.
178

179
    Parameters
180
    ----------
181
    summary
182
        Path to summary file being saved.
183
    log
184
        Path to log file with carbon emissions saved.
185
    """
186
    with open(log, encoding="utf8") as file:
3✔
187
        logs = yaml.safe_load(file)
3✔
188

189
    emissions = sum(
3✔
190
        lg["message"]["emissions"]
191
        for lg in logs
192
        if isinstance(lg["message"], dict) and "emissions" in lg["message"]
193
    )
194

195
    with open(summary, "a", encoding="utf8") as outfile:
3✔
196
        yaml.dump({"emissions": emissions}, outfile, default_flow_style=False)
3✔
197

198

199
def end_summary(summary: Path) -> None:
3✔
200
    """
201
    Write final time to summary and close.
202

203
    Parameters
204
    ----------
205
    summary
206
        Path to summary file being saved.
207
    """
208
    with open(summary, "a", encoding="utf8") as outfile:
3✔
209
        yaml.dump(
3✔
210
            {"end_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S")},
211
            outfile,
212
            default_flow_style=False,
213
        )
214
    logging.shutdown()
3✔
215

216

217
def save_struct_calc(
3✔
218
    *,
219
    inputs: dict,
220
    struct: MaybeSequence[Atoms],
221
    struct_path: Path,
222
    arch: Architectures,
223
    device: Devices,
224
    model_path: str,
225
    read_kwargs: ASEReadArgs,
226
    calc_kwargs: dict[str, Any],
227
    log: Path,
228
) -> None:
229
    """
230
    Add structure and calculator input information to a dictionary.
231

232
    Parameters
233
    ----------
234
    inputs
235
        Inputs dictionary to add information to.
236
    struct
237
        Structure to be simulated.
238
    struct_path
239
        Path of structure file.
240
    arch
241
        MLIP architecture.
242
    device
243
        Device to run calculations on.
244
    model_path
245
        Path to MLIP model.
246
    read_kwargs
247
        Keyword arguments to pass to ase.io.read.
248
    calc_kwargs
249
        Keyword arguments to pass to the calculator.
250
    log
251
        Path to log file.
252
    """
253
    from ase import Atoms
3✔
254

255
    # Clean up duplicate parameters
256
    for key in (
3✔
257
        "struct",
258
        "struct_path",
259
        "arch",
260
        "device",
261
        "model_path",
262
        "read_kwargs",
263
        "calc_kwargs",
264
        "log_kwargs",
265
    ):
266
        inputs.pop(key, None)
3✔
267

268
    if isinstance(struct, Atoms):
3✔
269
        inputs["struct"] = {
3✔
270
            "n_atoms": len(struct),
271
            "struct_path": struct_path,
272
            "formula": struct.get_chemical_formula(),
273
        }
274
    elif isinstance(struct, Sequence):
3✔
275
        inputs["traj"] = {
3✔
276
            "length": len(struct),
277
            "struct_path": struct_path,
278
            "struct": {
279
                "n_atoms": len(struct[0]),
280
                "formula": struct[0].get_chemical_formula(),
281
            },
282
        }
283

284
    inputs["calc"] = {
3✔
285
        "arch": arch,
286
        "device": device,
287
        "model_path": model_path,
288
        "read_kwargs": read_kwargs,
289
        "calc_kwargs": calc_kwargs,
290
    }
291

292
    inputs["log"] = log
3✔
293

294
    # Convert all paths to strings in inputs nested dictionary
295
    dict_paths_to_strs(inputs)
3✔
296

297

298
def check_config(ctx: Context) -> None:
3✔
299
    """
300
    Check options in configuration file are valid options for CLI command.
301

302
    Parameters
303
    ----------
304
    ctx
305
        Typer (Click) Context within command.
306
    """
307
    # Compare options from config file (default_map) to function definition (params)
308
    for option in ctx.default_map:
3✔
309
        # Check options individually so can inform user of specific issue
310
        if option not in ctx.params:
3✔
311
            raise ValueError(f"'{option}' in configuration file is not a valid option")
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc