• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stfc / janus-core / 14172215793

31 Mar 2025 01:03PM UTC coverage: 92.205% (-0.4%) from 92.585%
14172215793

Pull #421

github

web-flow
Merge 36c48fbff into 47f7cc79b
Pull Request #421: add fairchem fixes #420

1 of 11 new or added lines in 1 file covered. (9.09%)

124 existing lines in 19 files now uncovered.

2756 of 2989 relevant lines covered (92.2%)

2.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.86
/janus_core/cli/utils.py
1
"""Utility functions for CLI."""
2

3
from __future__ import annotations
3✔
4

5
from collections.abc import Sequence
3✔
6
from copy import deepcopy
3✔
7
import datetime
3✔
8
import logging
3✔
9
from pathlib import Path
3✔
10
from typing import TYPE_CHECKING, Any
3✔
11

12
from typer_config import conf_callback_factory, yaml_loader
3✔
13
import yaml
3✔
14

15
from janus_core.helpers.utils import build_file_dir
3✔
16

17
if TYPE_CHECKING:
3✔
18
    from ase import Atoms
×
19
    from typer import Context
×
20

UNCOV
21
    from janus_core.cli.types import CorrelationKwargs, TyperDict
×
UNCOV
22
    from janus_core.helpers.janus_types import (
×
23
        MaybeSequence,
24
        PathLike,
25
    )
26

27

28
def dict_paths_to_strs(dictionary: dict) -> None:
3✔
29
    """
30
    Recursively iterate over dictionary, converting Path values to strings.
31

32
    Parameters
33
    ----------
34
    dictionary
35
        Dictionary to be converted.
36
    """
37
    for key, value in dictionary.items():
3✔
38
        if isinstance(value, dict):
3✔
39
            dict_paths_to_strs(value)
3✔
40
        elif isinstance(value, Sequence) and not isinstance(value, str):
3✔
41
            dictionary[key] = [
3✔
42
                str(path) if isinstance(path, Path) else path for path in value
43
            ]
44
        elif isinstance(value, Path):
3✔
45
            dictionary[key] = str(value)
3✔
46

47

48
def dict_tuples_to_lists(dictionary: dict) -> None:
3✔
49
    """
50
    Recursively iterate over dictionary, converting tuple values to lists.
51

52
    Parameters
53
    ----------
54
    dictionary
55
        Dictionary to be converted.
56
    """
57
    for key, value in dictionary.items():
3✔
58
        if isinstance(value, dict):
3✔
59
            dict_tuples_to_lists(value)
3✔
60
        elif isinstance(value, tuple):
3✔
UNCOV
61
            dictionary[key] = list(value)
×
62
        elif isinstance(value, list):
3✔
63
            dictionary[key] = [list(x) if isinstance(x, tuple) else x for x in value]
3✔
64

65

66
def dict_remove_hyphens(dictionary: dict) -> dict:
3✔
67
    """
68
    Recursively iterate over dictionary, replacing hyphens with underscores in keys.
69

70
    Parameters
71
    ----------
72
    dictionary
73
        Dictionary to be converted.
74

75
    Returns
76
    -------
77
    dict
78
        Dictionary with hyphens in keys replaced with underscores.
79
    """
80
    for key, value in dictionary.items():
3✔
81
        if isinstance(value, dict):
3✔
82
            dictionary[key] = dict_remove_hyphens(value)
3✔
83
    return {k.replace("-", "_"): v for k, v in dictionary.items()}
3✔
84

85

86
def set_read_kwargs_index(read_kwargs: dict[str, Any]) -> None:
3✔
87
    """
88
    Set default read_kwargs["index"] to final image and check its value is an integer.
89

90
    To ensure only a single Atoms object is read, slices such as ":" are forbidden.
91

92
    Parameters
93
    ----------
94
    read_kwargs
95
        Keyword arguments to be passed to ase.io.read. If specified,
96
        read_kwargs["index"] must be an integer, and if not, a default value
97
        of -1 is set.
98
    """
99
    read_kwargs.setdefault("index", -1)
3✔
100
    try:
3✔
101
        int(read_kwargs["index"])
3✔
102
    except ValueError as e:
3✔
103
        raise ValueError("`read_kwargs['index']` must be an integer") from e
3✔
104

105

106
def parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
3✔
107
    """
108
    Convert list of TyperDict objects to list of dictionaries.
109

110
    Parameters
111
    ----------
112
    typer_dicts
113
        List of TyperDict objects to convert.
114

115
    Returns
116
    -------
117
    list[dict]
118
        List of converted dictionaries.
119

120
    Raises
121
    ------
122
    ValueError
123
        If items in list are not converted to dicts.
124
    """
125
    for i, typer_dict in enumerate(typer_dicts):
3✔
126
        typer_dicts[i] = typer_dict.value if typer_dict else {}
3✔
127
        if not isinstance(typer_dicts[i], dict):
3✔
UNCOV
128
            raise ValueError(
×
129
                f"""{typer_dicts[i]} must be passed as a dictionary wrapped in quotes.\
130
 For example, "{{'key': value}}" """
131
            )
132
    return typer_dicts
3✔
133

134

135
def yaml_converter_loader(config_file: str) -> dict[str, Any]:
3✔
136
    """
137
    Load yaml configuration and replace hyphens with underscores.
138

139
    Parameters
140
    ----------
141
    config_file
142
        Yaml configuration file to read.
143

144
    Returns
145
    -------
146
    dict[str, Any]
147
        Dictionary with loaded configuration.
148
    """
149
    if not config_file:
3✔
150
        return {}
3✔
151

152
    config = yaml_loader(config_file)
3✔
153
    # Replace all "-"" with "_" in conf
154
    return dict_remove_hyphens(config)
3✔
155

156

157
yaml_converter_callback = conf_callback_factory(yaml_converter_loader)
3✔
158

159

160
def start_summary(
3✔
161
    *,
162
    command: str,
163
    summary: Path,
164
    config: dict[str, Any],
165
    info: dict[str, Any],
166
    output_files: dict[str, PathLike],
167
) -> None:
168
    """
169
    Write initial summary contents.
170

171
    Parameters
172
    ----------
173
    command
174
        Name of CLI command being used.
175
    summary
176
        Path to summary file being saved.
177
    config
178
        Inputs to CLI command to save.
179
    info
180
        Extra information to save.
181
    output_files
182
        Output files with labels to be generated by CLI command.
183
    """
184
    config.pop("config", None)
3✔
185
    output_files["summary"] = summary.absolute()
3✔
186

187
    summary_contents = {
3✔
188
        "command": f"janus {command}",
189
        "start_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S"),
190
        "config": config,
191
        "info": info,
192
        "output_files": output_files,
193
    }
194

195
    # Convert all paths to strings in inputs nested dictionary
196
    dict_paths_to_strs(summary_contents)
3✔
197
    dict_tuples_to_lists(summary_contents)
3✔
198

199
    build_file_dir(summary)
3✔
200
    with open(summary, "w", encoding="utf8") as outfile:
3✔
201
        yaml.dump(summary_contents, outfile, default_flow_style=False)
3✔
202

203

204
def carbon_summary(*, summary: Path, log: Path) -> None:
3✔
205
    """
206
    Calculate and write carbon tracking summary.
207

208
    Parameters
209
    ----------
210
    summary
211
        Path to summary file being saved.
212
    log
213
        Path to log file with carbon emissions saved.
214
    """
215
    with open(log, encoding="utf8") as file:
3✔
216
        logs = yaml.safe_load(file)
3✔
217

218
    emissions = sum(
3✔
219
        lg["message"]["emissions"]
220
        for lg in logs
221
        if isinstance(lg["message"], dict) and "emissions" in lg["message"]
222
    )
223

224
    with open(summary, "a", encoding="utf8") as outfile:
3✔
225
        yaml.dump({"emissions": emissions}, outfile, default_flow_style=False)
3✔
226

227

228
def end_summary(summary: Path) -> None:
3✔
229
    """
230
    Write final time to summary and close.
231

232
    Parameters
233
    ----------
234
    summary
235
        Path to summary file being saved.
236
    """
237
    with open(summary, "a", encoding="utf8") as outfile:
3✔
238
        yaml.dump(
3✔
239
            {"end_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S")},
240
            outfile,
241
            default_flow_style=False,
242
        )
243
    logging.shutdown()
3✔
244

245

246
def get_struct_info(
3✔
247
    *,
248
    struct: MaybeSequence[Atoms],
249
    struct_path: Path,
250
) -> dict[str, Any]:
251
    """
252
    Add structure information to a dictionary.
253

254
    Parameters
255
    ----------
256
    struct
257
        Structure to be simulated.
258
    struct_path
259
        Path of structure file.
260

261
    Returns
262
    -------
263
    dict[str, Any]
264
        Dictionary with structure information.
265
    """
266
    from ase import Atoms
3✔
267

268
    info = {}
3✔
269

270
    if isinstance(struct, Atoms):
3✔
271
        info["struct"] = {
3✔
272
            "n_atoms": len(struct),
273
            "struct_path": struct_path,
274
            "formula": struct.get_chemical_formula(),
275
        }
276
    elif isinstance(struct, Sequence):
3✔
277
        info["traj"] = {
3✔
278
            "length": len(struct),
279
            "struct_path": struct_path,
280
            "struct": {
281
                "n_atoms": len(struct[0]),
282
                "formula": struct[0].get_chemical_formula(),
283
            },
284
        }
285

286
    return info
3✔
287

288

289
def get_config(*, params: dict[str, Any], all_kwargs: dict[str, Any]) -> dict[str, Any]:
3✔
290
    """
291
    Get configuration and set kwargs dictionaries.
292

293
    Parameters
294
    ----------
295
    params
296
        CLI input parameters from ctx.
297
    all_kwargs
298
        Name and contents of all kwargs dictionaries.
299

300
    Returns
301
    -------
302
    dict[str, Any]
303
        Input parameters with parsed kwargs dictionaries substituted in.
304
    """
305
    for param in params:
3✔
306
        if param in all_kwargs:
3✔
307
            params[param] = all_kwargs[param]
3✔
308

309
    return params
3✔
310

311

312
def check_config(ctx: Context) -> None:
3✔
313
    """
314
    Check options in configuration file are valid options for CLI command.
315

316
    Parameters
317
    ----------
318
    ctx
319
        Typer (Click) Context within command.
320
    """
321
    # Compare options from config file (default_map) to function definition (params)
322
    for option in ctx.default_map:
3✔
323
        # Check options individually so can inform user of specific issue
324
        if option not in ctx.params:
3✔
325
            raise ValueError(f"'{option}' in configuration file is not a valid option")
3✔
326

327

328
def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]:
3✔
329
    """
330
    Parse CLI CorrelationKwargs to md correlation_kwargs.
331

332
    Parameters
333
    ----------
334
    kwargs
335
        CLI correlation keyword options.
336

337
    Returns
338
    -------
339
    list[dict]
340
        The parsed correlation_kwargs for md.
341
    """
342
    from janus_core.processing import observables
3✔
343

344
    parsed_kwargs = []
3✔
345
    for name, cli_kwargs in kwargs.items():
3✔
346
        arguments = {
3✔
347
            "blocks",
348
            "points",
349
            "averaging",
350
            "update_frequency",
351
            "a_kwargs",
352
            "b_kwargs",
353
            "a",
354
            "b",
355
        }
356
        if not (set(cli_kwargs.keys()) <= arguments):
3✔
357
            raise ValueError(
3✔
358
                "correlation_kwargs got unexpected argument(s)"
359
                f"{set(cli_kwargs.keys()).difference(arguments)}"
360
            )
361

362
        if "a" not in cli_kwargs and "b" not in cli_kwargs:
3✔
363
            raise ValueError("At least one observable must be supplied as 'a' or 'b'")
3✔
364

365
        if "points" not in cli_kwargs:
3✔
366
            raise ValueError("Correlation keyword argument 'points' must be specified")
3✔
367

368
        # Accept an Observable to be replicated.
369
        if "b" not in cli_kwargs:
3✔
370
            a = cli_kwargs["a"]
3✔
371
            b = deepcopy(a)
3✔
372
            # Copying Observable, so can copy kwargs as well.
373
            if "b_kwargs" not in cli_kwargs and "a_kwargs" in cli_kwargs:
3✔
374
                cli_kwargs["b_kwargs"] = cli_kwargs["a_kwargs"]
3✔
375
        elif "a" not in cli_kwargs:
3✔
376
            b = cli_kwargs["b"]
3✔
377
            a = deepcopy(b)
3✔
378
            if "a_kwargs" not in cli_kwargs and "b_kwargs" in cli_kwargs:
3✔
UNCOV
379
                cli_kwargs["a_kwargs"] = cli_kwargs["b_kwargs"]
×
380
        else:
381
            a = cli_kwargs["a"]
3✔
382
            b = cli_kwargs["b"]
3✔
383

384
        a_kwargs = cli_kwargs["a_kwargs"] if "a_kwargs" in cli_kwargs else {}
3✔
385
        b_kwargs = cli_kwargs["b_kwargs"] if "b_kwargs" in cli_kwargs else {}
3✔
386

387
        # Accept "." in place of one kwargs to repeat.
388
        if a_kwargs == "." and b_kwargs == ".":
3✔
389
            raise ValueError("a_kwargs and b_kwargs cannot 'ditto' each other")
3✔
390
        if a_kwargs and b_kwargs == ".":
3✔
391
            b_kwargs = a_kwargs
3✔
392
        elif b_kwargs and a_kwargs == ".":
3✔
UNCOV
393
            a_kwargs = b_kwargs
×
394

395
        cor_kwargs = {
3✔
396
            "name": name,
397
            "points": cli_kwargs["points"],
398
            "a": getattr(observables, a)(**a_kwargs),
399
            "b": getattr(observables, b)(**b_kwargs),
400
        }
401

402
        for optional in cli_kwargs.keys() & {"blocks", "averaging", "update_frequency"}:
3✔
UNCOV
403
            cor_kwargs[optional] = cli_kwargs[optional]
×
404

405
        parsed_kwargs.append(cor_kwargs)
3✔
406
    return parsed_kwargs
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc