• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 4960685499

pending completion
4960685499

Pull #34

github

GitHub
Merge 7b7c1d7ba into d1de71ffa
Pull Request #34: Preparation for JOSS submission

184 of 184 new or added lines in 28 files covered. (100.0%)

1071 of 2324 relevant lines covered (46.08%)

0.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.78
/coolest/template/classes/parameter.py
1
# Single parameter of a profile
2

3
from typing import List
1✔
4

5
from coolest.template.classes.base import APIBaseObject
1✔
6
from coolest.template.classes.probabilities import Prior, PosteriorStatistics
1✔
7
from coolest.template.classes.grid import PixelatedRegularGrid, IrregularGrid
1✔
8

9
import numpy as np
1✔
10

11

12
__all__ = [
1✔
13
    'Parameter',
14
    'NonLinearParameter', 
15
    'LinearParameter', 
16
    'HyperParameter',
17
    'LinearParameterSet',
18
    'NonLinearParameterSet',
19
    'PixelatedRegularGridParameter',
20
    'IrregularGridParameter',
21
]
22

23

24
class DefinitionRange(APIBaseObject):
1✔
25
    """Defines the interval over which a parameter is defined.
1✔
26

27
    Parameters
28
    ----------
29
    min_value : (int, float), optional
30
        Lower bound of the interval (inclusive), by default None
31
    max_value : (int, float), optional
32
        Upper bound of the interval (inclusive), by default None
33
    """
34
    
35
    def __init__(self, min_value=None, max_value=None):
1✔
36
        self.min_value = min_value
1✔
37
        self.max_value = max_value
1✔
38

39

40
class PointEstimate(APIBaseObject):
1✔
41
    """Define a point in the parameter space.
1✔
42

43
    Parameters
44
    ----------
45
    value : float, optional
46
        Value of the parameter, by default None
47
    """
48
    
49
    def __init__(self, value=None):
1✔
50
        self.value = value
1✔
51

52

53
class Parameter(APIBaseObject):
1✔
54
    """Base class of a generic model parameter.
1✔
55

56
    Parameters
57
    ----------
58
    documentation : str
59
        Short description of the parameter.
60
    definition_range : DefinitionRange, optional
61
        Interval over which the parameter is defined, by default None
62
    units : str, optional
63
        Unit of the parameter, if any, by default None
64
    fixed : bool, optional
65
        If True, the parameter is considered fixed 
66
        (i.e. should not be, or has not be optimized), by default False
67
    point_estimate : PointEstimate, optional
68
        Point-estimate value of the parameter, by default None
69
    posterior_stats : PosteriorStatistics, optional
70
        Summary statistics of the marginalized posterior 
71
        distribution of the parameter, by default None
72
    prior : Prior, optional
73
        Prior assigned the parameter, if any, by default None
74
    latex_str : str, optional
75
        LaTeX representation of the parameter, by default None
76
    """
77
    def __init__(self, 
1✔
78
                 documentation: str, 
79
                 definition_range: DefinitionRange = None,
80
                 units: str = None,
81
                 fixed: bool = False,
82
                 point_estimate: PointEstimate = None,
83
                 posterior_stats: PosteriorStatistics = None,
84
                 prior: Prior = None,
85
                 latex_str: str = None) -> None:
86
        self.documentation = documentation
1✔
87
        self.units = units
1✔
88
        self.definition_range = definition_range
1✔
89
        self.fixed = fixed
1✔
90
        if not isinstance(point_estimate, PointEstimate):
1✔
91
            self.point_estimate = PointEstimate(point_estimate)
1✔
92
        else:
93
            self.point_estimate = point_estimate
×
94
        if posterior_stats is None:
1✔
95
            posterior_stats = PosteriorStatistics()
1✔
96
        self.posterior_stats = posterior_stats
1✔
97
        if prior is None:
1✔
98
            prior = Prior()
1✔
99
        self.prior = prior
1✔
100
        self.latex_str = latex_str
1✔
101
        self.id = None
1✔
102
        super().__init__()
1✔
103
        
104
    def set_point_estimate(self, point_estimate):
1✔
105
        """Set the point estimate value of the parameter.
106

107
        Parameters
108
        ----------
109
        point_estimate : int, float, list, PointEstimate
110
            Parameter value, or directly a PointEstimate instance.
111

112
        Raises
113
        ------
114
        ValueError
115
            If the provided point_estimate has not a supported type.
116
        ValueError
117
            If the parameter value is below its minimum allowed value.
118
        ValueError
119
            If the parameter value is above its maximum allowed value.
120
        """
121
        if isinstance(point_estimate, (float, int, list)):
1✔
122
            self.point_estimate = PointEstimate(value=point_estimate)
1✔
123
        elif isinstance(point_estimate, tuple):
1✔
124
            self.point_estimate = PointEstimate(value=list(point_estimate))
×
125
        elif isinstance(point_estimate, np.ndarray):
1✔
126
            self.point_estimate = PointEstimate(value=point_estimate.tolist())
×
127
        elif isinstance(point_estimate, PointEstimate):
1✔
128
            self.point_estimate = point_estimate
1✔
129
        else:
130
            raise ValueError("Parameter prior must be either a PointEstimate instance "
×
131
                             "or a single number (float or int) or an array (tuple, list or ndarray).")
132
        if self.point_estimate.value is not None:
1✔
133
            val = self.point_estimate.value
1✔
134
            min_val = self.definition_range.min_value
1✔
135
            max_val = self.definition_range.max_value
1✔
136
            if min_val is not None and np.any(np.asarray(val) < np.asarray(min_val)):
1✔
137
                raise ValueError(f"Value cannot be smaller than {self.definition_range.min_value}.")
×
138
            if max_val is not None and np.any(np.asarray(val) > np.asarray(max_val)):
1✔
139
                raise ValueError(f"Value cannot be larger than {self.definition_range.max_value}.")
×
140

141
    def remove_point_estimate(self):
1✔
142
        """Remove the current point estimate of the parameter.
143
        """
144
        self.point_estimate = PointEstimate()
×
145

146
    def set_posterior(self, posterior_stats):
1✔
147
        """Set the posterior statistics of the parameter.
148

149
        Parameters
150
        ----------
151
        posterior_stats : PosteriorStatistics
152
            Instance of the PosteriorStatistics object.
153

154
        Raises
155
        ------
156
        ValueError
157
            If the argument is not a PosteriorStatistics instance.
158
        """
159
        if not isinstance(posterior_stats, PosteriorStatistics):
1✔
160
            raise ValueError("Parameter prior must be a PosteriorStatistics instance.")
×
161
        self.posterior_stats = posterior_stats
1✔
162

163
    def remove_posterior(self):
1✔
164
        """Remove the current posterior statistics of the parameter.
165
        """
166
        self.posterior_stats = PosteriorStatistics()
×
167

168
    def set_prior(self, prior):
1✔
169
        """Associate a prior distribution to the parameter.
170

171
        Parameters
172
        ----------
173
        prior : Prior
174
            Instance of Prior object.
175

176
        Raises
177
        ------
178
        ValueError
179
            If the argument is not a Prior instance.
180
        """
181
        if not isinstance(prior, Prior):
1✔
182
            raise ValueError("Parameter prior must be a Prior instance.")
×
183
        self.prior = prior
1✔
184

185
    def remove_prior(self):
1✔
186
        """Remove the current posterior statistics of the parameter.
187
        """
188
        self.prior = Prior()
×
189

190
    def fix(self):
1✔
191
        """Set the fixed attribute to True, marking it as fixed."""
192
        if self.point_estimate.value is None:
×
193
            raise ValueError("Cannot fix parameter as no point estimate value has been set.")
×
194
        self.fixed = True
×
195

196
    def unfix(self):
1✔
197
        """Set the fixed attribute to False, marking as free to vary"""
198
        self.fixed = False
×
199

200

201
class NonLinearParameter(Parameter):
1✔
202
    """Define a non-linear parameter of a lens model
1✔
203
    
204
    Warning: this class may be removed in the future.
205
    """
206

207
    def __init__(self, *args, **kwargs):
1✔
208
        super().__init__(*args, **kwargs)
1✔
209

210

211
class LinearParameter(Parameter):
1✔
212
    """Define a hyper-parameter of a lens model
1✔
213
    
214
    Warning: this class may be removed in the future, as it has adds unncessary abstraction level.
215
    """
216

217
    def __init__(self, *args, **kwargs):
1✔
218
        super().__init__(*args, **kwargs)
1✔
219

220

221
class HyperParameter(Parameter):
1✔
222
    """Define a hyper-parameter of a model"""
1✔
223

224
    def __init__(self, *args, **kwargs):
1✔
225
        super().__init__(*args, **kwargs)        
×
226

227

228
class ParameterSet(Parameter):
1✔
229
    """Typically for analytical basis sets"""
1✔
230

231
    def __init__(self, *args, **kwargs) -> None:
1✔
232
        if 'point_estimate' not in kwargs or kwargs['point_estimate'] is None:
1✔
233
            kwargs['point_estimate'] = []
1✔
234
        if not isinstance(kwargs['point_estimate'], list):
1✔
235
            raise ValueError("For any ParameterSet, `point_estimate` must be a list of values.")
×
236
        super().__init__(*args, **kwargs)
1✔
237
        self.num_values = len(self.point_estimate.value)
1✔
238

239

240
class LinearParameterSet(ParameterSet):
1✔
241
    """Typically for analytical basis sets.
1✔
242

243
    Warning: this class may be removed in the future, as it has adds unncessary abstraction level.
244
    """
245

246
    def __init__(self, *args, **kwargs) -> None:
1✔
247
        super().__init__(*args, **kwargs)
1✔
248
        
249
        
250
class NonLinearParameterSet(ParameterSet):
1✔
251
    """Typically for position of point sources.
1✔
252
    
253
    Warning: this class may be removed in the future, as it has adds unncessary abstraction level."""
254

255
    def __init__(self, *args, **kwargs) -> None:
1✔
256
        super().__init__(*args, **kwargs)
×
257

258

259
class PixelatedRegularGridParameter(PixelatedRegularGrid):
1✔
260
    """Typically for pixelated profiles"""
1✔
261

262
    def __init__(self, documentation, **kwargs_grid) -> None:
1✔
263
        self.documentation = documentation
1✔
264
        super().__init__(**kwargs_grid)
1✔
265

266

267
class IrregularGridParameter(IrregularGrid):
1✔
268
    """Typically for pixelated profiles"""
1✔
269

270
    def __init__(self, documentation, **kwargs_grid) -> None:
1✔
271
        self.documentation = documentation
1✔
272
        super().__init__(**kwargs_grid)
1✔
273
        
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc