• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

numenta / htm.java / #1206

26 Jan 2015 12:57PM UTC coverage: 14.404% (-0.005%) from 14.409%
#1206

push

David Ray
Merge pull request #168 from cogmission/network_api_work

testing the RNG is not part of the scope

714 of 4957 relevant lines covered (14.4%)

0.14 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/main/java/org/numenta/nupic/research/TemporalMemory.java
1
/* ---------------------------------------------------------------------
2
 * Numenta Platform for Intelligent Computing (NuPIC)
3
 * Copyright (C) 2014, Numenta, Inc.  Unless you have an agreement
4
 * with Numenta, Inc., for a separate license for this software code, the
5
 * following terms and conditions apply:
6
 *
7
 * This program is free software: you can redistribute it and/or modify
8
 * it under the terms of the GNU General Public License version 3 as
9
 * published by the Free Software Foundation.
10
 *
11
 * This program is distributed in the hope that it will be useful,
12
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
14
 * See the GNU General Public License for more details.
15
 *
16
 * You should have received a copy of the GNU General Public License
17
 * along with this program.  If not, see http://www.gnu.org/licenses.
18
 *
19
 * http://numenta.org/licenses/
20
 * ---------------------------------------------------------------------
21
 */
22

23
package org.numenta.nupic.research;
24

25
import java.util.ArrayList;
26
import java.util.LinkedHashMap;
27
import java.util.LinkedHashSet;
28
import java.util.List;
29
import java.util.Map;
30
import java.util.Set;
31

32
import org.numenta.nupic.Connections;
33
import org.numenta.nupic.model.Cell;
34
import org.numenta.nupic.model.Column;
35
import org.numenta.nupic.model.DistalDendrite;
36
import org.numenta.nupic.model.Synapse;
37
import org.numenta.nupic.util.SparseObjectMatrix;
38

39
/**
40
 * Temporal Memory implementation in Java
41
 * 
42
 * @author Chetan Surpur
43
 * @author David Ray
44
 */
45
public class TemporalMemory {
46
    
47
    /**
48
     * Constructs a new {@code TemporalMemory}
49
     */
50
    public TemporalMemory() {}
×
51
    
52
    /**
53
     * Uses the specified {@link Connections} object to Build the structural 
54
     * anatomy needed by this {@code TemporalMemory} to implement its algorithms.
55
     * 
56
     * @param        c                {@link Connections} object
57
     */
58
    public void init(Connections c) {
59
            SparseObjectMatrix<Column> matrix = c.getMemory() == null ?
60
                    new SparseObjectMatrix<Column>(c.getColumnDimensions()) :
61
                            c.getMemory();
62
            c.setMemory(matrix);
63
            
64
            int numColumns = matrix.getMaxIndex() + 1;
65
            int cellsPerColumn = c.getCellsPerColumn();
66
        Cell[] cells = new Cell[numColumns * cellsPerColumn];
67
        
68
        //Used as flag to determine if Column objects have been created.
×
69
        Column colZero = matrix.getObject(0);
×
70
        for(int i = 0;i < numColumns;i++) {
×
71
            Column column = colZero == null ? 
×
72
                    new Column(cellsPerColumn, i) : matrix.getObject(i);
73
            for(int j = 0;j < cellsPerColumn;j++) {
×
74
                cells[i * cellsPerColumn + j] = column.getCell(j);
×
75
            }
×
76
            //If columns have not been previously configured
77
            if(colZero == null) matrix.set(i, column);
78
        }
×
79
        //Only the TemporalMemory initializes cells so no need to test 
×
80
        c.setCells(cells);
×
81
    }
×
82
    
×
83
    /////////////////////////// CORE FUNCTIONS /////////////////////////////
×
84
    
85
    /**
86
     * Feeds input record through TM, performing inferencing and learning
×
87
     * 
88
     * @param connections                the connection memory
89
     * @param activeColumns     direct proximal dendrite input
×
90
     * @param learn             learning mode flag
×
91
     * @return                  {@link ComputeCycle} container for one cycle of inference values.
92
     */
93
    public ComputeCycle compute(Connections connections, int[] activeColumns, boolean learn) {
94
        ComputeCycle result = computeFn(connections, connections.getColumnSet(activeColumns), new LinkedHashSet<Cell>(connections.getPredictiveCells()), 
95
            new LinkedHashSet<DistalDendrite>(connections.getActiveSegments()), new LinkedHashMap<DistalDendrite, Set<Synapse>>(connections.getActiveSynapsesForSegment()), 
96
                new LinkedHashSet<Cell>(connections.getWinnerCells()), learn);
97
        
98
        connections.setActiveCells(result.activeCells());
99
        connections.setWinnerCells(result.winnerCells());
100
        connections.setPredictiveCells(result.predictiveCells());
101
        connections.setPredictedColumns(result.predictedColumns());
102
        connections.setActiveSegments(result.activeSegments());
103
        connections.setLearningSegments(result.learningSegments());
×
104
        connections.setActiveSynapsesForSegment(result.activeSynapsesForSegment());
×
105
        
×
106
        return result; 
107
    }
×
108
    
×
109
    /**
×
110
     * Functional version of {@link #compute(int[], boolean)}. 
×
111
     * This method is stateless and concurrency safe.
×
112
     * 
×
113
     * @param c                             {@link Connections} object containing state of memory members
×
114
     * @param activeColumns                 proximal dendrite input
115
     * @param prevPredictiveCells           cells predicting in t-1
×
116
     * @param prevActiveSegments            active segments in t-1
117
     * @param prevActiveSynapsesForSegment  {@link Synapse}s active in t-1
118
     * @param prevWinnerCells   `           previous winners
119
     * @param learn                         whether mode is "learning" mode
120
     * @return
121
     */
122
    public ComputeCycle computeFn(Connections c, Set<Column> activeColumns, Set<Cell> prevPredictiveCells, Set<DistalDendrite> prevActiveSegments,
123
        Map<DistalDendrite, Set<Synapse>> prevActiveSynapsesForSegment, Set<Cell> prevWinnerCells, boolean learn) {
124
        
125
        ComputeCycle cycle = new ComputeCycle();
126
        
127
        activateCorrectlyPredictiveCells(cycle, prevPredictiveCells, activeColumns);
128
        
129
        burstColumns(cycle, c, activeColumns, cycle.predictedColumns, prevActiveSynapsesForSegment);
130
        
131
        if(learn) {
132
            learnOnSegments(c, prevActiveSegments, cycle.learningSegments, prevActiveSynapsesForSegment, cycle.winnerCells, prevWinnerCells);
133
        }
134
        
×
135
        cycle.activeSynapsesForSegment = computeActiveSynapses(c, cycle.activeCells);
136
        
×
137
        computePredictiveCells(c, cycle, cycle.activeSynapsesForSegment);
138
        
×
139
        return cycle;
140
    }
×
141

×
142
    /**
143
     * Phase 1: Activate the correctly predictive cells
144
     * 
×
145
     * Pseudocode:
146
     *
×
147
     * - for each prev predictive cell
148
     *   - if in active column
×
149
     *     - mark it as active
150
     *     - mark it as winner cell
151
     *     - mark column as predicted
152
     *     
153
     * @param c                     ComputeCycle interim values container
154
     * @param prevPredictiveCells   predictive {@link Cell}s predictive cells in t-1
155
     * @param activeColumns         active columns in t
156
     */
157
    public void activateCorrectlyPredictiveCells(ComputeCycle c, Set<Cell> prevPredictiveCells, Set<Column> activeColumns) {
158
        for(Cell cell : prevPredictiveCells) {
159
            Column column = cell.getParentColumn();
160
            if(activeColumns.contains(column)) {
161
                c.activeCells.add(cell);
162
                c.winnerCells.add(cell);
163
                c.predictedColumns.add(column);
164
            }
165
        }
166
    }
167
    
×
168
    /**
×
169
     * Phase 2: Burst unpredicted columns.
×
170
     * 
×
171
     * Pseudocode:
×
172
     *
×
173
     * - for each unpredicted active column
174
     *   - mark all cells as active
×
175
     *   - mark the best matching cell as winner cell
×
176
     *     - (learning)
177
     *       - if it has no matching segment
178
     *         - (optimization) if there are prev winner cells
179
     *           - add a segment to it
180
     *       - mark the segment as learning
181
     * 
182
     * @param cycle                         ComputeCycle interim values container
183
     * @param c                             Connections temporal memory state
184
     * @param activeColumns                 active columns in t
185
     * @param predictedColumns              predicted columns in t
186
     * @param prevActiveSynapsesForSegment      LinkedHashMap of previously active segments which
187
     *                                      have had synapses marked as active in t-1     
188
     */
189
    public void burstColumns(ComputeCycle cycle, Connections c, Set<Column> activeColumns, Set<Column> predictedColumns, 
190
        Map<DistalDendrite, Set<Synapse>> prevActiveSynapsesForSegment) {
191
        
192
        Set<Column> unpred = new LinkedHashSet<Column>(activeColumns);
193
        
194
        unpred.removeAll(predictedColumns);
195
        for(Column column : unpred) {
196
            List<Cell> cells = column.getCells();
197
            cycle.activeCells.addAll(cells);
198
            
199
            Object[] bestSegmentAndCell = getBestMatchingCell(c, column, prevActiveSynapsesForSegment);
200
            DistalDendrite bestSegment = (DistalDendrite)bestSegmentAndCell[0];
201
            Cell bestCell = (Cell)bestSegmentAndCell[1];
×
202
            if(bestCell != null) {
×
203
                cycle.winnerCells.add(bestCell);
×
204
            }
×
205
            
206
            int segmentCounter = c.getSegmentCount();
×
207
            if(bestSegment == null) {
×
208
                bestSegment = bestCell.createSegment(c, segmentCounter);
×
209
                c.setSegmentCount(segmentCounter + 1);
×
210
            }
×
211
            
212
            cycle.learningSegments.add(bestSegment);
213
        }
×
214
    }
×
215
    
×
216
    /**
×
217
     * Phase 3: Perform learning by adapting segments.
218
     * <pre>
219
     * Pseudocode:
×
220
     *
×
221
     * - (learning) for each prev active or learning segment
×
222
     *   - if learning segment or from winner cell
223
     *   - strengthen active synapses
224
     *   - weaken inactive synapses
225
     *   - if learning segment
226
     *   - add some synapses to the segment
227
     *     - subsample from prev winner cells
228
     * </pre>    
229
     *     
230
     * @param c                             the Connections state of the temporal memory
231
     * @param prevActiveSegments                        the Set of segments active in the previous cycle.
232
     * @param learningSegments                                the Set of segments marked as learning {@link #burstColumns(ComputeCycle, Connections, Set, Set, Map)}
233
     * @param prevActiveSynapseSegments                the map of segments which were previously active to their associated {@link Synapse}s.
234
     * @param winnerCells                                        the Set of all winning cells ({@link Cell}s with the most active synapses)
235
     * @param prevWinnerCells                                the Set of cells which were winners during the last compute cycle
236
     */        
237
    public void learnOnSegments(Connections c, Set<DistalDendrite> prevActiveSegments, Set<DistalDendrite> learningSegments,
238
        Map<DistalDendrite, Set<Synapse>> prevActiveSynapseSegments, Set<Cell> winnerCells, Set<Cell> prevWinnerCells) {
239
        
240
            double permanenceIncrement = c.getPermanenceIncrement();
241
            double permanenceDecrement = c.getPermanenceDecrement();
242
                    
243
        List<DistalDendrite> prevAndLearning = new ArrayList<DistalDendrite>(prevActiveSegments);
244
        prevAndLearning.addAll(learningSegments);
245
        
246
        for(DistalDendrite dd : prevAndLearning) {
247
            boolean isLearningSegment = learningSegments.contains(dd);
×
248
            boolean isFromWinnerCell = winnerCells.contains(dd.getParentCell());
×
249
            
250
            Set<Synapse> activeSynapses = new LinkedHashSet<Synapse>(dd.getConnectedActiveSynapses(prevActiveSynapseSegments, 0));
×
251
            
×
252
            if(isLearningSegment || isFromWinnerCell) {
253
                dd.adaptSegment(c, activeSynapses, permanenceIncrement, permanenceDecrement);
×
254
            }
×
255
            
×
256
            int synapseCounter = c.getSynapseCount();  
257
            if(isLearningSegment) {
×
258
                int n = c.getMaxNewSynapseCount() - activeSynapses.size();
259
                Set<Cell> learnCells = dd.pickCellsToLearnOn(c, n, prevWinnerCells, c.getRandom());
×
260
                for(Cell sourceCell : learnCells) {
×
261
                    dd.createSynapse(c, sourceCell, c.getInitialPermanence(), synapseCounter);
262
                    synapseCounter += 1;
263
                }
×
264
                c.setSynapseCount(synapseCounter);
×
265
            }
×
266
        }
×
267
    }
×
268
    
×
269
    /**
×
270
     * Phase 4: Compute predictive cells due to lateral input on distal dendrites.
×
271
     *
×
272
     * Pseudocode:
273
     *
×
274
     * - for each distal dendrite segment with activity >= activationThreshold
×
275
     *   - mark the segment as active
276
     *   - mark the cell as predictive
277
     * 
278
     * @param c                 the Connections state of the temporal memory
279
     * @param cycle                                the state during the current compute cycle
280
     * @param activeSegments
281
     */
282
    public void computePredictiveCells(Connections c, ComputeCycle cycle, Map<DistalDendrite, Set<Synapse>> activeDendrites) {
283
        for(DistalDendrite dd : activeDendrites.keySet()) {
284
            Set<Synapse> connectedActive = dd.getConnectedActiveSynapses(activeDendrites, c.getConnectedPermanence());
285
            if(connectedActive.size() >= c.getActivationThreshold()) {
286
                cycle.activeSegments.add(dd);
287
                cycle.predictiveCells.add(dd.getParentCell());
288
            }
289
        }
290
    }
×
291
    
×
292
    /**
×
293
     * Forward propagates activity from active cells to the synapses that touch
×
294
     * them, to determine which synapses are active.
×
295
     * 
296
     * @param   c           the connections state of the temporal memory
×
297
     * @param cellsActive
×
298
     * @return 
299
     */
300
    public Map<DistalDendrite, Set<Synapse>> computeActiveSynapses(Connections c, Set<Cell> cellsActive) {
301
        Map<DistalDendrite, Set<Synapse>> activesSynapses = new LinkedHashMap<DistalDendrite, Set<Synapse>>();
302
        
303
        for(Cell cell : cellsActive) {
304
            for(Synapse s : cell.getReceptorSynapses(c)) {
305
                Set<Synapse> set = null;
306
                if((set = activesSynapses.get(s.getSegment())) == null) {
307
                    activesSynapses.put((DistalDendrite)s.getSegment(), set = new LinkedHashSet<Synapse>());
308
                }
×
309
                set.add(s);
310
            }
×
311
        }
×
312
        
×
313
        return activesSynapses;
×
314
    }
×
315
    
316
    /**
×
317
     * Called to start the input of a new sequence.
×
318
     * 
×
319
     * @param   connections   the Connections state of the temporal memory
320
     */
×
321
    public void reset(Connections connections) {
322
        connections.getActiveCells().clear();
323
        connections.getPredictiveCells().clear();
324
        connections.getActiveSegments().clear();
325
        connections.getActiveSynapsesForSegment().clear();
326
        connections.getWinnerCells().clear();
327
    }
328
    
329
    
×
330
    /////////////////////////// HELPER FUNCTIONS ///////////////////////////
×
331
    
×
332
    /**
×
333
     * Gets the cell with the best matching segment
×
334
     * (see `TM.getBestMatchingSegment`) that has the largest number of active
×
335
     * synapses of all best matching segments.
336
     * 
337
     * @param c                                                                        encapsulated memory and state
338
     * @param column                                                        {@link Column} within which to search for best cell
339
     * @param prevActiveSynapsesForSegment                a {@link DistalDendrite}'s previously active {@link Synapse}s
340
     * @return                an object array whose first index contains a segment, and the second contains a cell
341
     */
342
    public Object[] getBestMatchingCell(Connections c, Column column, Map<DistalDendrite, Set<Synapse>> prevActiveSynapsesForSegment) {
343
        Object[] retVal = new Object[2];
344
        Cell bestCell = null;
345
        DistalDendrite bestSegment = null;
346
        int maxSynapses = 0;
347
        for(Cell cell : column.getCells()) {
348
            DistalDendrite dd = getBestMatchingSegment(c, cell, prevActiveSynapsesForSegment);
349
            if(dd != null) {
350
                Set<Synapse> connectedActiveSynapses = dd.getConnectedActiveSynapses(prevActiveSynapsesForSegment, 0);
×
351
                if(connectedActiveSynapses.size() > maxSynapses) {
×
352
                    maxSynapses = connectedActiveSynapses.size();
×
353
                    bestCell = cell;
×
354
                    bestSegment = dd;
×
355
                }
×
356
            }
×
357
        }
×
358
        
×
359
        if(bestCell == null) {
×
360
            bestCell = column.getLeastUsedCell(c, c.getRandom());
×
361
        }
×
362
        
363
        retVal[0] = bestSegment;
364
        retVal[1] = bestCell;
×
365
        return retVal;
366
    }
×
367
    
×
368
    /**
369
     * Gets the segment on a cell with the largest number of activate synapses,
370
     * including all synapses with non-zero permanences.
×
371
     * 
×
372
     * @param c                                                                        encapsulated memory and state
×
373
     * @param column                                                        {@link Column} within which to search for best cell
374
     * @param activeSynapseSegments                                a {@link DistalDendrite}'s active {@link Synapse}s
375
     * @return        the best segment
376
     */
377
    public DistalDendrite getBestMatchingSegment(Connections c, Cell cell, Map<DistalDendrite, Set<Synapse>> activeSynapseSegments) {
378
        int maxSynapses = c.getMinThreshold();
379
        DistalDendrite bestSegment = null;
380
        for(DistalDendrite dd : cell.getSegments(c)) {
381
            Set<Synapse> activeSyns = dd.getConnectedActiveSynapses(activeSynapseSegments, 0);
382
            if(activeSyns.size() >= maxSynapses) {
383
                maxSynapses = activeSyns.size();
384
                bestSegment = dd;
385
            }
×
386
        }
×
387
        return bestSegment;
×
388
    }
×
389
    
×
390
    /**
×
391
     * Returns the column index given the cells per column and
×
392
     * the cell index passed in.
393
     * 
×
394
     * @param c                                {@link Connections} memory
×
395
     * @param cellIndex                the index where the requested cell resides
396
     * @return
397
     */
398
    protected int columnForCell(Connections c, int cellIndex) {
399
        return cellIndex / c.getCellsPerColumn();
400
    }
401
    
402
    /**
403
     * Returns the cell at the specified index.
404
     * @param index
405
     * @return
406
     */
×
407
    public Cell getCell(Connections c, int index) {
408
        return c.getCells()[index];
409
    }
410
    
411
    /**
412
     * Returns a {@link LinkedHashSet} of {@link Cell}s from a 
413
     * sorted array of cell indexes.
414
     *  
415
     * @param`c                                the {@link Connections} object
×
416
     * @param cellIndexes   indexes of the {@link Cell}s to return
417
     * @return
418
     */
419
    public LinkedHashSet<Cell> getCells(Connections c, int[] cellIndexes) {
420
            LinkedHashSet<Cell> cellSet = new LinkedHashSet<Cell>();
421
        for(int cell : cellIndexes) {
422
            cellSet.add(getCell(c, cell));
423
        }
424
        return cellSet;
425
    }
426
    
427
    /**
×
428
     * Returns a {@link LinkedHashSet} of {@link Column}s from a 
×
429
     * sorted array of Column indexes.
×
430
     *  
431
     * @param cellIndexes   indexes of the {@link Column}s to return
×
432
     * @return
433
     */
434
    public LinkedHashSet<Column> getColumns(Connections c, int[] columnIndexes) {
435
            return c.getColumnSet(columnIndexes);
436
    }
437
 }
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc