• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

quaquel / EMAworkbench / 18214982978

03 Oct 2025 06:39AM UTC coverage: 88.703% (+0.04%) from 88.664%
18214982978

Pull #422

github

web-flow
Merge fe026872f into 592d0cd98
Pull Request #422: ruff fixes

53 of 73 new or added lines in 16 files covered. (72.6%)

2 existing lines in 2 files now uncovered.

7852 of 8852 relevant lines covered (88.7%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

72.46
/ema_workbench/em_framework/parameters.py
1
"""parameters and related helper classes and functions."""
2

3
import abc
1✔
4
import numbers
1✔
5

6
import pandas as pd
1✔
7
import scipy as sp
1✔
8

9
from ..util import get_module_logger
1✔
10
from .util import NamedObject, NamedObjectMap, Variable
1✔
11

12
# Created on Jul 14, 2016
13
#
14
# .. codeauthor::jhkwakkel <j.h.kwakkel (at) tudelft (dot) nl>
15

16
__all__ = [
1✔
17
    "BooleanParameter",
18
    "CategoricalParameter",
19
    "Category",
20
    "Constant",
21
    "IntegerParameter",
22
    "Parameter",
23
    "RealParameter",
24
    "Variable",
25
    "parameters_from_csv",
26
    "parameters_to_csv",
27
]
28
_logger = get_module_logger(__name__)
1✔
29

30

31
class Bound(metaclass=abc.ABCMeta):
1✔
32
    """Bounds class."""
33

34
    def __get__(self, instance, cls):
1✔
35
        try:
1✔
36
            bound = instance.__dict__[self.internal_name]
1✔
37
        except KeyError:
1✔
38
            bound = self.get_bound(instance)
1✔
39
            self.__set__(instance, bound)
1✔
40
        return bound
1✔
41

42
    def __set__(self, instance, value):
1✔
43
        instance.__dict__[self.internal_name] = value
1✔
44

45
    def __set_name__(self, cls, name):
1✔
46
        self.name = name
1✔
47
        self.internal_name = "_" + name
1✔
48

49
    @abc.abstractmethod
50
    def get_bound(self, instance): ...
51

52

53
class UpperBound(Bound):
1✔
54
    def get_bound(self, instance):
1✔
55
        bound = instance.dist.ppf(1.0)
1✔
56
        return bound
1✔
57

58

59
class LowerBound(Bound):
1✔
60
    def get_bound(self, owner):
1✔
61
        ppf_zero = 0
1✔
62

63
        if isinstance(owner.dist.dist, sp.stats.rv_discrete):  # @UndefinedVariable
1✔
64
            # ppf at actual zero for rv_discrete gives lower bound - 1
65
            # due to a quirk in the scipy.stats implementation
66
            # so we use the smallest positive float instead
67
            ppf_zero = 5e-324
1✔
68

69
        bound = owner.dist.ppf(ppf_zero)
1✔
70
        return bound
1✔
71

72

73
class Constant(Variable):
1✔
74
    """Constant class.
75

76
    Can be used for any parameter that has to be set to a fixed value
77

78
    """
79

80
    def __init__(self, name, value, variable_name=None):
1✔
81
        """Init."""
82
        super().__init__(name, variable_name=variable_name)
×
83
        self.value = value
×
84

85
    def __repr__(self, *args, **kwargs):  # noqa: D105
1✔
86
        return f"{self.__class__.__name__}('{self.name}', {self.value})"
×
87

88

89
class Category(NamedObject):
1✔
90
    """Category class."""
91

92
    def __init__(self, name, value):
1✔
93
        """Init."""
94
        super().__init__(name)
1✔
95
        self.value = value
1✔
96

97

98
def create_category(cat):
1✔
99
    """Helper function for creating a Category object."""
100
    if isinstance(cat, Category):
1✔
101
        return cat
1✔
102
    else:
103
        return Category(str(cat), cat)
1✔
104

105

106
class Parameter(Variable, metaclass=abc.ABCMeta):
1✔
107
    """Base class for any model input parameter.
108

109
    Parameters
110
    ----------
111
    name : str
112
    lower_bound : int or float
113
    upper_bound : int or float
114
    resolution : collection
115

116
    Raises
117
    ------
118
    ValueError
119
        if lower bound is larger than upper bound
120
    ValueError
121
        if entries in resolution are outside range of lower_bound and
122
        upper_bound
123

124
    """
125

126
    lower_bound = LowerBound()
1✔
127
    upper_bound = UpperBound()
1✔
128
    default = None
1✔
129

130
    @property
1✔
131
    def resolution(self):
1✔
132
        """Getter for resolution."""
133
        return self._resolution
1✔
134

135
    @resolution.setter
1✔
136
    def resolution(self, value):
1✔
137
        """Setter for resolution."""
138
        if value:  # noqa: SIM102
1✔
139
            if (min(value) < self.lower_bound) or (max(value) > self.upper_bound):
1✔
140
                raise ValueError(
1✔
141
                    f"Resolution ({value}) not consistent with lower ({self.lower_bound}) and upper bound ({self.upper_bound})."
142
                )
143
        self._resolution = value
1✔
144

145
    def __init__(
1✔
146
        self,
147
        name: str,
148
        lower_bound,
149
        upper_bound,
150
        resolution=None,
151
        default=None,
152
        variable_name: str | list[str] | None = None,
153
    ):
154
        """Init."""
155
        super().__init__(name, variable_name=variable_name)
1✔
156
        self.lower_bound = lower_bound
1✔
157
        self.upper_bound = upper_bound
1✔
158
        self.resolution = resolution
1✔
159
        self.default = default
1✔
160
        self.dist = None
1✔
161
        self.uniform = True
1✔
162

163
    @classmethod
1✔
164
    def from_dist(cls, name, dist, **kwargs):
1✔
165
        """Factory method for creating a Parameter from a scipy distribution.
166

167
        Alternative constructor for creating a parameter from a frozen
168
        scipy.stats distribution directly
169

170
        Parameters
171
        ----------
172
        dist : scipy stats frozen dist
173
        **kwargs : valid keyword arguments for Parameter instance
174

175
        """
176
        assert isinstance(
1✔
177
            dist, sp.stats._distn_infrastructure.rv_frozen
178
        )  # @UndefinedVariable
179
        self = cls.__new__(cls)
1✔
180
        self.dist = dist
1✔
181
        self.name = name
1✔
182
        self.resolution = None
1✔
183
        self.variable_name = None
1✔
184
        self.uniform = False
1✔
185

186
        for k, v in kwargs.items():
1✔
187
            if k in {"default", "resolution", "variable_name"}:
1✔
188
                setattr(self, k, v)
1✔
189
            else:
190
                raise ValueError(f"Unknown property {k} for Parameter")
×
191

192
        return self
1✔
193

194
    def __eq__(self, other):  # noqa: D105
1✔
195
        if not isinstance(self, other.__class__):
1✔
196
            return False
×
197

198
        self_keys = set(self.__dict__.keys())
1✔
199
        other_keys = set(other.__dict__.keys())
1✔
200
        if self_keys - other_keys:
1✔
201
            return False
×
202
        else:
203
            for key in self_keys:
1✔
204
                if key != "dist":
1✔
205
                    if getattr(self, key) != getattr(other, key):
1✔
206
                        return False
1✔
207
                else:
208
                    # name, parameters
209
                    self_dist = getattr(self, key)
1✔
210
                    other_dist = getattr(other, key)
1✔
211
                    if self_dist.dist.name != other_dist.dist.name:
1✔
212
                        return False
×
213
                    if self_dist.args != other_dist.args:
1✔
214
                        return False
×
215

216
            return True
1✔
217

218
    def __hash__(self):
1✔
219
        """Hashing function."""
NEW
220
        return hash(tuple(self.__dict__.items()))
×
221

222
    def __str__(self):  # noqa: D105
1✔
223
        return self.name
×
224

225

226
class RealParameter(Parameter):
1✔
227
    """real valued model input parameter.
228

229
    Parameters
230
    ----------
231
    name : str
232
    lower_bound : int or float
233
    upper_bound : int or float
234
    resolution : iterable
235
    variable_name : str, or list of str
236

237
    Raises
238
    ------
239
    ValueError
240
        if lower bound is larger than upper bound
241
    ValueError
242
        if entries in resolution are outside range of lower_bound and
243
        upper_bound
244

245
    """
246

247
    def __init__(
1✔
248
        self,
249
        name: str,
250
        lower_bound,
251
        upper_bound,
252
        resolution=None,
253
        default=None,
254
        variable_name: str | list[str] | None = None,
255
    ):
256
        """Init."""
257
        super().__init__(
1✔
258
            name,
259
            lower_bound,
260
            upper_bound,
261
            resolution=resolution,
262
            default=default,
263
            variable_name=variable_name,
264
        )
265

266
        self.dist = sp.stats.uniform(
1✔
267
            lower_bound, upper_bound - lower_bound
268
        )  # @UndefinedVariable
269

270
    @classmethod
1✔
271
    def from_dist(cls, name, dist, **kwargs):  # noqa: D102
1✔
272
        if not isinstance(dist.dist, sp.stats.rv_continuous):  # @UndefinedVariable
1✔
273
            raise ValueError(
1✔
274
                f"dist should be instance of rv_continouos, not {dist.dist}"
275
            )
276
        return super().from_dist(name, dist, **kwargs)
1✔
277

278
    def __repr__(self):  # noqa: D105
1✔
279
        if isinstance(self.dist, sp.stats._distn_infrastructure.rv_continuous_frozen):
1✔
280
            return (
1✔
281
                f"RealParameter('{self.name}', {self.lower_bound}, {self.upper_bound}, "
282
                f"resolution={self.resolution}, default={self.default}, variable_name={self.variable_name})"
283
            )
284
        else:
285
            return super().__repr__()
×
286

287

288
class IntegerParameter(Parameter):
1✔
289
    """integer valued model input parameter.
290

291
    Parameters
292
    ----------
293
    name : str
294
    lower_bound : int
295
    upper_bound : int
296
    resolution : iterable
297
    variable_name : str, or list of str
298

299
    Raises
300
    ------
301
    ValueError
302
        if lower bound is larger than upper bound
303
    ValueError
304
        if entries in resolution are outside range of lower_bound and
305
        upper_bound, or not an integer instance
306
    ValueError
307
        if lower_bound or upper_bound is not an integer instance
308

309
    """
310

311
    def __init__(
1✔
312
        self,
313
        name,
314
        lower_bound,
315
        upper_bound,
316
        resolution=None,
317
        default=None,
318
        variable_name=None,
319
    ):
320
        """Init."""
321
        super().__init__(
1✔
322
            name,
323
            lower_bound,
324
            upper_bound,
325
            resolution=resolution,
326
            default=default,
327
            variable_name=variable_name,
328
        )
329

330
        lb_int = float(lower_bound).is_integer()
1✔
331
        up_int = float(upper_bound).is_integer()
1✔
332

333
        if not (lb_int and up_int):
1✔
334
            raise ValueError(
×
335
                f"Lower bound and upper bound must be integers, not {type(lower_bound)} and {type(upper_bound)}"
336
            )
337

338
        self.lower_bound = int(lower_bound)
1✔
339
        self.upper_bound = int(upper_bound)
1✔
340

341
        self.dist = sp.stats.randint(
1✔
342
            self.lower_bound, self.upper_bound + 1
343
        )  # @UndefinedVariable
344

345
        try:
1✔
346
            for idx, entry in enumerate(self.resolution):
1✔
347
                if not float(entry).is_integer():
1✔
348
                    raise ValueError(
×
349
                        f"All entries in resolution should be integers, not {type(entry)}"
350
                    )
351
                else:
352
                    self.resolution[idx] = int(entry)
1✔
353
        except TypeError:
1✔
354
            # if self.resolution is None
355
            pass
1✔
356

357
    @classmethod
1✔
358
    def from_dist(cls, name, dist, **kwargs):  # noqa: D102
1✔
359
        if not isinstance(dist.dist, sp.stats.rv_discrete):  # @UndefinedVariable
1✔
360
            raise ValueError(f"dist should be instance of rv_discrete, not {dist.dist}")
1✔
361
        return super().from_dist(name, dist, **kwargs)
1✔
362

363
    def __repr__(self):  # noqa: D105
1✔
364
        if isinstance(self.dist, sp.stats._distn_infrastructure.rv_discrete_frozen):
×
365
            return (
×
366
                f"IntegerParameter('{self.name}', {self.lower_bound}, {self.upper_bound}, "
367
                f"resolution={self.resolution}, default={self.default}, variable_name={self.variable_name})"
368
            )
369
        else:
370
            return super().__repr__()
×
371

372

373
class CategoricalParameter(IntegerParameter):
1✔
374
    """categorical model input parameter.
375

376
    Parameters
377
    ----------
378
    name : str
379
    categories : collection of obj
380
    variable_name : str, or list of str
381
    multivalue : boolean
382
                 if categories have a set of values, for each variable_name
383
                 a different one.
384
    # TODO: should multivalue not be a separate class?
385
    # TODO: multivalue as label is also horrible
386

387
    """
388

389
    @property
1✔
390
    def categories(self):  # noqa: D102
1✔
391
        return self._categories
1✔
392

393
    @categories.setter
1✔
394
    def categories(self, values):
1✔
395
        self._categories.extend(values)
1✔
396

397
    def __init__(
1✔
398
        self,
399
        name,
400
        categories,
401
        default=None,
402
        variable_name=None,
403
        multivalue=False,
404
    ):
405
        """Init."""
406
        lower_bound = 0
1✔
407
        upper_bound = len(categories) - 1
1✔
408

409
        if upper_bound == 0:
1✔
410
            raise ValueError(
×
411
                f"There should be more than 1 category, instead of {len(categories)}"
412
            )
413

414
        super().__init__(
1✔
415
            name,
416
            lower_bound,
417
            upper_bound,
418
            resolution=None,
419
            default=default,
420
            variable_name=variable_name,
421
        )
422
        cats = [create_category(cat) for cat in categories]
1✔
423

424
        self._categories = NamedObjectMap(Category)
1✔
425

426
        self.categories = cats
1✔
427
        self.resolution = list(range(len(self.categories)))
1✔
428
        self.multivalue = multivalue
1✔
429

430
    def index_for_cat(self, category):
1✔
431
        """Return index of category.
432

433
        Parameters
434
        ----------
435
        category : object
436

437
        Returns
438
        -------
439
        int
440

441

442
        """
443
        for i, cat in enumerate(self.categories):
1✔
444
            if cat.name == category:
1✔
445
                return i
1✔
446
        raise ValueError(f"Category {category} not found")
1✔
447

448
    def cat_for_index(self, index):
1✔
449
        """Return category given index.
450

451
        Parameters
452
        ----------
453
        index  : int
454

455
        Returns
456
        -------
457
        object
458

459
        """
460
        return self.categories[index]
1✔
461

462
    def __repr__(self, *args, **kwargs):  # noqa: D105
1✔
463
        template1 = "CategoricalParameter('{}', {}, default={})"
×
464
        template2 = "CategoricalParameter('{}', {})"
×
465

466
        if self.default:
×
467
            representation = template1.format(self.name, self.resolution, self.default)
×
468
        else:
469
            representation = template2.format(self.name, self.resolution)
×
470

471
        return representation
×
472

473
    def from_dist(self, name, dist):  # noqa: D102
1✔
474
        # TODO:: how to handle this
475
        # probably need to pass categories as list and zip
476
        # categories to integers implied by dist
477
        raise NotImplementedError(
×
478
            "Custom distributions over categories not supported yet"
479
        )
480

481

482
class BooleanParameter(CategoricalParameter):
1✔
483
    """boolean model input parameter.
484

485
    A BooleanParameter is similar to a CategoricalParameter, except
486
    the category values can only be True or False.
487

488
    Parameters
489
    ----------
490
    name : str
491
    variable_name : str, or list of str
492

493
    """
494

495
    def __init__(self, name, default=None, variable_name=None):
1✔
496
        """Init."""
497
        super().__init__(
1✔
498
            name,
499
            categories=[False, True],
500
            default=default,
501
            variable_name=variable_name,
502
        )
503

504
    def __repr__(self):  # noqa: D105
1✔
505
        return (
×
506
            f"BooleanParameter('{self.name}', default={self.default}, "
507
            f"variable_name={self.variable_name})"
508
        )
509

510

511
def parameters_to_csv(parameters, file_name):
1✔
512
    """Helper function for writing a collection of parameters to a csv file.
513

514
    Parameters
515
    ----------
516
    parameters : collection of Parameter instances
517
    file_name :  str
518

519

520
    The function iterates over the collection and turns these into a data
521
    frame prior to storing them. The resulting csv can be loaded using the
522
    parameters_from_csv function. Note that currently we don't store resolution
523
    and default attributes.
524

525
    """
526
    params = {}
1✔
527

528
    for i, param in enumerate(parameters):
1✔
529
        if isinstance(param, CategoricalParameter):
1✔
530
            values = param.resolution
1✔
531
        else:
532
            values = param.lower_bound, param.upper_bound
1✔
533

534
        dict_repr = dict(enumerate(values))
1✔
535
        dict_repr["name"] = param.name
1✔
536

537
        params[i] = dict_repr
1✔
538

539
    params = pd.DataFrame.from_dict(params, orient="index")
1✔
540

541
    # for readability it is nice if name is the first column, so let's
542
    # ensure this
543
    cols = params.columns.tolist()
1✔
544
    cols.insert(0, cols.pop(cols.index("name")))
1✔
545
    params = params.reindex(columns=cols)
1✔
546

547
    # we can now safely write the dataframe to a csv
548
    pd.DataFrame.to_csv(params, file_name, index=False)
1✔
549

550

551
def parameters_from_csv(uncertainties, **kwargs):
1✔
552
    """Helper function for creating many Parameters based on a DataFrame or csv file.
553

554
    Parameters
555
    ----------
556
    uncertainties : str, DataFrame
557
    **kwargs : dict, arguments to pass to pandas.read_csv
558

559
    Returns
560
    -------
561
    list of Parameter instances
562

563

564
    This helper function creates uncertainties. It assumes that the
565
    DataFrame or csv file has a column titled 'name', optionally a type column
566
    {int, real, cat}, can be included as well. the remainder of the columns
567
    are handled as values for the parameters. If type is not specified,
568
    the function will try to infer type from the values.
569

570
    Note that this function does not support the resolution and default kwargs
571
    on parameters.
572

573
    An example of a csv:
574

575
    NAME,TYPE,,,
576
    a_real,real,0,1.1,
577
    an_int,int,1,9,
578
    a_categorical,cat,a,b,c
579

580
    this CSV file would result in
581

582
    [RealParameter('a_real', 0, 1.1, resolution=[], default=None),
583
     IntegerParameter('an_int', 1, 9, resolution=[], default=None),
584
     CategoricalParameter('a_categorical', ['a', 'b', 'c'], default=None)]
585

586
    """
587
    if isinstance(uncertainties, str):
×
588
        uncertainties = pd.read_csv(uncertainties, **kwargs)
×
589
    elif not isinstance(uncertainties, pd.DataFrame):
×
590
        uncertainties = pd.DataFrame.from_dict(uncertainties)
×
591
    else:
592
        uncertainties = uncertainties.copy()
×
593

594
    parameter_map = {
×
595
        "int": IntegerParameter,
596
        "real": RealParameter,
597
        "cat": CategoricalParameter,
598
        "bool": BooleanParameter,
599
    }
600

601
    # check if names column is there
602
    if ("NAME" not in uncertainties) and ("name" not in uncertainties):
×
603
        raise IndexError("name column missing")
×
604
    elif "NAME" in uncertainties.columns:
×
605
        names = uncertainties["NAME"]
×
606
        uncertainties.drop(["NAME"], axis=1, inplace=True)
×
607
    else:
608
        names = uncertainties["name"]
×
609
        uncertainties.drop(["name"], axis=1, inplace=True)
×
610

611
    # check if type column is there
612
    infer_type = False
×
613
    if ("TYPE" not in uncertainties) and ("type" not in uncertainties):
×
614
        infer_type = True
×
615
    elif "TYPE" in uncertainties:
×
616
        types = uncertainties["TYPE"]
×
617
        uncertainties.drop(["TYPE"], axis=1, inplace=True)
×
618
    else:
619
        types = uncertainties["type"]
×
620
        uncertainties.drop(["type"], axis=1, inplace=True)
×
621

622
    uncs = []
×
623
    for i, row in uncertainties.iterrows():
×
624
        name = names[i]
×
625
        values = row.values[row.notnull().values]
×
626
        type = None  # @ReservedAssignment
×
627

628
        if infer_type:
×
629
            if len(values) != 2:
×
630
                type = "cat"  # @ReservedAssignment
×
631
            else:
632
                l, u = values  # noqa: E741
×
633

634
                if isinstance(l, numbers.Integral) and isinstance(u, numbers.Integral):
×
635
                    type = "int"  # @ReservedAssignment
×
636
                else:
637
                    type = "real"  # @ReservedAssignment
×
638

639
        else:
640
            type = types[i]  # @ReservedAssignment
×
641

642
            if (type != "cat") and (len(values) != 2):
×
643
                raise ValueError(
×
644
                    f"Too many values specified for {name}, is {values.shape[0]}, should be 2"
645
                )
646

647
        if type == "cat":
×
648
            uncs.append(parameter_map[type](name, values))
×
649
        else:
650
            uncs.append(parameter_map[type](name, *values))
×
651
    return uncs
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc