• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stfc / janus-core / 13588719633

28 Feb 2025 01:03PM UTC coverage: 92.639% (-0.2%) from 92.875%
13588719633

Pull #388

github

web-flow
Merge 6a65b7672 into 80882e842
Pull Request #388: Add Correlator CLI and defaults

27 of 37 new or added lines in 5 files covered. (72.97%)

8 existing lines in 2 files now uncovered.

2605 of 2812 relevant lines covered (92.64%)

2.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.79
/janus_core/cli/utils.py
1
"""Utility functions for CLI."""
2

3
from __future__ import annotations
3✔
4

5
from collections.abc import Sequence
3✔
6
from copy import deepcopy
3✔
7
import datetime
3✔
8
import logging
3✔
9
from pathlib import Path
3✔
10
from typing import TYPE_CHECKING, Any
3✔
11

12
from typer_config import conf_callback_factory, yaml_loader
3✔
13
import yaml
3✔
14

15
if TYPE_CHECKING:
3✔
16
    from ase import Atoms
×
17
    from typer import Context
×
18

NEW
19
    from janus_core.cli.types import CorrelationKwargs, TyperDict
×
20
    from janus_core.helpers.janus_types import (
×
21
        Architectures,
22
        ASEReadArgs,
23
        Devices,
24
        MaybeSequence,
25
    )
26

27
from janus_core.processing import observables
3✔
28

29

30
def dict_paths_to_strs(dictionary: dict) -> None:
3✔
31
    """
32
    Recursively iterate over dictionary, converting Path values to strings.
33

34
    Parameters
35
    ----------
36
    dictionary
37
        Dictionary to be converted.
38
    """
39
    for key, value in dictionary.items():
3✔
40
        if isinstance(value, dict):
3✔
41
            dict_paths_to_strs(value)
3✔
42
        elif isinstance(value, Path):
3✔
43
            dictionary[key] = str(value)
3✔
44

45

46
def dict_tuples_to_lists(dictionary: dict) -> None:
3✔
47
    """
48
    Recursively iterate over dictionary, converting tuple values to lists.
49

50
    Parameters
51
    ----------
52
    dictionary
53
        Dictionary to be converted.
54
    """
55
    for key, value in dictionary.items():
3✔
56
        if isinstance(value, dict):
3✔
57
            dict_paths_to_strs(value)
3✔
58
        elif isinstance(value, tuple):
3✔
59
            dictionary[key] = list(value)
3✔
60

61

62
def dict_remove_hyphens(dictionary: dict) -> dict:
3✔
63
    """
64
    Recursively iterate over dictionary, replacing hyphens with underscores in keys.
65

66
    Parameters
67
    ----------
68
    dictionary
69
        Dictionary to be converted.
70

71
    Returns
72
    -------
73
    dict
74
        Dictionary with hyphens in keys replaced with underscores.
75
    """
76
    for key, value in dictionary.items():
3✔
77
        if isinstance(value, dict):
3✔
78
            dictionary[key] = dict_remove_hyphens(value)
3✔
79
    return {k.replace("-", "_"): v for k, v in dictionary.items()}
3✔
80

81

82
def set_read_kwargs_index(read_kwargs: dict[str, Any]) -> None:
3✔
83
    """
84
    Set default read_kwargs["index"] to final image and check its value is an integer.
85

86
    To ensure only a single Atoms object is read, slices such as ":" are forbidden.
87

88
    Parameters
89
    ----------
90
    read_kwargs
91
        Keyword arguments to be passed to ase.io.read. If specified,
92
        read_kwargs["index"] must be an integer, and if not, a default value
93
        of -1 is set.
94
    """
95
    read_kwargs.setdefault("index", -1)
3✔
96
    try:
3✔
97
        int(read_kwargs["index"])
3✔
98
    except ValueError as e:
3✔
99
        raise ValueError("`read_kwargs['index']` must be an integer") from e
3✔
100

101

102
def parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
3✔
103
    """
104
    Convert list of TyperDict objects to list of dictionaries.
105

106
    Parameters
107
    ----------
108
    typer_dicts
109
        List of TyperDict objects to convert.
110

111
    Returns
112
    -------
113
    list[dict]
114
        List of converted dictionaries.
115

116
    Raises
117
    ------
118
    ValueError
119
        If items in list are not converted to dicts.
120
    """
121
    for i, typer_dict in enumerate(typer_dicts):
3✔
122
        typer_dicts[i] = typer_dict.value if typer_dict else {}
3✔
123
        if not isinstance(typer_dicts[i], dict):
3✔
UNCOV
124
            raise ValueError(
×
125
                f"""{typer_dicts[i]} must be passed as a dictionary wrapped in quotes.\
126
 For example, "{{'key': value}}" """
127
            )
128
    return typer_dicts
3✔
129

130

131
def yaml_converter_loader(config_file: str) -> dict[str, Any]:
3✔
132
    """
133
    Load yaml configuration and replace hyphens with underscores.
134

135
    Parameters
136
    ----------
137
    config_file
138
        Yaml configuration file to read.
139

140
    Returns
141
    -------
142
    dict[str, Any]
143
        Dictionary with loaded configuration.
144
    """
145
    if not config_file:
3✔
146
        return {}
3✔
147

148
    config = yaml_loader(config_file)
3✔
149
    # Replace all "-"" with "_" in conf
150
    return dict_remove_hyphens(config)
3✔
151

152

153
yaml_converter_callback = conf_callback_factory(yaml_converter_loader)
3✔
154

155

156
def start_summary(*, command: str, summary: Path, inputs: dict) -> None:
3✔
157
    """
158
    Write initial summary contents.
159

160
    Parameters
161
    ----------
162
    command
163
        Name of CLI command being used.
164
    summary
165
        Path to summary file being saved.
166
    inputs
167
        Inputs to CLI command to save.
168
    """
169
    save_info = {
3✔
170
        "command": f"janus {command}",
171
        "start_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S"),
172
        "inputs": inputs,
173
    }
174
    with open(summary, "w", encoding="utf8") as outfile:
3✔
175
        yaml.dump(save_info, outfile, default_flow_style=False)
3✔
176

177

178
def carbon_summary(*, summary: Path, log: Path) -> None:
3✔
179
    """
180
    Calculate and write carbon tracking summary.
181

182
    Parameters
183
    ----------
184
    summary
185
        Path to summary file being saved.
186
    log
187
        Path to log file with carbon emissions saved.
188
    """
189
    with open(log, encoding="utf8") as file:
3✔
190
        logs = yaml.safe_load(file)
3✔
191

192
    emissions = sum(
3✔
193
        lg["message"]["emissions"]
194
        for lg in logs
195
        if isinstance(lg["message"], dict) and "emissions" in lg["message"]
196
    )
197

198
    with open(summary, "a", encoding="utf8") as outfile:
3✔
199
        yaml.dump({"emissions": emissions}, outfile, default_flow_style=False)
3✔
200

201

202
def end_summary(summary: Path) -> None:
3✔
203
    """
204
    Write final time to summary and close.
205

206
    Parameters
207
    ----------
208
    summary
209
        Path to summary file being saved.
210
    """
211
    with open(summary, "a", encoding="utf8") as outfile:
3✔
212
        yaml.dump(
3✔
213
            {"end_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S")},
214
            outfile,
215
            default_flow_style=False,
216
        )
217
    logging.shutdown()
3✔
218

219

220
def save_struct_calc(
3✔
221
    *,
222
    inputs: dict,
223
    struct: MaybeSequence[Atoms],
224
    struct_path: Path,
225
    arch: Architectures,
226
    device: Devices,
227
    model_path: str,
228
    read_kwargs: ASEReadArgs,
229
    calc_kwargs: dict[str, Any],
230
    log: Path,
231
) -> None:
232
    """
233
    Add structure and calculator input information to a dictionary.
234

235
    Parameters
236
    ----------
237
    inputs
238
        Inputs dictionary to add information to.
239
    struct
240
        Structure to be simulated.
241
    struct_path
242
        Path of structure file.
243
    arch
244
        MLIP architecture.
245
    device
246
        Device to run calculations on.
247
    model_path
248
        Path to MLIP model.
249
    read_kwargs
250
        Keyword arguments to pass to ase.io.read.
251
    calc_kwargs
252
        Keyword arguments to pass to the calculator.
253
    log
254
        Path to log file.
255
    """
256
    from ase import Atoms
3✔
257

258
    # Clean up duplicate parameters
259
    for key in (
3✔
260
        "struct",
261
        "struct_path",
262
        "arch",
263
        "device",
264
        "model_path",
265
        "read_kwargs",
266
        "calc_kwargs",
267
        "log_kwargs",
268
    ):
269
        inputs.pop(key, None)
3✔
270

271
    if isinstance(struct, Atoms):
3✔
272
        inputs["struct"] = {
3✔
273
            "n_atoms": len(struct),
274
            "struct_path": struct_path,
275
            "formula": struct.get_chemical_formula(),
276
        }
277
    elif isinstance(struct, Sequence):
3✔
278
        inputs["traj"] = {
3✔
279
            "length": len(struct),
280
            "struct_path": struct_path,
281
            "struct": {
282
                "n_atoms": len(struct[0]),
283
                "formula": struct[0].get_chemical_formula(),
284
            },
285
        }
286

287
    inputs["calc"] = {
3✔
288
        "arch": arch,
289
        "device": device,
290
        "model_path": model_path,
291
        "read_kwargs": read_kwargs,
292
        "calc_kwargs": calc_kwargs,
293
    }
294

295
    inputs["log"] = log
3✔
296

297
    # Convert all paths to strings in inputs nested dictionary
298
    dict_paths_to_strs(inputs)
3✔
299

300

301
def check_config(ctx: Context) -> None:
3✔
302
    """
303
    Check options in configuration file are valid options for CLI command.
304

305
    Parameters
306
    ----------
307
    ctx
308
        Typer (Click) Context within command.
309
    """
310
    # Compare options from config file (default_map) to function definition (params)
311
    for option in ctx.default_map:
3✔
312
        # Check options individually so can inform user of specific issue
313
        if option not in ctx.params:
3✔
314
            raise ValueError(f"'{option}' in configuration file is not a valid option")
3✔
315

316

317
def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]:
3✔
318
    """
319
    Parse CLI CorrelationKwargs to md correlation_kwargs.
320

321
    Parameters
322
    ----------
323
    kwargs
324
        CLI correlation keyword options.
325

326
    Returns
327
    -------
328
    list[dict]
329
        The parsed correlation_kwargs for md.
330
    """
331
    parsed_kwargs = []
3✔
332
    for name, cli_kwargs in kwargs.value.items():
3✔
333
        if "a" not in cli_kwargs and "b" not in cli_kwargs:
3✔
NEW
334
            raise ValueError("At least on3 observable must be supplied as 'a' or 'b'")
×
335

336
        # Accept on Observable to be replicated.
337
        if "b" not in cli_kwargs:
3✔
338
            a = cli_kwargs["a"]
3✔
339
            b = deepcopy(a)
3✔
NEW
340
        elif "a" not in cli_kwargs:
×
NEW
341
            a = cli_kwargs["b"]
×
NEW
342
            b = deepcopy(a)
×
343
        else:
NEW
344
            a = cli_kwargs["a"]
×
NEW
345
            b = cli_kwargs["b"]
×
346

347
        a_kwargs = cli_kwargs["a_kwargs"] if "a_kwargs" in cli_kwargs else {}
3✔
348
        b_kwargs = cli_kwargs["b_kwargs"] if "b_kwargs" in cli_kwargs else {}
3✔
349

350
        # Accept "." in place of one kwargs to repeat.
351
        if a_kwargs == "." and b_kwargs == ".":
3✔
NEW
352
            raise ValueError("a_kwargs and b_kwargs cannot 'ditto' each other")
×
353
        if a_kwargs and b_kwargs == ".":
3✔
354
            b_kwargs = a_kwargs
3✔
355
        elif b_kwargs and a_kwargs == ".":
3✔
NEW
356
            a_kwargs = b_kwargs
×
357

358
        cor_kwargs = {
3✔
359
            "name": name,
360
            "a": getattr(observables, a)(**a_kwargs),
361
            "b": getattr(observables, b)(**b_kwargs),
362
        }
363

364
        for optional in ["blocks", "points", "averaging", "update_frequency"]:
3✔
365
            if optional in cli_kwargs:
3✔
NEW
366
                cor_kwargs[optional] = cli_kwargs[optional]
×
367

368
        parsed_kwargs.append(cor_kwargs)
3✔
369
    return parsed_kwargs
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc