• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

p-ortega / mf6rtm / 25156475946

25 Mar 2026 07:32AM UTC coverage: 83.028%. Remained the same
25156475946

push

github

web-flow
Merge pull request #57 from p-ortega/develop

Develop - migrated to pixi,  removed unused local deps, improves cmd run

19 of 35 new or added lines in 4 files covered. (54.29%)

46 existing lines in 2 files now uncovered.

1448 of 1744 relevant lines covered (83.03%)

0.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.91
/mf6rtm/simulation/solver.py
1
"""The solver module provides the Mf6RTM class that couples modflowapi and
2
phreeqcrm, along with functions to run the coupled simulations.
3
"""
4
import os
1✔
5
# import warnings
6
import numpy as np
1✔
7

8
from datetime import datetime
1✔
9
from typing import Any, Union, Optional
1✔
10
from pathlib import Path
1✔
11

12
from PIL import Image
1✔
13
from mf6rtm.simulation.mf6api import Mf6API
1✔
14
from mf6rtm.simulation.phreeqcbmi import PhreeqcBMI
1✔
15
from mf6rtm.simulation.discretization import total_cells_in_grid
1✔
16
from mf6rtm.config.config import MF6RTMConfig
1✔
17
from mf6rtm.io.externalio import SelectedOutput
1✔
18
from mf6rtm.utils import utils
1✔
19

20
# warnings.filterwarnings("ignore")
21
# warnings.filterwarnings("ignore", category=DeprecationWarning)
22

23
# global variables
24
DT_FMT = "%Y-%m-%d %H:%M:%S"
1✔
25

26
time_units_dict = {
1✔
27
    "seconds": 1,
28
    "minutes": 60,
29
    "hours": 3600,
30
    "days": 86400,
31
    "years": 31536000,
32
    "unknown": 1,  # if unknown assume seconds
33
}
34

35
def check_config_file(wd: os.PathLike) -> tuple[os.PathLike, os.PathLike]:
1✔
36
    assert os.path.exists(
1✔
37
        os.path.join(wd, "mf6rtm.toml")
38
        ), "mf6rtm.toml not found in model directory"
39
    config_file= os.path.join(wd, "mf6rtm.toml")
1✔
40
    config = MF6RTMConfig.from_toml_file(config_file)
1✔
41

42
    # validate config values like timing and tsteps
43
    config._validate_config()
1✔
44
    return config
1✔
45

46
def check_nam_files(wd:os.PathLike) -> tuple[os.PathLike,os.PathLike]:
1✔
47
    """Check if the nam files are present in the model directory"""
48
    nam = [f for f in os.listdir(wd) if f.endswith(".nam")]
1✔
49
    assert "mfsim.nam" in nam, "mfsim.nam file not found in model directory"
1✔
50
    # assert "gwf.nam" in nam, "gwf.nam file not found in model directory"
51
    return os.path.join(wd, "mfsim.nam") #, os.path.join(wd, "gwf.nam")
1✔
52

53
def prep_to_run(wd:os.PathLike, libname: Path | None = None) -> tuple[os.PathLike,os.PathLike]:
1✔
54
    """
55
    Prepares the model to run by checking if the model directory (wd) contains the necessary files
56
    and returns the path to the yaml file (phreeqcrm) and the dll file (mf6 api)
57

58
    Parameters
59
    ----------
60
    wd :os.PathLike
61
        The path to the working directory of model directory
62
    Returns
63
    -------
64
    tuple[PathLike,os.PathLike]
65
        The path to the phreeqcrm model file (yaml) and the path to the MODFLOW 6 dll (associated with mf6api).
66
    """
67
    # check if wd exists
68
    assert os.path.exists(wd), f"Path {wd} not found"
1✔
69

70
    # check if file starting with libmf6 exists
71
    dll_files = [f for f in os.listdir(wd) if f.startswith("libmf6")]
1✔
72
    if len(dll_files) == 0:
1✔
73
        # no libmf6 in directory
74
        print("Libname", libname)
1✔
75
        if libname is None:
1✔
76
            # fallback to system PATH
NEW
77
            print("libmf6 not found in model directory, assuming it is available in PATH/env")
×
NEW
78
            dll = "libmf6"
×
79
        else:
80
            # use provided Path or string
81
            lib_path = Path(libname)
1✔
82
            print("Using provided libmf6 path:", lib_path)
1✔
83
            if lib_path.exists():
1✔
84
                dll = str(lib_path)
1✔
85
            else:
NEW
86
                raise FileNotFoundError(f"Provided libmf6 path does not exist: {libname}")
×
NEW
87
    elif len(dll_files) == 1:
×
NEW
88
        print(f"Using libmf6 found in model directory: {dll_files[0]}")
×
NEW
89
        dll = str(Path(wd) / dll_files[0])
×
90
    else:
91
        # multiple DLLs found
NEW
92
        raise AssertionError(
×
93
            f"Multiple libmf6 files found in model directory: {dll_files}. "
94
            "Please keep only one version."
95
        )
96

97
    config = check_config_file(wd)
1✔
98
    check_nam_files(wd)
1✔
99
    if config.reactive['externalio']:
1✔
100
        from mf6rtm.io.externalio import Regenerator
1✔
101
        print("WARNING: Flag for external IO mode is active")
1✔
102
        regcls = Regenerator.regenerate_from_external_files(wd=wd,
1✔
103
                                                phinpfile='phinp.dat',
104
                                                yamlfile='mf6rtm.yaml',
105
                                                dllfile=dll
106
                                                )
107
        yamlfile = regcls.yamlfile
1✔
108
    else:
109
        yamlfile = os.path.join(wd, "mf6rtm.yaml")
1✔
110
    assert os.path.exists(
1✔
111
        os.path.join(wd, yamlfile)
112
    ), f"{yamlfile} not found in model directory {wd}"
113
    return yamlfile, dll
1✔
114

115
def solve(wd:os.PathLike, reactive: Union[bool, None] = None, nthread: int = 1, libname: str = None) -> bool:
1✔
116
    """Wrapper to prepare and call solve functions"""
117

118
    mf6rtm = initialize_interfaces(wd, nthread=nthread, libname=libname)
1✔
119
    if reactive is not None and isinstance(reactive, bool) and reactive != mf6rtm.reactive:
1✔
120
        print(
1✔
121
                f"Mode changed from "
122
                f"{'reactive' if mf6rtm.reactive else 'non-reactive'} to "
123
                f"{'reactive' if reactive else 'non-reactive'}\n"
124
            )
125
        mf6rtm._set_reactive(reactive)
1✔
126
        # let mf6 manage this for conservative runs
127
        mf6rtm.selected_output.get_selected_output_on = False
1✔
128
    mf6rtm.print_warning_user_active()
1✔
129
    success = mf6rtm._solve()
1✔
130
    return success
1✔
131

132

133
# TODO: we should maybe move this into the Mf6API as an alternative constructor
134
def initialize_interfaces(wd:os.PathLike, nthread: int = 1, libname: str = None) -> Mf6API:
1✔
135
    """Function to initialize the interfaces for modflowapi and phreeqcrm and returns the mf6rtm object"""
136

137
    yamlfile, dll = prep_to_run(wd, libname=libname)
1✔
138

139
    if nthread > 1:
1✔
140
        # set nthreds to nthread
UNCOV
141
        set_nthread_yaml(yamlfile, nthread=nthread)
×
142

143
    # initialize the interfaces
144
    mf6api = Mf6API(wd, dll)
1✔
145
    phreeqcrm = PhreeqcBMI(yamlfile) #FIXME: Does not work with path like
1✔
146
    mf6rtm = Mf6RTM(wd, mf6api, phreeqcrm)
1✔
147
    return mf6rtm
1✔
148

149

150
def set_nthread_yaml(yamlfile:os.PathLike, nthread: int = 1) -> None:
1✔
151
    """Function to set the number of threads in the yaml file"""
UNCOV
152
    with open(yamlfile, "r") as f:
×
UNCOV
153
        lines = f.readlines()
×
UNCOV
154
    for i, line in enumerate(lines):
×
UNCOV
155
        if "nthreads" in line:
×
UNCOV
156
            lines[i] = f"  nthreads: {nthread}\n"
×
UNCOV
157
    with open(yamlfile, "w") as f:
×
UNCOV
158
        f.writelines(lines)
×
UNCOV
159
    return
×
160

161

162
class Mf6RTM(object):
1✔
163
    def __init__(
1✔
164
        self,
165
        wd:os.PathLike,
166
        mf6api: Mf6API,
167
        phreeqcbmi: PhreeqcBMI,
168
    ) -> None:
169
        """
170
        Initialize the Mf6RTM instance with specified working directory, MF6API,
171
        and PhreeqcBMI instances.
172

173
        Parameters
174
        ----------
175
        wd :os.PathLike
176
            The working directory path for the model.
177
        mf6api : Mf6API
178
            An instance of the Mf6API class, representing the Modflow 6 API.
179
        phreeqcbmi : PhreeqcBMI
180
            An instance of the PhreeqcBMI class, representing the PHREEQC BMI.
181

182
        Attributes
183
        ----------
184
        mf6api : Mf6API
185
            The Modflow 6 API instance.
186
        phreeqcbmi : PhreeqcBMI
187
            The PHREEQC BMI instance.
188
        charge_offset : float
189
            Offset for charge, initialized to 0.0.
190
        wd :os.PathLike
191
            The working directory path.
192
        sout_fname : str
193
            Filename for the output, default is "sout.csv".
194
        reactive : bool
195
            Flag indicating if the model is reactive, default is True.
196
        epsaqu : float
197
            ??Epsaqueous value??, initialized to 0.0.
198
        fixed_components : Any
199
            Fixed components, default is None.
200
        get_selected_output_on : bool
201
            Flag indicating if selected output is on, default is True.
202
        component_model_dict : dict[str, str]
203
            Dictionary mapping PHREEQC aqueous chemical components to their
204
            corresponding Modflow 6 groundwater transport (gwt6) model names.
205
        conservative_transport_models: list[str]
206
            List of Modflow 6 groundwater transport (gwt6) model not coupled
207
            with PhreeqcRM
208
        nxyz : int
209
            Total number of cells in the grid.
210
        """
211
        assert isinstance(mf6api, Mf6API), "MF6API must be an instance of Mf6API"
1✔
212
        assert isinstance(
1✔
213
            phreeqcbmi, PhreeqcBMI
214
        ), "PhreeqcBMI must be an instance of PhreeqcBMI"
215
        self.mf6api = mf6api
1✔
216
        self.phreeqcbmi = phreeqcbmi
1✔
217
        self.charge_offset = 0.0
1✔
218
        self.wd = Path(wd)
1✔
219
        self.epsaqu = 0.0
1✔
220
        self.fixed_components = None
1✔
221
        self.selected_output = SelectedOutput(self)
1✔
222

223
        # set component model dictionary & list of conservative_transport_models
224
        self.component_model_dict, self.conservative_transport_models = self._create_component_model_dict()
1✔
225

226
        # set discretization
227
        self.nxyz = total_cells_in_grid(self.mf6api)
1✔
228
        # set time conversion factor
229
        self.set_time_conversion()
1✔
230

231
        self.config = MF6RTMConfig.from_toml_file(self.wd/"mf6rtm.toml")
1✔
232
        self.reactive = self.config.reactive['enabled']
1✔
233
        self.set_emulator_training()
1✔
234

235
    def set_emulator_training(self) -> None:
1✔
236
        """
237
        Configure emulator training output.
238

239
        Reads ``emulator_training_data`` from the configuration. If enabled,
240
        sets up emulator output variables; otherwise disables training data.
241

242
        Attributes
243
        ----------
244
        ml_output : bool
245
            Whether emulator training data output is enabled.
246

247
        Returns
248
        -------
249
        None
250
        """
251
        self.ml_output = bool(getattr(self.config, "emulator_training_data", False))
1✔
252

253
        if self.ml_output:
1✔
UNCOV
254
            self.set_emulator_output_add_variables()
×
UNCOV
255
            print("Saving emulator training data for surrogating")
×
256

257

258
    def set_emulator_output_add_variables(self) -> None:
1✔
259
        """
260
        Add emulator target and feature variables to the output.
261

262
        Updates ``selected_output`` with variables defined in the configuration.
263
        Defaults to empty lists if not provided.
264

265
        Attributes
266
        ----------
267
        selected_output.target_var : list of str
268
            Target variables for emulator training.
269
        selected_output.feat_var : list of str
270
            Feature variables for emulator training.
271

272
        Returns
273
        -------
274
        None
275
        """
UNCOV
276
        self.selected_output.target_var = getattr(
×
277
            self.config, "emulator_target_variables", []
278
        )
UNCOV
279
        self.selected_output.feat_var = getattr(
×
280
            self.config, "emulator_feature_variables", []
281
        )
282

283
    def print_warning_user_active(self):
1✔
284
        """
285
        Prints a warning if reaction timing is set to 'user'.
286
        """
287
        if self.config.reactive['timing'] == 'user':
1✔
UNCOV
288
            print(f"WARNING: Running reaction only in the following periods and time steps:")
×
UNCOV
289
            for period, timestep in self.config.reactive['tsteps']:
×
UNCOV
290
                print(f"  Period {period}, Time step {timestep}")
×
291
        else:
292
            return
1✔
293

294
    def get_saturation_from_mf6(self) -> dict[Any, np.ndarray]:
1✔
295
        """
296
        Get the saturation
297

298
        Parameters
299
        ----------
300
        mf6 (modflowapi): the modflow api object
301

302
        Returns
303
        -------
304
        array: the saturation
305
        """
306
        sat = {
1✔
307
            component: self.mf6api.get_value(
308
                self.mf6api.get_var_address(
309
                    "FMI/GWFSAT",
310
                    f"{self.component_model_dict[component]}"
311
                )
312
            )
313
            for component in self.phreeqcbmi.components
314
        }
315
        # select the first component to get the length of the array
316
        sat = sat[
1✔
317
            self.phreeqcbmi.components[0]
318
        ]  # saturation is the same for all components
319
        self.phreeqcbmi.sat_now = sat  # set phreeqcmbi saturation
1✔
320
        return sat
1✔
321

322
    def get_time_units_from_mf6(self) -> str:
1✔
323
        """Function to get the time units from mf6"""
324
        return self.mf6api.sim.tdis.time_units.get_data()
1✔
325

326
    def set_time_conversion(self) -> None:
1✔
327
        """Function to set the time conversion factor"""
328
        time_units = self.get_time_units_from_mf6()
1✔
329
        self.time_conversion = 1.0 / time_units_dict[time_units]
1✔
330
        self.phreeqcbmi.SetTimeConversion(self.time_conversion)
1✔
331

332
    def _create_component_model_dict(self)-> tuple[dict[str, str], list[str]]:
1✔
333
        """
334
        Create a dictionary of PHREEQC aqueous chemical component names and
335
        their corresponding Modflow 6 Groundwater Transport (GWT) model names.
336

337
        Returns
338
        -------
339
        component_model_dict : dict[str, str]
340
            A dictionary where the keys are the component names and the values are
341
            the corresponding transport model names.
342
        """
343
        components = self.phreeqcbmi.get_value_ptr("Components")
1✔
344
        # convert np.array to list of pure python strings
345
        components = [str(component) for component in components]
1✔
346

347
        gwt_model_names = [
1✔
348
            name for name in self.mf6api.sim.model_names
349
            if (self.mf6api.sim.get_model(name).model_type == 'gwt6')
350
        ]
351
        gwt_name_prefix = longest_common_substring(gwt_model_names)
1✔
352

353
        component_model_dict = dict(zip(components, [None]*len(components)))
1✔
354
        for component in components:
1✔
355
            for model_name in gwt_model_names:
1✔
356
                if model_name.replace(gwt_name_prefix, "").lower() == component.lower():
1✔
357
                    component_model_dict[component] = model_name
1✔
358
            if (component.lower() == 'charge') and (component_model_dict[component] == None):
1✔
UNCOV
359
                for model_name in gwt_model_names:
×
UNCOV
360
                    if model_name.replace(gwt_name_prefix, "").lower() == 'ch':
×
UNCOV
361
                        component_model_dict[component] = model_name
×
362
            assert (component_model_dict[component] != None,
1✔
363
                f"Component {component} is not matched with a transport model"
364
            )
365

366
        conservative_transport_models = list(
1✔
367
            set(gwt_model_names) - set(component_model_dict.values())
368
        )
369

370
        return component_model_dict, conservative_transport_models
1✔
371

372
    # TODO: remove or have raise not implemented error
373
    def _set_fixed_components(self, fixed_components): ...
374

375
    # TODO: make reactive a property
376
    def _set_reactive(self, reactive: bool) -> None:
1✔
377
        """Set the model to run only transport or transport and reactions"""
378
        self.reactive = reactive
1✔
379

380
    def _prepare_to_solve(self) -> None:
1✔
381
        """Prepare the model to solve"""
382
        # check if sout fname exists
383
        if self.selected_output._check_sout_exist():
1✔
384
            # if found remove it
385
            self.selected_output._rm_sout_file()
1✔
386

387
        self.mf6api._prepare_mf6()
1✔
388
        self.phreeqcbmi._prepare_phreeqcrm_bmi()
1✔
389

390
        # get and write sout headers
391
        self.selected_output._write_sout_headers()
1✔
392

393
    def _set_ctime(self) -> float:
1✔
394
        """Set the current time of the simulation from mf6api"""
395
        self.ctime = self.mf6api.get_current_time()
1✔
396
        self.phreeqcbmi._set_ctime(self.ctime)
1✔
397
        return self.ctime
1✔
398

399
    def _set_etime(self) -> float:
1✔
400
        """Set the end time of the simulation from mf6api"""
401
        self.etime = self.mf6api.get_end_time()
1✔
402
        return self.etime
1✔
403

404
    def _set_time_step(self) -> float:
1✔
405
        self.time_step = self.mf6api.get_time_step()
1✔
406
        return self.time_step
1✔
407

408
    def _finalize(self) -> None:
1✔
409
        """Finalize the APIs"""
410
        self._finalize_mf6api()
1✔
411
        self._finalize_phreeqcrm()
1✔
412

413
    def _finalize_mf6api(self) -> None:
1✔
414
        """Finalize the mf6api"""
415
        self.mf6api.finalize()
1✔
416

417
    def _finalize_phreeqcrm(self) -> None:
1✔
418
        """Finalize the phreeqcrm api"""
419
        self.phreeqcbmi.finalize()
1✔
420

421

422
    def _get_cdlbl_vect(self) -> np.ndarray[np.float64]:
1✔
423
        """Get the concentration array from phreeqc bmi reshape to (ncomps, nxyz)"""
424
        c_dbl_vect = self.phreeqcbmi.GetConcentrations()
1✔
425

426
        conc = [
1✔
427
            c_dbl_vect[i : i + self.nxyz] for i in range(0, len(c_dbl_vect), self.nxyz)
428
        ]  # reshape array
429
        # TODO: refactor to use np.reshape(), which is 2x faster
430
        return conc
1✔
431

432
    def _set_conc_at_current_kstep(self, c_dbl_vect: np.ndarray[np.float64]):
1✔
433
        """Saves the current concentration array to the object"""
434
        self.current_iteration_conc = np.reshape(
1✔
435
            c_dbl_vect, (self.phreeqcbmi.ncomps, self.nxyz)
436
        )
437

438
    def _set_conc_at_previous_kstep(self, c_dbl_vect: np.ndarray[np.float64]):
1✔
439
        """Saves the current concentration array to the object"""
440
        self.previous_iteration_conc = np.reshape(
1✔
441
            c_dbl_vect, (self.phreeqcbmi.ncomps, self.nxyz)
442
        )
443

444
    def _transfer_array_to_mf6(self) -> np.ndarray[np.float64]:
1✔
445
        """Transfer the concentration array to mf6"""
446
        c_dbl_vect = self._get_cdlbl_vect()
1✔
447

448
        # check if reactive cells were skipped due to small changes from transport and replace with previous conc
449
        if self._check_previous_conc_exists() and self._check_inactive_cells_exist(
1✔
450
            self.diffmask
451
        ):
UNCOV
452
            c_dbl_vect = self._replace_inactive_cells(c_dbl_vect, self.diffmask)
×
453
        else:
454
            pass
1✔
455

456
        conc_dict = {}
1✔
457
        for i, c in enumerate(self.phreeqcbmi.components):
1✔
458
            conc_dict[c] = c_dbl_vect[i]
1✔
459
            # Set concentrations in mf6
460
            gwt_model_name = self.component_model_dict[c]
1✔
461
            if gwt_model_name.lower() == "charge":
1✔
462
                self.mf6api.set_value(
1✔
463
                    f"{gwt_model_name.upper()}/X",
464
                    utils.concentration_l_to_m3(conc_dict[c]) + self.charge_offset,
465
                )
466
            else:
467
                self.mf6api.set_value(
1✔
468
                    f"{gwt_model_name.upper()}/X",
469
                    utils.concentration_l_to_m3(conc_dict[c]),
470
                )
471
        return c_dbl_vect
1✔
472

473
    def _check_previous_conc_exists(self) -> bool:
1✔
474
        """Function to replace inactive cells in the concentration array"""
475
        # check if self.previous_iteration_conc is a property
476
        return hasattr(self, "previous_iteration_conc")
1✔
477

478
    def _check_inactive_cells_exist(self, diffmask: np.ndarray[np.float64]) -> bool:
1✔
479
        """Function to check if inactive cells exist in the concentration array"""
480
        inact = utils.get_indices(0, diffmask)
1✔
481
        return len(inact) > 0
1✔
482

483
    def _replace_inactive_cells(
1✔
484
        self,
485
        c_dbl_vect: np.ndarray[np.float64],
486
        diffmask: np.ndarray[np.float64],
487
    ) -> np.ndarray[np.float64]:
488
        """Function to replace inactive cells in the concentration array"""
UNCOV
489
        c_dbl_vect = np.reshape(c_dbl_vect, (self.phreeqcbmi.ncomps, self.nxyz))
×
490
        # get inactive cells
UNCOV
491
        inactive_idx = [
×
492
            utils.get_indices(0, diffmask) for k in range(self.phreeqcbmi.ncomps)
493
        ]
UNCOV
494
        c_dbl_vect[:, inactive_idx] = self.previous_iteration_conc[:, inactive_idx]
×
UNCOV
495
        c_dbl_vect = c_dbl_vect.flatten()
×
UNCOV
496
        conc = [
×
497
            c_dbl_vect[i : i + self.nxyz] for i in range(0, len(c_dbl_vect), self.nxyz)
498
        ]
UNCOV
499
        return conc
×
500

501
    def _transfer_array_to_phreeqcrm(self) -> np.ndarray[np.float64]:
1✔
502
        """Transfer the concentration array to phreeqc bmi"""
503
        mf6_conc_array = []
1✔
504
        for c in self.phreeqcbmi.components:
1✔
505
            if c.lower() == "charge":
1✔
506
                mf6_conc_array.append(
1✔
507
                    utils.concentration_m3_to_l(
508
                        self.mf6api.get_value(
509
                            self.mf6api.get_var_address(
510
                                "X",
511
                                f"{self.component_model_dict[c].upper()}",
512
                            )
513
                        )
514
                        - self.charge_offset
515
                    )
516
                )
517

518
            else:
519
                mf6_conc_array.append(
1✔
520
                    utils.concentration_m3_to_l(
521
                        self.mf6api.get_value(
522
                            self.mf6api.get_var_address(
523
                                "X",
524
                                f"{self.component_model_dict[c].upper()}",
525
                            )
526
                        )
527
                    )
528
                )
529
        c_dbl_vect = np.reshape(mf6_conc_array, self.nxyz * self.phreeqcbmi.ncomps)
1✔
530
        self.phreeqcbmi.SetConcentrations(c_dbl_vect)
1✔
531

532
        # set the kper and kstp
533
        self.phreeqcbmi._get_kper_kstp_from_mf6api(
1✔
534
            self.mf6api
535
        )  # FIXME: calling this func here is not ideal
536

537
        return c_dbl_vect
1✔
538

539
    def _solve(self) -> bool:
1✔
540
        """Alias for the solve method to provide backward compatibility"""
541
        return self.solve()
1✔
542

543
    def is_reactive_tstep(self) -> bool:
1✔
544
        """
545
        Check if the current timestep should be reactive based on configuration.
546

547
        Returns:
548
            bool: True if current timestep should be reactive, False otherwise
549
        """
550
        # Early return if not in reactive mode
551

552
        if not self.reactive:
1✔
553
            return False
1✔
554

555
        # Get current timestep
556
        current_tstep = [self.mf6api.kper, self.mf6api.kstp]
1✔
557

558
        # Check strategy
559
        if self.config.reactive['timing'] == 'all':
1✔
560
            return True
1✔
UNCOV
561
        elif self.config.reactive['timing'] == 'user':
×
UNCOV
562
            return current_tstep in self.config.reactive['tsteps']
×
563
        else:
564
            # Handle unknown strategy
UNCOV
565
            print(f"Warning: Unknown strategy '{self.config.reactive['timing']}'. Defaulting to reactive.")
×
UNCOV
566
            return True
×
567

568
    def set_kiter(self) -> int:
1✔
569
        if hasattr(self, "kiter"):
1✔
570
            self.kiter += 1
1✔
571
        else:
572
            self.kiter = 0
1✔
573
        return self.kiter
1✔
574
    def solve(self) -> bool:
1✔
575
        """Solve the model"""
576
        success = False  # initialize success flag
1✔
577
        sim_start = datetime.now()
1✔
578
        self._prepare_to_solve()
1✔
579

580
        # check sout was created
581
        assert self.selected_output._check_sout_exist(), f"{self.selected_output.sout_fname} not found"
1✔
582

583
        print("Starting Solution at {0}".format(sim_start.strftime(DT_FMT)))
1✔
584
        ctime = self._set_ctime()
1✔
585
        etime = self._set_etime()
1✔
586
        while ctime < etime:
1✔
587
            # self iteration counter
588
            self.set_kiter()
1✔
589
            # length of the current solve time
590
            dt = self._set_time_step()
1✔
591
            self.mf6api.prepare_time_step(dt)
1✔
592
            self.mf6api._solve_gwt()
1✔
593

594
            # get saturation
595
            self.get_saturation_from_mf6()
1✔
596
            # check_reactive_kstp()
597
            if self.is_reactive_tstep():
1✔
598
                c_dbl_vect = self._transfer_array_to_phreeqcrm()
1✔
599
                self._set_conc_at_current_kstep(c_dbl_vect)
1✔
600

601
                # Export ML feature arrays if option is on
602
                if self.ml_output:
1✔
UNCOV
603
                    self.selected_output.write_ml_arrays(self.current_iteration_conc,
×
604
                                                    self.kiter,
605
                                                    add_var_names=self.selected_output.feat_var,
606
                                                    fname='_features.csv'
607
                                                )
608

609
                if ctime == 0.0:
1✔
610
                    self.diffmask = np.ones(self.nxyz)
1✔
611
                else:
612
                    diffmask = get_conc_change_mask(
1✔
613
                        self.current_iteration_conc,
614
                        self.previous_iteration_conc,
615
                        self.phreeqcbmi.ncomps,
616
                        self.nxyz,
617
                        treshold=self.epsaqu,
618
                    )
619
                    self.diffmask = diffmask
1✔
620
                # solve reactions
621
                self.phreeqcbmi._solve_phreeqcrm(dt, diffmask=self.diffmask)
1✔
622
                c_dbl_vect = self._transfer_array_to_mf6()
1✔
623

624
                self._set_conc_at_previous_kstep(c_dbl_vect)
1✔
625

626
            self.mf6api.finalize_time_step()
1✔
627
            ctime = self._set_ctime()  # update the current time tracking
1✔
628
            if self.selected_output.get_selected_output_on:
1✔
629
                # get sout and update df
630
                self.selected_output._update_selected_output()
1✔
631
                # append current sout rows to file
632
                self.selected_output._append_to_soutdf_file()
1✔
633
                # Export ML target arrays if option is on
634
                if self.ml_output:
1✔
UNCOV
635
                    self.selected_output.write_ml_arrays(self.previous_iteration_conc,
×
636
                                            self.kiter,
637
                                            add_var_names=self.selected_output.target_var,
638
                                            fname='_targets.csv'
639
                                                )
640

641
        sim_end = datetime.now()
1✔
642
        td = (sim_end - sim_start).total_seconds() / 60.0
1✔
643

644
        self.mf6api._check_num_fails()
1✔
645

646
        # Clean up and close api objs
647
        try:
1✔
648
            self._finalize()
1✔
649
            success = True
1✔
650
            print(mrbeaker())
1✔
651
            print(
1✔
652
                "\nMR BEAKER IMPORTANT MESSAGE: MODEL RUN FINISHED BUT CHECK THE RESULTS .. THEY ARE PROLY RUBBISH\n"
653
            )
UNCOV
654
        except:
×
UNCOV
655
            print("MR BEAKER IMPORTANT MESSAGE: SOMETHING WENT WRONG. BUMMER\n")
×
UNCOV
656
            pass
×
657
        print(
1✔
658
            "Solution finished at {0}. Running time: {1:10.5G} mins".format(
659
                sim_end.strftime(DT_FMT), td
660
            )
661
        )
662
        return success
1✔
663

664

665
def get_less_than_zero_idx(arr):
1✔
666
    """Function to get the index of all occurrences of <0 in an array"""
UNCOV
667
    idx = np.where(arr < 0)
×
UNCOV
668
    return idx
×
669

670

671
def get_inactive_idx(arr: np.ndarray, val: float = 1e30):
1✔
672
    """Function to get the index of all occurrences of <0 in an array"""
UNCOV
673
    idx = list(np.where(arr >= val)[0])
×
UNCOV
674
    return idx
×
675

676

677
def get_conc_change_mask(
1✔
678
    ci: np.ndarray[np.float64],
679
    ck: np.ndarray[np.float64],
680
    ncomp: int,
681
    nxyz: int,
682
    treshold: float = 1e-10,
683
) -> np.ndarray[np.float64]:
684
    """Function to get the active-inactive cell mask for concentration change to inform phreeqc which cells to update"""
685
    # reshape arrays to 2D (nxyz, ncomp)
686
    ci = ci.reshape(nxyz, ncomp)
1✔
687
    ck = ck.reshape(nxyz, ncomp) + 1e-30
1✔
688

689
    # get the difference between the two arrays and divide by ci
690
    diff = np.abs((ci - ck.reshape(-1 * nxyz, ncomp)) / ci) < treshold
1✔
691
    diff = np.where(diff, 0, 1)
1✔
692
    diff = diff.sum(axis=1)
1✔
693

694
    # where values <0 put -1 else 1
695
    diff = np.where(diff == 0, 0, 1)
1✔
696
    return diff
1✔
697

698

699
def longest_common_substring(strings):
1✔
700
    """Function to find the longest common substring of a list of strings
701
    Used here to find the common "stem" of the GWT model names for matching
702
    with PhreeqcRM components.
703
    """
704
    if not strings:
1✔
UNCOV
705
        return ""
×
706

707
    # Start with the first string as a reference
708
    reference_string = strings[0]
1✔
709
    longest_lcs = ""
1✔
710

711
    # Iterate through all possible substrings of the reference string
712
    for i in range(len(reference_string)):
1✔
713
        for j in range(i + 1, len(reference_string) + 1):
1✔
714
            current_substring = reference_string[i:j]
1✔
715

716
            # Check if this substring exists in all other strings
717
            is_common = True
1✔
718
            for other_string in strings[1:]:
1✔
719
                if current_substring not in other_string:
1✔
720
                    is_common = False
1✔
721
                    break
1✔
722

723
            # If it's common and longer than the current longest, update
724
            if is_common and len(current_substring) > len(longest_lcs):
1✔
UNCOV
725
                longest_lcs = current_substring
×
726

727
    return longest_lcs
1✔
728

729

730
def mrbeaker() -> str:
1✔
731
    """ASCII art of Mr. Beaker"""
732

733
    from mf6rtm.assets import mrbeaker_path
1✔
734

735
    mr_beaker_image = Image.open(mrbeaker_path())
1✔
736

737
    # Resize the image to fit the terminal width
738
    terminal_width = 70  # Adjust this based on your terminal width
1✔
739
    aspect_ratio = mr_beaker_image.width / mr_beaker_image.height
1✔
740
    terminal_height = int(terminal_width / aspect_ratio * 0.5)
1✔
741
    mr_beaker_image = mr_beaker_image.resize((terminal_width, terminal_height))
1✔
742

743
    # Convert the image to grayscale
744
    mr_beaker_image = mr_beaker_image.convert("L")
1✔
745

746
    # Convert the grayscale image to ASCII art
747
    ascii_chars = "%,.?>#*+=-:."
1✔
748

749
    mrbeaker = ""
1✔
750
    for y in range(int(mr_beaker_image.height)):
1✔
751
        mrbeaker += "\n"
1✔
752
        for x in range(int(mr_beaker_image.width)):
1✔
753
            pixel_value = mr_beaker_image.getpixel((x, y))
1✔
754
            mrbeaker += ascii_chars[pixel_value // 64]
1✔
755
        # mrbeaker += "\n"
756

757
    return mrbeaker
1✔
758

759
def run_cmd(cwd: Optional[os.PathLike] = None) -> None:
1✔
760
    """Console entrypoint compatibility wrapper.
761

762
    When used as a console script the entrypoint calls `mf6rtm:run_cmd`
763
    with no arguments. Allow `cwd` to be optional and default to the
764
    current working directory.
765
    """
NEW
766
    if cwd is None:
×
NEW
767
        cwd = os.getcwd()
×
768

769
    # run the solve function
UNCOV
770
    solve(cwd)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc