• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

KarlNaumann / MacroStat / 21241273980

22 Jan 2026 08:24AM UTC coverage: 96.528% (+0.09%) from 96.437%
21241273980

push

github

web-flow
Differentiation tools (#61)

* Add core autodiff and numerical Jacobian infrastructure

* Add LINEAR2D testing model following standard Behavior pattern

* Make apply_parameter_shocks robust to missing vector_sectors

* Add LINEAR2D-based tests for diff Jacobian tools

* Document diff tools and LINEAR2D testing model

301 of 306 branches covered (98.37%)

Branch coverage included in aggregate %.

226 of 244 new or added lines in 11 files covered. (92.62%)

2 existing lines in 1 file now uncovered.

2090 of 2171 relevant lines covered (96.27%)

0.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.57
/src/macrostat/core/behavior.py
1
"""
2
Behavior classes for the MacroStat model.
3
"""
4

5
__author__ = ["Karl Naumann-Woleske"]
1✔
6
__credits__ = ["Karl Naumann-Woleske"]
1✔
7
__license__ = "MIT"
1✔
8
__maintainer__ = ["Karl Naumann-Woleske"]
1✔
9

10
import logging
1✔
11

12
import torch
1✔
13

14
from macrostat.core.parameters import Parameters
1✔
15
from macrostat.core.scenarios import Scenarios
1✔
16
from macrostat.core.variables import Variables
1✔
17

18
logger = logging.getLogger(__name__)
1✔
19

20

21
class Behavior(torch.nn.Module):
1✔
22
    """Base class for the behavior of the MacroStat model."""
23

24
    def __init__(
1✔
25
        self,
26
        parameters: Parameters,
27
        scenarios: Scenarios,
28
        variables: Variables,
29
        scenario: int = 0,
30
        differentiable: bool = False,
31
        debug: bool = False,
32
    ):
33
        """Initialize the behavior of the MacroStat model.
34

35
        Parameters
36
        ----------
37
        parameters: macrostat.core.parameters.Parameters
38
            The parameters of the model.
39
        scenarios: macrostat.core.scenarios.Scenarios
40
            The scenarios of the model.
41
        variables: macrostat.core.variables.Variables
42
            The variables of the model.
43
        scenario: int
44
            The scenario to use for the model run.
45
        debug: bool
46
            Whether to print debug information.
47
        """
48
        # Initialize the parent class
49
        super().__init__()
1✔
50

51
        # Initialize the parameters
52
        self.params = parameters.to_nn_parameters()
1✔
53
        self.hyper = parameters.hyper
1✔
54

55
        # Initialize the scenarios
56
        self.scenarios = scenarios.to_nn_parameters(scenario=scenario)
1✔
57
        self.scenarioID = scenario
1✔
58

59
        # Initialize the variables
60
        self.variables = variables
1✔
61

62
        # Settings
63
        self.differentiable = differentiable
1✔
64
        self.debug = debug
1✔
65

66
    ############################################################################
67
    # Simulation of the model
68
    ############################################################################
69

70
    def forward(self):
1✔
71
        """Forward pass of the behavior.
72

73
        This should include the model's main loop, and is implemented as a placeholder.
74
        The idea is for users to implement an initialize() and step() function,
75
        which will be called by the forward() function.
76

77
        If there are additional steps necessary, users may wish to overwrite this function.
78
        """
79
        # Set the seed
80
        torch.manual_seed(self.hyper["seed"])
1✔
81

82
        # Initialize the output tensors
83
        self.state, self.history = self.variables.initialize_tensors()
1✔
84

85
        # Initialize the model
86
        logger.debug(
1✔
87
            f"Initializing model (t=0...{self.hyper['timesteps_initialization']})"
88
        )
89
        self.initialize()
1✔
90

91
        for t in range(self.hyper["timesteps_initialization"] + 1):
1✔
92
            self.variables.record_state(t, self.state)
1✔
93

94
        for t in range(self.hyper["timesteps_initialization"] + 1):
1✔
95
            self.history = self.variables.update_history(self.state)
1✔
96

97
        # Initialize the prior and state
98
        self.prior = self.state
1✔
99

100
        # Run the model for the remaining timesteps
101
        logger.debug(
1✔
102
            f"Simulating model (t={self.hyper['timesteps_initialization'] + 1}...{self.hyper['timesteps']})"
103
        )
104

105
        for t in range(self.hyper["timesteps_initialization"], self.hyper["timesteps"]):
1✔
106
            self.state = self.variables.new_state()
1✔
107
            # Get scenario series for this point in time
108
            idx = torch.where(
1✔
109
                torch.arange(self.hyper["timesteps"]) == t,
110
                torch.ones(1),
111
                torch.zeros(1),
112
            )
113
            scenario = {k: idx @ v for k, v in self.scenarios.items()}
1✔
114

115
            # Apply parameter shocks
116
            params = self.apply_parameter_shocks(t, scenario)
1✔
117

118
            # Step the model
119
            self.step(t=t, scenario=scenario, params=params)
1✔
120

121
            # Store the outputs
122
            self.variables.record_state(t, self.state)
1✔
123
            self.history = self.variables.update_history(self.state)
1✔
124
            self.prior = self.state
1✔
125

126
        return self.variables.gather_timeseries()
1✔
127

128
    def initialize(self):
1✔
129
        """Initialize the behavior.
130

131
        This should include the model's initialization steps, and set all of the
132
        necessary state variables. They only need to be set for one period, and
133
        will then be copied to the history and prior to be used in the step function.
134
        """
135
        raise NotImplementedError("Behavior.initialize() to be implemented by model")
136

137
    def step(self, t: int, scenario: dict, params: dict | None = None):
1✔
138
        """Step function of the behavior.
139

140
        This should include the model's main loop.
141

142
        Parameters
143
        ----------
144
        t: int
145
            The current timestep.
146
        scenario: dict
147
            The scenario information for the current timestep.
148
        """
149
        raise NotImplementedError("Behavior.step() to be implemented by model")
150

151
    def apply_parameter_shocks(self, t: int, scenario: dict):
1✔
152
        """Apply parameter shocks to the model.
153

154
        Any parameter in the model can be shocked/changed during the simulation
155
        using the scenario information. Specifically, for a parameter alpha, the
156
        user can pass two types of potential shocks:
157
        1. An multiplicative shock, generically named alpha_multiply
158
        2. An additive shock, generically named alpha_add
159

160
        This function will apply the shocks to the parameters, and return a
161
        dictionary with the updated parameters. The application of the shocks is
162
        independent, that is, the multiplicative shock does not affect the additive
163
        shock and vice versa. This is done by first applying the multiplicative
164
        shock, and then the additive shock.
165

166
        Parameters
167
        ----------
168
        t: int
169
            The current timestep.
170
        scenario: dict
171
            The scenario information for the current timestep.
172

173
        Returns
174
        -------
175
        dict
176
            A dictionary with the updated parameters.
177
        """
178
        # Optional sectoral structure for vector/matrix parameters
179
        vsecs = self.hyper.get("vector_sectors", [])
1✔
180
        n = len(vsecs)
1✔
181
        if n > 0:
1✔
182
            one, zero = torch.ones(n), torch.zeros(n)
1✔
183
            sec_vectors = {
1✔
184
                s: torch.where(torch.arange(n) == i, one, zero)
185
                for i, s in enumerate(vsecs)
186
            }
187

188
            # Generate index matrices per sector pair
189
            pairs = [(row, col) for row in vsecs for col in vsecs]
1✔
190
            one, zero = torch.ones(n, n), torch.zeros(n, n)
1✔
191
            sec_matrices = {
1✔
192
                s: torch.where(torch.arange(n * n).reshape(n, n) == i, one, zero)
193
                for i, s in enumerate(pairs)
194
            }
195
        else:
196
            sec_vectors = {}
1✔
197
            sec_matrices = {}
1✔
198

199
        params = {}
1✔
200
        for key, value in self.params.items():
1✔
201
            if len(value.shape) == 0:
1✔
202
                mul, add = torch.tensor(1.0), torch.tensor(0.0)
1✔
203
                if f"{key}_multiply" in scenario:
1✔
204
                    mul = scenario[f"{key}_multiply"]
1✔
205
                if f"{key}_add" in scenario:
1✔
206
                    add = scenario[f"{key}_add"]
1✔
207

208
            else:
209
                add = torch.zeros_like(value)
1✔
210
                mul = torch.ones_like(value)
1✔
211

212
                if len(value.shape) == 1 and sec_vectors:
1✔
213
                    for s, ix in sec_vectors.items():
1✔
214
                        if f"{s}_{key}_multiply" in scenario:
1✔
UNCOV
215
                            mul = mul * (ix * scenario[f"{s}_{key}_multiply"])
×
216
                        if f"{s}_{key}_add" in scenario:
1✔
217
                            add = add + (ix * scenario[f"{s}_{key}_add"])
1✔
218

219
                elif len(value.shape) == 2 and sec_matrices:
1✔
220
                    for rowcol, ix in sec_matrices.items():
1✔
221
                        s = f"{rowcol[0]}_{rowcol[1]}"
1✔
222

223
                        if f"{s}_{key}_multiply" in scenario:
1✔
UNCOV
224
                            mul = mul * (ix * scenario[f"{s}_{key}_multiply"])
×
225
                        if f"{s}_{key}_add" in scenario:
1✔
226
                            add = add + (ix * scenario[f"{s}_{key}_add"])
1✔
227

228
            params[key] = value * mul + add
1✔
229
        return params
1✔
230

231
    ############################################################################
232
    # Steady State
233
    ############################################################################
234

235
    def compute_theoretical_steady_state(self, **kwargs):
1✔
236
        """Compute the theoretical steady state of the model.
237

238
        This process generally follows the structure of the forward() function,
239
        but instead of simulating the model, the steady state is computed at
240
        each timestep. Therefore, (1) the model is initialized, and (2) for
241
        each timestep the parameter and scenario information is passed to the
242
        compute_theoretical_steady_state_per_step() function that computes the
243
        steady state at that timestep.
244

245
        Parameters
246
        ----------
247
        **kwargs: dict
248
            Additional keyword arguments.
249
        """
250
        # Set the seed
251
        torch.manual_seed(self.hyper["seed"])
1✔
252

253
        # Initialize the output tensors
254
        self.state, _ = self.variables.initialize_tensors()
1✔
255

256
        # Initialize the model
257
        info = f"(t=0...{self.hyper['timesteps_initialization']})"
1✔
258
        logger.debug(f"Initializing model {info}")
1✔
259
        self.initialize()
1✔
260

261
        for t in range(self.hyper["timesteps_initialization"]):
1✔
262
            self.variables.record_state(t, self.state)
1✔
263

264
        # Compute the steady state
265
        info = f"(t={self.hyper['timesteps_initialization'] + 1}...{self.hyper['timesteps']})"
1✔
266
        logger.debug(f"Computing theoretical steady state {info}")
1✔
267

268
        for t in range(
1✔
269
            self.hyper["timesteps_initialization"] + 1, self.hyper["timesteps"]
270
        ):
271
            self.state = self.variables.new_state()
1✔
272

273
            # Get scenario series for this point in time
274
            idx = torch.where(
1✔
275
                torch.arange(self.hyper["timesteps"]) == t,
276
                torch.ones(1),
277
                torch.zeros(1),
278
            )
279
            scenario = {k: idx @ v for k, v in self.scenarios.items()}
1✔
280

281
            # Apply parameter shocks
282
            params = self.apply_parameter_shocks(t, scenario)
1✔
283

284
            # Compute the steady state
285
            self.compute_theoretical_steady_state_per_step(
1✔
286
                t=t, params=params, scenario=scenario
287
            )
288

289
            # Store the outputs
290
            self.variables.record_state(t, self.state)
1✔
291

292
        return None
1✔
293

294
    def compute_theoretical_steady_state_per_step(self, **kwargs):
1✔
295
        """Compute the theoretical steady state of the model per step."""
296
        raise NotImplementedError(
297
            "Behavior.compute_theoretical_steady_state_per_step() to be implemented by model"
298
        )
299

300
    ############################################################################
301
    # Some Differentiable PyTorch Alternatives
302
    ############################################################################
303

304
    def diffwhere(self, condition, x1, x2):
1✔
305
        """Where condition that is differentiable with respect to the condition.
306

307
        Requires:
308
            self.hyper['diffwhere'] = True
309
            self.hyper['sigmoid_constant'] as a large number
310

311
        Note: For non-NaN/inf, where(x > eps, z, y) is (x - eps > 0) * (z - y) + y,
312
        so we can use the sigmoid function to approximate the where function.
313

314
        Parameters
315
        ----------
316
        condition : torch.Tensor
317
            Condition to be evaluated expressed as x - eps
318
        x1 : torch.Tensor
319
            Value to be returned if condition is True
320
        x2 : torch.Tensor
321
            Value to be returned if condition is False
322
        """
323
        sig = torch.sigmoid(torch.mul(condition, self.hyper["sigmoid_constant"]))
1✔
324
        return torch.add(torch.mul(sig, torch.sub(x1, x2)), x2)
1✔
325

326
    def tanhmask(self, x):
1✔
327
        """Convert a variable into 0 (x<0) and 1 (x>0)
328

329
        Requires:
330
            self.hyper['tanh_constant'] as a large number
331

332
        Parameters
333
        ----------
334
        x: torch.Tensor
335
            The variable to be converted.
336

337
        """
338
        kwg = {"dtype": torch.float64, "requires_grad": True}
1✔
339
        return torch.div(
1✔
340
            torch.add(
341
                torch.ones(x.size(), **kwg),
342
                torch.tanh(torch.mul(x, self.hyper["tanh_constant"])),
343
            ),
344
            torch.tensor(2.0, **kwg),
345
        )
346

347
    def diffmin(self, x1, x2):
1✔
348
        """Smooth approximation to the minimum
349
        B: https://mathoverflow.net/questions/35191/a-differentiable-approximation-to-the-minimum-function
350

351
        Requires:
352
            self.hyper['min_constant'] as a large number
353

354
        Parameters
355
        ----------
356
        x1: torch.Tensor
357
            The first variable to be compared.
358
        x2: torch.Tensor
359
            The second variable to be compared.
360
        """
361
        r = self.hyper["min_constant"]
1✔
362
        pt1 = torch.exp(torch.mul(x1, -1 * r))
1✔
363
        pt2 = torch.exp(torch.mul(x2, -1 * r))
1✔
364
        return torch.mul(-1 / r, torch.log(torch.add(pt1, pt2)))
1✔
365

366
    def diffmax(self, x1, x2):
1✔
367
        """Smooth approximation to the minimum
368
        B: https://mathoverflow.net/questions/35191/a-differentiable-approximation-to-the-minimum-function
369

370
        Requires:
371
            self.hyper['max_constant'] as a large number
372

373
        Parameters
374
        ----------
375
        x1: torch.Tensor
376
            The first variable to be compared.
377
        x2: torch.Tensor
378
            The second variable to be compared.
379
        """
380
        r = self.hyper["max_constant"]
1✔
381
        pt1 = torch.exp(torch.mul(x1, r))
1✔
382
        pt2 = torch.exp(torch.mul(x2, r))
1✔
383
        return torch.mul(1 / r, torch.log(torch.add(pt1, pt2)))
1✔
384

385
    def diffmin_v(self, x):
1✔
386
        """Smooth approximation to the minimum. See diffmin
387

388
        Parameters
389
        ----------
390
        x: torch.Tensor
391
            The variable to be converted.
392

393
        Requires:
394
            self.hyper['min_constant'] as a large number
395
        """
396
        r = self.hyper["min_constant"]
1✔
397
        temp = torch.exp(torch.mul(x, -1 * r))
1✔
398
        return torch.mul(-1 / r, torch.log(torch.sum(temp)))
1✔
399

400
    def diffmax_v(self, x):
1✔
401
        """Smooth approximation to the maximum for a tensor. See diffmax
402

403
        Requires:
404
            self.hyper['max_constant'] as a large number
405

406
        Parameters
407
        ----------
408
        x: torch.Tensor
409
            The variable to be converted.
410
        """
411
        r = self.hyper["max_constant"]
1✔
412
        temp = torch.exp(torch.mul(x, r))
1✔
413
        return torch.mul(1 / r, torch.log(torch.sum(temp)))
1✔
414

415

416
if __name__ == "__main__":
417
    pass
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc