• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stfc / janus-core / 13588549274

28 Feb 2025 12:53PM UTC coverage: 92.636% (-0.2%) from 92.875%
13588549274

Pull #388

github

web-flow
Merge 301b42cd1 into 80882e842
Pull Request #388: Add Correlator CLI and defaults

27 of 37 new or added lines in 5 files covered. (72.97%)

7 existing lines in 2 files now uncovered.

2604 of 2811 relevant lines covered (92.64%)

2.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.67
/janus_core/cli/utils.py
1
"""Utility functions for CLI."""
2

3
from __future__ import annotations
3✔
4

5
from collections.abc import Sequence
3✔
6
import datetime
3✔
7
import logging
3✔
8
from pathlib import Path
3✔
9
from typing import TYPE_CHECKING, Any
3✔
10

11
from typer_config import conf_callback_factory, yaml_loader
3✔
12
import yaml
3✔
13

14
if TYPE_CHECKING:
3✔
UNCOV
15
    from ase import Atoms
×
16
    from typer import Context
×
17

UNCOV
18
    from janus_core.cli.types import CorrelationKwargs, TyperDict
×
NEW
19
    from janus_core.helpers.janus_types import (
×
20
        Architectures,
21
        ASEReadArgs,
22
        Devices,
23
        MaybeSequence,
24
    )
25

26
from janus_core.processing import observables
3✔
27

28

29
def dict_paths_to_strs(dictionary: dict) -> None:
3✔
30
    """
31
    Recursively iterate over dictionary, converting Path values to strings.
32

33
    Parameters
34
    ----------
35
    dictionary
36
        Dictionary to be converted.
37
    """
38
    for key, value in dictionary.items():
3✔
39
        if isinstance(value, dict):
3✔
40
            dict_paths_to_strs(value)
3✔
41
        elif isinstance(value, Path):
3✔
42
            dictionary[key] = str(value)
3✔
43

44

45
def dict_tuples_to_lists(dictionary: dict) -> None:
3✔
46
    """
47
    Recursively iterate over dictionary, converting tuple values to lists.
48

49
    Parameters
50
    ----------
51
    dictionary
52
        Dictionary to be converted.
53
    """
54
    for key, value in dictionary.items():
3✔
55
        if isinstance(value, dict):
3✔
56
            dict_paths_to_strs(value)
3✔
57
        elif isinstance(value, tuple):
3✔
58
            dictionary[key] = list(value)
3✔
59

60

61
def dict_remove_hyphens(dictionary: dict) -> dict:
3✔
62
    """
63
    Recursively iterate over dictionary, replacing hyphens with underscores in keys.
64

65
    Parameters
66
    ----------
67
    dictionary
68
        Dictionary to be converted.
69

70
    Returns
71
    -------
72
    dict
73
        Dictionary with hyphens in keys replaced with underscores.
74
    """
75
    for key, value in dictionary.items():
3✔
76
        if isinstance(value, dict):
3✔
77
            dictionary[key] = dict_remove_hyphens(value)
3✔
78
    return {k.replace("-", "_"): v for k, v in dictionary.items()}
3✔
79

80

81
def set_read_kwargs_index(read_kwargs: dict[str, Any]) -> None:
3✔
82
    """
83
    Set default read_kwargs["index"] to final image and check its value is an integer.
84

85
    To ensure only a single Atoms object is read, slices such as ":" are forbidden.
86

87
    Parameters
88
    ----------
89
    read_kwargs
90
        Keyword arguments to be passed to ase.io.read. If specified,
91
        read_kwargs["index"] must be an integer, and if not, a default value
92
        of -1 is set.
93
    """
94
    read_kwargs.setdefault("index", -1)
3✔
95
    try:
3✔
96
        int(read_kwargs["index"])
3✔
97
    except ValueError as e:
3✔
98
        raise ValueError("`read_kwargs['index']` must be an integer") from e
3✔
99

100

101
def parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
3✔
102
    """
103
    Convert list of TyperDict objects to list of dictionaries.
104

105
    Parameters
106
    ----------
107
    typer_dicts
108
        List of TyperDict objects to convert.
109

110
    Returns
111
    -------
112
    list[dict]
113
        List of converted dictionaries.
114

115
    Raises
116
    ------
117
    ValueError
118
        If items in list are not converted to dicts.
119
    """
120
    for i, typer_dict in enumerate(typer_dicts):
3✔
121
        typer_dicts[i] = typer_dict.value if typer_dict else {}
3✔
122
        if not isinstance(typer_dicts[i], dict):
3✔
UNCOV
123
            raise ValueError(
×
124
                f"""{typer_dicts[i]} must be passed as a dictionary wrapped in quotes.\
125
 For example, "{{'key': value}}" """
126
            )
127
    return typer_dicts
3✔
128

129

130
def yaml_converter_loader(config_file: str) -> dict[str, Any]:
3✔
131
    """
132
    Load yaml configuration and replace hyphens with underscores.
133

134
    Parameters
135
    ----------
136
    config_file
137
        Yaml configuration file to read.
138

139
    Returns
140
    -------
141
    dict[str, Any]
142
        Dictionary with loaded configuration.
143
    """
144
    if not config_file:
3✔
145
        return {}
3✔
146

147
    config = yaml_loader(config_file)
3✔
148
    # Replace all "-"" with "_" in conf
149
    return dict_remove_hyphens(config)
3✔
150

151

152
yaml_converter_callback = conf_callback_factory(yaml_converter_loader)
3✔
153

154

155
def start_summary(*, command: str, summary: Path, inputs: dict) -> None:
3✔
156
    """
157
    Write initial summary contents.
158

159
    Parameters
160
    ----------
161
    command
162
        Name of CLI command being used.
163
    summary
164
        Path to summary file being saved.
165
    inputs
166
        Inputs to CLI command to save.
167
    """
168
    save_info = {
3✔
169
        "command": f"janus {command}",
170
        "start_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S"),
171
        "inputs": inputs,
172
    }
173
    with open(summary, "w", encoding="utf8") as outfile:
3✔
174
        yaml.dump(save_info, outfile, default_flow_style=False)
3✔
175

176

177
def carbon_summary(*, summary: Path, log: Path) -> None:
3✔
178
    """
179
    Calculate and write carbon tracking summary.
180

181
    Parameters
182
    ----------
183
    summary
184
        Path to summary file being saved.
185
    log
186
        Path to log file with carbon emissions saved.
187
    """
188
    with open(log, encoding="utf8") as file:
3✔
189
        logs = yaml.safe_load(file)
3✔
190

191
    emissions = sum(
3✔
192
        lg["message"]["emissions"]
193
        for lg in logs
194
        if isinstance(lg["message"], dict) and "emissions" in lg["message"]
195
    )
196

197
    with open(summary, "a", encoding="utf8") as outfile:
3✔
198
        yaml.dump({"emissions": emissions}, outfile, default_flow_style=False)
3✔
199

200

201
def end_summary(summary: Path) -> None:
3✔
202
    """
203
    Write final time to summary and close.
204

205
    Parameters
206
    ----------
207
    summary
208
        Path to summary file being saved.
209
    """
210
    with open(summary, "a", encoding="utf8") as outfile:
3✔
211
        yaml.dump(
3✔
212
            {"end_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S")},
213
            outfile,
214
            default_flow_style=False,
215
        )
216
    logging.shutdown()
3✔
217

218

219
def save_struct_calc(
3✔
220
    *,
221
    inputs: dict,
222
    struct: MaybeSequence[Atoms],
223
    struct_path: Path,
224
    arch: Architectures,
225
    device: Devices,
226
    model_path: str,
227
    read_kwargs: ASEReadArgs,
228
    calc_kwargs: dict[str, Any],
229
    log: Path,
230
) -> None:
231
    """
232
    Add structure and calculator input information to a dictionary.
233

234
    Parameters
235
    ----------
236
    inputs
237
        Inputs dictionary to add information to.
238
    struct
239
        Structure to be simulated.
240
    struct_path
241
        Path of structure file.
242
    arch
243
        MLIP architecture.
244
    device
245
        Device to run calculations on.
246
    model_path
247
        Path to MLIP model.
248
    read_kwargs
249
        Keyword arguments to pass to ase.io.read.
250
    calc_kwargs
251
        Keyword arguments to pass to the calculator.
252
    log
253
        Path to log file.
254
    """
255
    from ase import Atoms
3✔
256

257
    # Clean up duplicate parameters
258
    for key in (
3✔
259
        "struct",
260
        "struct_path",
261
        "arch",
262
        "device",
263
        "model_path",
264
        "read_kwargs",
265
        "calc_kwargs",
266
        "log_kwargs",
267
    ):
268
        inputs.pop(key, None)
3✔
269

270
    if isinstance(struct, Atoms):
3✔
271
        inputs["struct"] = {
3✔
272
            "n_atoms": len(struct),
273
            "struct_path": struct_path,
274
            "formula": struct.get_chemical_formula(),
275
        }
276
    elif isinstance(struct, Sequence):
3✔
277
        inputs["traj"] = {
3✔
278
            "length": len(struct),
279
            "struct_path": struct_path,
280
            "struct": {
281
                "n_atoms": len(struct[0]),
282
                "formula": struct[0].get_chemical_formula(),
283
            },
284
        }
285

286
    inputs["calc"] = {
3✔
287
        "arch": arch,
288
        "device": device,
289
        "model_path": model_path,
290
        "read_kwargs": read_kwargs,
291
        "calc_kwargs": calc_kwargs,
292
    }
293

294
    inputs["log"] = log
3✔
295

296
    # Convert all paths to strings in inputs nested dictionary
297
    dict_paths_to_strs(inputs)
3✔
298

299

300
def check_config(ctx: Context) -> None:
3✔
301
    """
302
    Check options in configuration file are valid options for CLI command.
303

304
    Parameters
305
    ----------
306
    ctx
307
        Typer (Click) Context within command.
308
    """
309
    # Compare options from config file (default_map) to function definition (params)
310
    for option in ctx.default_map:
3✔
311
        # Check options individually so can inform user of specific issue
312
        if option not in ctx.params:
3✔
313
            raise ValueError(f"'{option}' in configuration file is not a valid option")
3✔
314

315

316
def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]:
3✔
317
    """
318
    Parse CLI CorrelationKwargs to md correlation_kwargs.
319

320
    Parameters
321
    ----------
322
    kwargs
323
        CLI correlation keyword options.
324

325
    Returns
326
    -------
327
    list[dict]
328
        The parsed correlation_kwargs for md.
329
    """
330
    parsed_kwargs = []
3✔
331
    for name, cli_kwargs in kwargs.value.items():
3✔
332
        if "a" not in cli_kwargs and "b" not in cli_kwargs:
3✔
NEW
333
            raise ValueError("At least on3 observable must be supplied as 'a' or 'b'")
×
334

335
        # Accept on Observable to be replicated.
336
        if "b" not in cli_kwargs:
3✔
337
            a = cli_kwargs["a"]
3✔
338
            b = a
3✔
NEW
339
        elif "a" not in cli_kwargs:
×
NEW
340
            a = cli_kwargs["b"]
×
NEW
341
            b = a
×
342
        else:
NEW
343
            a = cli_kwargs["a"]
×
NEW
344
            b = cli_kwargs["b"]
×
345

346
        a_kwargs = cli_kwargs["a_kwargs"] if "a_kwargs" in cli_kwargs else {}
3✔
347
        b_kwargs = cli_kwargs["b_kwargs"] if "b_kwargs" in cli_kwargs else {}
3✔
348

349
        # Accept "." in place of one kwargs to repeat.
350
        if a_kwargs == "." and b_kwargs == ".":
3✔
NEW
351
            raise ValueError("a_kwargs and b_kwargs cannot 'ditto' each other")
×
352
        if a_kwargs and b_kwargs == ".":
3✔
353
            b_kwargs = a_kwargs
3✔
354
        elif b_kwargs and a_kwargs == ".":
3✔
NEW
355
            a_kwargs = b_kwargs
×
356

357
        cor_kwargs = {
3✔
358
            "name": name,
359
            "a": getattr(observables, a)(**a_kwargs),
360
            "b": getattr(observables, b)(**b_kwargs),
361
        }
362

363
        for optional in ["blocks", "points", "averaging", "update_frequency"]:
3✔
364
            if optional in cli_kwargs:
3✔
NEW
365
                cor_kwargs[optional] = cli_kwargs[optional]
×
366

367
        parsed_kwargs.append(cor_kwargs)
3✔
368
    return parsed_kwargs
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc