• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / trax / 305

pending completion
305

Pull #564

travis-ci

web-flow
Merge d8a0dda50 into 0bdc85030
Pull Request #564: Make arguments positional in forward_with_state.

23 of 23 new or added lines in 11 files covered. (100.0%)

2677 of 10836 relevant lines covered (24.7%)

0.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

48.25
/trax/layers/base.py
1
# coding=utf-8
2
# Copyright 2020 The Trax Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# Lint as: python3
17
"""Base layer class."""
1✔
18

19
import copy
1✔
20
import inspect
1✔
21
import pickle
1✔
22
import traceback
1✔
23

24
import jax
1✔
25
import numpy as np
1✔
26
import tensorflow as tf
1✔
27

28
from trax import math
1✔
29
from trax.math import nested_map
1✔
30
from trax.math import numpy as jnp
1✔
31
from trax.shapes import ShapeDtype
1✔
32
from trax.shapes import signature
1✔
33

34

35
EMPTY_WEIGHTS = ()
1✔
36
EMPTY_STATE = ()
1✔
37

38

39
class Layer(object):
1✔
40
  """Base class for composable layers in a deep learning network.
41

42
  Layers are the basic building blocks for deep learning models. A Trax layer
43
  computes a function from zero or more inputs to zero or more outputs,
44
  optionally using trainable weights (common) and non-parameter state (not
45
  common). Authors of new layer subclasses typically override at most two
46
  methods of the base `Layer` class:
47

48
    forward(inputs, weights):
49
      Computes this layer's output as part of a forward pass through the model.
50

51
    new_weights(self, input_signature):
52
      Returns new weights suitable for inputs with the given signature.
53

54
  A small subset of layer types are combinators -- they organize the computation
55
  of their sublayers, e.g., applying their sublayers in series or in parallel.
56

57
  All layers have the following properties, with default values implemented
58
  in the base `Layer` class:
59

60
    - n_in: int (default 1)
61
    - n_out: int (default 1)
62
    - weights: tuple (default empty -- the layer has no weights)
63
    - state: tuple (default empty -- the layer has no non-parameter state)
64
    - sublayers: tuple (default empty -- the layer has no sublayers)
65

66
  The inputs to a layer are tensors, packaged according to how many there are:
67

68
    - n_in = 0: an empty tuple ()
69
    - n_in = 1: one tensor (NOT wrapped in a tuple)
70
    - n_in > 1: a tuple of tensors
71

72
  (The special treatment of the single-input case is meant to simplify the
73
  work of layer writers; this design choice may be revisited in the future.)
74

75
  The outputs from a layer are also tensors, packaged the same as layer inputs:
76

77
    - n_out = 0: an empty tuple ()
78
    - n_out = 1: the tensor (NOT wrapped in a tuple)
79
    - n_out > 1: a tuple of tensors
80

81
  The Trax runtime maintains a data stack with which layer calls are composed.
82
  For more complex data network architectures, possibly involving multiple data
83
  flows, one can view each layer as a function from stack state to stack state,
84
  where the function's inputs are a slice from the stack, and the function's
85
  outputs are spliced back into the stack.
86
  """
87

88
  def __init__(self, n_in=1, n_out=1, name=None):
1✔
89
    """Creates a partially initialized, unconnected layer instance.
90

91
    Args:
92
      n_in: Number of inputs expected by this layer.
93
      n_out: Number of outputs promised by this layer.
94
      name: Class-like name for this layer; for use in debugging.
95
    """
96
    self._n_in = n_in
1✔
97
    self._n_out = n_out
1✔
98
    self._name = name or self.__class__.__name__
1✔
99
    self._sublayers = ()  # Default is no sublayers.
1✔
100
    self._input_signature = None
1✔
101
    self._rng = None
1✔
102
    self._weights = EMPTY_WEIGHTS  # cached weights
1✔
103
    self._state = EMPTY_STATE
1✔
104
    # record root call site for custom error messages:
105
    frame = _find_frame(inspect.currentframe())
1✔
106
    # Turns out that frame can mutate in time, so we just copy what we need.
107
    self._caller = {'filename': copy.copy(frame.f_code.co_filename),
1✔
108
                    'lineno': int(frame.f_lineno)}
109
    del frame  # Just in case.
1✔
110
    self._init_finished = False
1✔
111
    self._jit_cache = {}
1✔
112

113
  def __repr__(self):
114
    def indent_string(x):
115
      return '  ' + x.replace('\n', '\n  ')
116
    name_str = self._name
117
    n_in, n_out = self.n_in, self.n_out
118
    if n_in != 1: name_str += f'_in{n_in}'
119
    if n_out != 1: name_str += f'_out{n_out}'
120
    objs = self.sublayers
121
    if objs:
122
      objs_str = '\n'.join(indent_string(str(x)) for x in objs)
123
      return f'{name_str}[\n{objs_str}\n]'
124
    else:
125
      return name_str
126

127
  def __call__(self, x, weights=None, state=None, rng=None, n_accelerators=0):
1✔
128
    """Makes Layer instances callable; for use in tests or interactive settings.
129

130
    This convenience method helps library users play with, test, or otherwise
131
    probe the behavior of layers outside of a full training environment. It
132
    presents the layer as callable function from inputs to outputs, with the
133
    option of manually specifying weights and non-parameter state per individual
134
    call. For convenience, weights and non-parameter state are cached per layer
135
    instance, starting from default values of `EMPTY_WEIGHTS` and `EMPTY_STATE`,
136
    and acquiring non-empty values either by initialization or from values
137
    explicitly provided via the weights and state keyword arguments.
138

139
    Args:
140
      x: 0 or more input tensors, formatted the same as the inputs to
141
          Layer.forward.
142
      weights: Weights or None; if None, use self's cached weights value.
143
      state: State or None; if None, use self's cached state value.
144
      rng: rng object or None; if None, use a default computed from an
145
          integer 0 seed.
146
      n_accelerators: Number of accelerators to target.
147

148
    Returns:
149
      0 or more output tensors, formatted the same as the outputs from
150
          Layer.forward.
151
    """
152
    weights = self.weights if weights is None else weights
1✔
153
    state = self.state if state is None else state
1✔
154
    rng = self._rng if rng is None else rng
1✔
155
    rng = math.random.get_prng(0) if rng is None else rng
1✔
156

157
    forward_w_s_r = self.pure_fn
1✔
158
    # TODO(lukaszkaiser): n_accelerators is experimental, to decide on API
159
    if n_accelerators:
1✔
160
      if n_accelerators not in self._jit_cache:
×
161
        self._jit_cache[n_accelerators] = (
×
162
            jit_forward(forward_w_s_r, n_accelerators))
163
      forward_w_s_r = self._jit_cache[n_accelerators]
×
164
    outputs, new_state = forward_w_s_r(x, weights, state, rng)
1✔
165
    self.state = new_state
1✔
166
    self.weights = weights
1✔
167
    return outputs
1✔
168

169
  def forward(self, inputs, weights):
1✔
170
    """Computes this layer's output as part of a forward pass through the model.
171

172
    Authors of new Layer subclasses should override this method to define the
173
    forward computation that their layer performs, unless they need to use
174
    local non-trainable state or randomness, in which case they should
175
    override `forward_with_state` instead.
176

177
    Args:
178
      inputs: Input tensors, matching the number (n_in) expected by this
179
          layer, packaged as single positional arg. Specifically:
180
            - n_in = 0: an empty tuple or empty list
181
            - n_in = 1: a tensor (NOT wrapped in a tuple)
182
            - n_in > 1: a tuple or list of tensors, with n_in items
183
      weights: A tuple or list of trainable weights, with one element for this
184
          layer if this layer has no sublayers, or one for each sublayer if
185
          this layer has sublayers. If a layer (or sublayer) has no trainable
186
          weights, the corresponding weights element is an empty tuple.
187

188
    Returns:
189
      Tensors, matching the number (n_out) promised by this layer.
190
      Specifically:
191
        - n_out = 0: an empty tuple
192
        - n_out = 1: one tensor (NOT wrapped in a tuple)
193
        - n_out > 1: a tuple of tensors, with n_out items
194
    """
195
    raise NotImplementedError
196

197
  def forward_with_state(self, inputs, weights, state, rng):
1✔
198
    """Computes this layer's output as part of a forward pass through the model.
199

200
    Authors of new Layer subclasses should override this method to define the
201
    forward computation that their layer performs only if their layer uses
202
    local state or randomness. Otherwise override `forward` instead.
203

204
    Args:
205
      inputs: Input tensors, matching the number (n_in) expected by this
206
          layer. Specifically:
207
            - n_in = 0: an empty tuple or empty list
208
            - n_in = 1: a tensor (NOT wrapped in a tuple)
209
            - n_in > 1: a tuple or list of tensors, with n_in items
210
      weights: A tuple or list of trainable weights, with one element for this
211
          layer if this layer has no sublayers, or one for each sublayer if
212
          this layer has sublayers. If a layer (or sublayer) has no trainable
213
          weights, the corresponding weights element is an empty tuple.
214
      state: Layer-specific non-parameter state that can update between batches.
215
      rng: Single-use random number generator (JAX PRNG key).
216

217
    Returns:
218
      A tuple of (tensors, state). The tensors match the number (n_out) promised
219
      by this layer, and are formatted according to that number, specifically:
220
        - n_out = 0: an empty tuple
221
        - n_out = 1: one tensor (NOT wrapped in a tuple)
222
        - n_out > 1: a tuple of tensors, with n_out items
223
    """
224
    del rng
1✔
225
    return self.forward(inputs, weights), state
1✔
226

227
  def new_weights(self, input_signature):
1✔
228
    """Returns new weights suitable for inputs with the given signature.
229

230
    Authors of new Layer subclasses should override this method if their layer
231
    uses trainable weights. The default implementation works for layers that
232
    have no weights. Layers that have trainable state should override the
233
    `new_weights_and_state` method instead.
234

235
    Args:
236
      input_signature: A ShapeDtype instance (if this layer takes one input)
237
          or a list/tuple of ShapeDtype instances; signatures of inputs.
238
    """
239
    del input_signature
1✔
240
    return EMPTY_WEIGHTS
1✔
241

242
  def new_weights_and_state(self, input_signature):
1✔
243
    """Returns a (weights, state) pair suitable for initializing this layer.
244

245
    Authors of new Layer subclasses should override this method if their layer
246
    uses trainable weights or has non-parameter state that gets updated
247
    between batches. The default implementation works for layers that have
248
    no weights or state.
249

250
    Args:
251
      input_signature: A ShapeDtype instance (if this layer takes one input)
252
          or a list/tuple of ShapeDtype instances.
253
    """
254
    return self.new_weights(input_signature), EMPTY_STATE
1✔
255

256
  @property
1✔
257
  def has_backward(self):
258
    """Returns True if this layer provides its own (custom) backward pass code.
259

260
    A layer subclass that provides custom backward pass code (for custom
261
    gradients) must override this method to return True.
262
    """
263
    return False
1✔
264

265
  def backward(self, inputs, output, grad, weights, state, new_state, rng):
1✔
266
    """Custom backward pass to propagate gradients in a custom way.
267

268
    Args:
269
      inputs: Input tensors; can be a (possibly nested) tuple.
270
      output: The result of running this layer on inputs.
271
      grad: gradient signal (called cotangent in jax) computed based on
272
        subsequent layers. The structure and shape must match output.
273
      weights: layer weights
274
      state: start state.
275
      new_state: end state computed by running the layer
276
      rng: Single-use random number generator (JAX PRNG key).
277

278
    Returns:
279
      The custom gradient signal for the input. Note that we need to return
280
      a gradient for each argument of forward, so it will usually be a tuple
281
      of signals: the gradient for inputs and weights.
282
    """
283
    raise NotImplementedError
284

285
  # End of public subclassing interface.
286
  # Begin public callable interface.
287

288
  def init(self, input_signature, rng=None):
1✔
289
    """Initializes this layer and its sublayers recursively.
290

291
    This method is designed to initialize each layer instance once, even if the
292
    same layer instance occurs in multiple places in the network. This enables
293
    weight sharing to be implemented as layer sharing.
294

295
    Args:
296
      input_signature: `ShapeDtype` instance (if this layer takes one input)
297
          or list/tuple of `ShapeDtype` instances.
298
      rng: Single-use random number generator (JAX PRNG key). If none is
299
          provided, a default rng based on the integer seed 0 will be used.
300

301
    Returns:
302
      A (weights, state) tuple, in which weights contains newly created weights
303
          on the first call and `EMPTY_WEIGHTS` on all subsequent calls.
304
    """
305
    try:
1✔
306
      if self._rng is None:
1✔
307
        rng = math.random.get_prng(0) if rng is None else rng
1✔
308
        self._set_rng_recursive(rng)
1✔
309
      # Initialize weights once; store them for use when this layer is called.
310
      # Needs to call new_weights_and_state regardless of _init_finished because
311
      # state also needs to be initialized. After jitting, graph pruning should
312
      # be able to remove unnecessary computation.
313
      # TODO(lukaszkaiser): Revisit this decision and see whether layers sharing
314
      #   weights should also share states.
315
      weights, state = self.new_weights_and_state(input_signature)
1✔
316
      if not self._init_finished:
1✔
317
        self._init_finished = True
1✔
318
        self._weights = weights
1✔
319
        self._state = state
1✔
320
        return (weights, state)
1✔
321
      else:
322
        return (EMPTY_WEIGHTS, state)
×
323
    except Exception as e:
×
324
      name, trace = self._name, _short_traceback(skip=3)
×
325
      raise LayerError(name, 'init', self._caller,
×
326
                       input_signature, trace) from e
327

328
  def init_from_file(self, file_name, weights_only=False):
1✔
329
    """Initializes this layer and its sublayers from a file.
330

331
    We assume that the file is a pickled dictionary that contains the fields
332
    'weights' and 'state' with structures corresponding to this layers weights
333
    and state. Note that the pickled dictionary is allowed to contain other
334
    fields too, but these two are required to init.
335

336
    Args:
337
      file_name: the name of the file to initialize from.
338
      weights_only: if True, initialize only the weights, not state.
339
    """
340
    with tf.io.gfile.GFile(file_name, 'rb') as f:
×
341
      dictionary = pickle.load(f)
×
342
    self.weights = dictionary['weights']
×
343
    if not weights_only:
×
344
      self.state = dictionary['state']
×
345

346
  def new_rng(self):
1✔
347
    """Returns a new single-use random number generator (JAX PRNG key)."""
348
    self._rng, rng = math.random.split(self._rng)
1✔
349
    return rng
1✔
350

351
  def new_rngs(self, n):
1✔
352
    """Returns `n` single-use random number generators (JAX PRNG keys).
353

354
    Args:
355
      n: The number of rngs to return; must be an integer > 0.
356

357
    Returns:
358
      A tuple of `n` rngs. Successive calls will yield continually new values.
359
    """
360
    if n < 1:
1✔
361
      raise ValueError(f"Requested number of new rng's ({n}) less than 1.")
×
362
    rngs = math.random.split(self._rng, n + 1)
1✔
363
    self._rng = rngs[0]
1✔
364
    return tuple(rngs[1:])
1✔
365

366
  # End of public callable methods.
367
  # Methods and properties below are reserved for internal use.
368

369
  @property
1✔
370
  def n_in(self):
371
    """Returns how many tensors this layer expects as input."""
372
    return self._n_in
1✔
373

374
  @property
1✔
375
  def n_out(self):
376
    """Returns how many tensors this layer promises as output."""
377
    return self._n_out
1✔
378

379
  @property
1✔
380
  def sublayers(self):
381
    """Returns a tuple containing this layer's sublayers; may be empty."""
382
    return self._sublayers
1✔
383

384
  @property
1✔
385
  def input_signature(self):
386
    """Returns this layer's input signature.
387

388
    An input signature is a ShapeDtype instance (if the layer takes one input)
389
    or a tuple of ShapeDtype instances.
390
    """
391
    return self._input_signature
×
392

393
  @property
1✔
394
  def weights(self):
395
    """Returns this layer's weights.
396

397
    Depending on the layer, the weights can be in the form of:
398
      - an empty tuple
399
      - a tensor (ndarray)
400
      - a nested structure of tuples and tensors
401
    TODO(jonni): Simplify this picture (and underlying implementation).
402
    """
403
    return self._weights
1✔
404

405
  @weights.setter
1✔
406
  def weights(self, weights):
407
    self._weights = weights
1✔
408

409
  @property
1✔
410
  def state(self):
411
    """Returns a tuple containing this layer's state; may be empty."""
412
    return self._state
1✔
413

414
  @state.setter
1✔
415
  def state(self, state):
416
    self._state = state
1✔
417

418
  def pure_fn(self, x, weights, state, rng):
1✔
419
    """Applies this layer as a pure function with no optional args.
420

421
    This method exposes the layer's computation as a pure function. This is
422
    esp. useful for JIT compilation. Do not override, use `forward` instead.
423

424
    Args:
425
      x: See Layer.forward_with_state inputs.
426
      weights: See Layer.forward_with_state.
427
      state: See Layer.forward_with_state.
428
      rng: See Layer.forward_with_state.
429

430
    Returns:
431
      See Layer.forward_with_state.
432
    """
433
    try:
1✔
434
      # If weights are nothing, we may be reusing this layer.
435
      # Use the cached weights to calculate the value.
436
      # Note: to make sure jit tracers can decide this branch in python we use
437
      # `weights is EMPTY_WEIGHTS` instead of, e.g., `not weights` or
438
      # `weights == EMPTY_WEIGHTS`.
439
      if weights is EMPTY_WEIGHTS:  # pylint: disable=literal-comparison
1✔
440
        weights = self._weights
1✔
441
      else:
442
        # In this case, we're called for the first time: cache weights.
443
        self._weights = weights
1✔
444

445
      if not self.has_backward:
1✔
446
        outputs, s = (
1✔
447
            self.forward_with_state(x, weights, state, rng))
448
      else:
449
        outputs, s = self._do_custom_gradients(x, weights, state, rng=rng)
×
450
      self._state = s
1✔
451
      return outputs, s
1✔
452

453
    except Exception as e:
×
454
      name, trace = self._name, _short_traceback()
×
455
      raise LayerError(name, 'pure_fn',
×
456
                       self._caller, signature(x), trace) from e
457

458
  def output_signature(self, input_signature):
1✔
459
    """Returns output signature this layer would give for `input_signature`."""
460
    return self._forward_abstract(input_signature)[0]  # output only, not state
×
461

462
  def _forward_abstract(self, input_signature):
1✔
463
    """Computes shapes and dtypes this layer would produce in a forward pass.
464

465
    Args:
466
      input_signature: ShapeDtype instance (if this layer takes one input)
467
          or list/tuple of ShapeDtype instances.
468

469
    Returns:
470
      Tuple of (output, state).
471

472
      The output part of the tuple is a ShapeDtype instance representing the
473
      shape and type of the output (if this layer has one output) or a tuple
474
      of ShapeDtype instances (if this layer has more than one output).
475
    """
476
    try:
1✔
477
      # Note: By using rng_signature in place of an rng, we avoid computing and
478
      # permanently storing in global memory a large number of dropout masks.
479
      # TODO(jonni): Check if using an rng still carries this cost.
480
      rng_signature = ShapeDtype((2,), np.uint32)
1✔
481
      weight_signature = nested_map(signature, self.weights)
1✔
482
      forward_infer_shapes = math.abstract_eval(self.forward_with_state)
1✔
483
      return forward_infer_shapes(
1✔
484
          input_signature, weight_signature, self.state, rng_signature)
485
    except Exception as e:
×
486
      name, trace = self._name, _short_traceback(skip=3)
×
487
      raise LayerError(name, '_forward_abstract', self._caller, input_signature,
×
488
                       trace) from e
489

490
  # pylint: disable=protected-access
491
  def _set_rng_recursive(self, rng):
1✔
492
    """Sets the rng (JAX PRNG key) for this layer and sublayers, recursively."""
493
    self._rng = rng
1✔
494
    sublayers = self.sublayers
1✔
495
    if sublayers:
1✔
496
      rngs = math.random.split(rng, len(sublayers))
1✔
497
      for sublayer, rng in zip(sublayers, rngs):
1✔
498
        sublayer._set_rng_recursive(rng)
1✔
499

500
  def _set_input_signature_recursive(self, input_signature):
1✔
501
    """Sets input_signatures for this layer and sublayers, recursively.
502

503
    General combinators (those that can take multiple sublayers) must override
504
    this method to calculate and set input signatures for the sublayers. (See
505
    the `Serial` class in combinators.py for an example.)
506

507
    Args:
508
      input_signature: A `ShapeDtype` instance (if this layer takes one input)
509
          or a list/tuple of `ShapeDtype` instances
510
    """
511
    self._input_signature = input_signature
×
512

513
    # Handle the special case of a single immediate sublayer (which may in turn
514
    # have its own sublayers).
515
    sublayers = self.sublayers
×
516
    if sublayers and len(sublayers) == 1:
×
517
      sublayers[0]._set_input_signature_recursive(input_signature)
×
518
    if sublayers and len(sublayers) > 1:
×
519
      raise ValueError('A layer class whose instances can have more than one '
×
520
                       'sublayer must override the input_signature property '
521
                       'setter.')
522
  # pylint: enable=protected-access
523

524
  def replicate(self, n_accelerators):
1✔
525
    """Replicate weights and state for use on n accelerators. Experimental."""
526
    if n_accelerators > 1:
×
527
      self.weights = for_n_devices(self.weights, n_accelerators)
×
528
      self.state = for_n_devices(self.state, n_accelerators)
×
529

530
  def unreplicate(self, unreplicate_state=False):
1✔
531
    """Unreplicate weights and optionally state. Experimental."""
532
    self.weights = math.nested_map(self.weights, lambda x: x[0])
×
533
    if unreplicate_state:
×
534
      self.state = math.nested_map(self.state, lambda x: x[0])
×
535

536
  def _do_custom_gradients(self, x, weights, state, rng):
1✔
537
    """Calls this layer for a forward pass, but with custom gradients."""
538
    assert math.backend_name() == 'jax', (
×
539
        'Custom gradients are only supported in JAX for now.')
540

541
    # See this link for how custom transformations are defined in JAX:
542
    # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms
543
    @jax.custom_transforms
×
544
    def _do_forward(y, weights):
545
      res = self.forward_with_state(y, weights, state, rng)
×
546
      return res
×
547

548
    # This is the custom gradient (vector-jacobian product in JAX) function.
549
    # For the exact specification of this custom transformation see this link:
550
    # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all
551
    def do_forward_vjp(y, weights):
×
552
      """Custom gradient (vjp) function."""
553
      output, new_state = self.forward_with_state(y, weights, state, rng)
×
554
      def vjpfun(grad):
×
555
        grad = grad[0]  # Ignore dummy gradient wrt state.
×
556
        res = self.backward(y, output, grad, weights, state, new_state, rng)
×
557
        return res
×
558
      return (output, new_state), vjpfun
×
559

560
    jax.defvjp_all(_do_forward, do_forward_vjp)
×
561
    output, state = _do_forward(x, weights)
×
562
    state = jax.lax.stop_gradient(state)
×
563
    return output, state
×
564

565

566
def layer(n_in=1, n_out=1, name=None):
1✔
567
  """Decorator for creating simple layers.  DEPRECATED; use base.Fn instead."""
568

569
  def _build_layer_class(raw_fn):
×
570
    """Returns a Layer class whose callable instances execute the function."""
571

572
    def _init(self, **kwargs):
×
573
      self._kwargs = kwargs  # pylint: disable=protected-access
×
574
      Layer.__init__(self, n_in=n_in, n_out=n_out, name=name)
×
575

576
    def _forward(self, inputs, weights):
×
577
      """Uses this layer as part of a forward pass through the model."""
578
      del weights
×
579
      _validate_forward_input(inputs, n_in)
×
580
      raw_output = raw_fn(inputs, **self._kwargs)  # pylint: disable=protected-access
×
581
      output = () if _is_empty(raw_output) else raw_output
×
582
      return output
×
583

584
    # Set docstrings and create the class.
585
    _forward.__doc__ = raw_fn.__doc__
×
586
    # Note: None.__doc__ is None
587
    cls = type(raw_fn.__name__, (Layer,),
×
588
               {'__init__': _init,
589
                'forward': _forward})
590
    return cls
×
591

592
  return _build_layer_class
×
593

594

595
class PureLayer(Layer):
1✔
596
  """Pure function from inputs to outputs, packaged as neural network layer.
597

598
  The `PureLayer` class represents the simplest kinds of layers: layers with
599
  no trainable weights and no randomness, hence pure functions from inputs to
600
  outputs.
601
  """
602

603
  def __init__(self, forward_fn, n_in=1, n_out=1, name='PureLayer'):
1✔
604
    """Creates an unconnected PureLayer instance.
605

606
    Args:
607
      forward_fn: Pure function from input tensors to output tensors, where
608
          inputs and outputs are packaged as specified for `forward`.
609
      n_in: Number of inputs expected by this layer.
610
      n_out: Number of outputs promised by this layer.
611
      name: Class-like name for this layer; for use only in debugging.
612
    """
613
    super().__init__(n_in, n_out, name)
1✔
614
    self._forward_fn = forward_fn
1✔
615

616
  def forward(self, inputs, weights):
1✔
617
    """Overrides `Layer.forward`.
618

619
    Args:
620
      inputs: Input tensors, matching the number (n_in) expected by this layer.
621
      weights: Trainable weights in general, but this subclass doesn't use
622
          weights, so the only acceptable value is an empty tuple/list.
623

624
    Returns:
625
      Tensors, matching the number (n_out) promised by this layer.
626

627
    Raises:
628
      ValueError: If weights is other than an empty tuple/list.
629
    """
630
    _validate_forward_input(inputs, self.n_in)
1✔
631
    raw_output = self._forward_fn(inputs)
1✔
632
    output = () if _is_empty(raw_output) else raw_output
1✔
633
    return output
1✔
634

635

636
def Fn(name, f, n_out=1):  # pylint: disable=invalid-name
1✔
637
  """Returns a layer with no weights that applies the function f.
638

639
  `f` can take and return any number of arguments, and takes only positional
640
  arguments -- no default or keyword arguments. It often uses JAX-numpy (jnp).
641
  The following, for example, would create a layer that takes two inputs and
642
  returns two outputs -- element-wise sums and maxima:
643

644
      Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)
645

646
  The layer's number of inputs (`n_in`) is automatically set to number of
647
  positional arguments in `f`, but you must explicitly set the number of
648
  outputs (`n_out`) whenever it's not the default value 1.
649

650
  Args:
651
    name: Class-like name for the resulting layer; for use in debugging.
652
    f: Pure function from input tensors to output tensors, where each input
653
        tensor is a separate positional arg, e.g.:
654
            f(x0, x1) --> x0 + x1
655
        Output tensors must be packaged as specified for `Layer.forward`.
656
    n_out: Number of outputs promised by the layer; default value 1.
657

658
  Returns:
659
    Layer executing the function f.
660
  """
661
  # Inspect the function f to restrict to no-defaults and no-kwargs functions.
662
  argspec = inspect.getfullargspec(f)
1✔
663
  if argspec.defaults is not None:
1✔
664
    raise ValueError('Function has default arguments (not allowed).')
×
665
  if argspec.varkw is not None:
1✔
666
    raise ValueError('Function has keyword arguments (not allowed).')
×
667
  if argspec.varargs is not None:
1✔
668
    raise ValueError('Function has variable args (not allowed).')
×
669

670
  def _forward(xs):  # pylint: disable=invalid-name
1✔
671
    if not isinstance(xs, (tuple, list)):
1✔
672
      xs = (xs,)
1✔
673
    return f(*xs)
1✔
674

675
  n_in = len(argspec.args)
1✔
676
  name = name or 'Fn'
1✔
677
  return PureLayer(_forward, n_in=n_in, n_out=n_out, name=name)
1✔
678

679

680
class LayerError(Exception):
1✔
681
  """Exception raised in the layer stack.
682

683
  Attributes:
684
    message: the message corresponding to this exception.
685
  """
686

687
  def __init__(self, layer_name, function_name, caller,
1✔
688
               input_signature, traceback_string):
689
    self._layer_name = layer_name
×
690
    self._function_name = function_name
×
691
    self._caller = caller  # Python inspect object with init caller info.
×
692
    self._traceback = traceback_string
×
693
    self._input_signature = input_signature
×
694
    super(LayerError, self).__init__(self.message)
×
695

696
  @property
1✔
697
  def message(self):
698
    """Create error message."""
699
    prefix = 'Exception passing through layer '
×
700
    prefix += '%s (in %s):\n' % (self._layer_name, self._function_name)
×
701
    short_path = '[...]/' + '/'.join(
×
702
        self._caller['filename'].split('/')[-3:])
703
    caller = '  layer created in file %s, line %d\n' % (short_path,
×
704
                                                        self._caller['lineno'])
705
    shapes_str = '  layer input shapes: %s\n\n' % str(self._input_signature)
×
706
    return prefix + caller + shapes_str + self._traceback
×
707

708

709
def check_shape_agreement(layer_obj, input_signature):
1✔
710
  """Compares the layer's __call__ output to its _foward_abstract shape output.
711

712
  This function helps test layer mechanics and inter-layer connections that
713
  aren't dependent on specific data values.
714

715
  Args:
716
    layer_obj: A layer object.
717
    input_signature: A `ShapeDtype` instance (if `layer_obj` takes one input)
718
        or a list/tuple of ShapeDtype instances.
719

720
  Returns:
721
    A tuple representing either a single shape (if the layer has one output) or
722
    a tuple of shape tuples (if the layer has more than one output).
723
  """
724
  weights, state = layer_obj.init(input_signature)
×
725
  output_signature, _ = layer_obj._forward_abstract(input_signature)  # pylint: disable=protected-access
×
726
  if isinstance(output_signature, tuple):
×
727
    shape_output = tuple(x.shape for x in output_signature)
×
728
  else:
729
    shape_output = output_signature.shape
×
730

731
  rng1, rng2 = layer_obj.new_rngs(2)
×
732
  random_input = _random_values(input_signature, rng1)
×
733
  call_output = layer_obj(random_input, weights=weights, state=state, rng=rng2)
×
734
  call_output_shape = _shapes(call_output)
×
735

736
  msg = '_foward_abstract shape output %s != __call__ output shape %s' % (
×
737
      shape_output, call_output_shape)
738
  assert shape_output == call_output_shape, msg
×
739
  # TODO(jonni): Remove this assert? It makes test logs harder to read.
740
  return shape_output
×
741

742

743
def _validate_forward_input(x, n_in):
1✔
744
  if n_in != 1:
1✔
745
    if not isinstance(x, (tuple, list)):
1✔
746
      raise TypeError(
×
747
          f'Expected input to be a tuple or list; instead got {type(x)}.')
748
    if len(x) != n_in:
1✔
749
      raise ValueError(f'Input tuple length ({len(x)}) does not equal required '
×
750
                       f'number of inputs ({n_in}).')
751

752

753
def _is_empty(container):
1✔
754
  if container is None:
1✔
755
    raise ValueError('Argument "container" is None.')
×
756
  return isinstance(container, (list, tuple)) and len(container) == 0  # pylint: disable=g-explicit-length-test
1✔
757

758

759
def _find_frame(frame):
1✔
760
  """Find the frame with the caller on the stack."""
761
  # TODO(lukaszkaiser): rewrite this function in a systematic way.
762
  # We want to find the first place where the layer was called
763
  # that is *not* an __init__ function of an inheriting layer.
764
  # We also need to exclude a few decorator functions.
765
  while frame.f_code.co_name in ['__init__', 'gin_wrapper', '_validate',
1✔
766
                                 '_validate_forward_inputs', '_init']:
767
    # We only skip __init__ in internal layers, return otherwise.
768
    try:
1✔
769
      dirname = frame.f_code.co_filename.split('/')[-2]
1✔
770
    except IndexError:
×
771
      # Notebook cells have dummy filenames that do not contain any slashes
772
      dirname = frame.f_code.co_filename
×
773
    if dirname != 'layers' and frame.f_code.co_name == '__init__':
1✔
774
      return frame
×
775
    # If we are in an init, move up.
776
    frame = frame.f_back
1✔
777
  return frame
1✔
778

779

780
def _shorten_file_path(line):
1✔
781
  """Shorten file path in error lines for more readable tracebacks."""
782
  start = line.lower().find('file')
×
783
  if start < 0:
×
784
    return line
×
785
  first_quote = line.find('"', start)
×
786
  if first_quote < 0:
×
787
    return line
×
788
  second_quote = line.find('"', first_quote + 1)
×
789
  if second_quote < 0:
×
790
    return line
×
791
  path = line[first_quote + 1:second_quote]
×
792
  new_path = '/'.join(path.split('/')[-3:])
×
793
  return line[:first_quote] + '[...]/' + new_path + line[second_quote + 1:]
×
794

795

796
def _short_traceback(skip=3):
1✔
797
  """Cleaned-up form of traceback."""
798
  counter, res = 0, []
×
799
  # Skipping 3 lines by default: the top (useless) and self-call.
800
  # In python 3, we need to set chain to False (it doesn't exist in python 2).
801
  lines = traceback.format_exc(chain=False).splitlines()[skip:]  # pylint: disable=unexpected-keyword-arg
×
802
  for l in lines:
×
803
    if l.startswith('trax.layers.base.LayerError'):
×
804
      l = l[len('trax.layers.base.'):]  # Remove the trax.layers.base prefix.
×
805
    res.append(_shorten_file_path(l))
×
806
    if counter % 2 == 1:
×
807
      res.append('')
×
808
    counter += 1
×
809
    # If we see a LayerError, the traceback has already been processed.
810
    if l.startswith('LayerError'):
×
811
      # Skip 4 back except last as these are internal base-layer calls.
812
      res = res[:-4] + [res[-1]]
×
813
      res += lines[counter:]
×
814
      break
×
815
  return '\n'.join(res)
×
816

817

818
def _random_values(input_signature, rng):
1✔
819
  """Creates random floats or ints of the given shape.
820

821
  Args:
822
    input_signature: A `ShapeDtype` instance (if `layer_obj` takes one input)
823
        or a list/tuple of ShapeDtype instances.
824
    rng: Single-use random number generator (JAX PRNG key).
825

826
  Returns:
827
    Random values with the shape and type specified.
828
  """
829
  if isinstance(input_signature, ShapeDtype):
×
830
    shape, dtype = input_signature.shape, input_signature.dtype
×
831
    if np.issubdtype(dtype, np.integer):
×
832
      return math.random.bernoulli(rng, 0.5, shape).astype(np.int32)
×
833
    else:
834
      return math.random.uniform(rng, shape, minval=-1.0, maxval=1.0)
×
835
  elif isinstance(input_signature, (list, tuple)):
×
836
    return tuple(_random_values(x, rng) for x in input_signature)
×
837
  else:
838
    raise TypeError(type(input_signature))
×
839

840

841
def _shapes(x):
1✔
842
  """Get a structure of shapes for a structure of nested arrays."""
843
  def shape(x):
×
844
    try:
×
845
      return tuple([int(i) for i in x.shape])
×
846
    except Exception:  # pylint: disable=broad-except
×
847
      return ()
×
848
  return tuple(nested_map(shape, x))
×
849

850

851
def jit_forward(forward, n_devices):
1✔
852
  """Returns a JIT-compiled forward function running on n_devices."""
853
  model_predict = _accelerate(forward, n_devices)
×
854
  if n_devices == 1:
×
855
    return model_predict
×
856

857
  def predict(x, weights, state, rng):
×
858
    """Predict function jited and parallelized as requested."""
859
    res, state = _combine_devices(model_predict(
×
860
        reshape_by_device(x, n_devices),
861
        weights,
862
        state,
863
        jnp.stack(math.random.split(rng, n_devices))))
864
    return math.nested_map(lambda y: jnp.mean(y, axis=0), res), state
×
865

866
  return predict
×
867

868

869
def _combine_devices(x_tuple):
1✔
870
  """Combine multi-device tensors into a single batch."""
871
  def f(x):
×
872
    if len(x.shape) < 2:
×
873
      return x  # No extra batch dimension: use devices as batch, so return.
×
874
    batch_size = x.shape[0] * x.shape[1]
×
875
    return math.numpy.reshape(x, [batch_size] + list(x.shape[2:]))
×
876
  return math.nested_map(f, x_tuple)
×
877

878

879
def _accelerate(f, n_devices):
1✔
880
  """JITed version of f running on n_devices."""
881
  if n_devices == 1:
×
882
    return math.jit(f)
×
883

884
  return math.pmap(f, axis_name='batch')
×
885

886

887
def reshape_by_device(x, n_devices):
1✔
888
  """Reshapes possibly nested x into a shape (n_devices, ...)."""
889
  def f(x):
×
890
    x_shape = list(x.shape)
×
891
    batch_size = x_shape[0]
×
892
    batch_size_per_device = batch_size // n_devices
×
893
    if batch_size_per_device * n_devices != batch_size:
×
894
      raise ValueError(f'Number of devices ({n_devices}) does not evenly '
×
895
                       f'divide batch size ({batch_size}).')
896
    new_shape_prefix = [n_devices, batch_size_per_device]
×
897
    return math.numpy.reshape(x, new_shape_prefix + x_shape[1:])
×
898
  return math.nested_map(f, x)
×
899

900

901
def for_n_devices(x, n_devices):
1✔
902
  """Replicates/broadcasts `x` for n_devices."""
903
  def f(x):
×
904
    if n_devices > 1 and math.backend_name() == 'jax':
×
905
      return _multi_device_put(x)
×
906
    elif n_devices > 1:
×
907
      return jnp.broadcast_to(x, (n_devices,) + x.shape)
×
908
    else:
909
      return x
×
910
  return math.nested_map(f, x)
×
911

912

913
def _multi_device_put(x, devices=None):
1✔
914
  """Memory efficient multi-device replication / broadcast in JAX.
915

916
  JAX uses a ShardedDeviceArray class that holds a list of device buffers
917
  on separate devices for use with pmap'd computations.  Sharded arrays
918
  are explicitly used to eliminate unnecessary inter-device transfer of
919
  memory buffers between use in pmap'd computations.  The JAX API currently
920
  does not have a multi-device 'put' function that copies a buffer onto
921
  N devices in a memory-efficient fashion, so we implement our own here.
922

923
  Args:
924
    x: jax DeviceArray or numpy ndarray to be replicated.
925
    devices: a jax.devices() list or subset thereof of devices to
926
      replicate onto.  Should match the list passed to any pmaps
927
      ingesting the replicated array.
928

929
  Returns:
930
    A ShardedDeviceArray with
931
    dtype = x.dtype and shape = (n_devices,) + x.shape
932
    that's backed by replicated device_buffers on each local device.
933
  """
934
  # Convert _FilledConstants that don't have device_buffer, etc.
935
  if type(x) != jax.xla.DeviceArray:  # pylint: disable=unidiomatic-typecheck
×
936
    x = jnp.array(x)
×
937
  # Calculate the abstract shape of the replicated array.
938
  if not devices:
×
939
    devices = jax.local_devices()
×
940
  n_devices = len(devices)
×
941
  x_aval = jax.xla.abstractify(x)
×
942
  broadcast_x_aval = jax.abstract_arrays.ShapedArray(
×
943
      (n_devices,) + x_aval.shape,
944
      x_aval.dtype)
945
  # Create copies of the underlying device buffer for each local device.
946
  broadcast_buffers = [
×
947
      jax.interpreters.xla.device_put(x, dv)
948
      for dv in devices
949
  ]
950
  return jax.pxla.ShardedDeviceArray(broadcast_x_aval, broadcast_buffers)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2024 Coveralls, Inc