• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / trax / 159

pending completion
159

Pull #544

travis-ci

web-flow
Merge 8efdd57ec into 4d99ad496
Pull Request #544: Log L2 norm of gradient while training.

12 of 12 new or added lines in 4 files covered. (100.0%)

2690 of 10870 relevant lines covered (24.75%)

0.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

16.86
/trax/supervised/trainer_lib.py
1
# coding=utf-8
2
# Copyright 2020 The Trax Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Trax main training functions."""
1✔
17

18
from __future__ import absolute_import
1✔
19
from __future__ import division
1✔
20
from __future__ import print_function
1✔
21

22
import collections
1✔
23
import functools
1✔
24
import gzip as gzip_lib
1✔
25
import itertools
1✔
26
import os
1✔
27
import pickle
1✔
28
import random
1✔
29
import sys
1✔
30
import time
1✔
31

32
from absl import logging
1✔
33

34
import gin
1✔
35

36
import jax
1✔
37
import numpy
1✔
38
import six
1✔
39
import tensorflow.compat.v2 as tf
1✔
40
from trax import history as trax_history
1✔
41
from trax import jaxboard
1✔
42
from trax import layers as tl
1✔
43
from trax import lr_schedules as lr
1✔
44
from trax import math
1✔
45
from trax import optimizers as trax_opt
1✔
46
from trax.math import numpy as np
1✔
47
from trax.math import random as jax_random
1✔
48
from trax.shapes import ShapeDtype
1✔
49
from trax.supervised import inputs as trax_inputs
1✔
50

51

52
# TODO(afrozm): Maybe flatten everything from OptState into TrainerState.
53
TrainerState = collections.namedtuple('_TrainerState', [
1✔
54
    'step',         # Current training step number.
55
    'opt_state',    # OptState.
56
    'history',      # trax.history.History.
57
    'model_state',  # Auxilliary state of the model.
58
])
59

60

61
OptState = collections.namedtuple('_OptState', [
1✔
62
    'weights',     # Model weights.
63
    'slots',       # Per-parameter optimizer state, e.g. gradient moments.
64
    'opt_params',  # Optimizer (hyper)parameters, e.g. learning rate, momentum.
65
])
66

67

68
_DEFAULT_METRICS = {
1✔
69
    'loss': tl.CrossEntropyLoss(),
70
    'accuracy': tl.AccuracyScalar(),
71
    'sequence_accuracy': tl.SequenceAccuracyScalar(),
72
    'neg_log_perplexity': tl.Serial(tl.CrossEntropyLoss(), tl.Negate()),
73
    'weights_per_batch_per_core': tl.SumOfWeights(),
74
}
75

76

77
class Trainer(object):
1✔
78
  """Trax trainer.
79

80
  A trainer allows to make training steps, train for full epochs,
81
  save the training state and access evaluation data.
82
  """
83

84
  def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs,
1✔
85
               output_dir=None, random_seed=None, n_devices=None,
86
               checkpoints_at=None, should_save_checkpoints=True,
87
               should_write_summaries=True, nontrainable_param_map=None,
88
               id_to_mask=None,
89
               metrics=None, checkpoint_highest=None, checkpoint_lowest=None):
90

91
    self._is_chief, self._n_devices, rng = (
×
92
        self._init_host_and_devices(n_devices, random_seed))
93
    self._should_save_checkpoints = should_save_checkpoints and self._is_chief
×
94
    self._checkpoints_at = checkpoints_at or []
×
95
    self._should_write_summaries = should_write_summaries
×
96
    if not output_dir:
×
97
      self._should_save_checkpoints = False
×
98
      self._should_write_summaries = False
×
99
    self._checkpoint_highest = checkpoint_highest
×
100
    self._checkpoint_lowest = checkpoint_lowest
×
101
    self._id_to_mask = id_to_mask
×
102
    self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
×
103
    # Inputs is either an Inputs instance or a function that returns it.
104
    self._inputs = inputs
×
105
    if callable(inputs):  # If we pass a function, e.g., through gin, call it.
×
106
      self._inputs = inputs()
×
107
    # Mask id_to_mask and add weights if needed.
108
    # TODO(lukaszkaiser, jonni): move this out of Trainer to input processing.
109
    self._inputs = _add_weights_and_mask(self._inputs, id_to_mask)
×
110
    # Initialize the learning rate to a dummy value. It will be set in reset().
111
    opt = optimizer(learning_rate=0.0)
×
112

113
    # Setup the model.
114
    model_train = model(mode='train')
×
115
    model_predict_eval = model(mode='eval')
×
116

117
    # Setup state.
118
    rng, init_rng = jax_random.split(rng)
×
119
    self._rngs = np.stack(jax_random.split(rng, self._n_devices))
×
120

121
    def new_opt_state_and_model_state(shape_dtype, rng):
×
122
      """Returns optimizer and model states suitable for training a model."""
123
      # Combine inputs and targets on the stack.
124
      shapes, dtypes = shape_dtype
×
125
      input_signature = tuple(ShapeDtype(s, d)
×
126
                              for (s, d) in zip(shapes, dtypes))
127
      # We need to create a new model instance and not reuse `model_train` here,
128
      # because `m.initialize` puts cached parameter values in `m` and hence the
129
      # next call of `m.initialize` will give wrong results.
130
      m = tl.Serial(model(mode='train'), loss_fn)
×
131
      m._set_rng_recursive(rng)  # pylint: disable=protected-access
×
132
      weights, state = m.init(input_signature)
×
133
      (slots, opt_params) = opt.tree_init(weights)
×
134
      return (OptState(weights, slots, opt_params), state)
×
135

136
    if _is_jit_init():
×
137
      # JIT parameter initialization to avoid memory fragmentation
138
      new_opt_state_and_model_state = math.jit(new_opt_state_and_model_state,
×
139
                                               static_argnums=(0,))
140
    self._new_opt_state_and_model_state = (
×
141
        lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
142
            self._inputs.example_shape_dtype, init_rng))
143

144
    # Arrange and initialize metrics layers.
145
    self._metrics = list(sorted(self._metrics_dict.keys()))
×
146
    metrics_layers = [self._metrics_dict[m] for m in self._metrics]
×
147
    metrics_in_parallel = tl.Branch(*metrics_layers)
×
148
    metrics_in_parallel._set_rng_recursive(init_rng)  # pylint: disable=protected-access
×
149
    example_signature = tuple(
×
150
        ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)
151
    )
152
    model_predict_eval.init(example_signature)
×
153
    output_signature = model_predict_eval.output_signature(example_signature)
×
154
    m_weights, m_state = metrics_in_parallel.init(output_signature)
×
155
    self._metrics_weights = self._for_n_devices(m_weights)
×
156
    self._metrics_state = self._for_n_devices(m_state)
×
157

158
    # Jit model_predict and update so they're fast.
159
    self._jit_eval = _jit_predict_fn(
×
160
        model_predict_eval, metrics_in_parallel, self._n_devices)
161
    self._jit_update_fn = _jit_update_fn(
×
162
        model_train, loss_fn, opt, self._n_devices)
163

164
    self._model_train = model_train
×
165
    self._model_predict_eval = model_predict_eval
×
166
    self._loss_fn = loss_fn
×
167
    # TODO(pkozakowski): "Learning rate schedules" are currently able to control
168
    # control all optimizer parameters and model state, so let's rename them
169
    # accordingly.
170
    self._lr_schedule = lr_schedule
×
171

172
    if nontrainable_param_map is None:
×
173
      nontrainable_param_map = {}
×
174
    self._nontrainable_param_map = nontrainable_param_map
×
175

176
    # Those fields will be set in reset().
177
    self._output_dir = None
×
178
    self._train_sw = None
×
179
    self._eval_sw = None
×
180
    self._history = None
×
181
    self._lr_fn = None
×
182
    self._opt_state = None
×
183
    self._step = None
×
184
    self._model_state = None
×
185
    self.reset(output_dir)
×
186

187
  @property
1✔
188
  def n_devices(self):
189
    return self._n_devices
×
190

191
  @property
1✔
192
  def step(self):
193
    return self._step
×
194

195
  @property
1✔
196
  def model_weights(self):
197
    # Currently we need to pick [0] as we ignore loss weights (empty).
198
    weights = self._opt_state.weights[0]
×
199
    if self.n_devices > 1:
×
200
      unreplicate = lambda x: x[0]
×
201
      weights = math.nested_map(unreplicate, weights)
×
202
    return weights
×
203

204
  @model_weights.setter
1✔
205
  def model_weights(self, weights):
206
    new_model_weights = self._for_n_devices(weights)
×
207
    if isinstance(self._opt_state.weights, list):
×
208
      self._opt_state.weights[0] = new_model_weights
×
209
    else:  # weights are a tuple, need to re-create
210
      new_weights = [new_model_weights] + list(self._opt_state.weights[1:])
×
211
      self._opt_state = self._opt_state._replace(weights=new_weights)
×
212

213
  @property
1✔
214
  def model_state(self):
215
    # Currently we need to pick [0] as we ignore loss state (empty).
216
    state = self._model_state[0]
×
217
    if self.n_devices > 1:
×
218
      unreplicate = lambda x: x[0]
×
219
      state = math.nested_map(unreplicate, state)
×
220
    return state
×
221

222
  @model_state.setter
1✔
223
  def model_state(self, state):
224
    new_model_state = self._for_n_devices(state)
×
225
    if isinstance(self._model_state, list):
×
226
      self._model_state[0] = new_model_state
×
227
    else:  # weights are a tuple, need to re-create
228
      self._model_state = [new_model_state] + list(self._model_state[1:])
×
229

230
  @property
1✔
231
  def state(self):
232
    return TrainerState(
×
233
        opt_state=self._opt_state, step=self._step, history=self._history,
234
        model_state=self._model_state)
235

236
  @property
1✔
237
  def nontrainable_params(self):
238
    # TODO(afrozm): Give further thought to this name.
239
    # TODO(lukaszkaiser): it makes no sense to use an accelerator (e.g. TPU)
240
    # in op-by-op mode just to compute the learning rate. However, there
241
    # should be a cleaner approach that forceably swapping out the backend.
242
    with math.use_backend('numpy'):
×
243
      return self._lr_fn(self._step)
×
244

245
  def reset(self, output_dir, init_checkpoint=None):
1✔
246
    """Reset the model parameters.
247

248
    Restores the parameters from the given output_dir if a checkpoint exists,
249
    otherwise randomly initializes them.
250

251
    Does not re-jit the model.
252

253
    Args:
254
      output_dir: Output directory.
255
      init_checkpoint: Initial checkpoint to use (default $output_dir/model.pkl)
256
    """
257
    self.close()
×
258
    self._output_dir = output_dir
×
259
    if output_dir is not None:
×
260
      tf.io.gfile.makedirs(output_dir)
×
261
    else:
262
      assert not self._should_save_checkpoints
×
263
      assert not self._should_write_summaries
×
264

265
    # Create summary writers and history.
266
    if self._should_write_summaries:
×
267
      self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'train'),
×
268
                                              enable=self._is_chief)
269
      self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'eval'),
×
270
                                             enable=self._is_chief)
271

272
    # Reset the train and eval streams.
273
    self._train_stream = _repeat_stream(self._inputs.train_stream,
×
274
                                        self._n_devices)
275
    # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval
276
    #   set by adding a padding and stopping the stream when too large.
277
    self._eval_stream = _repeat_stream(
×
278
        self._inputs.eval_stream, self._n_devices)
279
    self._train_eval_stream = _repeat_stream(
×
280
        self._inputs.train_eval_stream, self._n_devices)
281

282
    # Restore the training state.
283
    if output_dir is not None:
×
284
      state = load_trainer_state(output_dir, init_checkpoint)
×
285
    else:
286
      state = TrainerState(step=None, opt_state=None,
×
287
                           history=trax_history.History(), model_state=None)
288
    self._step = state.step or 0
×
289
    history = state.history
×
290
    self._lr_fn = self._lr_schedule(history)
×
291
    self._history = history
×
292
    if state.opt_state:
×
293
      opt_state = state.opt_state
×
294
      model_state = state.model_state
×
295
    else:
296
      opt_state, model_state = self._new_opt_state_and_model_state()
×
297
      model_state = self._for_n_devices(model_state)
×
298
    self._opt_state = OptState(*self._for_n_devices(opt_state))
×
299
    self._model_state = model_state
×
300
    if not state.opt_state and self._should_save_checkpoints:
×
301
      self.save_state(keep=False)
×
302

303
    self.update_nontrainable_params()
×
304

305
  def train_epoch(self, n_steps, n_eval_steps):
1✔
306
    """Runs `n_steps` of training, with periodic logging, saving, and evals."""
307
    # TODO(jonni): Clarify how this method relates to the stricter notion of
308
    # epoch (training for as many steps as needed for a full pass through the
309
    # training data).
310
    print()  # Add visual separator in logs for start of training epoch.
×
311
    start_time = time.time()
×
312

313
    for _ in range(n_steps):
×
314
      batch = next(self._train_stream)
×
315
      if self.n_devices > 1:  # TODO(lukaszkaiser): use everywhere if possible.
×
316
        batch = _reshape_by_device(batch, self.n_devices)
×
317
      self.train_step(batch)
×
318
      if self._should_save_now():
×
319
        self.save_state(keep=True)
×
320
      if self._should_log_now():
×
321
        for (name, value) in self.nontrainable_params.items():
×
322
          self._train_sw.scalar('training/{}'.format(name), value)
×
323

324
    # At end of n_steps, do bookkeeping, run evals, and save state.
325
    elapsed_time = time.time() - start_time
×
326
    self.log_step('Ran %d train steps in %0.2f secs' % (n_steps, elapsed_time))
×
327
    if self._train_sw and n_steps > 1:
×
328
      self._train_sw.scalar('training/steps per second',
×
329
                            n_steps / elapsed_time, step=self._step)
330
      self._train_sw.flush()
×
331
    self.evaluate(n_eval_steps)
×
332
    if self._eval_sw:
×
333
      self._eval_sw.flush()
×
334
    if self._should_save_checkpoints:
×
335
      self.save_state(keep=False)
×
336
    if self._should_save_checkpoints and self._current_step_is_best(high=True):
×
337
      self.save_state(keep=False, prefix='highest_' + self._checkpoint_highest)
×
338
    if self._should_save_checkpoints and self._current_step_is_best(high=False):
×
339
      self.save_state(keep=False, prefix='lowest_' + self._checkpoint_lowest)
×
340

341
  def train_step(self, batch):
1✔
342
    """Run one training step and update self._opt_state."""
343
    # Calculate the current optimizer parameters.
344
    # TODO(pkozakowski): Optimizer parameters get polluted with model state,
345
    # which doesn't break anything but is weird. Filter it out.
346
    opt_param_updates = self._for_n_devices(
×
347
        math.nested_map(np.array, self.nontrainable_params))
348
    opt_state = self._opt_state
×
349
    opt_state.opt_params.update(opt_param_updates)
×
350

351
    # Run the update.
352
    (weights, slots, stat), self._model_state, self._rngs = self._jit_update_fn(
×
353
        self._step, opt_state, batch, self._model_state, self._rngs)
354
    self._model_state = self._map_to_state_dicts(self._state_dicts_update)
×
355
    self._opt_state = opt_state._replace(weights=weights, slots=slots)
×
356
    if self._should_log_now():
×
357
      for name, value in stat.items():
×
358
        self._train_sw.scalar('training/' + name, value, step=self._step)
×
359
    self._step += 1
×
360

361
  def evaluate(self, n_eval_steps):
1✔
362
    """Evaluate the model and log metrics."""
363
    _, rng = jax_random.split(self._rngs[0])
×
364
    # TODO(lukaszkaiser): both model state and parameters by default include
365
    # the loss layer. Currently, we access the pure-model parameters by just
366
    # indexing, [0] here. But we should make it more explicit in a better API.
367
    weights = (self._opt_state[0][0], self._metrics_weights)
×
368
    state = (self._model_state[0], self._metrics_state)
×
369
    self.log_step('Evaluation')
×
370
    train_eval_slice = itertools.islice(self._train_eval_stream, n_eval_steps)
×
371
    train_metrics, _ = self.evaluation_round(train_eval_slice, weights, state,
×
372
                                             rng)
373
    self.log_metrics(train_metrics, self._train_sw, 'train')
×
374
    eval_slice = itertools.islice(self._eval_stream, n_eval_steps)
×
375
    eval_metrics, _ = self.evaluation_round(eval_slice, weights, state, rng)
×
376
    self.log_metrics(eval_metrics, self._eval_sw, 'eval')
×
377
    self.log_step('Finished evaluation')
×
378

379
    # Save the optimizer weights in the history
380
    for (name, value) in self.nontrainable_params.items():
×
381
      self._history.append('train', 'training/{}'.format(name), self._step,
×
382
                           value)
383

384
  def evaluation_round(self, inputs_stream, weights, state, rng):
1✔
385
    """Evaluate.
386

387
    Args:
388
      inputs_stream: iterable of inputs to evaluate on.
389
      weights: weights for each f in eval_fns.
390
      state: state for each f in eval_fns.
391
      rng: random number generator.
392

393
    Returns:
394
      metrics: dict from metric name to metric value averaged over the number of
395
        inputs.
396
      state: end state for `predict_fn`.
397
    """
398
    metrics = collections.defaultdict(float)
×
399
    count = 0
×
400
    for inp in inputs_stream:
×
401
      count += 1
×
402
      rng, subrng = jax_random.split(rng)
×
403
      metric_values, _ = self._jit_eval(inp, weights, state, subrng)
×
404
      try:
×
405
        metric_values = list(metric_values)
×
406
      except (TypeError, IndexError):
×
407
        metric_values = [float(metric_values)]
×
408
      for m, v in zip(self._metrics, metric_values):
×
409
        metrics[m] += v
×
410
    return {m: v / count for (m, v) in six.iteritems(metrics)}, state
×
411

412
  def update_model_state(self, key, value):
1✔
413
    """Updates model state based on nontrainable_params."""
414
    # Translate model state keys to nontrainable param names.
415
    if key in self._nontrainable_param_map:
×
416
      p_name = self._nontrainable_param_map[key]
×
417
    else:
418
      # If a key is not in mapping, it stays the same.
419
      p_name = key
×
420
    if p_name in self.nontrainable_params:
×
421
      if self._step == 0:
×
422
        log('Mapping model state key {} to nontrainable param {}.'
×
423
            ''.format(key, p_name))
424
        return self._for_n_devices(np.array(self.nontrainable_params[p_name]))
×
425
    return value
×
426

427
  def update_nontrainable_params(self):
1✔
428
    self._lr_fn = self._lr_schedule(self._history)
×
429

430
  def save_gin(self):
1✔
431
    assert self._output_dir is not None
×
432
    config_path = os.path.join(self._output_dir, 'config.gin')
×
433
    config_str = gin.operative_config_str()
×
434
    with tf.io.gfile.GFile(config_path, 'w') as f:
×
435
      f.write(config_str)
×
436
    sw = self._train_sw
×
437
    if sw:
×
438
      sw.text('gin_config',
×
439
              jaxboard.markdownify_operative_config_str(config_str))
440

441
  def _save_state_dict(self, trainer_state_dict, weights_file):
1✔
442
    pickle_to_file(trainer_state_dict, weights_file)
×
443
    log('Model saved to %s' % weights_file, stdout=False)
×
444

445
  def save_state(self, keep, prefix='model'):
1✔
446
    """Save trainer state given a possibly replicated opt_state."""
447
    opt_state = self._opt_state
×
448
    if self.n_devices > 1:
×
449
      first_replica = lambda x: x[0]
×
450
      opt_state = OptState(*math.nested_map(first_replica, opt_state))
×
451
    # This line, while optional, allows JAX to transfer arrays from the device
452
    # to the host in parallel, which is particularly important for cloud TPU.
453
    if math.backend_name() == 'jax':
×
454
      opt_state = jax.device_get(opt_state)
×
455
    step, history, model_state = self._step, self._history, self._model_state
×
456
    output_dir = self._output_dir
×
457

458
    weights_file = os.path.join(output_dir, prefix + '.pkl')
×
459

460
    # This dict will be stored as the model.
461
    trainer_state_dict = make_trainer_state_dict(step,
×
462
                                                 opt_state,
463
                                                 history,
464
                                                 model_state)
465
    self._save_state_dict(trainer_state_dict, weights_file)
×
466

467
    if keep:
×
468
      weights_file = os.path.join(output_dir, '{}_{}.pkl'.format(prefix, step))
×
469
      self._save_state_dict(trainer_state_dict, weights_file)
×
470

471
  def save_computation_graphs(self, save_backward_graph):
1✔
472
    """Dump computation graphs to files."""
473
    if self.n_devices != 1:
×
474
      return  # TODO(lukaszkaiser): make this work with more devices.
×
475
    batch = next(self._train_stream)
×
476
    output_dir = self._output_dir
×
477
    if self.n_devices > 1:
×
478
      batch = _reshape_by_device(batch, self.n_devices)
×
479
    weights = self._opt_state[0][0]
×
480
    forward_computation = jax.xla_computation(self._model_predict_eval)(
×
481
        batch, weights=weights, state=self._model_state[0],
482
        rng=self._rngs[0])
483
    with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.txt'), 'w') as f:
×
484
      f.write(forward_computation.GetHloText())
×
485
    with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.dot'), 'w') as f:
×
486
      f.write(forward_computation.GetHloDotGraph())
×
487
    backward_computation = jax.xla_computation(self._jit_update_fn)(
×
488
        self._step, self._opt_state, batch, self._model_state,
489
        self._rngs)
490
    with tf.io.gfile.GFile(os.path.join(output_dir, 'backward.txt'), 'w') as f:
×
491
      f.write(backward_computation.GetHloText())
×
492
    if save_backward_graph:  # Backward graphs can be large so we guard it.
×
493
      with tf.io.gfile.GFile(
×
494
          os.path.join(output_dir, 'backward.dot'), 'w') as f:
495
        f.write(backward_computation.GetHloDotGraph())
×
496

497
  def log_step(self, step_message):
1✔
498
    log('Step % 6d: %s' % (self.step, step_message))
×
499

500
  def log_metrics(self, metrics, summ_writer, log_prefix):
1✔
501
    """Log metrics to summary writer and history."""
502
    history = self._history
×
503
    rjust_len = max([0] + [len(name) for name in metrics])
×
504
    for name, value in six.iteritems(metrics):
×
505
      self.log_step('%s %s | % .8f' % (
×
506
          log_prefix.ljust(5), name.rjust(rjust_len), value))
507
      full_name = 'metrics/' + name
×
508
      if history:
×
509
        history.append(log_prefix, full_name, self.step, value)
×
510
      if summ_writer:
×
511
        summ_writer.scalar(full_name, value, self.step)
×
512

513
  def print_n_weights(self):
1✔
514
    """Prints the total count of trainable weights."""
515
    opt_state = self._opt_state
×
516
    sizes = _sizes(opt_state.weights)
×
517
    if self.n_devices > 1:
×
518
      unreplicate = lambda x: x[0]
×
519
      single_weights = math.nested_map(unreplicate, opt_state.weights)
×
520
      sizes = _sizes(single_weights)
×
521
    total_size = _nested_reduce(sum, sizes)
×
522
    self.log_step('Total number of trainable weights: %d' % total_size)
×
523

524
  def _init_host_and_devices(self, n_devices=None, random_seed=None):
1✔
525
    """Initializes host and device attributes for this trainer.
526

527
    Args:
528
      n_devices: Number of devices this trainer will use. If `None`, get the
529
          number from the backend.
530
      random_seed: Random seed as the starting point for all random numbers used
531
          by the trainer. If `None`, calculate one from system time and host id.
532

533
    Returns:
534
      is_chief: True if this trainer has special chief responsibilities.
535
      n_devices: The passed in value of n_devices or a computed default.
536
      random_seed: The passed in value of random_seed or a computed default.
537
    """
538
    if math.backend_name() == 'jax':
×
539
      host_id = jax.host_id()
×
540
      host_count = jax.host_count()
×
541
    else:
542
      host_id = 0
×
543
      host_count = 1
×
544
    is_chief = (host_id == 0)
×
545

546
    device_count = math.device_count()
×
547
    n_devices = n_devices or device_count
×
548
    # TODO(lukaszkaiser): remove this restriction when possible.
549
    if n_devices != device_count and math.backend_name() == 'jax':
×
550
      raise ValueError('JAX cannot work yet with n_devices != all devices: '
×
551
                       '%d != %d' % (n_devices, device_count))
552

553
    if random_seed is None and host_count > 1:
×
554
      random_seed = int(1e6 * (host_id + time.time())) % 2**32
×
555
    return is_chief, n_devices, init_random_number_generators(random_seed)
×
556

557
  def _map_to_state_dicts(self, f):
1✔
558
    """Map the function f to all dicts in model state."""
559
    # TODO(jonni): Can we replace _nested_map with math.nested_map?
560
    def _nested_map(f, x):
×
561
      if isinstance(x, list):
×
562
        return [_nested_map(f, y) for y in x]
×
563
      if isinstance(x, tuple):
×
564
        return tuple([_nested_map(f, y) for y in x])
×
565
      if isinstance(x, dict) and len(x) == 1:
×
566
        return f(x)
×
567
      return x
×
568
    return _nested_map(f, self._model_state)
×
569

570
  def _state_dicts_update(self, state_dict):
1✔
571
    assert len(state_dict.keys()) == 1
×
572
    key = list(state_dict.keys())[0]
×
573
    value = state_dict[key]
×
574
    return {key: self.update_model_state(key, value)}
×
575

576
  def _should_save_now(self):
1✔
577
    return self._should_save_checkpoints and self._step in self._checkpoints_at
×
578

579
  def _current_step_is_best(self, high):
1✔
580
    """Is the current step the best (highest if high, else lowest)."""
581
    metric = self._checkpoint_highest if high else self._checkpoint_lowest
×
582
    if metric is None:
×
583
      return False
×
584
    # History is a list of pairs (step, value).
585
    history = self._history.get('eval', 'metrics/' + metric)
×
586
    sequence = [float(i[1]) for i in history]  # Just the values.
×
587
    best = max(sequence) if high else min(sequence)  # Best value.
×
588
    last_is_best = float(history[-1][1]) == best  # Is last the best?
×
589
    cur_step = history[-1][0] == self._step  # Is last the current step?
×
590
    return cur_step and last_is_best
×
591

592
  def _should_log_now(self):
1✔
593
    return (self._train_sw is not None
×
594
            and (self._step == 1 or self._step % 10 == 0))
595

596
  def _for_n_devices(self, x):
1✔
597
    """Replicates/broadcasts `x` for n devices if `self.n_devicess > 1`."""
598
    return tl.for_n_devices(x, self.n_devices)  # pylint: disable=protected-access
×
599

600
  def close(self):
1✔
601
    if self._train_sw is not None:
×
602
      self._train_sw.close()
×
603
      self._train_sw = None
×
604
    if self._eval_sw is not None:
×
605
      self._eval_sw.close()
×
606
      self._eval_sw = None
×
607

608

609
@gin.configurable(blacklist=['output_dir'])
1✔
610
def train(output_dir,
1✔
611
          model=gin.REQUIRED,
612
          loss_fn=tl.CrossEntropyLoss(),
613
          inputs=trax_inputs.inputs,
614
          optimizer=trax_opt.Adafactor,
615
          lr_schedule=lr.MultifactorSchedule,
616
          trainer_class=Trainer,
617
          steps=1000,
618
          checkpoints_at=None,
619
          eval_steps=10,
620
          eval_frequency=100,
621
          random_seed=None,
622
          save_graphs=True,
623
          save_backward_graph=False,
624
          nontrainable_param_map=None,
625
          id_to_mask=None,
626
          metrics=None,
627
          checkpoint_highest=None,
628
          checkpoint_lowest=None,
629
          custom_train_fn=None):
630
  """Train the model on the inputs.
631

632
  Args:
633
    output_dir: Directory where to put the logs and checkpoints.
634
    model: The model to train as a callable returning 2 callables, an init_fn
635
      and apply_fn.
636
    loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state,
637
      rng -> loss.
638
    inputs: callable returning trax.inputs.Inputs.
639
    optimizer: The optimizer (see optimizers/base.py for signature).
640
    lr_schedule: A learning rate schedule as a function that takes history and
641
      returns a function from step to learning rate (a float).
642
    trainer_class: The trainer class to use.
643
    steps: int, total number of training steps.
644
    checkpoints_at: list of integers. Save a checkpoint for each training step
645
      in the list.
646
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
647
    eval_frequency: int, how often to run evaluation (every eval_frequency
648
      steps). If None or 0, eval disabled.
649
    random_seed: the random seed to use; time/os dependent if None (default).
650
    save_graphs: bool, if True, save computation graph to file.
651
    save_backward_graph: bool, if True, save backward graph to file too.
652
    nontrainable_param_map: dict, mapping from model nontrainable parameter
653
      names to control names in PolicySchedule.
654
    id_to_mask: id to mask out (None by default).
655
    metrics: optionally override the default metrics dictionary.
656
    checkpoint_highest: save the checkpoint highest at this metric.
657
    checkpoint_lowest: save the checkpoint lowest at this metric.
658
    custom_train_fn: custom train function to call, entirely bypassing this one
659

660
  Returns:
661
    trax.TrainerState
662
  """
663
  if custom_train_fn is not None:
×
664
    return custom_train_fn(output_dir, model=model)
×
665

666
  n_devices = num_devices()
×
667
  # TODO(lukaszkaiser): remove has_weights and id_to_mask (configure loss).
668
  trainer = trainer_class(model, loss_fn, optimizer, lr_schedule, inputs,
×
669
                          output_dir,
670
                          random_seed=random_seed, n_devices=n_devices,
671
                          checkpoints_at=checkpoints_at,
672
                          nontrainable_param_map=nontrainable_param_map,
673
                          metrics=metrics, id_to_mask=id_to_mask,
674
                          checkpoint_lowest=checkpoint_lowest,
675
                          checkpoint_highest=checkpoint_highest)
676

677
  epoch_steps = [steps]  # Only training if eval_frequency is 0 or None
×
678
  if eval_frequency and eval_steps > 0:
×
679
    epoch_steps = itertools.chain([1,  # first epoch only 1 step
×
680
                                   eval_frequency - 1],
681
                                  itertools.repeat(eval_frequency))
682
  trainer.log_step('Starting training using %d devices' % trainer.n_devices)
×
683
  trainer.print_n_weights()
×
684

685
  try:
×
686
    for epoch_steps in epochs(steps, trainer.step, epoch_steps):
×
687
      trainer.train_epoch(epoch_steps, eval_steps)
×
688

689
      # Update nontrainable parameters with new history
690
      trainer.update_nontrainable_params()
×
691

692
      # Bookkeeping we do at the first step
693
      if trainer.step == 1:
×
694
        # Save computation graph (single-device only for now)
695
        if (save_graphs and math.backend_name() == 'jax'):
×
696
          trainer.save_computation_graphs(save_backward_graph)
×
697

698
        # Save Gin config
699
        trainer.save_gin()
×
700

701
    trainer.log_step('Training done')
×
702
  except Exception as e:
×
703
    raise e
×
704
  finally:
705
    trainer.close()
×
706
  return trainer.state
×
707

708

709
@gin.configurable
1✔
710
def num_devices(value=None):
1✔
711
  """Returns how many devices to use (if None, default, use all available)."""
712
  return value
×
713

714

715
@gin.configurable
1✔
716
def _is_jit_init(value=None):
1✔
717
  if value is None:
×
718
    value = math.backend_name() == 'jax'
×
719
  return value
×
720

721

722
@gin.configurable
1✔
723
def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True):
1✔
724
  """Returns a (JIT-compiled) function that computes updates for one step."""
725
  model_and_loss = tl.Serial(predict_fn, loss_fn)
×
726
  # Gradients are always wrt. the first argument, so putting weights first.
727
  def model_and_loss_call(weights, batch, state, rng):
×
728
    res = model_and_loss(batch, weights=weights, state=state, rng=rng)
×
729
    return res, model_and_loss.state
×
730
  if n_devices == 1:  # TODO(lukaszkaiser): remove branch when not needed.
×
731
    def single_update(i, opt_state, batch, state, rng):
×
732
      weights, slots, opt_params = opt_state
×
733
      rng, subrng = jax_random.split(rng[0])
×
734
      grad_fn = math.grad(model_and_loss_call, has_aux=True)
×
735
      grads, state = grad_fn(weights, batch, state, rng)
×
736
      return optimizer.tree_update(
×
737
          i, grads, weights, slots, opt_params), state, [subrng]
738
    return math.jit(single_update) if jit else single_update
×
739

740
  # Else, for n_devices > 1:
741
  @functools.partial(math.pmap, axis_name='batch')
×
742
  def mapped_update(i, opt_state, batch, state, rng):
743
    """This is a multi-device version of the update function above."""
744
    # We assume all tensors have the first dimension = n_devices.
745
    weights, slots, opt_params = opt_state
×
746
    rng, subrng = jax_random.split(rng)
×
747
    grad_fn = math.grad(model_and_loss_call, has_aux=True)
×
748
    grads, state = grad_fn(weights, batch, state, rng)
×
749
    # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just
750
    # the number of devices on this host machine, however psum goes over all
751
    # devices of all hosts (ex: a TPU pod) and we need to be averaging over all
752
    # of them.
753
    grads = jax.tree_util.tree_map(
×
754
        lambda g: math.psum(g, 'batch') / math.psum(np.array(1.0), 'batch'),
755
        grads)
756
    return optimizer.tree_update(
×
757
        i, grads, weights, slots, opt_params), state, subrng
758

759
  def update(i, opt_state, batch, state, rng):
×
760
    return mapped_update(np.repeat(i, n_devices), opt_state, batch, state, rng)
×
761

762
  return update
×
763

764

765
@gin.configurable
1✔
766
def _jit_predict_fn(model_predict, metric_fn, n_devices, jit=True):
1✔
767
  """Returns a JIT-compiled predict function (unless jit=False)."""
768
  model = tl.Serial(model_predict, metric_fn)
×
769
  if not jit:
×
770
    return model.pure_fn
×
771

772
  return tl.jit_forward(model.pure_fn, n_devices)
×
773

774

775
@gin.configurable
1✔
776
def _jit_compute_loss_fn(predict_fn, loss_fn, n_devices, jit=True):
1✔
777
  """Returns a (JIT-compiled) function that computes the loss for one step."""
778
  if n_devices == 1:  # TODO(lukaszkaiser): remove branch when not needed.
×
779
    def single_compute_loss(opt_state, batch, state, rng):
×
780
      rng, subrng = jax_random.split(rng[0])
×
781
      loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng)
×
782
      return loss_val, state, [subrng]
×
783
    return math.jit(single_compute_loss) if jit else single_compute_loss
×
784

785
  # Else, for n_devices > 1:
786
  @functools.partial(math.pmap, axis_name='batch')
×
787
  def mapped_compute_loss(opt_state, batch, state, rng):
788
    """This is a multi-device version of the update function above."""
789
    # We assume all tensors have the first dimension = n_devices.
790
    rng, subrng = jax_random.split(rng)
×
791
    loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng)
×
792
    return loss_val, state, subrng
×
793

794
  def compute_loss(opt_state, batch, state, rng):
×
795
    return mapped_compute_loss(
×
796
        opt_state, _reshape_by_device(batch, n_devices), state, rng)
797

798
  return compute_loss
×
799

800

801
def log(s, stdout=True):
1✔
802
  logging.info(s)
×
803
  if stdout:
×
804
    print(s)
×
805
    sys.stdout.flush()
×
806

807

808
def epochs(total_steps, steps_to_skip, epoch_steps):
1✔
809
  """Generates the number of steps in each epoch before reaching total_steps.
810

811
  Args:
812
    total_steps: int, total number of steps.
813
    steps_to_skip: int, number of steps to skip because of a restart.
814
    epoch_steps: iterable of int, numbers of steps in each epoch.
815

816
  Yields:
817
    epoch_steps: int, number of steps in this epoch
818
  """
819
  steps_to_go = total_steps - steps_to_skip
×
820
  epoch_steps = iter(epoch_steps)
×
821

822
  # Remove the desired number of steps from the stream.
823
  for steps_this_epoch in epoch_steps:
×
824
    if steps_this_epoch > steps_to_skip:
×
825
      # Put back the number of steps left in the unfinished epoch.
826
      epoch_steps = itertools.chain(
×
827
          [steps_this_epoch - steps_to_skip], epoch_steps)
828
    if steps_this_epoch >= steps_to_skip:
×
829
      break
×
830
    steps_to_skip -= steps_this_epoch
×
831

832
  # Yield the remaining steps per epoch up to total_steps.
833
  for steps_this_epoch in epoch_steps:
×
834
    steps_this_epoch = min(steps_this_epoch, steps_to_go)
×
835
    yield steps_this_epoch
×
836
    steps_to_go -= steps_this_epoch
×
837
    if steps_to_go == 0:
×
838
      break
×
839

840

841
def make_trainer_state_dict(step,
1✔
842
                            opt_state,
843
                            history,
844
                            model_state):
845
  """Creates a trainer state dictionary to save to disk.
846

847
  Args:
848
    step: int, a step number
849
    opt_state: OptState namedtuple
850
    history: `trax.history.History`, the history object.
851
    model_state: A nested structure of the model state.
852

853
  Returns:
854
    A dictionary with the fields of TrainerState and OptState flattened.
855
  """
856

857
  return {
×
858
      'step': step,
859
      'weights': opt_state.weights[0],
860
      'loss_weights': opt_state.weights[1],
861
      'slots': opt_state.slots,
862
      'opt_params': opt_state.opt_params,
863
      'history': history,
864
      'state': model_state[0],
865
      'loss_state': model_state[1],
866
      'version_timestamp': 'Jan-13-2020'  # To update in the future if needed.
867
  }
868

869

870
def trainer_state_from_dict(trainer_state_dict):
1✔
871
  """Given the trainer state dictionary, returns `TrainerState`."""
872
  # TODO(afrozm): This becomes simpler if OptState is flattened into
873
  # TrainerState.
874
  step = trainer_state_dict['step']
×
875
  history = trainer_state_dict['history']
×
876
  # TODO(lukaszkaiser): remove the first branch after everyone ports to 'state'.
877
  if 'model_state' in trainer_state_dict:
×
878
    model_state = trainer_state_dict['model_state']
×
879
  else:
880
    model_state = (trainer_state_dict['state'],
×
881
                   trainer_state_dict['loss_state'])
882
  weights = trainer_state_dict['weights']
×
883
  # TODO(lukaszkaiser): remove the next 2 lines after 'loss_weights' is in use.
884
  if 'loss_weights' in trainer_state_dict:
×
885
    weights = (weights, trainer_state_dict['loss_weights'])
×
886
  opt_state = OptState(
×
887
      weights=weights,
888
      slots=trainer_state_dict['slots'],
889
      opt_params=trainer_state_dict['opt_params'])
890
  return TrainerState(step=step, opt_state=OptState(*opt_state),
×
891
                      history=history, model_state=model_state)
892

893

894
def load_trainer_state(output_dir, weights_file=None):
1✔
895
  """Returns a TrainerState instance loaded from the given `output_dir`."""
896
  if weights_file is None:
×
897
    weights_file = os.path.join(output_dir, 'model.pkl')
×
898
    if not tf.io.gfile.exists(weights_file):
×
899
      return TrainerState(step=None, opt_state=None,
×
900
                          history=trax_history.History(), model_state=None)
901
  elif not tf.io.gfile.exists(weights_file):
×
902
    raise ValueError('File not found: %s' % weights_file)
×
903

904
  with tf.io.gfile.GFile(weights_file, 'rb') as f:
×
905
    trainer_state_dict = pickle.load(f)
×
906
  trainer_state = trainer_state_from_dict(trainer_state_dict)
×
907
  log('Model loaded from %s at step %d' % (weights_file, trainer_state.step))
×
908
  logging.debug('From loaded model : history = %s', trainer_state.history)
×
909
  return trainer_state
×
910

911

912
def init_random_number_generators(seed=None):
1✔
913
  """Initializes random generators for Python, NumPy, TensorFlow, and JAX."""
914
  # Seed Python random (None as seed is okay), then use it to seed the others.
915
  random.seed(seed)
×
916
  if seed is None:
×
917
    seed = random.randint(0, 2**31 - 1)
×
918
  numpy.random.seed(seed)
×
919
  tf.random.set_seed(seed)
×
920
  return jax_random.get_prng(seed)
×
921

922

923
def _reshape_by_device(x, n_devices):
1✔
924
  """Reshapes possibly nested x into a shape (n_devices, ...)."""
925
  return tl.reshape_by_device(x, n_devices)  # pylint: disable=protected-access
×
926

927

928
def _nested_reduce(f, x):
1✔
929
  """Fold the function f to the nested structure x (dicts, tuples, lists)."""
930
  if isinstance(x, list):
×
931
    return f([_nested_reduce(f, y) for y in x])
×
932
  if isinstance(x, tuple):
×
933
    return f([_nested_reduce(f, y) for y in x])
×
934
  return x
×
935

936

937
def _sizes(x):
1✔
938
  """Get a structure of sizes for a structure of nested arrays."""
939
  def size(x):
×
940
    try:
×
941
      return x.size
×
942
    except Exception:  # pylint: disable=broad-except
×
943
      return 0
×
944
  return math.nested_map(size, x)
×
945

946

947
def _repeat_stream(stream, n_devices):
1✔
948
  """Repeat a stream indefinitely."""
949
  while True:
×
950
    for example in stream(n_devices):
×
951
      yield example
×
952

953

954
def pickle_to_file(obj, file_path, gzip=False):
1✔
955
  """Pickle obj to file_path with gzipping and failure protection."""
956
  # Pickle to tmp file and overwrite to prevent writing partial files.
957
  tmp_file_path = file_path + '._tmp_'
×
958
  with tf.io.gfile.GFile(tmp_file_path, 'wb') as f:
×
959
    if not gzip:
×
960
      pickle.dump(obj, f)
×
961
    else:
962
      with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf:
×
963
        pickle.dump(obj, gzipf)
×
964
  # Moving a file is much less error-prone than pickling large files.
965
  tf.io.gfile.rename(tmp_file_path, file_path, overwrite=True)
×
966

967

968
def unpickle_from_file(file_path, gzip=False):
1✔
969
  """Unpickle obj from file_path with gzipping."""
970
  with tf.io.gfile.GFile(file_path, 'rb') as f:
×
971
    if not gzip:
×
972
      obj = pickle.load(f)
×
973
    else:
974
      with gzip_lib.GzipFile(fileobj=f, compresslevel=2) as gzipf:
×
975
        obj = pickle.load(gzipf)
×
976
  return obj
×
977

978

979
def _add_weights_and_mask(inputs, id_to_mask):
1✔
980
  """Add weights to inputs without weights and masks by id if requested.
981

982
  Each of the (train, eval, train_eval) streams of inputs is augmented in
983
  the following way:
984
  * if the stream consists of pairs (inputs, targets), a loss mask is added
985
    that is creates as a tensor of ones of the same shape as targets
986
  * if id_to_mask is not None, and the stream (after the previous point) has
987
    triples (inputs, targets, weights), the weights are multipled by a 0/1 mask
988
    that is 0 iff targets is equal to id_to_mask (1 otherwise).
989

990
  Args:
991
    inputs: a trax_inputs.Inputs object to operate on
992
    id_to_mask: int or None, id to pad in targets if not None
993

994
  Returns:
995
    a trax_inputs.Inputs object with augmented streams
996
  """
997
  def _with_masks(input_stream):
×
998
    """Create masks for the given stream."""
999
    for example in input_stream:
×
1000
      if len(example) > 3 or len(example) < 2:
×
1001
        assert id_to_mask is None, 'Cannot automatically mask this stream.'
×
1002
        yield example
×
1003
      else:
1004
        if len(example) == 2:
×
1005
          weights = numpy.ones_like(example[1]).astype(numpy.float32)
×
1006
        else:
1007
          weights = example[2].astype(numpy.float32)
×
1008
        mask = 1.0 - numpy.equal(example[1], id_to_mask).astype(np.float32)
×
1009
        weights *= mask
×
1010
        yield (example[0], example[1], weights)
×
1011
  return trax_inputs.Inputs(
×
1012
      train_stream=lambda n: _with_masks(inputs.train_stream(n)),
1013
      eval_stream=lambda n: _with_masks(inputs.eval_stream(n)),
1014
      train_eval_stream=lambda n: _with_masks(inputs.train_eval_stream(n)))
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2024 Coveralls, Inc