• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / trax / 650

pending completion
650

push

travis-ci

Copybara-Service
Remove new_rng(s) methods and _set_rng_recursive which now becomes just self.rng setter.

PiperOrigin-RevId: 311373864

37 of 37 new or added lines in 9 files covered. (100.0%)

2703 of 11015 relevant lines covered (24.54%)

0.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

27.69
/trax/layers/convolution.py
1
# coding=utf-8
2
# Copyright 2020 The Trax Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# Lint as: python3
17
"""Trax convolution layers."""
1✔
18

19
import functools
1✔
20
import itertools
1✔
21
import operator
1✔
22

23
from trax import math
1✔
24
from trax.layers import base
1✔
25
from trax.layers import initializers as init
1✔
26
from trax.math import numpy as np
1✔
27

28

29
class Conv(base.Layer):
1✔
30
  """Layer constructor function for a general convolution layer."""
31

32
  def __init__(self, filters, kernel_size, strides=None, padding='VALID',
1✔
33
               dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
34
               kernel_initializer=None,
35
               bias_initializer=init.RandomNormalInitializer(1e-6)):
36
    super(Conv, self).__init__()
×
37
    self._filters = filters
×
38
    self._kernel_size = kernel_size
×
39
    self._padding = padding
×
40
    self._dimension_numbers = dimension_numbers
×
41
    self._lhs_spec, self._rhs_spec, self._out_spec = dimension_numbers
×
42
    self._one = (1,) * len(kernel_size)
×
43
    self._strides = strides or self._one
×
44
    self._bias_initializer = bias_initializer
×
45
    rhs_spec = self._rhs_spec
×
46
    self._kernel_initializer = kernel_initializer
×
47
    if kernel_initializer is None:
×
48
      self._kernel_initializer = init.GlorotNormalInitializer(
×
49
          rhs_spec.index('O'), rhs_spec.index('I'))
50

51
  def _check_nhwc(self):
1✔
52
    msg = 'Convolutions on more than 4 dimensions only supported in NHWC.'
×
53
    assert self._lhs_spec == self._out_spec == 'NHWC', msg
×
54

55
  def forward(self, x, weights):
1✔
56
    w, b = weights
×
57
    x_shape = list(x.shape)
×
58
    if len(x_shape) > 4:
×
59
      self._check_nhwc()
×
60
      new_batch_dim = functools.reduce(operator.mul, x_shape[:-3])
×
61
      x = np.reshape(x, [new_batch_dim] + x_shape[-3:])
×
62
    res = math.conv(
×
63
        x, w, self._strides, self._padding, self._dimension_numbers,
64
        self._one) + b
65
    if len(x_shape) > 4:
×
66
      res = np.reshape(res, x_shape[:-3] + list(res.shape[-3:]))
×
67
    return res
×
68

69
  def _kernel_shape(self, input_shape):
1✔
70
    """Helper to calculate the kernel shape."""
71
    kernel_size_iter = iter(self._kernel_size)
×
72
    return [self._filters if c == 'O' else
×
73
            input_shape[self._lhs_spec.index('C')] if c == 'I' else
74
            next(kernel_size_iter) for c in self._rhs_spec]
75

76
  def new_weights(self, input_signature):
1✔
77
    input_shape = input_signature.shape
×
78
    if len(input_shape) > 4:
×
79
      self._check_nhwc()
×
80
      new_batch_dim = functools.reduce(operator.mul, input_shape[:-3])
×
81
      input_shape = [new_batch_dim] + list(input_shape[-3:])
×
82
    kernel_shape = self._kernel_shape(input_shape)
×
83
    bias_shape = [self._filters if c == 'C' else 1 for c in self._out_spec]
×
84
    bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
×
85
    rng1, rng2 = math.random.split(self.rng, 2)
×
86
    w = self._kernel_initializer(kernel_shape, rng1)
×
87
    b = self._bias_initializer(bias_shape, rng2)
×
88
    return (w, b)
×
89

90

91
class CausalConv(Conv):
1✔
92
  """Causal (masked) convolution for [batch x time x depth] sequences.
93

94
  Maintains causality along time axis. Used in language modeling tasks.
95
  """
96

97
  def __init__(self,
1✔
98
               filters,
99
               kernel_width=3,
100
               kernel_initializer=None,
101
               bias_initializer=init.RandomNormalInitializer(1e-6)):
102
    super(CausalConv, self).__init__(
×
103
        filters=filters,
104
        kernel_size=(kernel_width,),
105
        strides=None,
106
        padding='VALID',
107
        dimension_numbers=('NWC', 'WIO', 'NWC'),
108
        kernel_initializer=kernel_initializer,
109
        bias_initializer=bias_initializer)
110

111
  def forward(self, x, weights):
1✔
112
    assert self._padding == 'VALID'
×
113
    # Left pad with 0s. Applying an unmasked valid convolution on top of this
114
    # yields a causal convolution.
115
    # TODO(ddohan): Support strided and dilated convolutions.
116
    rate = 1
×
117
    effective_kernel_size = int((self._kernel_size[0] - 1) * rate + 1)
×
118
    pad = effective_kernel_size - 1
×
119
    x_leftpad = np.pad(x, pad_width=[[0, 0], [pad, 0], [0, 0]], mode='constant')
×
120
    return super(CausalConv, self).forward(x_leftpad, weights)
×
121

122

123
def Conv1d(filters, kernel_size, stride=1, padding='VALID',
1✔
124
           kernel_initializer=None,
125
           bias_initializer=init.RandomNormalInitializer(1e-6)):
126
  return Conv(filters, (kernel_size,), strides=(stride,), padding=padding,
×
127
              dimension_numbers=('NWC', 'WIO', 'NWC'),
128
              kernel_initializer=kernel_initializer,
129
              bias_initializer=bias_initializer)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2024 Coveralls, Inc