• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MatthewGerber / rlai / 27471184649

13 Jun 2026 03:37PM UTC coverage: 0.0% (-79.4%) from 79.431%
27471184649

push

github

MatthewGerber
* Change machine type.
* Try xdist.

0 of 5484 relevant lines covered (0.0%)

0.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/rlai/state_value/function_approximation/__init__.py
1
from argparse import ArgumentParser
×
2
from typing import List, Tuple, Optional
×
3

4
import numpy as np
×
5
from matplotlib.backends.backend_pdf import PdfPages
×
6
from numpy.random import RandomState
×
7

8
from rlai.core import MdpState, Environment
×
9
from rlai.docs import rl_text
×
10
from rlai.models.feature_extraction import StationaryFeatureScaler
×
11
from rlai.state_value import StateValueEstimator, ValueEstimator
×
12
from rlai.state_value.function_approximation.models import StateFunctionApproximationModel
×
13
from rlai.state_value.function_approximation.models.feature_extraction import StateFeatureExtractor
×
14
from rlai.utils import parse_arguments, load_class
×
15

16

17
@rl_text(chapter='Value Estimation', page=195)
×
18
class ApproximateValueEstimator(ValueEstimator):
×
19
    """
20
    Approximate value estimator.
21
    """
22

23
    def update(
×
24
            self,
25
            value: float,
26
            weight: Optional[float] = None
27
    ):
28
        """
29
        Update the value estimate.
30

31
        :param value: New value.
32
        :param weight: Weight.
33
        """
34

35
        self.estimator.add_sample(self.state, value, weight)
×
36
        self.estimator.update_count += 1
×
37

38
    def get_value(
×
39
            self
40
    ) -> float:
41
        """
42
        Get current estimated value.
43

44
        :return: Value.
45
        """
46

47
        return self.estimator.evaluate(self.state)
×
48

49
    def __init__(
×
50
            self,
51
            estimator: 'ApproximateStateValueEstimator',
52
            state: MdpState
53
    ):
54
        """
55
        Initialize the estimator.
56

57
        :param estimator: State-action value estimator.
58
        :param state: State.
59
        """
60

61
        self.estimator = estimator
×
62
        self.state = state
×
63

64

65
@rl_text(chapter='Value Estimation', page=195)
×
66
class ApproximateStateValueEstimator(StateValueEstimator):
×
67
    """
68
    Approximate state-value estimator.
69
    """
70

71
    @classmethod
×
72
    def get_argument_parser(
×
73
            cls
74
    ) -> ArgumentParser:
75
        """
76
        Get argument parser.
77

78
        :return: Argument parser.
79
        """
80

81
        parser = ArgumentParser(
×
82
            prog=f'{cls.__module__}.{cls.__name__}',
83
            parents=[super().get_argument_parser()],
84
            allow_abbrev=False,
85
            add_help=False
86
        )
87

88
        parser.add_argument(
×
89
            '--function-approximation-model',
90
            type=str,
91
            help='Fully-qualified type name of function approximation model.'
92
        )
93

94
        parser.add_argument(
×
95
            '--feature-extractor',
96
            type=str,
97
            help='Fully-qualified type name of feature extractor.'
98
        )
99

100
        parser.add_argument(
×
101
            '--scale-outcomes',
102
            action='store_true',
103
            help='Whether to scale (standardize) outcomes before fitting the function approximation model.'
104
        )
105

106
        return parser
×
107

108
    @classmethod
×
109
    def init_from_arguments(
×
110
            cls,
111
            args: List[str],
112
            random_state: RandomState,
113
            environment: Environment
114
    ) -> Tuple[StateValueEstimator, List[str]]:
115
        """
116
        Initialize a state-value estimator from arguments.
117

118
        :param args: Arguments.
119
        :param random_state: Random state.
120
        :param environment: Environment.
121
        :return: 2-tuple of a state-value estimator and a list of unparsed arguments.
122
        """
123

124
        parsed_args, unparsed_args = parse_arguments(cls, args)
×
125

126
        # load feature extractor
127
        feature_extractor_class = load_class(parsed_args.feature_extractor)
×
128
        fex, unparsed_args = feature_extractor_class.init_from_arguments(
×
129
            args=unparsed_args,
130
            environment=environment
131
        )
132
        del parsed_args.feature_extractor
×
133

134
        # load model
135
        model_class = load_class(parsed_args.function_approximation_model)
×
136
        model, unparsed_args = model_class.init_from_arguments(
×
137
            args=unparsed_args,
138
            random_state=random_state,
139
            fit_intercept=not fex.extracts_intercept()
140
        )
141
        del parsed_args.function_approximation_model
×
142

143
        # initialize estimator
144
        estimator = cls(
×
145
            model=model,
146
            feature_extractor=fex,
147
            **vars(parsed_args)
148
        )
149

150
        return estimator, unparsed_args
×
151

152
    def __init__(
×
153
            self,
154
            model: StateFunctionApproximationModel,
155
            feature_extractor: StateFeatureExtractor,
156
            scale_outcomes: bool
157
    ):
158
        """
159
        Initialize the estimator.
160

161
        :param model: Model.
162
        :param feature_extractor: Feature extractor.
163
        :param scale_outcomes: Whether to scale state-value outcomes before fitting the estimator model.
164
        """
165

166
        super().__init__()
×
167

168
        self.model = model
×
169
        self.feature_extractor = feature_extractor
×
170
        self.scale_outcomes = scale_outcomes
×
171

172
        self.experience_states: List[MdpState] = []
×
173
        self.experience_values: List[float] = []
×
174
        self.weights: Optional[np.ndarray] = None
×
175
        self.experience_pending: bool = False
×
176
        self.value_scaler = StationaryFeatureScaler()
×
177

178
    def add_sample(
×
179
            self,
180
            state: MdpState,
181
            value: float,
182
            weight: Optional[float]
183
    ):
184
        """
185
        Add a sample of experience to the estimator. The collection of samples will be used to fit the function
186
        approximation model when `improve` is called.
187

188
        :param state: State.
189
        :param value: Value.
190
        :param weight: Weight.
191
        """
192

193
        self.experience_states.append(state)
×
194
        self.experience_values.append(value)
×
195

196
        if weight is not None:
×
197
            if self.weights is None:
×
198
                self.weights = np.array([weight])
×
199
            else:
200
                self.weights = np.append(self.weights, [weight], axis=0)
×
201

202
        self.experience_pending = True
×
203

204
    def improve(
×
205
            self
206
    ):
207
        """
208
        Improve an agent's policy using the current sample of experience collected through calls to `add_sample`.
209

210
        :return: Number of states improved.
211
        """
212

213
        # if we have pending experience, then fit the model and reset the data.
214
        if self.experience_pending:
×
215

216
            state_feature_matrix = self.extract_features(self.experience_states, True)
×
217

218
            outcomes = np.array(self.experience_values)
×
219
            if self.scale_outcomes:
×
220
                outcomes = self.value_scaler.scale_features(outcomes.reshape(-1, 1), True).flatten()
×
221

222
            # feature extractors may return a matrix with no columns if extraction was not possible
223
            if state_feature_matrix.shape[1] > 0:
×
224
                self.model.fit(
×
225
                    feature_matrix=state_feature_matrix,
226
                    outcomes=outcomes,
227
                    weights=self.weights
228
                )
229

230
            self.experience_states.clear()
×
231
            self.experience_values.clear()
×
232
            self.weights = None
×
233
            self.experience_pending = False
×
234

235
    def evaluate(
×
236
            self,
237
            state: MdpState
238
    ) -> float:
239
        """
240
        Evaluate the estimator's function approximation model at a state.
241

242
        :param state: State.
243
        :return: Estimate.
244
        """
245

246
        # extract feature matrix
247
        state_feature_matrix = self.extract_features([state], False)
×
248

249
        # feature extractors may return a matrix with no columns if extraction was not possible
250
        if state_feature_matrix.shape[1] == 0:  # pragma no cover
251
            return 0.0
252

253
        state_values = self.model.evaluate(state_feature_matrix)
×
254

255
        # invert the state value back to the original space if we're scaling
256
        if self.scale_outcomes:
×
257
            state_values = self.value_scaler.invert_scaled_features(state_values.reshape((-1, 1))).flatten()
×
258

259
        assert len(state_values) == 1
×
260

261
        return float(state_values[0])
×
262

263
    def extract_features(
×
264
            self,
265
            states: List[MdpState],
266
            refit_scaler: bool
267
    ) -> np.ndarray:
268
        """
269
        Extract features for states.
270

271
        :param states: States.
272
        :param refit_scaler: Whether to refit the feature scaler before scaling the extracted features. This is
273
        only appropriate in settings where nonstationarity is desired (e.g., during training). During evaluation, the
274
        scaler should remain fixed, which means this should be False.
275
        :return: State-feature matrix (#states, #features).
276
        """
277

278
        return self.feature_extractor.extract(states, refit_scaler)
×
279

280
    def plot(
×
281
            self,
282
            pdf: Optional[PdfPages]
283
    ):
284
        """
285
        Plot the current estimator.
286

287
        :param pdf: PDF to plot to, or None to show directly.
288
        """
289

290
        self.model.plot(True, pdf)
×
291

292
    def reset_for_new_run(
×
293
            self,
294
            state: MdpState
295
    ):
296
        """
297
        Reset for new run.
298
        """
299

300
        self.feature_extractor.reset_for_new_run(state)
×
301

302
    def __getitem__(
×
303
            self,
304
            state: MdpState
305
    ) -> ApproximateValueEstimator:
306
        """
307
        Get the value estimator for a state.
308

309
        :param state: State.
310
        :return: Value estimator.
311
        """
312

313
        return ApproximateValueEstimator(self, state)
×
314

315
    def __len__(
×
316
            self
317
    ) -> int:
318
        """
319
        Get number of states defined by the estimator.
320

321
        :return: Number of states.
322
        """
323

324
        # a bit of a hack, as we don't actually track the number of states.
325
        return 1
×
326

327
    def __contains__(
×
328
            self,
329
            state: MdpState
330
    ) -> bool:
331
        """
332
        Check whether a state is defined by the estimator.
333

334
        :param state: State.
335
        :return: True if defined and False otherwise.
336
        """
337

338
        return True
×
339

340
    def __eq__(
×
341
            self,
342
            other: object
343
    ) -> bool:
344
        """
345
        Check whether the estimator equals another.
346

347
        :param other: Other estimator.
348
        :return: True if equal and False otherwise.
349
        """
350

351
        if not isinstance(other, ApproximateStateValueEstimator):
×
352
            raise ValueError(f'Expected {ApproximateStateValueEstimator}')
×
353

354
        return self.model == other.model
×
355

356
    def __ne__(
×
357
            self,
358
            other: object
359
    ) -> bool:
360
        """
361
        Check whether the estimator does not equal another.
362

363
        :param other: Other estimator.
364
        :return: True if not equal and False otherwise.
365
        """
366

367
        return not (self == other)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc