• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

gorgonia / tensor / c35555fa6dc585780eef107d742faf55ba7c2c21

09 Apr 2024 01:50AM UTC coverage: 21.6% (+0.04%) from 21.565%
c35555fa6dc585780eef107d742faf55ba7c2c21

push

github

web-flow
Fix #140 (#141)

* Fix #140

+ Fix SortIndex()
+ Add SortIndexStable()

* `any` is not supported in Go1.15

---------

Co-authored-by: Chewxy <chewxy@gmail.com>

20 of 25 new or added lines in 1 file covered. (80.0%)

33 existing lines in 3 files now uncovered.

13186 of 61046 relevant lines covered (21.6%)

15821.76 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

63.28
/defaultenginefloat32.go
1
package tensor
2

3
import (
4
        "github.com/pkg/errors"
5
        "gorgonia.org/tensor/internal/execution"
6
        "gorgonia.org/tensor/internal/storage"
7

8
        "gorgonia.org/vecf32"
9
)
10

11
func handleFuncOptsF32(expShape Shape, o DataOrder, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) {
783✔
12
        fo := ParseFuncOpts(opts...)
783✔
13

783✔
14
        reuseT, incr := fo.IncrReuse()
783✔
15
        safe = fo.Safe()
783✔
16
        toReuse = reuseT != nil
783✔
17

783✔
18
        if toReuse {
857✔
19
                var ok bool
74✔
20
                if reuse, ok = reuseT.(DenseTensor); !ok {
74✔
21
                        returnOpOpt(fo)
×
22
                        err = errors.Errorf("Cannot reuse a different type of Tensor in a *Dense-Scalar operation. Reuse is of %T", reuseT)
×
23
                        return
×
24
                }
×
25
                if reuse.len() != expShape.TotalSize() && !expShape.IsScalar() {
74✔
26
                        returnOpOpt(fo)
×
27
                        err = errors.Errorf(shapeMismatch, reuse.Shape(), expShape)
×
28
                        err = errors.Wrapf(err, "Cannot use reuse: shape mismatch")
×
29
                        return
×
30
                }
×
31

32
                if !incr && reuse != nil {
110✔
33
                        reuse.setDataOrder(o)
36✔
34
                        // err = reuse.reshape(expShape...)
36✔
35
                }
36✔
36

37
        }
38
        returnOpOpt(fo)
783✔
39
        return
783✔
40
}
41

42
func prepDataVSF32(a Tensor, b interface{}, reuse Tensor) (dataA *storage.Header, dataB float32, dataReuse *storage.Header, ait, iit Iterator, useIter bool, err error) {
2✔
43
        // get data
2✔
44
        dataA = a.hdr()
2✔
45
        switch bt := b.(type) {
2✔
46
        case float32:
2✔
47
                dataB = bt
2✔
48
        case *float32:
×
49
                dataB = *bt
×
50
        default:
×
51
                err = errors.Errorf("b is not a float32: %T", b)
×
52
                return
×
53
        }
54
        if reuse != nil {
4✔
55
                dataReuse = reuse.hdr()
2✔
56
        }
2✔
57

58
        if a.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) {
2✔
59
                ait = a.Iterator()
×
60
                if reuse != nil {
×
61
                        iit = reuse.Iterator()
×
62
                }
×
63
                useIter = true
×
64
        }
65
        return
2✔
66
}
67

68
func (e Float32Engine) checkThree(a, b Tensor, reuse Tensor) error {
787✔
69
        if !a.IsNativelyAccessible() {
1,162✔
70
                return errors.Errorf(inaccessibleData, a)
375✔
71
        }
375✔
72
        if !b.IsNativelyAccessible() {
412✔
73
                return errors.Errorf(inaccessibleData, b)
×
74
        }
×
75

76
        if reuse != nil && !reuse.IsNativelyAccessible() {
412✔
77
                return errors.Errorf(inaccessibleData, reuse)
×
78
        }
×
79

80
        if a.Dtype() != Float32 {
412✔
81
                return errors.Errorf("Expected a to be of Float32. Got %v instead", a.Dtype())
×
82
        }
×
83
        if a.Dtype() != b.Dtype() || (reuse != nil && b.Dtype() != reuse.Dtype()) {
412✔
84
                return errors.Errorf("Expected a, b and reuse to have the same Dtype. Got %v, %v and %v instead", a.Dtype(), b.Dtype(), reuse.Dtype())
×
85
        }
×
86
        return nil
412✔
87
}
88

89
func (e Float32Engine) checkTwo(a Tensor, reuse Tensor) error {
2✔
90
        if !a.IsNativelyAccessible() {
2✔
91
                return errors.Errorf(inaccessibleData, a)
×
92
        }
×
93
        if reuse != nil && !reuse.IsNativelyAccessible() {
2✔
94
                return errors.Errorf(inaccessibleData, reuse)
×
95
        }
×
96

97
        if a.Dtype() != Float32 {
2✔
98
                return errors.Errorf("Expected a to be of Float32. Got %v instead", a.Dtype())
×
99
        }
×
100

101
        if reuse != nil && reuse.Dtype() != a.Dtype() {
2✔
102
                return errors.Errorf("Expected reuse to be the same as a. Got %v instead", reuse.Dtype())
×
103
        }
×
104
        return nil
2✔
105
}
106

107
// Float32Engine is an execution engine that is optimized to only work with float32s. It assumes all data will are float32s.
108
//
109
// Use this engine only as form of optimization. You should probably be using the basic default engine for most cases.
110
type Float32Engine struct {
111
        StdEng
112
}
113

114
// makeArray allocates a slice for the array
115
func (e Float32Engine) makeArray(arr *array, t Dtype, size int) {
13,109✔
116
        if t != Float32 {
13,109✔
117
                panic("Float32Engine only creates float32s")
×
118
        }
119
        if size < 0 {
13,109✔
120
                panic("Cannot have negative sizes when making array")
×
121
        }
122
        arr.Header.Raw = make([]byte, size*4)
13,109✔
123
        arr.t = t
13,109✔
124
}
125

126
func (e Float32Engine) FMA(a, x, y Tensor) (retVal Tensor, err error) {
4✔
127
        reuse := y
4✔
128
        if err = e.checkThree(a, x, reuse); err != nil {
5✔
129
                return nil, errors.Wrap(err, "Failed checks")
1✔
130
        }
1✔
131

132
        var dataA, dataB, dataReuse *storage.Header
3✔
133
        var ait, bit, iit Iterator
3✔
134
        var useIter bool
3✔
135
        if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, x, reuse); err != nil {
3✔
136
                return nil, errors.Wrap(err, "Float32Engine.FMA")
×
137
        }
×
138
        if useIter {
4✔
UNCOV
139
                err = execution.MulIterIncrF32(dataA.Float32s(), dataB.Float32s(), dataReuse.Float32s(), ait, bit, iit)
1✔
UNCOV
140
                retVal = reuse
1✔
UNCOV
141
                return
1✔
UNCOV
142
        }
1✔
143

144
        vecf32.IncrMul(dataA.Float32s(), dataB.Float32s(), dataReuse.Float32s())
2✔
145
        retVal = reuse
2✔
146
        return
2✔
147
}
148

149
func (e Float32Engine) FMAScalar(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) {
2✔
150
        reuse := y
2✔
151
        if err = e.checkTwo(a, reuse); err != nil {
2✔
152
                return nil, errors.Wrap(err, "Failed checks")
×
153
        }
×
154

155
        var ait, iit Iterator
2✔
156
        var dataTensor, dataReuse *storage.Header
2✔
157
        var scalar float32
2✔
158
        var useIter bool
2✔
159
        if dataTensor, scalar, dataReuse, ait, iit, useIter, err = prepDataVSF32(a, x, reuse); err != nil {
2✔
160
                return nil, errors.Wrapf(err, opFail, "Float32Engine.FMAScalar")
×
161
        }
×
162
        if useIter {
2✔
163
                err = execution.MulIterIncrVSF32(dataTensor.Float32s(), scalar, dataReuse.Float32s(), ait, iit)
×
164
                retVal = reuse
×
165
        }
×
166

167
        execution.MulIncrVSF32(dataTensor.Float32s(), scalar, dataReuse.Float32s())
2✔
168
        retVal = reuse
2✔
169
        return
2✔
170
}
171

172
// Add performs a + b elementwise. Both a and b must have the same shape.
173
// Acceptable FuncOpts are: UseUnsafe(), WithReuse(T), WithIncr(T)
174
func (e Float32Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) {
1,263✔
175
        if a.RequiresIterator() || b.RequiresIterator() {
1,743✔
176
                return e.StdEng.Add(a, b, opts...)
480✔
177
        }
480✔
178

179
        var reuse DenseTensor
783✔
180
        var safe, toReuse, incr bool
783✔
181
        if reuse, safe, toReuse, incr, err = handleFuncOptsF32(a.Shape(), a.DataOrder(), opts...); err != nil {
783✔
182
                return nil, errors.Wrap(err, "Unable to handle funcOpts")
×
183
        }
×
184
        if err = e.checkThree(a, b, reuse); err != nil {
1,157✔
185
                return nil, errors.Wrap(err, "Failed checks")
374✔
186
        }
374✔
187

188
        var hdrA, hdrB, hdrReuse *storage.Header
409✔
189
        var dataA, dataB, dataReuse []float32
409✔
190

409✔
191
        if hdrA, hdrB, hdrReuse, _, _, _, _, _, err = prepDataVV(a, b, reuse); err != nil {
409✔
192
                return nil, errors.Wrapf(err, "Float32Engine.Add")
×
193
        }
×
194
        dataA = hdrA.Float32s()
409✔
195
        dataB = hdrB.Float32s()
409✔
196
        if hdrReuse != nil {
450✔
197
                dataReuse = hdrReuse.Float32s()
41✔
198
        }
41✔
199

200
        switch {
409✔
201
        case incr:
19✔
202
                vecf32.IncrAdd(dataA, dataB, dataReuse)
19✔
203
                retVal = reuse
19✔
204
        case toReuse:
22✔
205
                copy(dataReuse, dataA)
22✔
206
                vecf32.Add(dataReuse, dataB)
22✔
207
                retVal = reuse
22✔
208
        case !safe:
358✔
209
                vecf32.Add(dataA, dataB)
358✔
210
                retVal = a
358✔
211
        default:
10✔
212
                ret := a.Clone().(headerer)
10✔
213
                vecf32.Add(ret.hdr().Float32s(), dataB)
10✔
214
                retVal = ret.(Tensor)
10✔
215
        }
216
        return
409✔
217
}
218

219
func (e Float32Engine) Inner(a, b Tensor) (retVal float32, err error) {
×
220
        var A, B []float32
×
221
        var AD, BD *Dense
×
222
        var ok bool
×
223

×
224
        if AD, ok = a.(*Dense); !ok {
×
225
                return 0, errors.Errorf("a is not a *Dense")
×
226
        }
×
227
        if BD, ok = b.(*Dense); !ok {
×
228
                return 0, errors.Errorf("b is not a *Dense")
×
229
        }
×
230

231
        A = AD.Float32s()
×
232
        B = BD.Float32s()
×
233
        retVal = whichblas.Sdot(len(A), A, 1, B, 1)
×
234
        return
×
235
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc