• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MatthewGerber / rlai / 27471184649

13 Jun 2026 03:37PM UTC coverage: 0.0% (-79.4%) from 79.431%
27471184649

push

github

MatthewGerber
* Change machine type.
* Try xdist.

0 of 5484 relevant lines covered (0.0%)

0.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/rlai/state_value/function_approximation/models/feature_extraction.py
1
from abc import ABC, abstractmethod
×
2
from itertools import product
×
3
from typing import List, Optional, Dict, Callable, Any
×
4

5
import numpy as np
×
6

7
from rlai.core import MdpState
×
8
from rlai.docs import rl_text
×
9
from rlai.models.feature_extraction import FeatureExtractor, OneHotCategory, OneHotCategoricalFeatureInteracter
×
10

11

12
@rl_text(chapter='Feature Extractors', page=1)
×
13
class StateFeatureExtractor(FeatureExtractor, ABC):
×
14
    """
15
    Feature extractor for states.
16
    """
17

18
    def __init__(
×
19
            self,
20
            scale_features: bool
21
    ):
22
        """
23
        Initialize the extractor.
24

25
        :param scale_features: Whether to scale features.
26
        """
27

28
        self.scale_features = scale_features
×
29

30
    @abstractmethod
×
31
    def extract(
×
32
            self,
33
            states: List[MdpState],
34
            refit_scaler: bool
35
    ) -> np.ndarray:
36
        """
37
        Extract state features.
38

39
        :param states: States.
40
        :param refit_scaler: Whether to refit the feature scaler before scaling the extracted features. This is
41
        only appropriate in settings where nonstationarity is desired (e.g., during training). During evaluation, the
42
        scaler should remain fixed, which means this should be False.
43
        :return: State-feature matrix (#states, #features).
44
        """
45

46

47
class StateIndicator(ABC):
×
48
    """
49
    Abstract state indicator for one-hot state encoding. This is similar to other one-hot encoding schemes in which a
50
    single block of low-level features is hot (turned on) and all others are turned off. Here, the on/off status is
51
    determined by some property of the aggregated state (feature vector). For example, in a policy gradient setup, one
52
    might partition the feature space by high-level properties, thus partitioning the policy and enabling the use of
53
    separate control policies depending on the state of the system. Concrete derived classes determine how the
54
    indicators are specified.
55
    """
56

57
    @abstractmethod
×
58
    def __str__(
×
59
            self
60
    ) -> str:
61
        """
62
        Get string.
63

64
        :return: String.
65
        """
66

67
    @abstractmethod
×
68
    def get_range(
×
69
            self
70
    ) -> List[Any]:
71
        """
72
        Get the range (possible values) of the current indicator.
73

74
        :return: Range of values.
75
        """
76

77
    @abstractmethod
×
78
    def get_value(
×
79
            self,
80
            state_vector: np.ndarray
81
    ) -> Any:
82
        """
83
        Get the value of the current indicator for a state.
84

85
        :param state_vector: State vector.
86
        :return: Value, which must be in the range returned by `get_range`.
87
        """
88

89

90
class StateLambdaIndicator(StateIndicator):
×
91
    """
92
    Returns an indicator based on a lambda function applied to a state's feature vector.
93
    """
94

95
    def __init__(
×
96
            self,
97
            function: Callable[[np.ndarray], Any],
98
            function_range: List[Any]
99
    ):
100
        """
101
        Initialize the indicator.
102

103
        :param function: Function to apply to states.
104
        :param function_range: Range of function.
105
        """
106

107
        self.function = function
×
108
        self.function_range = function_range
×
109

110
    def __str__(
×
111
            self
112
    ) -> str:
113
        """
114
        Get string.
115

116
        :return: String.
117
        """
118

119
        return '<function>'
×
120

121
    def get_range(
×
122
            self
123
    ) -> List[Any]:
124
        """
125
        Get the range (possible values) of the current indicator.
126

127
        :return: Range of values.
128
        """
129

130
        return self.function_range
×
131

132
    def get_value(
×
133
            self,
134
            state_vector: np.ndarray
135
    ) -> Any:
136
        """
137
        Get the value of the current indicator for a state.
138

139
        :param state_vector: State vector.
140
        :return: Value, which must be in the range returned by `get_range`.
141
        """
142

143
        return self.function(state_vector)
×
144

145

146
class StateDimensionIndicator(StateIndicator, ABC):
×
147
    """
148
    Returns an indicator based on the value of a particular dimension within the state.
149
    """
150

151
    def __init__(
×
152
            self,
153
            dimension: int
154
    ):
155
        """
156
        Initialize the indicator.
157

158
        :param dimension: Dimension.
159
        """
160

161
        self.dimension = dimension
×
162

163

164
class StateDimensionSegment(StateDimensionIndicator):
×
165
    """
166
    Returns an indicator based on breakpoints across a dimension of the state.
167
    """
168

169
    @staticmethod
×
170
    def get_segments(
×
171
            dimension_breakpoints: Dict[int, List[float]]
172
    ) -> List[StateIndicator]:
173
        """
174
        Get segments for a dictionary of breakpoints
175

176
        :param dimension_breakpoints: Breakpoints keyed on dimensions with breakpoints as values.
177
        """
178

179
        return [
×
180
            StateDimensionSegment(dimension, low, high)
181
            for dimension, breakpoints in dimension_breakpoints.items()
182
            for low, high in zip([None] + breakpoints[:-1], breakpoints, strict=True)  # type: ignore[operator]
183
        ]
184

185
    def __init__(
×
186
            self,
187
            dimension: int,
188
            low: Optional[float],
189
            high: Optional[float]
190
    ):
191
        """
192
        Initialize the segment.
193

194
        :param dimension: Dimension index.
195
        :param low: Low value (inclusive) of the segment.
196
        :param high: High value (exclusive) of the segment.
197
        """
198

199
        super().__init__(dimension)
×
200

201
        self.low = low
×
202
        self.high = high
×
203

204
    def __str__(
×
205
            self
206
    ) -> str:
207
        """
208
        Get string.
209

210
        :return: String.
211
        """
212

213
        return f'd{self.dimension}:  {"(" if self.low is None else "["}{self.low}, {self.high})'
×
214

215
    def get_range(
×
216
            self
217
    ) -> List[Any]:
218
        """
219
        Get the range (possible values) of the current indicator.
220

221
        :return: Range of values.
222
        """
223

224
        return [True, False]
×
225

226
    def get_value(
×
227
            self,
228
            state_vector: np.ndarray
229
    ) -> Any:
230
        """
231
        Get the value of the current indicator for a state.
232

233
        :param state_vector: State vector.
234
        :return: Value.
235
        """
236

237
        dimension_value = float(state_vector[self.dimension])
×
238

239
        above_low = self.low is None or dimension_value >= self.low
×
240
        below_high = self.high is None or dimension_value < self.high
×
241

242
        return above_low and below_high
×
243

244

245
class StateDimensionLambda(StateDimensionIndicator):
×
246
    """
247
    Returns an indicator based on a lambda function applied to a state dimension.
248
    """
249

250
    def __init__(
×
251
            self,
252
            dimension: int,
253
            function: Callable[[float], Any],
254
            function_range: List[Any]
255
    ):
256
        """
257
        Initialize the segment.
258

259
        :param dimension: Dimension.
260
        :param function: Function to apply to values in the given dimension.
261
        :param function_range: Range of function.
262
        """
263

264
        super().__init__(dimension)
×
265

266
        self.function = function
×
267
        self.function_range = function_range
×
268

269
    def __str__(
×
270
            self
271
    ) -> str:
272
        """
273
        Get string.
274

275
        :return: String.
276
        """
277

278
        return f'd{self.dimension}:  <function>'
×
279

280
    def get_range(self) -> List[Any]:
×
281
        """
282
        Get the range (possible values) of the current indicator.
283

284
        :return: Range of values.
285
        """
286

287
        return self.function_range
×
288

289
    def get_value(
×
290
            self,
291
            state_vector: np.ndarray
292
    ) -> Any:
293
        """
294
        Get the value of the current indicator for a state.
295

296
        :param state_vector: State vector.
297
        :return: Value.
298
        """
299

300
        return self.function(float(state_vector[self.dimension]))
×
301

302

303
class OneHotStateIndicatorFeatureInteracter:
×
304
    """
305
    One-hot state indicator feature interacter.
306
    """
307

308
    def interact(
×
309
            self,
310
            state_matrix: np.ndarray,
311
            state_feature_matrix: np.ndarray,
312
            refit_scaler: bool
313
    ) -> np.ndarray:
314
        """
315
        Interact a state-feature matrix with its one-hot state-indicator encoding.
316

317
        :param state_matrix: State matrix (#obs, #state_dimensionality), from which to derive indicators.
318
        :param state_feature_matrix: State-feature matrix (#obs, #features).
319
        :param refit_scaler: Whether to refit the scaler. Only has an effect if `scale_features` was True when
320
        initializing the current interacter.
321
        :return: Interacted state-feature matrix (#obs, #features * #joint_indicators).
322
        """
323

324
        # interact feature vectors per state category, where the category indicates the joint indicator of the state.
325
        state_categories = [
×
326
            OneHotCategory(*[
327
                indicator.get_value(state_vector)
328
                for indicator in self.indicators
329
            ])
330
            for state_vector in state_matrix
331
        ]
332

333
        interacted_state_feature_matrix = self.interacter.interact(
×
334
            feature_matrix=state_feature_matrix,
335
            categorical_values=state_categories,
336
            refit_scaler=refit_scaler
337
        )
338

339
        return interacted_state_feature_matrix
×
340

341
    def __init__(
×
342
            self,
343
            indicators: List[StateIndicator],
344
            scale_features: bool
345
    ):
346
        """
347
        Initialize the interacter.
348

349
        :param indicators: State-dimension indicators.
350
        :param scale_features: Whether to scale features.
351
        """
352

353
        self.indicators = indicators
×
354

355
        self.interacter = OneHotCategoricalFeatureInteracter(
×
356
            categories=[
357
                OneHotCategory(*args)
358
                for args in product(*[
359
                    indicator.get_range()
360
                    for indicator in self.indicators
361
                ])
362
            ],
363
            scale_features=scale_features
364
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc