• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stfc / janus-core / 13943776973

19 Mar 2025 10:01AM UTC coverage: 92.534% (-0.1%) from 92.657%
13943776973

Pull #388

github

web-flow
Merge 75651ab37 into 09d1a281f
Pull Request #388: Add Correlator CLI and defaults

30 of 37 new or added lines in 5 files covered. (81.08%)

8 existing lines in 2 files now uncovered.

2677 of 2893 relevant lines covered (92.53%)

2.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.27
/janus_core/cli/utils.py
1
"""Utility functions for CLI."""
2

3
from __future__ import annotations
3✔
4

5
from collections.abc import Sequence
3✔
6
from copy import deepcopy
3✔
7
import datetime
3✔
8
import logging
3✔
9
from pathlib import Path
3✔
10
from typing import TYPE_CHECKING, Any
3✔
11

12
from typer_config import conf_callback_factory, yaml_loader
3✔
13
import yaml
3✔
14

15
from janus_core.helpers.utils import build_file_dir
3✔
16

17
if TYPE_CHECKING:
3✔
18
    from ase import Atoms
×
19
    from typer import Context
×
20

NEW
21
    from janus_core.cli.types import CorrelationKwargs, TyperDict
×
22
    from janus_core.helpers.janus_types import (
×
23
        MaybeSequence,
24
    )
25

26
from janus_core.processing import observables
3✔
27

28

29
def dict_paths_to_strs(dictionary: dict) -> None:
3✔
30
    """
31
    Recursively iterate over dictionary, converting Path values to strings.
32

33
    Parameters
34
    ----------
35
    dictionary
36
        Dictionary to be converted.
37
    """
38
    for key, value in dictionary.items():
3✔
39
        if isinstance(value, dict):
3✔
40
            dict_paths_to_strs(value)
3✔
41
        elif isinstance(value, Path):
3✔
42
            dictionary[key] = str(value)
3✔
43

44

45
def dict_tuples_to_lists(dictionary: dict) -> None:
3✔
46
    """
47
    Recursively iterate over dictionary, converting tuple values to lists.
48

49
    Parameters
50
    ----------
51
    dictionary
52
        Dictionary to be converted.
53
    """
54
    for key, value in dictionary.items():
3✔
55
        if isinstance(value, dict):
3✔
56
            dict_tuples_to_lists(value)
3✔
57
        elif isinstance(value, tuple):
3✔
58
            dictionary[key] = list(value)
3✔
59

60

61
def dict_remove_hyphens(dictionary: dict) -> dict:
3✔
62
    """
63
    Recursively iterate over dictionary, replacing hyphens with underscores in keys.
64

65
    Parameters
66
    ----------
67
    dictionary
68
        Dictionary to be converted.
69

70
    Returns
71
    -------
72
    dict
73
        Dictionary with hyphens in keys replaced with underscores.
74
    """
75
    for key, value in dictionary.items():
3✔
76
        if isinstance(value, dict):
3✔
77
            dictionary[key] = dict_remove_hyphens(value)
3✔
78
    return {k.replace("-", "_"): v for k, v in dictionary.items()}
3✔
79

80

81
def set_read_kwargs_index(read_kwargs: dict[str, Any]) -> None:
3✔
82
    """
83
    Set default read_kwargs["index"] to final image and check its value is an integer.
84

85
    To ensure only a single Atoms object is read, slices such as ":" are forbidden.
86

87
    Parameters
88
    ----------
89
    read_kwargs
90
        Keyword arguments to be passed to ase.io.read. If specified,
91
        read_kwargs["index"] must be an integer, and if not, a default value
92
        of -1 is set.
93
    """
94
    read_kwargs.setdefault("index", -1)
3✔
95
    try:
3✔
96
        int(read_kwargs["index"])
3✔
97
    except ValueError as e:
3✔
98
        raise ValueError("`read_kwargs['index']` must be an integer") from e
3✔
99

100

101
def parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
3✔
102
    """
103
    Convert list of TyperDict objects to list of dictionaries.
104

105
    Parameters
106
    ----------
107
    typer_dicts
108
        List of TyperDict objects to convert.
109

110
    Returns
111
    -------
112
    list[dict]
113
        List of converted dictionaries.
114

115
    Raises
116
    ------
117
    ValueError
118
        If items in list are not converted to dicts.
119
    """
120
    for i, typer_dict in enumerate(typer_dicts):
3✔
121
        typer_dicts[i] = typer_dict.value if typer_dict else {}
3✔
122
        if not isinstance(typer_dicts[i], dict):
3✔
UNCOV
123
            raise ValueError(
×
124
                f"""{typer_dicts[i]} must be passed as a dictionary wrapped in quotes.\
125
 For example, "{{'key': value}}" """
126
            )
127
    return typer_dicts
3✔
128

129

130
def yaml_converter_loader(config_file: str) -> dict[str, Any]:
3✔
131
    """
132
    Load yaml configuration and replace hyphens with underscores.
133

134
    Parameters
135
    ----------
136
    config_file
137
        Yaml configuration file to read.
138

139
    Returns
140
    -------
141
    dict[str, Any]
142
        Dictionary with loaded configuration.
143
    """
144
    if not config_file:
3✔
145
        return {}
3✔
146

147
    config = yaml_loader(config_file)
3✔
148
    # Replace all "-"" with "_" in conf
149
    return dict_remove_hyphens(config)
3✔
150

151

152
yaml_converter_callback = conf_callback_factory(yaml_converter_loader)
3✔
153

154

155
def start_summary(
3✔
156
    *, command: str, summary: Path, config: dict[str, Any], info: dict[str, Any]
157
) -> None:
158
    """
159
    Write initial summary contents.
160

161
    Parameters
162
    ----------
163
    command
164
        Name of CLI command being used.
165
    summary
166
        Path to summary file being saved.
167
    config
168
        Inputs to CLI command to save.
169
    info
170
        Extra information to save.
171
    """
172
    config.pop("config", None)
3✔
173

174
    summary_contents = {
3✔
175
        "command": f"janus {command}",
176
        "start_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S"),
177
        "config": config,
178
        "info": info,
179
    }
180

181
    # Convert all paths to strings in inputs nested dictionary
182
    dict_paths_to_strs(summary_contents)
3✔
183
    dict_tuples_to_lists(summary_contents)
3✔
184

185
    build_file_dir(summary)
3✔
186
    with open(summary, "w", encoding="utf8") as outfile:
3✔
187
        yaml.dump(summary_contents, outfile, default_flow_style=False)
3✔
188

189

190
def carbon_summary(*, summary: Path, log: Path) -> None:
3✔
191
    """
192
    Calculate and write carbon tracking summary.
193

194
    Parameters
195
    ----------
196
    summary
197
        Path to summary file being saved.
198
    log
199
        Path to log file with carbon emissions saved.
200
    """
201
    with open(log, encoding="utf8") as file:
3✔
202
        logs = yaml.safe_load(file)
3✔
203

204
    emissions = sum(
3✔
205
        lg["message"]["emissions"]
206
        for lg in logs
207
        if isinstance(lg["message"], dict) and "emissions" in lg["message"]
208
    )
209

210
    with open(summary, "a", encoding="utf8") as outfile:
3✔
211
        yaml.dump({"emissions": emissions}, outfile, default_flow_style=False)
3✔
212

213

214
def end_summary(summary: Path) -> None:
3✔
215
    """
216
    Write final time to summary and close.
217

218
    Parameters
219
    ----------
220
    summary
221
        Path to summary file being saved.
222
    """
223
    with open(summary, "a", encoding="utf8") as outfile:
3✔
224
        yaml.dump(
3✔
225
            {"end_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S")},
226
            outfile,
227
            default_flow_style=False,
228
        )
229
    logging.shutdown()
3✔
230

231

232
def get_struct_info(
3✔
233
    *,
234
    struct: MaybeSequence[Atoms],
235
    struct_path: Path,
236
) -> dict[str, Any]:
237
    """
238
    Add structure information to a dictionary.
239

240
    Parameters
241
    ----------
242
    struct
243
        Structure to be simulated.
244
    struct_path
245
        Path of structure file.
246

247
    Returns
248
    -------
249
    dict[str, Any]
250
        Dictionary with structure information.
251
    """
252
    from ase import Atoms
3✔
253

254
    info = {}
3✔
255

256
    if isinstance(struct, Atoms):
3✔
257
        info["struct"] = {
3✔
258
            "n_atoms": len(struct),
259
            "struct_path": struct_path,
260
            "formula": struct.get_chemical_formula(),
261
        }
262
    elif isinstance(struct, Sequence):
3✔
263
        info["traj"] = {
3✔
264
            "length": len(struct),
265
            "struct_path": struct_path,
266
            "struct": {
267
                "n_atoms": len(struct[0]),
268
                "formula": struct[0].get_chemical_formula(),
269
            },
270
        }
271

272
    return info
3✔
273

274

275
def get_config(*, params: dict[str, Any], all_kwargs: dict[str, Any]) -> dict[str, Any]:
3✔
276
    """
277
    Get configuration and set kwargs dictionaries.
278

279
    Parameters
280
    ----------
281
    params
282
        CLI input parameters from ctx.
283
    all_kwargs
284
        Name and contents of all kwargs dictionaries.
285

286
    Returns
287
    -------
288
    dict[str, Any]
289
        Input parameters with parsed kwargs dictionaries substituted in.
290
    """
291
    for param in params:
3✔
292
        if param in all_kwargs:
3✔
293
            params[param] = all_kwargs[param]
3✔
294

295
    return params
3✔
296

297

298
def check_config(ctx: Context) -> None:
3✔
299
    """
300
    Check options in configuration file are valid options for CLI command.
301

302
    Parameters
303
    ----------
304
    ctx
305
        Typer (Click) Context within command.
306
    """
307
    # Compare options from config file (default_map) to function definition (params)
308
    for option in ctx.default_map:
3✔
309
        # Check options individually so can inform user of specific issue
310
        if option not in ctx.params:
3✔
311
            raise ValueError(f"'{option}' in configuration file is not a valid option")
3✔
312

313

314
def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]:
3✔
315
    """
316
    Parse CLI CorrelationKwargs to md correlation_kwargs.
317

318
    Parameters
319
    ----------
320
    kwargs
321
        CLI correlation keyword options.
322

323
    Returns
324
    -------
325
    list[dict]
326
        The parsed correlation_kwargs for md.
327
    """
328
    parsed_kwargs = []
3✔
329
    for name, cli_kwargs in kwargs.value.items():
3✔
330
        if "a" not in cli_kwargs and "b" not in cli_kwargs:
3✔
NEW
331
            raise ValueError("At least one observable must be supplied as 'a' or 'b'")
×
332

333
        # Accept on Observable to be replicated.
334
        if "b" not in cli_kwargs:
3✔
335
            a = cli_kwargs["a"]
3✔
336
            b = deepcopy(a)
3✔
337
        elif "a" not in cli_kwargs:
3✔
338
            a = cli_kwargs["b"]
3✔
339
            b = deepcopy(a)
3✔
340
        else:
NEW
341
            a = cli_kwargs["a"]
×
NEW
342
            b = cli_kwargs["b"]
×
343

344
        a_kwargs = cli_kwargs["a_kwargs"] if "a_kwargs" in cli_kwargs else {}
3✔
345
        b_kwargs = cli_kwargs["b_kwargs"] if "b_kwargs" in cli_kwargs else {}
3✔
346

347
        # Accept "." in place of one kwargs to repeat.
348
        if a_kwargs == "." and b_kwargs == ".":
3✔
NEW
349
            raise ValueError("a_kwargs and b_kwargs cannot 'ditto' each other")
×
350
        if a_kwargs and b_kwargs == ".":
3✔
351
            b_kwargs = a_kwargs
3✔
352
        elif b_kwargs and a_kwargs == ".":
3✔
NEW
353
            a_kwargs = b_kwargs
×
354

355
        cor_kwargs = {
3✔
356
            "name": name,
357
            "a": getattr(observables, a)(**a_kwargs),
358
            "b": getattr(observables, b)(**b_kwargs),
359
        }
360

361
        for optional in ["blocks", "points", "averaging", "update_frequency"]:
3✔
362
            if optional in cli_kwargs:
3✔
NEW
363
                cor_kwargs[optional] = cli_kwargs[optional]
×
364

365
        parsed_kwargs.append(cor_kwargs)
3✔
366
    return parsed_kwargs
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc