• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

numenta / htm.java / #1206

26 Jan 2015 12:57PM UTC coverage: 14.404% (-0.005%) from 14.409%
#1206

push

David Ray
Merge pull request #168 from cogmission/network_api_work

testing the RNG is not part of the scope

714 of 4957 relevant lines covered (14.4%)

0.14 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

58.05
/src/main/java/org/numenta/nupic/encoders/ScalarEncoder.java
1
/* ---------------------------------------------------------------------
2
 * Numenta Platform for Intelligent Computing (NuPIC)
3
 * Copyright (C) 2014, Numenta, In  Unless you have an agreement
4
 * with Numenta, In, for a separate license for this software code, the
5
 * following terms and conditions apply:
6
 *
7
 * This program is free software: you can redistribute it and/or modify
8
 * it under the terms of the GNU General Public License version 3 as
9
 * published by the Free Software Foundation.
10
 *
11
 * This program is distributed in the hope that it will be useful,
12
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
14
 * See the GNU General Public License for more details.
15
 *
16
 * You should have received a copy of the GNU General Public License
17
 * along with this program.  If not, see http://www.gnu.org/licenses.
18
 *
19
 * http://numenta.org/licenses/
20
 * ---------------------------------------------------------------------
21
 */
22

23
package org.numenta.nupic.encoders;
24

25
import gnu.trove.list.TDoubleList;
26
import gnu.trove.list.array.TDoubleArrayList;
27
import org.numenta.nupic.Connections;
28
import org.numenta.nupic.FieldMetaType;
29
import org.numenta.nupic.util.ArrayUtils;
30
import org.numenta.nupic.util.Condition;
31
import org.numenta.nupic.util.MinMax;
32
import org.numenta.nupic.util.SparseObjectMatrix;
33
import org.numenta.nupic.util.Tuple;
34

35
import java.util.ArrayList;
36
import java.util.Arrays;
37
import java.util.HashMap;
38
import java.util.List;
39
import java.util.Map;
40

41

42
/**
43
 * DOCUMENTATION TAKEN DIRECTLY FROM THE PYTHON VERSION:
44
 * 
45
 * A scalar encoder encodes a numeric (floating point) value into an array
46
 * of bits. The output is 0's except for a contiguous block of 1's. The
47
 * location of this contiguous block varies continuously with the input value.
48
 *
49
 * The encoding is linear. If you want a nonlinear encoding, just transform
50
 * the scalar (e.g. by applying a logarithm function) before encoding.
51
 * It is not recommended to bin the data as a pre-processing step, e.g.
52
 * "1" = $0 - $.20, "2" = $.21-$0.80, "3" = $.81-$1.20, et as this
53
 * removes a lot of information and prevents nearby values from overlapping
54
 * in the output. Instead, use a continuous transformation that scales
55
 * the data (a piecewise transformation is fine).
56
 *
57
 *
58
 * Parameters:
59
 * -----------------------------------------------------------------------------
60
 * w --        The number of bits that are set to encode a single value - the
61
 *             "width" of the output signal
62
 *             restriction: w must be odd to avoid centering problems.
63
 *
64
 * minval --   The minimum value of the input signal.
65
 *
66
 * maxval --   The upper bound of the input signal
67
 *
68
 * periodic -- If true, then the input value "wraps around" such that minval = maxval
69
 *             For a periodic value, the input must be strictly less than maxval,
70
 *             otherwise maxval is a true upper bound.
71
 *
72
 * There are three mutually exclusive parameters that determine the overall size of
73
 * of the output. Only one of these should be specifed to the constructor:
74
 *
75
 * n      --      The number of bits in the output. Must be greater than or equal to w
76
 * radius --      Two inputs separated by more than the radius have non-overlapping
77
 *                representations. Two inputs separated by less than the radius will
78
 *                in general overlap in at least some of their bits. You can think
79
 *                of this as the radius of the input.
80
 * resolution --  Two inputs separated by greater than, or equal to the resolution are guaranteed
81
 *                 to have different representations.
82
 *
83
 * Note: radius and resolution are specified w.r.t the input, not output. w is
84
 * specified w.r.t. the output.
85
 *
86
 * Example:
87
 * day of week.
88
 * w = 3
89
 * Minval = 1 (Monday)
90
 * Maxval = 8 (Monday)
91
 * periodic = true
92
 * n = 14
93
 * [equivalently: radius = 1.5 or resolution = 0.5]
94
 *
95
 * The following values would encode midnight -- the start of the day
96
 * monday (1)   -> 11000000000001
97
 * tuesday(2)   -> 01110000000000
98
 * wednesday(3) -> 00011100000000
99
 * ...
100
 * sunday (7)   -> 10000000000011
101
 *
102
 * Since the resolution is 12 hours, we can also encode noon, as
103
 * monday noon  -> 11100000000000
104
 * monday midnight-> 01110000000000
105
 * tuesday noon -> 00111000000000
106
 * et
107
 *
108
 *
109
 * It may not be natural to specify "n", especially with non-periodic
110
 * data. For example, consider encoding an input with a range of 1-10
111
 * (inclusive) using an output width of 5.  If you specify resolution =
112
 * 1, this means that inputs of 1 and 2 have different outputs, though
113
 * they overlap, but 1 and 1.5 might not have different outputs.
114
 * This leads to a 14-bit representation like this:
115
 *
116
 * 1 ->  11111000000000  (14 bits total)
117
 * 2 ->  01111100000000
118
 * ...
119
 * 10->  00000000011111
120
 * [resolution = 1; n=14; radius = 5]
121
 *
122
 * You could specify resolution = 0.5, which gives
123
 * 1   -> 11111000... (22 bits total)
124
 * 1.5 -> 011111.....
125
 * 2.0 -> 0011111....
126
 * [resolution = 0.5; n=22; radius=2.5]
127
 *
128
 * You could specify radius = 1, which gives
129
 * 1   -> 111110000000....  (50 bits total)
130
 * 2   -> 000001111100....
131
 * 3   -> 000000000011111...
132
 * ...
133
 * 10  ->                           .....000011111
134
 * [radius = 1; resolution = 0.2; n=50]
135
 *
136
 *
137
 * An N/M encoding can also be used to encode a binary value,
138
 * where we want more than one bit to represent each state.
139
 * For example, we could have: w = 5, minval = 0, maxval = 1,
140
 * radius = 1 (which is equivalent to n=10)
141
 * 0 -> 1111100000
142
 * 1 -> 0000011111
143
 *
144
 *
145
 * Implementation details:
146
 * --------------------------------------------------------------------------
147
 * range = maxval - minval
148
 * h = (w-1)/2  (half-width)
149
 * resolution = radius / w
150
 * n = w * range/radius (periodic)
151
 * n = w * range/radius + 2 * h (non-periodic)
152
 * 
153
 * @author metaware
154
 */
155
public class ScalarEncoder extends Encoder<Double> {
156
        /**
157
         * Constructs a new {@code ScalarEncoder}
158
         */
159
        ScalarEncoder() {}
1✔
160
        
161
        /**
162
         * Returns a builder for building ScalarEncoders. 
163
         * This builder may be reused to produce multiple builders
164
         * 
165
         * @return a {@code ScalarEncoder.Builder}
166
         */
167
        public static Encoder.Builder<ScalarEncoder.Builder, ScalarEncoder> builder() {
168
                return new ScalarEncoder.Builder();
1✔
169
        }
170
        
171
        /**
172
         * Returns true if the underlying encoder works on deltas
173
         */
174
        public boolean isDelta() {
175
                return false;
176
        }
×
177
        
178
        /**
179
         * w -- number of bits to set in output
180
     * minval -- minimum input value
181
     * maxval -- maximum input value (input is strictly less if periodic == True)
182
         *
183
     * Exactly one of n, radius, resolution must be set. "0" is a special
184
     * value that means "not set".
185
         *
186
     * n -- number of bits in the representation (must be > w)
187
     * radius -- inputs separated by more than, or equal to this distance will have non-overlapping
188
     * representations
189
     * resolution -- inputs separated by more than, or equal to this distance will have different
190
     * representations
191
         * 
192
     * name -- an optional string which will become part of the description
193
         *
194
     * clipInput -- if true, non-periodic inputs smaller than minval or greater
195
     * than maxval will be clipped to minval/maxval
196
         *
197
     * forced -- if true, skip some safety checks (for compatibility reasons), default false
198
         */
199
        public void init() {
200
                if(getW() % 2 == 0) {
201
                        throw new IllegalStateException(
1✔
202
                                "W must be an odd number (to eliminate centering difficulty)");
×
203
                }
204
                
205
                setHalfWidth((getW() - 1) / 2);
206
                
1✔
207
                // For non-periodic inputs, padding is the number of bits "outside" the range,
208
            // on each side. I.e. the representation of minval is centered on some bit, and
209
            // there are "padding" bits to the left of that centered bit; similarly with
210
            // bits to the right of the center bit of maxval
211
                setPadding(isPeriodic() ? 0 : getHalfWidth());
212
                
1✔
213
                if(!Double.isNaN(getMinVal()) && !Double.isNaN(getMinVal())) {
214
                        if(getMinVal() >= getMaxVal()) {
1✔
215
                                throw new IllegalStateException("maxVal must be > minVal");
1✔
216
                        }
×
217
                        setRangeInternal(getMaxVal() - getMinVal());
218
                }
1✔
219
                
220
                // There are three different ways of thinking about the representation. Handle
221
            // each case here.
222
                initEncoder(getW(), getMinVal(), getMaxVal(), getN(), getRadius(), getResolution());
223
                
1✔
224
                //nInternal represents the output area excluding the possible padding on each side
225
                setNInternal(getN() - 2 * getPadding());
226
                
1✔
227
                if(getName() == null) {
228
                        if((getMinVal() % ((int)getMinVal())) > 0 ||
1✔
229
                            (getMaxVal() % ((int)getMaxVal())) > 0) {
1✔
230
                                setName("[" + getMinVal() + ":" + getMaxVal() + "]");
1✔
231
                        }else{
×
232
                                setName("[" + (int)getMinVal() + ":" + (int)getMaxVal() + "]");
233
                        }
1✔
234
                }
235
                
236
                //Checks for likely mistakes in encoder settings
237
                if(!isForced()) {
238
                        checkReasonableSettings();
1✔
239
                }
×
240
        description.add(new Tuple(2, (name = getName()) == "None" ? "[" + (int)getMinVal() + ":" + (int)getMaxVal() + "]" : name, 0));
241
        }
1✔
242
        
1✔
243
        /**
244
         * There are three different ways of thinking about the representation. 
245
     * Handle each case here.
246
     * 
247
         * @param c
248
         * @param minVal
249
         * @param maxVal
250
         * @param n
251
         * @param radius
252
         * @param resolution
253
         */
254
        public void initEncoder(int w, double minVal, double maxVal, int n, double radius, double resolution) {
255
                if(n != 0) {
256
                        if(minVal != 0 && maxVal != 0) {
1✔
257
                            if(!isPeriodic()) {
1✔
258
                                        setResolution(getRangeInternal() / (getN() - getW()));
1✔
259
                                }else{
1✔
260
                                        setResolution(getRangeInternal() / getN());
261
                                }
×
262
                                
263
                                setRadius(getW() * getResolution());
264
                                
1✔
265
                                if(isPeriodic()) {
266
                                        setRange(getRangeInternal());
1✔
267
                                }else{
×
268
                                        setRange(getRangeInternal() + getResolution());
269
                                }
1✔
270
                        }
271
                }else{
272
                        if(radius != 0) {
273
                                setResolution(getRadius() / w);
1✔
274
                        }else if(resolution != 0) {
1✔
275
                                setRadius(getResolution() * w);
1✔
276
                        }else{
1✔
277
                                throw new IllegalStateException(
278
                                        "One of n, radius, resolution must be specified for a ScalarEncoder");
×
279
                        }
280
                        
281
                        if(isPeriodic()) {
282
                                setRange(getRangeInternal());
1✔
283
                        }else{
×
284
                                setRange(getRangeInternal() + getResolution());
285
                        }
1✔
286
                        
287
                        double nFloat = w * (getRange() / getRadius()) + 2 * getPadding();
288
                        setN((int)Math.ceil(nFloat));
1✔
289
                }
1✔
290
        }
291

1✔
292
        /**
293
         * Return the bit offset of the first bit to be set in the encoder output.
294
     * For periodic encoders, this can be a negative number when the encoded output
295
     * wraps around.
296
     * 
297
         * @param c                        the memory
298
         * @param input                the input data
299
         * @return                        an encoded array
300
         */
301
        public Integer getFirstOnBit(double input) {
302
                if(input == SENTINEL_VALUE_FOR_MISSING_DATA) {
303
                        return null;
1✔
304
                }else{
×
305
                        if(input < getMinVal()) {
306
                                if(clipInput() && !isPeriodic()) {
1✔
307
                                        if(getVerbosity() > 0) {
×
308
                                                System.out.println("Clipped input " + getName() +
×
309
                                                        "=" + input + " to minval " + getMinVal());
×
310
                                        }
×
311
                                        input = getMinVal();
312
                                }else{
×
313
                                        throw new IllegalStateException("input (" + input +") less than range (" +
314
                                                getMinVal() + " - " + getMaxVal());
×
315
                                }
×
316
                        }
317
                }
318
                
319
                if(isPeriodic()) {
320
                        if(input >= getMaxVal()) {
1✔
321
                                throw new IllegalStateException("input (" + input +") greater than periodic range (" +
×
322
                                        getMinVal() + " - " + getMaxVal());
×
323
                        }
×
324
                }else{
325
                        if(input > getMaxVal()) {
326
                                if(clipInput()) {
1✔
327
                                        if(getVerbosity() > 0) {
×
328
                                                System.out.println("Clipped input " + getName() + "=" + input + " to maxval " + getMaxVal());
×
329
                                        }
×
330
                                        
331
                                        input = getMaxVal();
332
                                }else{
×
333
                                        throw new IllegalStateException("input (" + input +") greater than periodic range (" +
334
                                                getMinVal() + " - " + getMaxVal());
×
335
                                }
×
336
                        }
337
                }
338
                
339
                int centerbin;
340
                if(isPeriodic()) {
341
                        centerbin = (int)((int)((input - getMinVal()) *  getNInternal() / getRange())) + getPadding();
1✔
342
                }else{
×
343
                        centerbin = (int)((int)(((input - getMinVal()) + getResolution()/2) / getResolution())) + getPadding();
344
                }
1✔
345
                
346
                int minbin = centerbin - getHalfWidth();
347
                return minbin;
1✔
348
        }
349
        
350
        /**
351
         * Check if the settings are reasonable for the SpatialPooler to work
352
         * @param c
353
         */
354
        public void checkReasonableSettings() {
355
                if(getW() < 21) {
×
356
                        throw new IllegalStateException(
×
357
                                "Number of bits in the SDR (%d) must be greater than 2, and recommended >= 21 (use forced=True to override)");
358
                }
359
        }
×
360
        
361
        /**
362
         * {@inheritDoc}
363
         */
364
        @Override
365
        public List<FieldMetaType> getDecoderOutputFieldTypes() {
366
                return Arrays.asList(new FieldMetaType[] { FieldMetaType.FLOAT });
1✔
367
        }
368
        
369
        /**
370
         * Should return the output width, in bits.
371
         */
372
        public int getWidth() {
373
                return getN();
374
        }
1✔
375
        
376
        /**
377
         * {@inheritDoc}
378
         * NO-OP
379
         */
380
        @Override
381
        public int[] getBucketIndices(String input) { return null; }
382
        
×
383
        /**
384
         * Returns the bucket indices.
385
         * 
386
         * @param        input         
387
         */
388
        @Override
389
        public int[] getBucketIndices(double input) {
390
                int minbin = getFirstOnBit(input);
391
                
1✔
392
                //For periodic encoders, the bucket index is the index of the center bit
393
                int bucketIdx;
394
                if(isPeriodic()) {
395
                        bucketIdx = minbin + getHalfWidth();
1✔
396
                        if(bucketIdx < 0) {
×
397
                                bucketIdx += getN();
×
398
                        }
×
399
                }else{//for non-periodic encoders, the bucket index is the index of the left bit
400
                        bucketIdx = minbin;
401
                }
1✔
402
                
403
                return new int[] { bucketIdx };
404
        }
1✔
405
        
406
        /**
407
         * Encodes inputData and puts the encoded value into the numpy output array,
408
     * which is a 1-D array of length returned by {@link Connections#getW()}.
409
         *
410
     * Note: The numpy output array is reused, so clear it before updating it.
411
         * @param inputData Data to encode. This should be validated by the encoder.
412
         * @param output 1-D array of same length returned by {@link Connections#getW()}
413
     * 
414
         * @return
415
         */
416
        @Override
417
        public void encodeIntoArray(Double input, int[] output) {
418
                if(Double.isNaN(input)) {
419
                        Arrays.fill(output, 0);
1✔
420
                        return;
1✔
421
                }
1✔
422
                
423
                Integer bucketVal = getFirstOnBit(input);
424
                if(bucketVal != null) {
1✔
425
                        int bucketIdx = bucketVal;
1✔
426
                        Arrays.fill(output, 0);
1✔
427
                        int minbin = bucketIdx;
1✔
428
                        int maxbin = minbin + 2*getHalfWidth();
1✔
429
                        if(isPeriodic()) {
1✔
430
                                if(maxbin >= getN()) {
1✔
431
                                        int bottombins = maxbin - getN() + 1;
×
432
                                        int[] range = ArrayUtils.range(0, bottombins);
×
433
                                        ArrayUtils.setIndexesTo(output, range, 1);
×
434
                                        maxbin = getN() - 1;
×
435
                                }
×
436
                                if(minbin < 0) {
437
                                        int topbins = -minbin;
×
438
                                        ArrayUtils.setIndexesTo(
×
439
                                                output, ArrayUtils.range(getN() - topbins, getN()), 1);
×
440
                                        minbin = 0;
×
441
                                }
×
442
                        }
443
                        
444
                        ArrayUtils.setIndexesTo(output, ArrayUtils.range(minbin, maxbin + 1), 1);
445
                }
1✔
446
                
447
                if(getVerbosity() >= 2) {
448
                        System.out.println("");
1✔
449
                        System.out.println("input: " + input);
×
450
                        System.out.println("range: " + getMinVal() + " - " + getMaxVal());
×
451
                        System.out.println("n:" + getN() + "w:" + getW() + "resolution:" + getResolution() +
×
452
                                "radius:" + getRadius() + "periodic:" + isPeriodic());
×
453
                        System.out.println("output: " + Arrays.toString(output));
×
454
                        System.out.println("input desc: " + decode(output, ""));
×
455
                }
×
456
        }
457

1✔
458
        public DecodeResult decode(int[] encoded, String parentFieldName) {
459
                // For now, we simply assume any top-down output greater than 0
460
            // is ON. Eventually, we will probably want to incorporate the strength
461
            // of each top-down output.
462
                if(encoded == null || encoded.length < 1) { 
463
                        return null;
464
                }
465
                int[] tmpOutput = Arrays.copyOf(encoded, encoded.length);
466
                
467
                // ------------------------------------------------------------------------
468
            // First, assume the input pool is not sampled 100%, and fill in the
469
            //  "holes" in the encoded representation (which are likely to be present
470
            //  if this is a coincidence that was learned by the SP).
471

472
            // Search for portions of the output that have "holes"
473
                int maxZerosInARow = getHalfWidth();
1✔
474
                for(int i = 0;i < maxZerosInARow;i++) {
×
475
                        int[] searchStr = new int[i + 3];
476
                        Arrays.fill(searchStr, 1);
1✔
477
                        ArrayUtils.setRangeTo(searchStr, 1, -1, 0);
478
                        int subLen = searchStr.length;
479
                        
480
                        // Does this search string appear in the output?
481
                        if(isPeriodic()) {
482
                                for(int j = 0;j < getN();j++) {
483
                                        int[] outputIndices = ArrayUtils.range(j, j + subLen);
484
                                        outputIndices = ArrayUtils.modulo(outputIndices, getN());
1✔
485
                                        if(Arrays.equals(searchStr, ArrayUtils.sub(tmpOutput, outputIndices))) {
1✔
486
                                                ArrayUtils.setIndexesTo(tmpOutput, outputIndices, 1);
1✔
487
                                        }
1✔
488
                                }
1✔
489
                        }else{
1✔
490
                                for(int j = 0;j < getN() - subLen + 1;j++) {
491
                                        if(Arrays.equals(searchStr, ArrayUtils.sub(tmpOutput, ArrayUtils.range(j, j + subLen)))) {
492
                                                ArrayUtils.setRangeTo(tmpOutput, j, j + subLen, 1);
1✔
493
                                        }
×
494
                                }
×
495
                        }
×
496
                }
×
497
                
×
498
                if(getVerbosity() >= 2) {
499
                        System.out.println("raw output:" + Arrays.toString(
500
                                ArrayUtils.sub(encoded, ArrayUtils.range(0, getN()))));
501
                        System.out.println("filtered output:" + Arrays.toString(tmpOutput));
1✔
502
                }
1✔
503
                
×
504
                // ------------------------------------------------------------------------
505
            // Find each run of 1's.
506
                int[] nz = ArrayUtils.where(tmpOutput, new Condition.Adapter<Integer>() {
507
                        public boolean eval(int n) {
508
                                return n > 0;
509
                        }
1✔
510
                });
×
511
                List<Tuple> runs = new ArrayList<Tuple>(); //will be tuples of (startIdx, runLength)
×
512
                Arrays.sort(nz);
×
513
                int[] run = new int[] { nz[0], 1 };
514
                int i = 1;
515
                while(i < nz.length) {
516
                        if(nz[i] == run[0] + run[1]) {
517
                                run[1] += 1;
1✔
518
                        }else{
519
                                runs.add(new Tuple(2, run[0], run[1]));
520
                                run = new int[] { nz[i], 1 };
1✔
521
                        }
522
                        i += 1;
523
                }
1✔
524
                runs.add(new Tuple(2, run[0], run[1]));
1✔
525
                
1✔
526
                // If we have a periodic encoder, merge the first and last run if they
1✔
527
            // both go all the way to the edges
1✔
528
                if(isPeriodic() && runs.size() > 1) {
1✔
529
                        int l = runs.size() - 1;
1✔
530
                        if(((Integer)runs.get(0).get(0)) == 0 && ((Integer)runs.get(l).get(0)) + ((Integer)runs.get(l).get(1)) == getN()) {
531
                                runs.set(l, new Tuple(2, 
×
532
                                        (Integer)runs.get(l).get(0),  
×
533
                                                ((Integer)runs.get(l).get(1)) + ((Integer)runs.get(0).get(1)) ));
534
                                runs = runs.subList(1, runs.size());
1✔
535
                        }
536
                }
1✔
537
                
538
                // ------------------------------------------------------------------------
539
            // Now, for each group of 1's, determine the "left" and "right" edges, where
540
            // the "left" edge is inset by halfwidth and the "right" edge is inset by
1✔
541
            // halfwidth.
×
542
            // For a group of width w or less, the "left" and "right" edge are both at
×
543
            // the center position of the group.
×
544
                int left = 0;
×
545
                int right = 0;
×
546
                List<MinMax> ranges = new ArrayList<MinMax>();
×
547
                for(Tuple tupleRun : runs) {
548
                        int start = (Integer)tupleRun.get(0);
549
                        int runLen = (Integer)tupleRun.get(1);
550
                        if(runLen <= getW()) {
551
                                left = right = start + runLen / 2;
552
                        }else{
553
                                left = start + getHalfWidth();
554
                                right = start + runLen - 1 - getHalfWidth();
555
                        }
556
                        
1✔
557
                        double inMin, inMax;
1✔
558
                        // Convert to input space.
1✔
559
                        if(!isPeriodic()) {
1✔
560
                                inMin = (left - getPadding()) * getResolution() + getMinVal();
1✔
561
                                inMax = (right - getPadding()) * getResolution() + getMinVal();
1✔
562
                        }else{
1✔
563
                                inMin = (left - getPadding()) * getRange() / getNInternal() + getMinVal();
1✔
564
                                inMax = (right - getPadding()) * getRange() / getNInternal() + getMinVal();
565
                        }
×
566
                        // Handle wrap-around if periodic
×
567
                        if(isPeriodic()) {
568
                                if(inMin >= getMaxVal()) {
569
                                        inMin -= getRange();
570
                                        inMax -= getRange();
571
                                }
1✔
572
                        }
1✔
573
                        
1✔
574
                        // Clip low end
575
                        if(inMin < getMinVal()) {
×
576
                                inMin = getMinVal();
×
577
                        }
578
                        if(inMax < getMinVal()) {
579
                                inMax = getMinVal();
1✔
580
                        }
×
581
                        
×
582
                        // If we have a periodic encoder, and the max is past the edge, break into
×
583
                        //         2 separate ranges
584
                        if(isPeriodic() && inMax >= getMaxVal()) {
585
                                ranges.add(new MinMax(inMin, getMaxVal()));
586
                                ranges.add(new MinMax(getMinVal(), inMax - getRange()));
587
                        }else{
1✔
588
                                if(inMax > getMaxVal()) {
×
589
                                        inMax = getMaxVal();
590
                                }
1✔
591
                                if(inMin > getMaxVal()) {
×
592
                                        inMin = getMaxVal();
593
                                }
594
                                ranges.add(new MinMax(inMin, inMax));
595
                        }
596
                }
1✔
597
                
×
598
                String desc = generateRangeDescription(ranges);
×
599
                String fieldName;
600
                // Return result
1✔
601
                if(!parentFieldName.isEmpty()) {
×
602
                        fieldName = String.format("%s.%s", parentFieldName, getName());
603
                }else{
1✔
604
                        fieldName = getName();
×
605
                }
606
                
1✔
607
                RangeList inner = new RangeList(ranges, desc);
608
                Map<String, RangeList> fieldsDict = new HashMap<String, RangeList>();
1✔
609
                fieldsDict.put(fieldName, inner);
610
                
1✔
611
                return new DecodeResult(fieldsDict, Arrays.asList(new String[] { fieldName }));
612
        }
613
        
1✔
614
        /**
×
615
         * Generate description from a text description of the ranges
616
         * 
1✔
617
         * @param        ranges                A list of {@link MinMax}es.
618
         */
619
        public String generateRangeDescription(List<MinMax> ranges) {
1✔
620
                StringBuilder desc = new StringBuilder();
1✔
621
                int numRanges = ranges.size();
1✔
622
                for(int i = 0;i < numRanges;i++) {
623
                        if(ranges.get(i).min() != ranges.get(i).max()) {
1✔
624
                                desc.append(String.format("%.2f-%.2f", ranges.get(i).min(), ranges.get(i).max()));
625
                        }else{
626
                                desc.append(String.format("%.2f", ranges.get(i).min()));
627
                        }
628
                        if(i < numRanges - 1) {
629
                                desc.append(", ");
630
                        }
631
                }
632
                return desc.toString();
1✔
633
        }
1✔
634
        
1✔
635
        /**
1✔
636
         * Return the internal topDownMapping matrix used for handling the
×
637
     * bucketInfo() and topDownCompute() methods. This is a matrix, one row per
638
     * category (bucket) where each row contains the encoded output for that
1✔
639
     * category.
640
     * 
1✔
641
         * @param c                the connections memory
×
642
         * @return                the internal topDownMapping
643
         */
644
        public SparseObjectMatrix<int[]> getTopDownMapping() {
1✔
645
                
646
                if(topDownMapping == null) {
647
                        //The input scalar value corresponding to each possible output encoding
648
                        if(isPeriodic()) {
649
                                setTopDownValues(
650
                                        ArrayUtils.arange(getMinVal() + getResolution() / 2.0, 
651
                                                getMaxVal(), getResolution()));
652
                        }else{
653
                                //Number of values is (max-min)/resolutions
654
                                setTopDownValues(
655
                                        ArrayUtils.arange(getMinVal(), getMaxVal() + getResolution() / 2.0, 
656
                                                getResolution()));
657
                        }
658
                }
1✔
659
                
660
                //Each row represents an encoded output pattern
1✔
661
                int numCategories = getTopDownValues().length;
×
662
                SparseObjectMatrix<int[]> topDownMapping;
×
663
                setTopDownMapping(
×
664
                        topDownMapping = new SparseObjectMatrix<int[]>(
665
                                new int[] { numCategories }));
666
                
1✔
667
                double[] topDownValues = getTopDownValues();
1✔
668
                int[] outputSpace = new int[getN()];
1✔
669
                double minVal = getMinVal();
670
                double maxVal = getMaxVal();
671
                for(int i = 0;i < numCategories;i++) {
672
                        double value = topDownValues[i];
673
                        value = Math.max(value, minVal);
1✔
674
                        value = Math.min(value, maxVal);
675
                        encodeIntoArray(value, outputSpace);
1✔
676
                        topDownMapping.set(i, Arrays.copyOf(outputSpace, outputSpace.length));
677
                }
678
                
679
                return topDownMapping;
1✔
680
        }
1✔
681
        
1✔
682
        /**
1✔
683
         * {@inheritDoc}
1✔
684
         * 
1✔
685
         * @param <S>        the input value, in this case a double
1✔
686
         * @return        a list of one input double
1✔
687
         */
1✔
688
        @Override
1✔
689
        public <S> TDoubleList getScalars(S d) {
690
                TDoubleList retVal = new TDoubleArrayList();
691
                retVal.add((Double)d);
1✔
692
                return retVal;
693
        }
694
        
695
        /**
696
         * Returns a list of items, one for each bucket defined by this encoder.
697
     * Each item is the value assigned to that bucket, this is the same as the
698
     * EncoderResult.value that would be returned by getBucketInfo() for that
699
     * bucket and is in the same format as the input that would be passed to
700
     * encode().
701
         * 
702
     * This call is faster than calling getBucketInfo() on each bucket individually
×
703
     * if all you need are the bucket values.
×
704
         *
×
705
         * @param        returnType                 class type parameter so that this method can return encoder
706
     *                                                         specific value types
707
     * 
708
     * @return list of items, each item representing the bucket value for that
709
     *        bucket.
710
         */
711
        @SuppressWarnings("unchecked")
712
        @Override
713
        public <S> List<S> getBucketValues(Class<S> t) {
714
                if(bucketValues == null) {
715
                        SparseObjectMatrix<int[]> topDownMapping = getTopDownMapping();
716
                        int numBuckets = topDownMapping.getMaxIndex() + 1;
717
                        bucketValues = new ArrayList<Double>();
718
                        for(int i = 0;i < numBuckets;i++) {
719
                                ((List<Double>)bucketValues).add((Double)getBucketInfo(new int[] { i }).get(0).get(1));
720
                        }
721
                }
722
                return (List<S>)bucketValues;
723
        }
724
        
725
        /**
726
         * {@inheritDoc}
1✔
727
         */
1✔
728
        @Override
1✔
729
        public List<EncoderResult> getBucketInfo(int[] buckets) {
1✔
730
                SparseObjectMatrix<int[]> topDownMapping = getTopDownMapping();
1✔
731
                
1✔
732
                //The "category" is simply the bucket index
733
                int category = buckets[0];
734
                int[] encoding = topDownMapping.getObject(category);
1✔
735
                
736
                //Which input value does this correspond to?
737
                double inputVal;
738
                if(isPeriodic()) {
739
                        inputVal = getMinVal() + getResolution() / 2 + category * getResolution();
740
                }else{
741
                        inputVal = getMinVal() + category * getResolution();
742
                }
1✔
743
                
744
                return Arrays.asList(
745
                        new EncoderResult[] { 
1✔
746
                                new EncoderResult(inputVal, inputVal, encoding) });
1✔
747
                        
748
        }
749
        
750
        /**
1✔
751
         * {@inheritDoc}
×
752
         */
753
        @Override
1✔
754
        public List<EncoderResult> topDownCompute(int[] encoded) {
755
                //Get/generate the topDown mapping table
756
                SparseObjectMatrix<int[]> topDownMapping = getTopDownMapping();
1✔
757
                
758
                // See which "category" we match the closest.
759
                int category = ArrayUtils.argmax(rightVecProd(topDownMapping, encoded));
760
                
761
                return getBucketInfo(new int[] { category });
762
        }
763
        
764
        /**
765
         * Returns a list of {@link Tuple}s which in this case is a list of
1✔
766
         * key value parameter values for this {@code ScalarEncoder}
767
         * 
768
         * @return        a list of {@link Tuple}s
1✔
769
         */
770
        public List<Tuple> dict() {
1✔
771
                List<Tuple> l = new ArrayList<Tuple>();
772
                l.add(new Tuple(2, "maxval", getMaxVal()));
773
                l.add(new Tuple(2, "bucketValues", getBucketValues(Double.class)));
774
                l.add(new Tuple(2, "nInternal", getNInternal()));
775
                l.add(new Tuple(2, "name", getName()));
776
                l.add(new Tuple(2, "minval", getMinVal()));
777
                l.add(new Tuple(2, "topDownValues", Arrays.toString(getTopDownValues())));
778
                l.add(new Tuple(2, "verbosity", getVerbosity()));
779
                l.add(new Tuple(2, "clipInput", clipInput()));
780
                l.add(new Tuple(2, "n", getN()));
×
781
                l.add(new Tuple(2, "padding", getPadding()));
×
782
                l.add(new Tuple(2, "range", getRange()));
×
783
                l.add(new Tuple(2, "periodic", isPeriodic()));
×
784
                l.add(new Tuple(2, "radius", getRadius()));
×
785
                l.add(new Tuple(2, "w", getW()));
×
786
                l.add(new Tuple(2, "topDownMappingM", getTopDownMapping()));
×
787
                l.add(new Tuple(2, "halfwidth", getHalfWidth()));
×
788
                l.add(new Tuple(2, "resolution", getResolution()));
×
789
                l.add(new Tuple(2, "rangeInternal", getRangeInternal()));
×
790
                
×
791
                return l;
×
792
        }
×
793

×
794
        /**
×
795
         * Returns a {@link EncoderBuilder} for constructing {@link ScalarEncoder}s
×
796
         * 
×
797
         * The base class architecture is put together in such a way where boilerplate
×
798
         * initialization can be kept to a minimum for implementing subclasses, while avoiding
×
799
         * the mistake-proneness of extremely long argument lists.
800
         * 
×
801
         * @see ScalarEncoder.Builder#setStuff(int)
802
         */
803
        public static class Builder extends Encoder.Builder<ScalarEncoder.Builder, ScalarEncoder> {
804
                private Builder() {}
805

806
                @Override
807
                public ScalarEncoder build() {
808
                        //Must be instantiated so that super class can initialize 
809
                        //boilerplate variables.
810
                        encoder = new ScalarEncoder();
811
                        
812
                        //Call super class here
813
                        super.build();
1✔
814
                        
815
                        ////////////////////////////////////////////////////////
816
                        //  Implementing classes would do setting of specific //
817
                        //  vars here together with any sanity checking       //
818
                        ////////////////////////////////////////////////////////
819
                        
1✔
820
                        ((ScalarEncoder)encoder).init();
821
                        
822
                        return (ScalarEncoder)encoder;
1✔
823
                }
824
                
825
                /**
826
                 * Never called - just here as an example of specialization for a specific 
827
                 * subclass of Encoder.Builder
828
                 * 
829
                 * Example specific method!!
1✔
830
                 * 
831
                 * @param stuff
1✔
832
                 * @return
833
                 */
834
                public ScalarEncoder.Builder setStuff(int stuff) {
835
                        return this;
836
                }
837
        }
838
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc