• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / trax / 595

pending completion
595

push

travis-ci

Copybara-Service
Save dist_inputs instead of log probabilities in the replay buffer and simplify RLTask a bit using nested_* ops.

PiperOrigin-RevId: 311202787

54 of 54 new or added lines in 5 files covered. (100.0%)

2717 of 11045 relevant lines covered (24.6%)

0.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

20.0
/trax/rl/task.py
1
# coding=utf-8
2
# Copyright 2020 The Trax Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# Lint as: python3
17
"""Classes for defining RL tasks in Trax."""
1✔
18

19
import collections
1✔
20
import os
1✔
21

22
import gin
1✔
23
import gym
1✔
24
import numpy as np
1✔
25

26
from trax import math
1✔
27
from trax.supervised import trainer_lib
1✔
28

29

30
class _TimeStep(object):
1✔
31
  """A single step of interaction with a RL environment.
32

33
  TimeStep stores a single step in the trajectory of an RL run:
34
  * observation (same as observation) at the beginning of the step
35
  * action that was takes (or None if none taken yet)
36
  * reward gotten when the action was taken (or None if action wasn't taken)
37
  * log-probability of the action taken (or None if not specified)
38
  * discounted return from that state (includes the reward from this step)
39
  """
40

41
  def __init__(self, observation, action=None, reward=None, dist_inputs=None):
1✔
42
    self.observation = observation
×
43
    self.action = action
×
44
    self.reward = reward
×
45
    self.dist_inputs = dist_inputs
×
46
    self.discounted_return = None
×
47

48

49
# Tuple for representing trajectories and batches of them in numpy; immutable.
50
TrajectoryNp = collections.namedtuple('TrajectoryNp', [
1✔
51
    'observations',
52
    'actions',
53
    'dist_inputs',
54
    'rewards',
55
    'returns',
56
    'mask',
57
])
58

59

60
class Trajectory(object):
1✔
61
  """A trajectory of interactions with a RL environment.
62

63
  Trajectories are created when interacting with a RL environment. They can
64
  be prolonged and sliced and when completed, allow to re-calculate returns.
65
  """
66

67
  def __init__(self, observation):
1✔
68
    # TODO(lukaszkaiser): add support for saving and loading trajectories,
69
    # reuse code from base_trainer.dump_trajectories and related functions.
70
    if observation is not None:
×
71
      self._timesteps = [_TimeStep(observation)]
×
72

73
  def __len__(self):
1✔
74
    return len(self._timesteps)
×
75

76
  def __str__(self):
1✔
77
    return str([(ts.observation, ts.action, ts.reward)
×
78
                for ts in self._timesteps])
79

80
  def __repr__(self):
81
    return repr([(ts.observation, ts.action, ts.reward)
82
                 for ts in self._timesteps])
83

84
  def __getitem__(self, key):
1✔
85
    t = Trajectory(None)
×
86
    t._timesteps = self._timesteps[key]  # pylint: disable=protected-access
×
87
    return t
×
88

89
  @property
1✔
90
  def timesteps(self):
91
    return self._timesteps
×
92

93
  @property
1✔
94
  def total_return(self):
95
    """Sum of all rewards in this trajectory."""
96
    return sum([t.reward or 0.0 for t in self._timesteps])
×
97

98
  @property
1✔
99
  def last_observation(self):
100
    """Return the last observation in this trajectory."""
101
    last_timestep = self._timesteps[-1]
×
102
    return last_timestep.observation
×
103

104
  def extend(self, action, dist_inputs, reward, new_observation):
1✔
105
    """Take action in the last state, getting reward and going to new state."""
106
    last_timestep = self._timesteps[-1]
×
107
    last_timestep.action = action
×
108
    last_timestep.dist_inputs = dist_inputs
×
109
    last_timestep.reward = reward
×
110
    new_timestep = _TimeStep(new_observation)
×
111
    self._timesteps.append(new_timestep)
×
112

113
  def calculate_returns(self, gamma):
1✔
114
    """Calculate discounted returns."""
115
    ret = 0.0
×
116
    for timestep in reversed(self._timesteps):
×
117
      cur_reward = timestep.reward or 0.0
×
118
      ret = gamma * ret + cur_reward
×
119
      timestep.discounted_return = ret
×
120

121
  def _default_timestep_to_np(self, ts):
1✔
122
    """Default way to convert timestep to numpy."""
123
    return math.nested_map(np.array, (
×
124
        ts.observation,
125
        ts.action,
126
        ts.dist_inputs,
127
        ts.reward,
128
        ts.discounted_return,
129
    ))
130

131
  def to_np(self, timestep_to_np=None):
1✔
132
    """Create a tuple of numpy arrays from a given trajectory."""
133
    observations, actions, dist_inputs, rewards, returns, mask = (
×
134
        [], [], [], [], [], []
135
    )
136
    timestep_to_np = timestep_to_np or self._default_timestep_to_np
×
137
    for timestep in self._timesteps:
×
138
      if timestep.action is None:
×
139
        obs = timestep_to_np(timestep)[0]
×
140
        observations.append(obs)
×
141
      else:
142
        (obs, act, dinp, rew, ret) = timestep_to_np(timestep)
×
143
        observations.append(obs)
×
144
        actions.append(act)
×
145
        dist_inputs.append(dinp)
×
146
        rewards.append(rew)
×
147
        returns.append(ret)
×
148
        mask.append(1.0)
×
149

150
    def stack(x):
×
151
      if not x:
×
152
        return None
×
153
      return math.nested_stack(x)
×
154

155
    return TrajectoryNp(*map(stack, (
×
156
        observations, actions, dist_inputs, rewards, returns, mask
157
    )))
158

159

160
def play(env, policy, dm_suite=False, max_steps=None):
1✔
161
  """Play an episode in env taking actions according to the given policy.
162

163
  Environment is first reset and an from then on, a game proceeds. At each
164
  step, the policy is asked to choose an action and the environment moves
165
  forward. A Trajectory is created in that way and returns when the episode
166
  finished, which is either when env returns `done` or max_steps is reached.
167

168
  Args:
169
    env: the environment to play in, conforming to gym.Env or
170
      DeepMind suite interfaces.
171
    policy: a function taking a Trajectory and returning a pair consisting
172
      of an action (int or float) and the confidence in that action (float,
173
      defined as the log of the probability of taking that action).
174
    dm_suite: whether we are using the DeepMind suite or the gym interface
175
    max_steps: for how many steps to play.
176

177
  Returns:
178
    a completed trajectory that was just played.
179
  """
180
  terminal = False
×
181
  cur_step = 0
×
182
  if dm_suite:
×
183
    cur_trajectory = Trajectory(env.reset().observation)
×
184
    while not terminal and (max_steps is None or cur_step < max_steps):
×
185
      action, dist_inputs = policy(cur_trajectory)
×
186
      observation = env.step(action)
×
187
      cur_trajectory.extend(action, dist_inputs,
×
188
                            observation.reward,
189
                            observation.observation)
190
      cur_step += 1
×
191
      terminal = observation.step_type.last()
×
192
  else:
193
    cur_trajectory = Trajectory(env.reset())
×
194
    while not terminal and (max_steps is None or cur_step < max_steps):
×
195
      action, dist_inputs = policy(cur_trajectory)
×
196
      observation, reward, terminal, _ = env.step(action)
×
197
      cur_trajectory.extend(action, dist_inputs, reward, observation)
×
198
      cur_step += 1
×
199
  return cur_trajectory
×
200

201

202
def _zero_pad(x, pad, axis):
1✔
203
  """Helper for np.pad with 0s for single-axis case."""
204
  pad_widths = [(0, 0)] * len(x.shape)
×
205
  pad_widths[axis] = pad  # Padding on axis.
×
206
  return np.pad(x, pad_widths, mode='constant',
×
207
                constant_values=x.dtype.type(0))
208

209

210
def _random_policy(action_space):
1✔
211
  return lambda _: (action_space.sample(), None)
×
212

213

214
def _sample_proportionally(inputs, weights):
1✔
215
  """Sample an element from the inputs list proportionally to weights.
216

217
  Args:
218
    inputs: a list, we will return one element of this list.
219
    weights: a list of numbers of the same length as inputs; we will sample
220
      the k-th input with probability weights[k] / sum(weights).
221

222
  Returns:
223
    an element from inputs.
224
  """
225
  l = len(inputs)
×
226
  if l != len(weights):
×
227
    raise ValueError(f'Inputs and weights must have the same length, but do not'
×
228
                     f': {l} != {len(weights)}')
229
  weights_sum = float(sum(weights))
×
230
  norm_weights = [w / weights_sum for w in weights]
×
231
  idx = np.random.choice(l, p=norm_weights)
×
232
  return inputs[int(idx)]
×
233

234

235
@gin.configurable()
1✔
236
class RLTask:
237
  """A RL task: environment and a collection of trajectories."""
238

239
  def __init__(self, env=gin.REQUIRED, initial_trajectories=1, gamma=0.99,
1✔
240
               dm_suite=False, max_steps=None,
241
               timestep_to_np=None, num_stacked_frames=1,
242
               n_replay_epochs=1):
243
    r"""Configures a RL task.
244

245
    Args:
246
      env: Environment confirming to the gym.Env interface or a string,
247
        in which case `gym.make` will be called on this string to create an env.
248
      initial_trajectories: either a dict or list of Trajectories to use
249
        at start or an int, in which case that many trajectories are
250
        collected using a random policy to play in env.
251
      gamma: float: discount factor for calculating returns.
252
      dm_suite: whether we are using the DeepMind suite or the gym interface
253
      max_steps: Optional int: stop all trajectories at that many steps.
254
      timestep_to_np: a function that turns a timestep into a numpy array
255
        (ie., a tensor); if None, we just use the state of the timestep to
256
        represent it, but other representations (such as embeddings that include
257
        actions or serialized representations) can be passed here.
258
      num_stacked_frames: the number of stacked frames for Atari.
259
      n_replay_epochs: the size of the replay buffer expressed in epochs.
260
    """
261
    if isinstance(env, str):
×
262
      self._env_name = env
×
263
      if dm_suite:
×
264
        env = environments.load_from_settings(
×
265
            platform='atari',
266
            settings={
267
                'levelName': env,
268
                'interleaved_pixels': True,
269
                'zero_indexed_actions': True
270
            })
271
        env = atari_wrapper.AtariWrapper(environment=env,
×
272
                                         num_stacked_frames=num_stacked_frames)
273
      else:
274
        env = gym.make(env)
×
275
    else:
276
      self._env_name = type(env).__name__
×
277
    self._env = env
×
278
    self._dm_suite = dm_suite
×
279
    self._max_steps = max_steps
×
280
    self._gamma = gamma
×
281
    self._initial_trajectories = initial_trajectories
×
282
    # TODO(lukaszkaiser): find a better way to pass initial trajectories,
283
    # whether they are an explicit list, a file, or a number of random ones.
284
    if isinstance(initial_trajectories, int):
×
285
      if self._initial_trajectories > 0:
×
286
        initial_trajectories = [
×
287
            self.play(_random_policy(self.action_space))
288
            for _ in range(initial_trajectories)
289
        ]
290
      else:
291
        initial_trajectories = [
×
292
            # Whatever we gather here is intended to be removed
293
            # in PolicyTrainer. Here we just gather some example inputs.
294
            self.play(_random_policy(self.action_space))
295
        ]
296

297
    if isinstance(initial_trajectories, list):
×
298
      initial_trajectories = {0: initial_trajectories}
×
299
    self._timestep_to_np = timestep_to_np
×
300
    # Stored trajectories are indexed by epoch and within each epoch they
301
    # are stored in the order of generation so we can implement replay buffers.
302
    # TODO(lukaszkaiser): use dump_trajectories from BaseTrainer to allow
303
    # saving and reading trajectories from disk.
304
    self._trajectories = collections.defaultdict(list)
×
305
    self._trajectories.update(initial_trajectories)
×
306
    # When we repeatedly save, trajectories for many epochs do not change, so
307
    # we don't need to save them again. This keeps track which are unchanged.
308
    self._saved_epochs_unchanged = []
×
309
    self._n_replay_epochs = n_replay_epochs
×
310
    self._n_trajectories = 0
×
311
    self._n_interactions = 0
×
312

313
  @property
1✔
314
  def env(self):
315
    return self._env
×
316

317
  @property
1✔
318
  def env_name(self):
319
    return self._env_name
×
320

321
  @property
1✔
322
  def max_steps(self):
323
    return self._max_steps
×
324

325
  @property
1✔
326
  def gamma(self):
327
    return self._gamma
×
328

329
  @property
1✔
330
  def action_space(self):
331
    if self._dm_suite:
×
332
      return gym.spaces.Discrete(self._env.action_spec().num_values)
×
333
    else:
334
      return self._env.action_space
×
335

336
  @property
1✔
337
  def observation_space(self):
338
    """Returns the env's observation space in a Gym interface."""
339
    if self._dm_suite:
×
340
      return gym.spaces.Box(
×
341
          shape=self._env.observation_spec().shape,
342
          dtype=self._env.observation_spec().dtype,
343
          low=float('-inf'),
344
          high=float('+inf'),
345
      )
346
    else:
347
      return self._env.observation_space
×
348

349
  @property
1✔
350
  def trajectories(self):
351
    return self._trajectories
×
352

353
  @property
1✔
354
  def timestep_to_np(self):
355
    return self._timestep_to_np
×
356

357
  @timestep_to_np.setter
1✔
358
  def timestep_to_np(self, ts):
359
    self._timestep_to_np = ts
×
360

361
  def _epoch_filename(self, base_filename, epoch):
1✔
362
    """Helper function: file name for saving the given epoch."""
363
    # If base is /foo/task.pkl, we save epoch 1 under /foo/task_epoch1.pkl.
364
    filename, ext = os.path.splitext(base_filename)
×
365
    return filename + '_epoch' + str(epoch) + ext
×
366

367
  def set_n_replay_epochs(self, n_replay_epochs):
1✔
368
    self._n_replay_epochs = n_replay_epochs
×
369

370
  def init_from_file(self, file_name):
1✔
371
    """Initialize this task from file."""
372
    dictionary = trainer_lib.unpickle_from_file(file_name, gzip=False)
×
373
    self._n_trajectories = dictionary['n_trajectories']
×
374
    self._n_interactions = dictionary['n_interactions']
×
375
    self._max_steps = dictionary['max_steps']
×
376
    self._gamma = dictionary['gamma']
×
377
    epochs_to_load = dictionary['all_epochs'][-self._n_replay_epochs:]
×
378

379
    for epoch in epochs_to_load:
×
380
      trajectories = trainer_lib.unpickle_from_file(
×
381
          self._epoch_filename(file_name, epoch), gzip=True)
382
      self._trajectories[epoch] = trajectories
×
383
    self._saved_epochs_unchanged = epochs_to_load
×
384

385
  def save_to_file(self, file_name):
1✔
386
    """Save this task to file."""
387
    # Save trajectories from new epochs first.
388
    epochs_to_save = [e for e in self._trajectories.keys()
×
389
                      if e not in self._saved_epochs_unchanged]
390
    for epoch in epochs_to_save:
×
391
      trainer_lib.pickle_to_file(self._trajectories[epoch],
×
392
                                 self._epoch_filename(file_name, epoch),
393
                                 gzip=True)
394
    # Now save the list of epochs (so the trajectories are already there,
395
    # even in case of preemption).
396
    dictionary = {'n_interactions': self._n_interactions,
×
397
                  'n_trajectories': self._n_trajectories,
398
                  'max_steps': self._max_steps,
399
                  'gamma': self._gamma,
400
                  'all_epochs': list(self._trajectories.keys())}
401
    trainer_lib.pickle_to_file(dictionary, file_name, gzip=False)
×
402

403
  def play(self, policy, max_steps=None):
1✔
404
    """Play an episode in env taking actions according to the given policy."""
405
    if max_steps is None:
×
406
      max_steps = self._max_steps
×
407
    cur_trajectory = play(self._env, policy, self._dm_suite, max_steps)
×
408
    cur_trajectory.calculate_returns(self._gamma)
×
409
    return cur_trajectory
×
410

411
  def collect_trajectories(self, policy, n, epoch_id=1):
1✔
412
    """Collect n trajectories in env playing the given policy."""
413
    new_trajectories = [self.play(policy) for _ in range(n)]
×
414
    self._trajectories[epoch_id].extend(new_trajectories)
×
415
    # Mark that epoch epoch_id has changed.
416
    if epoch_id in self._saved_epochs_unchanged:
×
417
      self._saved_epochs_unchanged = [e for e in self._saved_epochs_unchanged
×
418
                                      if e != epoch_id]
419
    # Calculate returns.
420
    returns = [t.total_return for t in new_trajectories]
×
421

422
    # Remove epochs not intended to be in the buffer
423
    current_trajectories = {
×
424
        key: value for key, value in self._trajectories.items() \
425
          if key >= epoch_id - self._n_replay_epochs}
426
    self._trajectories = collections.defaultdict(list)
×
427
    self._trajectories.update(current_trajectories)
×
428

429
    self._n_trajectories += n
×
430
    self._n_interactions += sum([len(traj) for traj in new_trajectories])
×
431

432
    return sum(returns) / float(len(returns))
×
433

434
  def n_trajectories(self, epochs=None):
1✔
435
    # TODO(henrykm) support selection of epochs if really necessary (will
436
    # require a dump of a list of lengths in save_to_file
437
    del epochs
×
438
    return self._n_trajectories
×
439

440
  def n_interactions(self, epochs=None):
1✔
441
    # TODO(henrykm) support selection of epochs if really necessary (will
442
    # require a dump of a list of lengths in save_to_file
443
    del epochs
×
444
    return self._n_interactions
×
445

446
  def remove_epoch(self, epoch):
1✔
447
    """Useful when we need to remove an unwanted trajectory."""
448
    if epoch in self._trajectories.keys():
×
449
      self._trajectories.pop(epoch)
×
450

451
  def trajectory_stream(self, epochs=None, max_slice_length=None,
1✔
452
                        include_final_state=False,
453
                        sample_trajectories_uniformly=False):
454
    """Return a stream of random trajectory slices from the specified epochs.
455

456
    Args:
457
      epochs: a list of epochs to use; we use all epochs if None
458
      max_slice_length: maximum length of the slices of trajectories to return
459
      include_final_state: whether to include slices with the final state of
460
        the trajectory which may have no action and reward
461
      sample_trajectories_uniformly: whether to sample trajectories uniformly,
462
       or proportionally to the number of slices in each trajectory (default)
463

464
    Yields:
465
      random trajectory slices sampled uniformly from all slices of length
466
      upto max_slice_length in all specified epochs
467
    """
468
    # TODO(lukaszkaiser): add option to sample from n last trajectories.
469
    end_offset = 0 if include_final_state else 1
×
470
    def n_slices(t):
×
471
      """How many slices of length upto max_slice_length in a trajectory."""
472
      if not max_slice_length:
×
473
        return 1
×
474
      # A trajectory [a, b, c, end_state] will have 2 slices of length 2:
475
      # the slice [a, b] and the one [b, c], with end_offset; 3 without.
476
      return max(1, len(t) - max_slice_length + 1 - end_offset)
×
477

478
    while True:
×
479
      all_epochs = list(self._trajectories.keys())
×
480
      max_epoch = max(all_epochs) + 1
×
481
      # Bind the epoch indices to a new name so they can be recalculated every
482
      # epoch.
483
      epoch_indices = epochs or all_epochs
×
484
      epoch_indices = [
×
485
          # So -1 means "last".
486
          ep % max_epoch for ep in epoch_indices
487
      ]
488
      # Remove duplicates.
489
      epoch_indices = list(set(epoch_indices))
×
490

491
      # Sample an epoch proportionally to number of slices in each epoch.
492
      if len(epoch_indices) == 1:  # Skip this step if there's just 1 epoch.
×
493
        epoch_id = epoch_indices[0]
×
494
      else:
495
        slices_per_epoch = [sum([n_slices(t) for t in self._trajectories[ep]])
×
496
                            for ep in epoch_indices]
497
        epoch_id = _sample_proportionally(epoch_indices, slices_per_epoch)
×
498
      epoch = self._trajectories[epoch_id]
×
499

500
      # Sample a trajectory proportionally to number of slices in each one.
501
      if sample_trajectories_uniformly:
×
502
        slices_per_trajectory = [1] * len(epoch)
×
503
      else:
504
        slices_per_trajectory = [n_slices(t) for t in epoch]
×
505
      trajectory = _sample_proportionally(epoch, slices_per_trajectory)
×
506

507
      # Sample a slice from the trajectory.
508
      slice_start = np.random.randint(n_slices(trajectory))
×
509
      slice_end = slice_start + (max_slice_length or len(trajectory))
×
510
      slice_end = min(slice_end, len(trajectory) - end_offset)
×
511
      yield trajectory[slice_start:slice_end]
×
512

513
  def trajectory_batch_stream(self, batch_size, epochs=None,
1✔
514
                              max_slice_length=None,
515
                              min_slice_length=None,
516
                              include_final_state=False,
517
                              sample_trajectories_uniformly=False):
518
    """Return a stream of trajectory batches from the specified epochs.
519

520
    This function returns a stream of tuples of numpy arrays (tensors).
521
    If tensors have different lengths, they will be padded by 0.
522

523
    Args:
524
      batch_size: the size of the batches to return
525
      epochs: a list of epochs to use; we use all epochs if None
526
      max_slice_length: maximum length of the slices of trajectories to return
527
      min_slice_length: minimum length of the slices of trajectories to return
528
      include_final_state: whether to include slices with the final state of
529
        the trajectory which may have no action and reward
530
      sample_trajectories_uniformly: whether to sample trajectories uniformly,
531
       or proportionally to the number of slices in each trajectory (default)
532

533
    Yields:
534
      batches of trajectory slices sampled uniformly from all slices of length
535
      at least min_slice_length and up to max_slice_length in all specified
536
      epochs
537
    """
538
    def pad(tensor_list):
×
539
      # Replace Nones with valid tensors.
540
      not_none_tensors = [t for t in tensor_list if t is not None]
×
541
      assert not_none_tensors, 'All tensors to pad are None.'
×
542
      prototype = np.zeros_like(not_none_tensors[0])
×
543
      tensor_list = [t if t is not None else prototype for t in tensor_list]
×
544

545
      max_len = max([t.shape[0] for t in tensor_list])
×
546
      min_len = min([t.shape[0] for t in tensor_list])
×
547
      if max_len == min_len:  # No padding needed.
×
548
        return np.array(tensor_list)
×
549

550
      pad_len = 2**int(np.ceil(np.log2(max_len)))
×
551
      return np.array([_zero_pad(t, (0, pad_len - t.shape[0]), axis=0)
×
552
                       for t in tensor_list])
553
    cur_batch = []
×
554
    for t in self.trajectory_stream(
×
555
        epochs, max_slice_length,
556
        include_final_state, sample_trajectories_uniformly):
557
      # TODO(pkozakowski): Instead sample the trajectories out of those with
558
      # the minimum length.
559
      if min_slice_length is not None and len(t) < min_slice_length:
×
560
        continue
×
561

562
      cur_batch.append(t)
×
563
      if len(cur_batch) == batch_size:
×
564
        obs, act, dinp, rew, ret, mask = zip(*[
×
565
            t.to_np(self._timestep_to_np) for t in cur_batch
566
        ])
567
        # Where act, rew and ret will usually have the following shape:
568
        # [batch_size, trajectory_length-1], which we call [B, L-1].
569
        # Observations are more complex and will usuall be [B, L] + S where S
570
        # is the shape of the observation space (self.observation_space.shape).
571
        # We stop the recursion at level 1, so we pass lists of arrays into
572
        # pad().
573
        yield math.nested_map(
×
574
            pad, TrajectoryNp(obs, act, dinp, rew, ret, mask), level=1
575
        )
576
        cur_batch = []
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2024 Coveralls, Inc