• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

harvardnlp / namedtensor / 189

pending completion
189

Pull #41

travis-ci

web-flow
.
Pull Request #41: Split out nn and distributions into their own directories

20 of 20 new or added lines in 11 files covered. (100.0%)

153 of 954 relevant lines covered (16.04%)

0.16 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

22.13
/namedtensor/core.py
1
from .schema import _Schema
1✔
2
from einops import rearrange
1✔
3

4

5
def assert_match(*tensors):
1✔
6
    sizes = {}
×
7
    failure = False
×
8
    for t in tensors:
×
9
        shape = t.vshape
×
10
        for i, k in t._schema.enum_all():
×
11
            v = shape[i]
×
12
            if v == 1:
×
13
                continue
×
14
            if k in sizes:
×
15
                failure = failure or sizes[k] != v
×
16
            else:
17
                sizes[k] = v
×
18
    assert not failure, " ".join([str(t._sizes) for t in tensors])
×
19

20

21
class NamedTensorBase:
1✔
22
    """
23
    Attributes:
24
        tensor: The raw tensor data
25
        dims: Tuple of unique dimension names associated with this array.
26
        ndim: Number of dimensions
27
        sizes: The raw dimension sizes
28
        shape: Ordered mapping from dimension names to lengths.
29
    """
30

31
    def __init__(self, tensor, names, mask=0):
1✔
32
        self._tensor = tensor
×
33
        self._schema = _Schema.build(names, mask)
×
34
        assert len(self._tensor.shape) == len(self._schema._names), (
×
35
            "Tensor has %d dim, but only %d names"
36
            % (len(self._tensor.shape), len(self._schema._names))
37
        )
38

39
    @property
1✔
40
    def dims(self):
41
        "Return the dim names for the tensor"
42
        return tuple(self._schema._names)
×
43

44
    @property
1✔
45
    def vshape(self):
46
        "The raw dim size for the tensor."
47
        return tuple(self._tensor.size())
×
48

49
    @property
1✔
50
    def shape(self):
51
        "The ordered dict of available dimensions."
52
        return self._schema.ordered_dict(self._tensor.size())
×
53

54
    def __repr__(self):
1✔
55
        return "NamedTensor(\n\t%s,\n\t%s)" % (
×
56
            self._tensor,
57
            self._schema._names,
58
        )
59

60
    def size(self, dim):
1✔
61
        "Return the raw shape of the tensor"
62
        i = self._schema.get(dim)
×
63
        return self._tensor.size(i)
×
64

65
    def assert_size(self, **kwargs):
1✔
66
        "Return the raw shape of the tensor"
67
        for dim, v in kwargs.items():
×
68
            i = self._schema.get(dim)
×
69
            assert self._tensor.size(i) == v, (
×
70
                "Size of %s should be %d, got %d"
71
                % (dim, v, self._tensor.size(i))
72
            )
73
        return self
×
74

75
    @property
1✔
76
    def values(self):
77
        "The raw underlying tensor object."
78
        return self._tensor
×
79

80
    def _new(self, tensor, drop=None, add=None, updates={}, mask=None):
1✔
81
        return self.__class__(
×
82
            tensor,
83
            self._schema.drop(drop).update(updates)._names
84
            + (() if not add else add),
85
            self._schema._masked if mask is None else mask,
86
        )
87

88
    def _to_einops(self):
1✔
89
        return self._schema._to_einops()
×
90

91
    def mask_to(self, name):
1✔
92
        if name == "":
×
93
            return self._new(self._tensor, mask=0)
×
94
        else:
95
            return self._new(self._tensor, mask=self._schema.get(name) + 1)
×
96

97
    def stack(self, dims, name):
1✔
98
        "Stack any number of existing dimensions into a single new dimension."
99
        for dim in dims:
×
100
            self._schema.get(dim)
×
101
        return self._merge(dims, name)
×
102

103
    def split(self, dim, names, **dim_sizes):
1✔
104
        "Split an of existing dimension into new dimensions."
105
        return self._split(dim, names, dim_sizes)
×
106

107
    def rename(self, dim, name):
1✔
108
        "Rename a dimension."
109
        return self._split(dim, (name,), {})
×
110

111
    def transpose(self, *dims):
1✔
112
        "Return a new DataArray object with transposed dimensions."
113
        for dim in dims:
×
114
            self._schema.get(dim)
×
115
        to_dims = (
×
116
            tuple((d for d in self._schema._names if d not in dims)) + dims
117
        )
118
        recipe = "%s -> %s" % (self._to_einops(), " ".join(to_dims))
×
119
        tensor = rearrange(self._tensor, recipe)
×
120
        return self.__class__(tensor, to_dims)
×
121

122
    # Todo: fix arg names
123
    def _merge(self, names, dim):
1✔
124
        s = ""
×
125
        ex = []
×
126
        first = True
×
127
        for d in self._schema._names:
×
128
            if d not in names:
×
129
                s += " " + d
×
130
                ex.append(d)
×
131
            elif first:
×
132
                s += " (" + " ".join(names) + ")"
×
133
                ex.append(dim)
×
134
                first = False
×
135
        tensor = rearrange(
×
136
            self._tensor, "%s -> %s" % (self._schema._to_einops(), s)
137
        )
138
        return self.__class__(tensor, ex)
×
139

140
    def _split(self, dim, names, size_dict):
1✔
141
        query = ""
×
142
        ex = []
×
143
        for i, d in self._schema.enum_all():
×
144
            if d != dim:
×
145
                query += " " + d
×
146
                ex.append(d)
×
147
            else:
148
                query += " (" + " ".join(names) + ")"
×
149
                ex += names
×
150

151
        tensor = rearrange(
×
152
            self._tensor,
153
            "%s -> %s" % (query, " ".join(ex)),
154
            **{d: size_dict[d] for d in names if d in size_dict}
155
        )
156
        return self.__class__(tensor, ex)
×
157

158
    def _rearrange(self, term):
1✔
159
        assert ")" not in term
×
160
        recipe = "%s -> %s" % (self._to_einops(), term)
×
161
        tensor = rearrange(self._tensor, recipe)
×
162
        return self.__class__(tensor, term)
×
163

164
    def __len__(self):
1✔
165
        return len(self._tensor)
×
166

167
    def _promote(self, dims):
1✔
168
        "Move dims to the front of the line"
169
        term = " ".join(
×
170
            [d for d in self._schema._names if d not in dims]
171
            + dims.split()[1:]
172
        )
173
        return self._rearrange(term)
×
174

175
    def _force_order(self, names):
1✔
176
        """ Forces self to take order in names, adds 1-size dims if needed """
177
        s = ""
×
178
        ex = []
×
179
        for d in names:
×
180
            if d not in self._schema._names:
×
181
                ex.append(d)
×
182
                s += " ()"
×
183
            else:
184
                ex.append(d)
×
185
                s += " " + d
×
186
        tensor = rearrange(self._tensor, "%s -> %s" % (self._to_einops(), s))
×
187
        return self.__class__(tensor, ex)
×
188

189
    def _broadcast_order(self, other):
1✔
190
        """ Outputs a shared order (list) that works for self and other """
191
        order = []
×
192
        for d in other._schema._names:
×
193
            if d not in self._schema._names:
×
194
                order.append(d)
×
195
        for d in self._schema._names:
×
196
            order.append(d)
×
197
        return order
×
198

199
    def _mask_broadcast_order(self, main):
1✔
200
        """
201
        If broadcasting possible from self (mask) to main, outputs a shared order.
202
        Otherwise errors and prints dimensions that exist in mask but not in main.
203
        """
204

205
        to_be_broadcasted = set(self._schema._names)
×
206
        broadcasted_to = set(main._schema._names)
×
207

208
        diff = to_be_broadcasted.difference(broadcasted_to)
×
209
        diff_string = ", ".join(diff)
×
210

211
        assert len(diff) == 0, (
×
212
            "Attemped to broadcast mask but unable to broadcast dimensions %s"
213
            % diff_string
214
        )
215

216
        return self._broadcast_order(main)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2024 Coveralls, Inc