• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stfc / aiida-mlip / 9096383187

15 May 2024 01:15PM UTC coverage: 77.778% (-15.7%) from 93.432%
9096383187

Pull #123

github

web-flow
Merge f5bfba8f9 into 18695f779
Pull Request #123: 121 add training

0 of 95 new or added lines in 2 files covered. (0.0%)

441 of 567 relevant lines covered (77.78%)

3.11 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/aiida_mlip/calculations/train.py
1
"""Class for training machine learning models."""
2

NEW
3
from pathlib import Path
×
4

NEW
5
from aiida.common import InputValidationError, datastructures
×
NEW
6
import aiida.common.folders
×
NEW
7
from aiida.engine import CalcJob, CalcJobProcessSpec
×
NEW
8
import aiida.engine.processes
×
NEW
9
from aiida.orm import Dict, FolderData, SinglefileData
×
10

NEW
11
from aiida_mlip.data.config import JanusConfigfile
×
NEW
12
from aiida_mlip.data.model import ModelData
×
13

14

NEW
15
def validate_inputs(
×
16
    inputs: dict, port_namespace: aiida.engine.processes.ports.PortNamespace
17
):
18
    """
19
    Check if the inputs are valid.
20

21
    Parameters
22
    ----------
23
    inputs : dict
24
        The inputs dictionary.
25

26
    port_namespace : `aiida.engine.processes.ports.PortNamespace`
27
        An instance of aiida's `PortNameSpace`.
28

29
    Raises
30
    ------
31
    InputValidationError
32
        Error message if validation fails, None otherwise.
33
    """
NEW
34
    if "mlip_config" in port_namespace:
×
NEW
35
        if "mlip_config" not in inputs:
×
NEW
36
            raise InputValidationError("No config file given")
×
NEW
37
        mlip_dict = inputs.mlip_config.as_dictionary()
×
NEW
38
        required_keys = ("train_file", "valid_file", "test_file", "name")
×
NEW
39
        for key in required_keys:
×
NEW
40
            if key not in mlip_dict:
×
NEW
41
                raise InputValidationError(f"Mandatory key {key} not in config file")
×
42
            # Check if the keys actually correspond to a path except name which is
43
            # just the name to use for the output files
NEW
44
            if key != "name":
×
NEW
45
                if not Path(key).exists():
×
NEW
46
                    raise InputValidationError(f"Path given for {key} does not exist")
×
47

48

NEW
49
class Train(CalcJob):  # numpydoc ignore=PR01
×
50
    """
51
    Calcjob implementation to train mlips.
52

53
    Attributes
54
    ----------
55
    DEFAULT_OUTPUT_FILE : str
56
        Default stdout file name.
57
    DEFAULT_INPUT_FILE : str
58
        Default input file name.
59
    LOG_FILE : str
60
        Default log file name.
61

62
    Methods
63
    -------
64
    define(spec: CalcJobProcessSpec) -> None:
65
        Define the process specification, its inputs, outputs and exit codes.
66
    validate_inputs(value: dict, port_namespace: PortNamespace) -> Optional[str]:
67
        Check if the inputs are valid.
68
    prepare_for_submission(folder: Folder) -> CalcInfo:
69
        Create the input files for the `CalcJob`.
70
    """
71

NEW
72
    DEFAULT_OUTPUT_FILE = "aiida-stdout.txt"
×
NEW
73
    DEFAULT_INPUT_FILE = "aiida.xyz"
×
NEW
74
    LOG_FILE = "aiida.log"
×
75

NEW
76
    @classmethod
×
NEW
77
    def define(cls, spec: CalcJobProcessSpec) -> None:
×
78
        """
79
        Define the process specification, its inputs, outputs and exit codes.
80

81
        Parameters
82
        ----------
83
        spec : `aiida.engine.CalcJobProcessSpec`
84
            The calculation job process spec to define.
85
        """
NEW
86
        super().define(spec)
×
NEW
87
        spec.inputs.validator = validate_inputs
×
88
        # Define inputs
NEW
89
        spec.input(
×
90
            "mlip_config",
91
            valid_type=JanusConfigfile,
92
            required=False,
93
            help="Mlip architecture to use for calculation, defaults to mace",
94
        )
NEW
95
        spec.output("model", valid_type=ModelData)
×
NEW
96
        spec.output("compiled_model", valid_type=SinglefileData)
×
NEW
97
        spec.output(
×
98
            "results_dict",
99
            valid_type=Dict,
100
            help="The `results_dict` output node of the successful calculation.",
101
        )
NEW
102
        spec.output("logs", valid_type=FolderData)
×
NEW
103
        spec.output("checkpoints", valid_type=FolderData)
×
NEW
104
        spec.default_output_node = "results_dict"
×
105
        # Exit codes
NEW
106
        spec.exit_code(
×
107
            305,
108
            "ERROR_MISSING_OUTPUT_FILES",
109
            message="Some output files missing or cannot be read",
110
        )
111

112
    # pylint: disable=too-many-locals
NEW
113
    def prepare_for_submission(
×
114
        self, folder: aiida.common.folders.Folder
115
    ) -> datastructures.CalcInfo:
116
        """
117
        Create the input files for the `Calcjob`.
118

119
        Parameters
120
        ----------
121
        folder : aiida.common.folders.Folder
122
            An `aiida.common.folders.Folder` to temporarily write files on disk.
123

124
        Returns
125
        -------
126
        aiida.common.datastructures.CalcInfo
127
            An instance of `aiida.common.datastructures.CalcInfo`.
128
        """
NEW
129
        cmd_line = {}
×
130

NEW
131
        cmd_line["mlip-config"] = "mlip_train.yaml"
×
NEW
132
        config_parse = self.inputs.config.get_content()
×
NEW
133
        mlip_dict = self.inputs.mlip_config.as_dictionary()
×
134
        # Copy config file content inside the folder where the calculation is run
NEW
135
        with folder.open("mlip_config.yaml", "w", encoding="utf-8") as configfile:
×
NEW
136
            configfile.write(config_parse)
×
137

NEW
138
        model_dir = Path(mlip_dict.get("model_dir", "."))
×
NEW
139
        model_output = model_dir / f"{mlip_dict['name']}.model"
×
NEW
140
        compiled_model_output = model_dir / f"{mlip_dict['name']}_compiled.model"
×
141

NEW
142
        codeinfo = datastructures.CodeInfo()
×
143

144
        # Initialize cmdline_params with train command
NEW
145
        codeinfo.cmdline_params = ["train"]
×
146

NEW
147
        for flag, value in cmd_line.items():
×
NEW
148
            codeinfo.cmdline_params += [f"--{flag}", str(value)]
×
149

150
        # Node where the code is saved
NEW
151
        codeinfo.code_uuid = self.inputs.code.uuid
×
152
        # Save name of output as you need it for running the code
NEW
153
        codeinfo.stdout_name = self.metadata.options.output_filename
×
154

NEW
155
        calcinfo = datastructures.CalcInfo()
×
NEW
156
        calcinfo.codes_info = [codeinfo]
×
157
        # Save the info about the node where the calc is stored
NEW
158
        calcinfo.uuid = str(self.uuid)
×
159
        # Retrieve output files
NEW
160
        calcinfo.retrieve_list = [
×
161
            self.metadata.options.output_filename,
162
            self.uuid,
163
            mlip_dict["log_dir"],
164
            mlip_dict["result_dir"],
165
            mlip_dict["checkpoint_dir"],
166
            model_output,
167
            compiled_model_output,
168
        ]
169

NEW
170
        return calcinfo
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc