• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stfc / janus-core / 20097762678

10 Dec 2025 11:56AM UTC coverage: 92.179% (-0.1%) from 92.315%
20097762678

push

github

web-flow
Remove deprecated model_path parameter (#640)

1 of 2 new or added lines in 2 files covered. (50.0%)

3 existing lines in 1 file now uncovered.

2911 of 3158 relevant lines covered (92.18%)

2.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.02
/janus_core/cli/utils.py
1
"""Utility functions for CLI."""
2

3
from __future__ import annotations
3✔
4

5
from collections.abc import Sequence
3✔
6
from copy import deepcopy
3✔
7
import datetime
3✔
8
import logging
3✔
9
from pathlib import Path
3✔
10
from typing import TYPE_CHECKING, Any
3✔
11

12
from typer import CallbackParam, secho
3✔
13
from typer.colors import YELLOW
3✔
14
from typer_config import conf_callback_factory, yaml_loader
3✔
15
import yaml
3✔
16

17
from janus_core.helpers.utils import build_file_dir
3✔
18

19
if TYPE_CHECKING:
20
    from ase import Atoms
21
    from typer import Context
22

23
    from janus_core.cli.types import CorrelationKwargs, TyperDict
24
    from janus_core.helpers.janus_types import (
25
        MaybeSequence,
26
        PathLike,
27
    )
28

29

30
def dict_paths_to_strs(dictionary: dict) -> None:
3✔
31
    """
32
    Recursively iterate over dictionary, converting Path values to strings.
33

34
    Parameters
35
    ----------
36
    dictionary
37
        Dictionary to be converted.
38
    """
39
    for key, value in dictionary.items():
3✔
40
        if isinstance(value, dict):
3✔
41
            dict_paths_to_strs(value)
3✔
42
        elif isinstance(value, Sequence) and not isinstance(value, str):
3✔
43
            dictionary[key] = [
3✔
44
                str(path) if isinstance(path, Path) else path for path in value
45
            ]
46
        elif isinstance(value, Path):
3✔
47
            dictionary[key] = str(value)
3✔
48

49

50
def dict_tuples_to_lists(dictionary: dict) -> None:
3✔
51
    """
52
    Recursively iterate over dictionary, converting tuple values to lists.
53

54
    Parameters
55
    ----------
56
    dictionary
57
        Dictionary to be converted.
58
    """
59
    for key, value in dictionary.items():
3✔
60
        if isinstance(value, dict):
3✔
61
            dict_tuples_to_lists(value)
3✔
62
        elif isinstance(value, tuple):
3✔
63
            dictionary[key] = list(value)
×
64
        elif isinstance(value, list):
3✔
65
            dictionary[key] = [list(x) if isinstance(x, tuple) else x for x in value]
3✔
66

67

68
def dict_remove_hyphens(dictionary: dict) -> dict:
3✔
69
    """
70
    Recursively iterate over dictionary, replacing hyphens with underscores in keys.
71

72
    Parameters
73
    ----------
74
    dictionary
75
        Dictionary to be converted.
76

77
    Returns
78
    -------
79
    dict
80
        Dictionary with hyphens in keys replaced with underscores.
81
    """
82
    for key, value in dictionary.items():
3✔
83
        if isinstance(value, dict):
3✔
84
            dictionary[key] = dict_remove_hyphens(value)
3✔
85
    return {k.replace("-", "_"): v for k, v in dictionary.items()}
3✔
86

87

88
def set_read_kwargs_index(read_kwargs: dict[str, Any]) -> None:
3✔
89
    """
90
    Set default read_kwargs["index"] to final image and check its value is an integer.
91

92
    To ensure only a single Atoms object is read, slices such as ":" are forbidden.
93

94
    Parameters
95
    ----------
96
    read_kwargs
97
        Keyword arguments to be passed to ase.io.read. If specified,
98
        read_kwargs["index"] must be an integer, and if not, a default value
99
        of -1 is set.
100
    """
101
    read_kwargs.setdefault("index", -1)
3✔
102
    try:
3✔
103
        int(read_kwargs["index"])
3✔
104
    except ValueError as e:
3✔
105
        raise ValueError("`read_kwargs['index']` must be an integer") from e
3✔
106

107

108
def parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
3✔
109
    """
110
    Convert list of TyperDict objects to list of dictionaries.
111

112
    Parameters
113
    ----------
114
    typer_dicts
115
        List of TyperDict objects to convert.
116

117
    Returns
118
    -------
119
    list[dict]
120
        List of converted dictionaries.
121

122
    Raises
123
    ------
124
    ValueError
125
        If items in list are not converted to dicts.
126
    """
127
    for i, typer_dict in enumerate(typer_dicts):
3✔
128
        typer_dicts[i] = typer_dict.value if typer_dict else {}
3✔
129
        if not isinstance(typer_dicts[i], dict):
3✔
130
            raise ValueError(
×
131
                f"""{typer_dicts[i]} must be passed as a dictionary wrapped in quotes.\
132
 For example, "{{'key': value}}" """
133
            )
134
    return typer_dicts
3✔
135

136

137
def yaml_converter_loader(config_file: str) -> dict[str, Any]:
3✔
138
    """
139
    Load yaml configuration, replace hyphens with underscores, and swap filter keyword.
140

141
    Parameters
142
    ----------
143
    config_file
144
        Yaml configuration file to read.
145

146
    Returns
147
    -------
148
    dict[str, Any]
149
        Dictionary with loaded configuration.
150
    """
151
    if not config_file:
3✔
152
        return {}
3✔
153

154
    config = yaml_loader(config_file)
3✔
155

156
    # Rename filter
157
    if "filter" in config:
3✔
158
        config["filter_class"] = config.pop("filter")
×
159

160
    # Replace all "-"" with "_" in conf
161
    return dict_remove_hyphens(config)
3✔
162

163

164
yaml_converter_callback = conf_callback_factory(yaml_converter_loader)
3✔
165

166

167
def start_summary(
3✔
168
    *,
169
    command: str,
170
    summary: Path,
171
    config: dict[str, Any],
172
    info: dict[str, Any],
173
    output_files: dict[str, PathLike],
174
) -> None:
175
    """
176
    Write initial summary contents.
177

178
    Parameters
179
    ----------
180
    command
181
        Name of CLI command being used.
182
    summary
183
        Path to summary file being saved.
184
    config
185
        Inputs to CLI command to save.
186
    info
187
        Extra information to save.
188
    output_files
189
        Output files with labels to be generated by CLI command.
190
    """
191
    config.pop("config", None)
3✔
192
    output_files["summary"] = summary.absolute()
3✔
193

194
    summary_contents = {
3✔
195
        "command": f"janus {command}",
196
        "start_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S"),
197
        "config": config,
198
        "info": info,
199
        "output_files": output_files,
200
    }
201

202
    # Convert all paths to strings in inputs nested dictionary
203
    dict_paths_to_strs(summary_contents)
3✔
204
    dict_tuples_to_lists(summary_contents)
3✔
205

206
    build_file_dir(summary)
3✔
207
    with open(summary, "w", encoding="utf8") as outfile:
3✔
208
        yaml.dump(summary_contents, outfile, default_flow_style=False)
3✔
209

210

211
def carbon_summary(*, summary: Path, log: Path) -> None:
3✔
212
    """
213
    Calculate and write carbon tracking summary.
214

215
    Parameters
216
    ----------
217
    summary
218
        Path to summary file being saved.
219
    log
220
        Path to log file with carbon emissions saved.
221
    """
222
    with open(log, encoding="utf8") as file:
3✔
223
        logs = yaml.safe_load(file)
3✔
224

225
    emissions = sum(
3✔
226
        lg["message"]["emissions"]
227
        for lg in logs
228
        if isinstance(lg["message"], dict) and "emissions" in lg["message"]
229
    )
230

231
    with open(summary, "a", encoding="utf8") as outfile:
3✔
232
        yaml.dump({"emissions": emissions}, outfile, default_flow_style=False)
3✔
233

234

235
def end_summary(summary: Path) -> None:
3✔
236
    """
237
    Write final time to summary and close.
238

239
    Parameters
240
    ----------
241
    summary
242
        Path to summary file being saved.
243
    """
244
    with open(summary, "a", encoding="utf8") as outfile:
3✔
245
        yaml.dump(
3✔
246
            {"end_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S")},
247
            outfile,
248
            default_flow_style=False,
249
        )
250
    logging.shutdown()
3✔
251

252

253
def get_struct_info(
3✔
254
    *,
255
    struct: MaybeSequence[Atoms],
256
    struct_path: Path,
257
) -> dict[str, Any]:
258
    """
259
    Add structure information to a dictionary.
260

261
    Parameters
262
    ----------
263
    struct
264
        Structure to be simulated.
265
    struct_path
266
        Path of structure file.
267

268
    Returns
269
    -------
270
    dict[str, Any]
271
        Dictionary with structure information.
272
    """
273
    from ase import Atoms
3✔
274

275
    info = {}
3✔
276

277
    if isinstance(struct, Atoms):
3✔
278
        info["struct"] = {
3✔
279
            "n_atoms": len(struct),
280
            "struct_path": struct_path,
281
            "formula": struct.get_chemical_formula(),
282
        }
283
    elif isinstance(struct, Sequence):
3✔
284
        info["traj"] = {
3✔
285
            "length": len(struct),
286
            "struct_path": struct_path,
287
            "struct": {
288
                "n_atoms": len(struct[0]),
289
                "formula": struct[0].get_chemical_formula(),
290
            },
291
        }
292

293
    return info
3✔
294

295

296
def get_config(*, params: dict[str, Any], all_kwargs: dict[str, Any]) -> dict[str, Any]:
3✔
297
    """
298
    Get configuration and set kwargs dictionaries.
299

300
    Parameters
301
    ----------
302
    params
303
        CLI input parameters from ctx.
304
    all_kwargs
305
        Name and contents of all kwargs dictionaries.
306

307
    Returns
308
    -------
309
    dict[str, Any]
310
        Input parameters with parsed kwargs dictionaries substituted in.
311
    """
312
    for param in params:
3✔
313
        if param in all_kwargs:
3✔
314
            params[param] = all_kwargs[param]
3✔
315

316
    return params
3✔
317

318

319
def check_config(ctx: Context) -> None:
3✔
320
    """
321
    Check options in configuration file are valid options for CLI command.
322

323
    Parameters
324
    ----------
325
    ctx
326
        Typer (Click) Context within command.
327
    """
328
    # Compare options from config file (default_map) to function definition (params)
329
    for option in ctx.default_map:
3✔
330
        # Check options individually so can inform user of specific issue
331
        if option not in ctx.params:
3✔
332
            raise ValueError(f"'{option}' in configuration file is not a valid option")
3✔
333

334

335
def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]:
3✔
336
    """
337
    Parse CLI CorrelationKwargs to md correlation_kwargs.
338

339
    Parameters
340
    ----------
341
    kwargs
342
        CLI correlation keyword options.
343

344
    Returns
345
    -------
346
    list[dict]
347
        The parsed correlation_kwargs for md.
348
    """
349
    from janus_core.processing import observables
3✔
350

351
    parsed_kwargs = []
3✔
352
    for name, cli_kwargs in kwargs.items():
3✔
353
        arguments = {
3✔
354
            "blocks",
355
            "points",
356
            "averaging",
357
            "update_frequency",
358
            "a_kwargs",
359
            "b_kwargs",
360
            "a",
361
            "b",
362
        }
363
        if not (set(cli_kwargs.keys()) <= arguments):
3✔
364
            raise ValueError(
3✔
365
                "correlation_kwargs got unexpected argument(s)"
366
                f"{set(cli_kwargs.keys()).difference(arguments)}"
367
            )
368

369
        if "a" not in cli_kwargs and "b" not in cli_kwargs:
3✔
370
            raise ValueError("At least one observable must be supplied as 'a' or 'b'")
3✔
371

372
        if "points" not in cli_kwargs:
3✔
373
            raise ValueError("Correlation keyword argument 'points' must be specified")
3✔
374

375
        # Accept an Observable to be replicated.
376
        if "b" not in cli_kwargs:
3✔
377
            a = cli_kwargs["a"]
3✔
378
            b = deepcopy(a)
3✔
379
            # Copying Observable, so can copy kwargs as well.
380
            if "b_kwargs" not in cli_kwargs and "a_kwargs" in cli_kwargs:
3✔
381
                cli_kwargs["b_kwargs"] = cli_kwargs["a_kwargs"]
3✔
382
        elif "a" not in cli_kwargs:
3✔
383
            b = cli_kwargs["b"]
3✔
384
            a = deepcopy(b)
3✔
385
            if "a_kwargs" not in cli_kwargs and "b_kwargs" in cli_kwargs:
3✔
386
                cli_kwargs["a_kwargs"] = cli_kwargs["b_kwargs"]
×
387
        else:
388
            a = cli_kwargs["a"]
3✔
389
            b = cli_kwargs["b"]
3✔
390

391
        a_kwargs = cli_kwargs["a_kwargs"] if "a_kwargs" in cli_kwargs else {}
3✔
392
        b_kwargs = cli_kwargs["b_kwargs"] if "b_kwargs" in cli_kwargs else {}
3✔
393

394
        # Accept "." in place of one kwargs to repeat.
395
        if a_kwargs == "." and b_kwargs == ".":
3✔
396
            raise ValueError("a_kwargs and b_kwargs cannot 'ditto' each other")
3✔
397
        if a_kwargs and b_kwargs == ".":
3✔
398
            b_kwargs = a_kwargs
3✔
399
        elif b_kwargs and a_kwargs == ".":
3✔
400
            a_kwargs = b_kwargs
×
401

402
        cor_kwargs = {
3✔
403
            "name": name,
404
            "points": cli_kwargs["points"],
405
            "a": getattr(observables, a)(**a_kwargs),
406
            "b": getattr(observables, b)(**b_kwargs),
407
        }
408

409
        for optional in cli_kwargs.keys() & {"blocks", "averaging", "update_frequency"}:
3✔
410
            cor_kwargs[optional] = cli_kwargs[optional]
×
411

412
        parsed_kwargs.append(cor_kwargs)
3✔
413
    return parsed_kwargs
3✔
414

415

416
# Callback to print warning for deprecated options
417
def deprecated_option(param: CallbackParam, value: Any):
3✔
418
    """
419
    Print warning for deprecated option.
420

421
    Parameters
422
    ----------
423
    param
424
        Callback parameters from typer.
425
    value
426
        Value of parameter.
427

428
    Returns
429
    -------
430
    Any
431
        Unmodified parameter value.
432
    """
UNCOV
433
    if value:
×
UNCOV
434
        secho(f"Warning: --{param.name} is deprecated.", fg=YELLOW)
×
UNCOV
435
    return value
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc