• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

NeuralEnsemble / PyNN / 968

pending completion
968

cron

travis-ci-com

GitHub
Merge pull request #761 from apdavison/pytest

Migrate test suite from nose to pytest

10 of 10 new or added lines in 5 files covered. (100.0%)

6955 of 9974 relevant lines covered (69.73%)

0.7 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.38
/pyNN/common/populations.py
1
# encoding: utf-8
2
"""
1✔
3
Common implementation of ID, Population, PopulationView and Assembly classes.
4

5
These base classes should be sub-classed by the backend-specific classes.
6

7
:copyright: Copyright 2006-2022 by the PyNN team, see AUTHORS.
8
:license: CeCILL, see LICENSE for details.
9
"""
10

11
import numpy as np
1✔
12
import logging
1✔
13
import operator
1✔
14
from itertools import chain
1✔
15
from functools import reduce
1✔
16
from collections import defaultdict
1✔
17
from pyNN import random, recording, errors, standardmodels, core, space, descriptions
1✔
18
from pyNN.models import BaseCellType
1✔
19
from pyNN.parameters import ParameterSpace, LazyArray, simplify as simplify_parameter_array
1✔
20
from pyNN.recording import files
1✔
21

22

23
deprecated = core.deprecated
1✔
24
logger = logging.getLogger("PyNN")
1✔
25

26

27
def is_conductance(target_cell):
1✔
28
    """
29
    Returns True if the target cell uses conductance-based synapses, False if
30
    it uses current-based synapses, and None if the synapse-basis cannot be
31
    determined.
32
    """
33
    if hasattr(target_cell, 'local') and target_cell.local and hasattr(target_cell, 'celltype'):
×
34
        is_conductance = target_cell.celltype.conductance_based
×
35
    else:
36
        is_conductance = None
×
37
    return is_conductance
×
38

39

40
class IDMixin(object):
1✔
41
    """
42
    Instead of storing ids as integers, we store them as ID objects,
43
    which allows a syntax like:
44
        p[3,4].tau_m = 20.0
45
    where p is a Population object.
46
    """
47
    # Simulator ID classes should inherit both from the base type of the ID
48
    # (e.g., int or long) and from IDMixin.
49

50
    def __getattr__(self, name):
1✔
51
        if name == "parent":
1✔
52
            raise Exception("parent is not set")
1✔
53
        elif name == "set":
1✔
54
            errmsg = "For individual cells, set values using the parameter name directly, " \
×
55
                     "e.g. population[0].tau_m = 20.0, or use 'set' on a population view, " \
56
                     "e.g. population[0:1].set(tau_m=20.0)"
57
            raise AttributeError(errmsg)
×
58
        try:
1✔
59
            val = self.get_parameters()[name]
1✔
60
        except KeyError:
1✔
61
            raise errors.NonExistentParameterError(name,
1✔
62
                                                   self.celltype.__class__.__name__,
63
                                                   self.celltype.get_parameter_names())
64
        return val
1✔
65

66
    def __setattr__(self, name, value):
1✔
67
        if name == "parent":
1✔
68
            object.__setattr__(self, name, value)
1✔
69
        elif self.celltype.has_parameter(name):
1✔
70
            self.set_parameters(**{name: value})
1✔
71
        else:
72
            object.__setattr__(self, name, value)
1✔
73

74
    def set_parameters(self, **parameters):
1✔
75
        """
76
        Set cell parameters, given as a sequence of parameter=value arguments.
77
        """
78
        # if some of the parameters are computed from the values of other
79
        # parameters, need to get and translate all parameters
80
        if self.local:
1✔
81
            self.as_view().set(**parameters)
1✔
82
        else:
83
            raise errors.NotLocalError(
×
84
                "Cannot set parameters for a cell that does not exist on this node.")
85

86
    def get_parameters(self):
1✔
87
        """Return a dict of all cell parameters."""
88
        if self.local:
1✔
89
            parameter_names = self.celltype.get_parameter_names()
1✔
90
            return dict((k, v) for k, v in zip(parameter_names, self.as_view().get(parameter_names)))
1✔
91
        else:
92
            raise errors.NotLocalError(
1✔
93
                "Cannot obtain parameters for a cell that does not exist on this node.")
94

95
    @property
1✔
96
    def celltype(self):
1✔
97
        return self.parent.celltype
1✔
98

99
    @property
1✔
100
    def is_standard_cell(self):
1✔
101
        return isinstance(self.celltype, standardmodels.StandardCellType)
1✔
102

103
    def _set_position(self, pos):
1✔
104
        """
105
        Set the cell position in 3D space.
106

107
        Cell positions are stored in an array in the parent Population.
108
        """
109
        assert isinstance(pos, (tuple, np.ndarray))
1✔
110
        assert len(pos) == 3
1✔
111
        self.parent._set_cell_position(self, pos)
1✔
112

113
    def _get_position(self):
1✔
114
        """
115
        Return the cell position in 3D space.
116

117
        Cell positions are stored in an array in the parent Population, if any,
118
        or within the ID object otherwise. Positions are generated the first
119
        time they are requested and then cached.
120
        """
121
        return self.parent._get_cell_position(self)
1✔
122

123
    position = property(_get_position, _set_position)
1✔
124

125
    @property
1✔
126
    def local(self):
1✔
127
        return self.parent.is_local(self)
1✔
128

129
    def inject(self, current_source):
1✔
130
        """Inject current from a current source object into the cell."""
131
        current_source.inject_into([self])
1✔
132

133
    def get_initial_value(self, variable):
1✔
134
        """Get the initial value of a state variable of the cell."""
135
        return self.parent._get_cell_initial_value(self, variable)
1✔
136

137
    def set_initial_value(self, variable, value):
1✔
138
        """Set the initial value of a state variable of the cell."""
139
        self.parent._set_cell_initial_value(self, variable, value)
1✔
140

141
    def as_view(self):
1✔
142
        """Return a PopulationView containing just this cell."""
143
        index = self.parent.id_to_index(self)
1✔
144
        return self.parent[index:index + 1]
1✔
145

146

147
class BasePopulation(object):
1✔
148
    _record_filter = None
1✔
149

150
    def __getitem__(self, index):
1✔
151
        """
152
        Return either a single cell (ID object) from the Population, if `index`
153
        is an integer, or a subset of the cells (PopulationView object), if
154
        `index` is a slice or array.
155

156
        Note that __getitem__ is called when using [] access, e.g.
157
            p = Population(...)
158
            p[2] is equivalent to p.__getitem__(2).
159
            p[3:6] is equivalent to p.__getitem__(slice(3, 6))
160
        """
161
        if isinstance(index, (int, np.integer)):
1✔
162
            return self.all_cells[index]
1✔
163
        elif isinstance(index, (slice, list, np.ndarray)):
1✔
164
            return self._get_view(index)
1✔
165
        elif isinstance(index, tuple):
1✔
166
            return self._get_view(list(index))
1✔
167
        else:
168
            raise TypeError(
1✔
169
                "indices must be integers, slices, lists, arrays or tuples, not %s" % type(index).__name__)
170

171
    def __len__(self):
1✔
172
        """Return the total number of cells in the population (all nodes)."""
173
        return self.size
1✔
174

175
    @property
1✔
176
    def local_size(self):
1✔
177
        """Return the number of cells in the population on the local MPI node"""
178
        return len(self.local_cells)  # would self._mask_local.sum() be faster?
1✔
179

180
    def __iter__(self):
1✔
181
        """Iterator over cell ids on the local node."""
182
        return iter(self.local_cells)
1✔
183

184
    @property
1✔
185
    def conductance_based(self):
1✔
186
        """
187
        Indicates whether the post-synaptic response is modelled as a change
188
        in conductance or a change in current.
189
        """
190
        return self.celltype.conductance_based
1✔
191

192
    @property
1✔
193
    def receptor_types(self):
1✔
194
        return self.celltype.receptor_types
1✔
195

196
    def is_local(self, id):
1✔
197
        """
198
        Indicates whether the cell with the given ID exists on the local MPI node.
199
        """
200
        assert id.parent is self
1✔
201
        index = self.id_to_index(id)
1✔
202
        return self._mask_local[index]
1✔
203

204
    def all(self):
1✔
205
        """Iterator over cell ids on all MPI nodes."""
206
        return iter(self.all_cells)
×
207

208
    def __add__(self, other):
1✔
209
        """
210
        A Population/PopulationView can be added to another Population,
211
        PopulationView or Assembly, returning an Assembly.
212
        """
213
        assert isinstance(other, BasePopulation)
1✔
214
        return self._assembly_class(self, other)
1✔
215

216
    def _get_cell_position(self, id):
1✔
217
        index = self.id_to_index(id)
×
218
        return self.positions[:, index]
×
219

220
    def _set_cell_position(self, id, pos):
1✔
221
        index = self.id_to_index(id)
×
222
        self.positions[:, index] = pos
×
223

224
    @property
1✔
225
    def position_generator(self):  # "generator" is a misleading name, has no yield statement
1✔
226
        def gen(i):
1✔
227
            return self.positions.T[i]
1✔
228
        return gen
1✔
229

230
    def _get_cell_initial_value(self, id, variable):
1✔
231
        if variable in self.initial_values:
1✔
232
            assert isinstance(self.initial_values[variable], LazyArray)
1✔
233
            index = self.id_to_local_index(id)
1✔
234
            return self.initial_values[variable][index]
1✔
235
        else:
236
            logger.warning(
×
237
                "Variable '{}' is not in initial values, returning 0.0".format(variable))
238
            return 0.0
×
239

240
    def _set_cell_initial_value(self, id, variable, value):
1✔
241
        assert isinstance(self.initial_values[variable], LazyArray)
×
242
        index = self.id_to_local_index(id)
×
243
        self.initial_values[variable][index] = value
×
244

245
    def nearest(self, position):
1✔
246
        """Return the neuron closest to the specified position."""
247
        # doesn't always work correctly if a position is equidistant between
248
        # two neurons, i.e. 0.5 should be rounded up, but it isn't always.
249
        # also doesn't take account of periodic boundary conditions
250
        pos = np.array([position] * self.positions.shape[1]).transpose()
1✔
251
        dist_arr = (self.positions - pos)**2
1✔
252
        distances = dist_arr.sum(axis=0)
1✔
253
        nearest = distances.argmin()
1✔
254
        return self[nearest]
1✔
255

256
    def sample(self, n, rng=None):
1✔
257
        """
258
        Randomly sample `n` cells from the Population, and return a
259
        PopulationView object.
260
        """
261
        assert isinstance(n, int)
1✔
262
        if not rng:
1✔
263
            rng = random.NumpyRNG()
1✔
264
        indices = rng.permutation(np.arange(len(self), dtype=int))[0:n]
1✔
265
        logger.debug("The %d cells selected have indices %s" % (n, indices))
1✔
266
        logger.debug("%s.sample(%s)", self.label, n)
1✔
267
        return self._get_view(indices)
1✔
268

269
    def get(self, parameter_names, gather=False, simplify=True):
1✔
270
        """
271
        Get the values of the given parameters for every local cell in the
272
        population, or, if gather=True, for all cells in the population.
273

274
        Values will be expressed in the standard PyNN units (i.e. millivolts,
275
        nanoamps, milliseconds, microsiemens, nanofarads, event per second).
276
        """
277
        # if all the cells have the same value for a parameter, should
278
        # we return just the number, rather than an array?
279
        if isinstance(parameter_names, str):
1✔
280
            parameter_names = (parameter_names,)
1✔
281
            return_list = False
1✔
282
        else:
283
            return_list = True
1✔
284
        if isinstance(self.celltype, standardmodels.StandardCellType):
1✔
285
            if any(name in self.celltype.computed_parameters() for name in parameter_names):
1✔
286
                native_names = self.celltype.get_native_names()  # need all parameters in order to calculate values
1✔
287
            else:
288
                native_names = self.celltype.get_native_names(*parameter_names)
1✔
289
            native_parameter_space = self._get_parameters(*native_names)
1✔
290
            parameter_space = self.celltype.reverse_translate(native_parameter_space)
1✔
291
        else:
292
            parameter_space = self._get_parameters(*parameter_names)
1✔
293
        # what if parameter space is homogeneous on some nodes but not on others?
294
        parameter_space.evaluate(simplify=simplify)
1✔
295
        # this also causes problems if the population size matches the number of MPI nodes
296
        parameters = dict(parameter_space.items())
1✔
297
        if gather == True and self._simulator.state.num_processes > 1:
1✔
298
            # seems inefficient to do it in a loop - should do as single operation
299
            for name in parameter_names:
×
300
                values = parameters[name]
×
301
                if isinstance(values, np.ndarray):
×
302
                    all_values = {self._simulator.state.mpi_rank: values.tolist()}
×
303
                    local_indices = np.arange(self.size,)[self._mask_local].tolist()
×
304
                    all_indices = {self._simulator.state.mpi_rank: local_indices}
×
305
                    all_values = recording.gather_dict(all_values)
×
306
                    all_indices = recording.gather_dict(all_indices)
×
307
                    if self._simulator.state.mpi_rank == 0:
×
308
                        values = reduce(operator.add, all_values.values())
×
309
                        indices = reduce(operator.add, all_indices.values())
×
310
                        idx = np.argsort(indices)
×
311
                        values = np.array(values)[idx]
×
312
                parameters[name] = values
×
313
        try:
1✔
314
            values = [parameters[name] for name in parameter_names]
1✔
315
        except KeyError as err:
×
316
            raise errors.NonExistentParameterError("%s. Valid parameters for %s are: %s" % (
×
317
                err, self.celltype, self.celltype.get_parameter_names()))
318
        if return_list:
1✔
319
            return values
1✔
320
        else:
321
            assert len(parameter_names) == 1
1✔
322
            return values[0]
1✔
323

324
    def set(self, **parameters):
1✔
325
        """
326
        Set one or more parameters for every cell in the population.
327

328
        Values passed to set() may be:
329
            (1) single values
330
            (2) RandomDistribution objects
331
            (3) lists/arrays of values of the same size as the population
332
            (4) mapping functions, where a mapping function accepts a single
333
                argument (the cell index) and returns a single value.
334

335
        Here, a "single value" may be either a single number or a list/array of
336
        numbers (e.g. for spike times). Values should be expressed in the
337
        standard PyNN units (i.e. millivolts, nanoamps, milliseconds,
338
        microsiemens, nanofarads, event per second).
339

340
        Examples::
341

342
            p.set(tau_m=20.0, v_rest=-65).
343
            p.set(spike_times=[0.3, 0.7, 0.9, 1.4])
344
            p.set(cm=rand_distr, tau_m=lambda i: 10 + i/10.0)
345
        """
346
        # TODO: add example using of function of (x,y,z) and Population.position_generator
347
        if self.local_size > 0:
1✔
348
            if (isinstance(self.celltype, standardmodels.StandardCellType)
1✔
349
                    and any(name in self.celltype.computed_parameters() for name in parameters)
350
                    and not isinstance(self.celltype, standardmodels.cells.SpikeSourceArray)):
351
                      # the last condition above is a bit of hack to avoid calling expand() unecessarily
352
                # need to get existing parameter space of models so we can perform calculations
353
                native_names = self.celltype.get_native_names()
1✔
354
                parameter_space = self.celltype.reverse_translate(
1✔
355
                    self._get_parameters(*native_names))
356
                if self.local_size != self.size:
1✔
357
                    parameter_space.expand((self.size,), self._mask_local)
×
358
                parameter_space.update(**parameters)
1✔
359
            else:
360
                parameter_space = ParameterSpace(parameters,
1✔
361
                                                 self.celltype.get_schema(),
362
                                                 (self.size,),
363
                                                 self.celltype.__class__)
364
            if isinstance(self.celltype, standardmodels.StandardCellType):
1✔
365
                parameter_space = self.celltype.translate(parameter_space)
1✔
366
            assert parameter_space.shape == (self.size,), "{} != {}".format(
1✔
367
                parameter_space.shape, self.size)
368
            self._set_parameters(parameter_space)
1✔
369

370
    @deprecated("set(parametername=value_array)")
1✔
371
    def tset(self, parametername, value_array):
1✔
372
        """
373
        'Topographic' set. Set the value of parametername to the values in
374
        value_array, which must have the same dimensions as the Population.
375
        """
376
        self.set(**{parametername: value_array})
1✔
377

378
    @deprecated("set(parametername=rand_distr)")
1✔
379
    def rset(self, parametername, rand_distr):
1✔
380
        """
381
        'Random' set. Set the value of parametername to a value taken from
382
        rand_distr, which should be a RandomDistribution object.
383
        """
384
        # Note that we generate enough random numbers for all cells on all nodes
385
        # but use only those relevant to this node. This ensures that the
386
        # sequence of random numbers does not depend on the number of nodes,
387
        # provided that the same rng with the same seed is used on each node.
388
        self.set(**{parametername: rand_distr})
1✔
389

390
    def initialize(self, **initial_values):
1✔
391
        """
392
        Set initial values of state variables, e.g. the membrane potential.
393

394
        Values passed to initialize() may be:
395
            (1) single numeric values (all neurons set to the same value)
396
            (2) RandomDistribution objects
397
            (3) lists/arrays of numbers of the same size as the population
398
            (4) mapping functions, where a mapping function accepts a single
399
                argument (the cell index) and returns a single number.
400

401
        Values should be expressed in the standard PyNN units (i.e. millivolts,
402
        nanoamps, milliseconds, microsiemens, nanofarads, event per second).
403

404
        Examples::
405

406
            p.initialize(v=-70.0)
407
            p.initialize(v=rand_distr, gsyn_exc=0.0)
408
            p.initialize(v=lambda i: -65 + i/10.0)
409
        """
410
        for variable, value in initial_values.items():
1✔
411
            logger.debug("In Population '%s', initialising %s to %s" %
1✔
412
                         (self.label, variable, value))
413
            initial_value = LazyArray(value, shape=(self.size,), dtype=float)
1✔
414
            self._set_initial_value_array(variable, initial_value)
1✔
415
            self.initial_values[variable] = initial_value
1✔
416

417
    def find_units(self, variable):
1✔
418
        """
419
        Returns units of the specified variable or parameter, as a string.
420
        Works for all the recordable variables and neuron parameters of all standard models.
421
        """
422
        return self.celltype.units[variable]
1✔
423

424
    def annotate(self, **annotations):
1✔
425
        self.annotations.update(annotations)
1✔
426

427
    def can_record(self, variable):
1✔
428
        """Determine whether `variable` can be recorded from this population."""
429
        return self.celltype.can_record(variable)
1✔
430

431
    @property
1✔
432
    def injectable(self):
1✔
433
        return self.celltype.injectable
×
434

435
    def record(self, variables, to_file=None, sampling_interval=None):
1✔
436
        """
437
        Record the specified variable or variables for all cells in the
438
        Population or view.
439

440
        `variables` may be either a single variable name or a list of variable
441
        names. For a given celltype class, `celltype.recordable` contains a list of
442
        variables that can be recorded for that celltype.
443

444
        If specified, `to_file` should be either a filename or a Neo IO instance and `write_data()`
445
        will be automatically called when `end()` is called.
446

447
        `sampling_interval` should be a value in milliseconds, and an integer
448
        multiple of the simulation timestep.
449
        """
450
        if variables is None:  # reset the list of things to record
1✔
451
            # note that if record(None) is called on a view of a population
452
            # recording will be reset for the entire population, not just the view
453
            self.recorder.reset()
1✔
454
        else:
455
            logger.debug("%s.record('%s')", self.label, variables)
1✔
456
            if self._record_filter is None:
1✔
457
                self.recorder.record(variables, self.all_cells, sampling_interval)
1✔
458
            else:
459
                self.recorder.record(variables, self._record_filter, sampling_interval)
1✔
460
        if isinstance(to_file, str):
1✔
461
            self.recorder.file = to_file
1✔
462
            self._simulator.state.write_on_end.append((self, variables, self.recorder.file))
1✔
463

464
    @deprecated("record('v')")
1✔
465
    def record_v(self, to_file=True):
1✔
466
        """
467
        Record the membrane potential for all cells in the Population.
468
        """
469
        self.record('v', to_file)
1✔
470

471
    @deprecated("record(['gsyn_exc', 'gsyn_inh'])")
1✔
472
    def record_gsyn(self, to_file=True):
1✔
473
        """
474
        Record synaptic conductances for all cells in the Population.
475
        """
476
        self.record(['gsyn_exc', 'gsyn_inh'], to_file)
1✔
477

478
    def write_data(self, io, variables='all', gather=True, clear=False, annotations=None):
1✔
479
        """
480
        Write recorded data to file, using one of the file formats supported by
481
        Neo.
482

483
        `io`:
484
            a Neo IO instance
485
        `variables`:
486
            either a single variable name or a list of variable names.
487
            Variables must have been previously recorded, otherwise an
488
            Exception will be raised.
489

490
        For parallel simulators, if `gather` is True, all data will be gathered
491
        to the master node and a single output file created there. Otherwise, a
492
        file will be written on each node, containing only data from the cells
493
        simulated on that node.
494

495
        If `clear` is True, recorded data will be deleted from the `Population`.
496

497
        `annotations` should be a dict containing simple data types such as
498
        numbers and strings. The contents will be written into the output data
499
        file as metadata.
500
        """
501
        logger.debug("Population %s is writing %s to %s [gather=%s, clear=%s]" % (
1✔
502
            self.label, variables, io, gather, clear))
503
        self.recorder.write(variables, io, gather, self._record_filter, clear=clear,
1✔
504
                            annotations=annotations)
505

506
    def get_data(self, variables='all', gather=True, clear=False):
1✔
507
        """
508
        Return a Neo `Block` containing the data (spikes, state variables)
509
        recorded from the Population.
510

511
        `variables` - either a single variable name or a list of variable names
512
                      Variables must have been previously recorded, otherwise an
513
                      Exception will be raised.
514

515
        For parallel simulators, if `gather` is True, all data will be gathered
516
        to all nodes and the Neo `Block` will contain data from all nodes.
517
        Otherwise, the Neo `Block` will contain only data from the cells
518
        simulated on the local node.
519

520
        If `clear` is True, recorded data will be deleted from the `Population`.
521
        """
522
        return self.recorder.get(variables, gather, self._record_filter, clear)
1✔
523

524
    @deprecated("write_data(file, 'spikes')")
1✔
525
    def printSpikes(self, file, gather=True, compatible_output=True):
1✔
526
        self.write_data(file, 'spikes', gather)
1✔
527

528
    @deprecated("get_data('spikes')")
1✔
529
    def getSpikes(self, gather=True, compatible_output=True):
1✔
530
        return self.get_data('spikes', gather)
1✔
531

532
    @deprecated("write_data(file, 'v')")
1✔
533
    def print_v(self, file, gather=True, compatible_output=True):
1✔
534
        self.write_data(file, 'v', gather)
1✔
535

536
    @deprecated("get_data('v')")
1✔
537
    def get_v(self, gather=True, compatible_output=True):
1✔
538
        return self.get_data('v', gather)
1✔
539

540
    @deprecated("write_data(file, ['gsyn_exc', 'gsyn_inh'])")
1✔
541
    def print_gsyn(self, file, gather=True, compatible_output=True):
1✔
542
        self.write_data(file, ['gsyn_exc', 'gsyn_inh'], gather)
1✔
543

544
    @deprecated("get_data(['gsyn_exc', 'gsyn_inh'])")
1✔
545
    def get_gsyn(self, gather=True, compatible_output=True):
1✔
546
        return self.get_data(['gsyn_exc', 'gsyn_inh'], gather)
1✔
547

548
    def get_spike_counts(self, gather=True):
1✔
549
        """
550
        Returns a dict containing the number of spikes for each neuron.
551

552
        The dict keys are neuron IDs, not indices.
553
        """
554
        # arguably, we should use indices
555
        return self.recorder.count('spikes', gather, self._record_filter)
1✔
556

557
    @deprecated("mean_spike_count()")
1✔
558
    def meanSpikeCount(self, gather=True):
1✔
559
        return self.mean_spike_count(gather)
1✔
560

561
    def mean_spike_count(self, gather=True):
1✔
562
        """
563
        Returns the mean number of spikes per neuron.
564
        """
565
        spike_counts = self.get_spike_counts(gather)
1✔
566
        total_spikes = sum(spike_counts.values())
1✔
567
        if self._simulator.state.mpi_rank == 0 or not gather:  # should maybe use allgather, and get the numbers on all nodes
1✔
568
            if len(spike_counts) > 0:
1✔
569
                return float(total_spikes) / len(spike_counts)
1✔
570
            else:
571
                return 0
×
572
        else:
573
            return np.nan
×
574

575
    def inject(self, current_source):
1✔
576
        """
577
        Connect a current source to all cells in the Population.
578
        """
579
        if not self.celltype.injectable:
1✔
580
            raise TypeError("Can't inject current into a spike source.")
1✔
581
        current_source.inject_into(self)
1✔
582

583
    # name should be consistent with saving/writing data, i.e. save_data() and save_positions() or write_data() and write_positions()
584
    def save_positions(self, file):
1✔
585
        """
586
        Save positions to file. The output format is ``index x y z``
587
        """
588
        if isinstance(file, str):
1✔
589
            file = recording.files.StandardTextFile(file, mode='w')
×
590
        cells = self.all_cells
1✔
591
        result = np.empty((len(cells), 4))
1✔
592
        result[:, 0] = np.array([self.id_to_index(id) for id in cells])
1✔
593
        result[:, 1:4] = self.positions.T
1✔
594
        if self._simulator.state.mpi_rank == 0:
1✔
595
            file.write(result, {'population': self.label})
1✔
596
            file.close()
1✔
597

598

599
class Population(BasePopulation):
1✔
600
    """
601
    A group of neurons all of the same type. "Population" is used as a generic
602
    term intended to include layers, columns, nuclei, etc., of cells.
603

604
    Arguments:
605
        `size`:
606
            number of cells in the Population. For backwards-compatibility,
607
            `size` may also be a tuple giving the dimensions of a grid,
608
            e.g. ``size=(10,10)`` is equivalent to ``size=100`` with ``structure=Grid2D()``.
609

610
        `cellclass`:
611
            a cell type (a class inheriting from :class:`pyNN.models.BaseCellType`).
612

613
        `cellparams`:
614
            a dict, or other mapping, containing parameters, which is passed to
615
            the neuron model constructor.
616

617
        `structure`:
618
            a :class:`pyNN.space.Structure` instance, used to specify the
619
            positions of neurons in space.
620

621
        `initial_values`:
622
            a dict, or other mapping, containing initial values for the neuron
623
            state variables.
624

625
        `label`:
626
            a name for the population. One will be auto-generated if this is not
627
            supplied.
628
    """
629
    _nPop = 0
1✔
630

631
    def __init__(self, size, cellclass, cellparams=None, structure=None,
1✔
632
                 initial_values={}, label=None):
633
        """
634
        Create a population of neurons all of the same type.
635
        """
636
        if not hasattr(self, "_simulator"):
1✔
637
            errmsg = "`common.Population` should not be instantiated directly. " \
×
638
                     "You should import Population from a PyNN backend module, " \
639
                     "e.g. pyNN.nest or pyNN.neuron"
640
            raise Exception(errmsg)
×
641
        if not isinstance(size, (int, np.integer)):  # also allow a single integer, for a 1D population
1✔
642
            assert isinstance(
1✔
643
                size, tuple), "`size` must be an integer or a tuple of ints. You have supplied a %s" % type(size)
644
            # check the things inside are ints
645
            for e in size:
1✔
646
                assert isinstance(
1✔
647
                    e, int), "`size` must be an integer or a tuple of ints. Element '%s' is not an int" % str(e)
648

649
            assert structure is None, "If you specify `size` as a tuple you may not specify structure."
1✔
650
            if len(size) == 1:
1✔
651
                structure = space.Line()
1✔
652
            elif len(size) == 2:
1✔
653
                nx, ny = size
1✔
654
                structure = space.Grid2D(nx / float(ny))
1✔
655
            elif len(size) == 3:
1✔
656
                nx, ny, nz = size
1✔
657
                structure = space.Grid3D(nx / float(ny), nx / float(nz))
1✔
658
            else:
659
                raise Exception(
1✔
660
                    "A maximum of 3 dimensions is allowed. What do you think this is, string theory?")
661
            size = int(reduce(operator.mul, size))
1✔
662
        self.size = size
1✔
663
        self.label = label or 'population%d' % Population._nPop
1✔
664
        self._structure = structure or space.Line()
1✔
665
        self._positions = None
1✔
666
        self._is_sorted = True
1✔
667
        if isinstance(cellclass, BaseCellType):
1✔
668
            self.celltype = cellclass
1✔
669
            assert cellparams is None   # cellparams being retained for backwards compatibility, but use is deprecated
1✔
670
        elif issubclass(cellclass, BaseCellType):
1✔
671
            self.celltype = cellclass(**cellparams)
1✔
672
            # emit deprecation warning
673
        else:
674
            raise TypeError(
×
675
                "cellclass must be an instance or subclass of BaseCellType, not a %s" % type(cellclass))
676
        self.annotations = {}
1✔
677
        self.recorder = self._recorder_class(self)
1✔
678
        # Build the arrays of cell ids
679
        # Cells on the local node are represented as ID objects, other cells by integers
680
        # All are stored in a single numpy array for easy lookup by address
681
        # The local cells are also stored in a list, for easy iteration
682
        self._create_cells()
1✔
683
        self.first_id = self.all_cells[0]
1✔
684
        self.last_id = self.all_cells[-1]
1✔
685
        self.initial_values = {}
1✔
686
        all_initial_values = self.celltype.default_initial_values.copy()
1✔
687
        all_initial_values.update(initial_values)
1✔
688
        self.initialize(**all_initial_values)
1✔
689
        Population._nPop += 1
1✔
690

691
    def __repr__(self):
1✔
692
        return "Population(%d, %r, structure=%r, label=%r)" % (self.size, self.celltype, self.structure, self.label)
1✔
693

694
    @property
1✔
695
    def local_cells(self):
1✔
696
        """
697
        An array containing cell ids for the local node.
698
        """
699
        return self.all_cells[self._mask_local]
1✔
700

701
    def id_to_index(self, id):
1✔
702
        """
703
        Given the ID(s) of cell(s) in the Population, return its (their) index
704
        (order in the Population).
705

706
            >>> assert p.id_to_index(p[5]) == 5
707
        """
708
        if not np.iterable(id):
1✔
709
            if not self.first_id <= id <= self.last_id:
1✔
710
                raise ValueError("id should be in the range [%d,%d], actually %d" % (
1✔
711
                    self.first_id, self.last_id, id))
712
            return int(id - self.first_id)  # this assumes ids are consecutive
1✔
713
        else:
714
            if isinstance(id, PopulationView):
1✔
715
                id = id.all_cells
1✔
716
            id = np.array(id)
1✔
717
            if (self.first_id > id.min()) or (self.last_id < id.max()):
1✔
718
                raise ValueError("ids should be in the range [%d,%d], actually [%d, %d]" % (
1✔
719
                    self.first_id, self.last_id, id.min(), id.max()))
720
            return (id - self.first_id).astype(int)  # this assumes ids are consecutive
1✔
721

722
    def id_to_local_index(self, id):
1✔
723
        """
724
        Given the ID(s) of cell(s) in the Population, return its (their) index
725
        (order in the Population), counting only cells on the local MPI node.
726
        """
727
        if self._simulator.state.num_processes > 1:
1✔
728
            return self.local_cells.tolist().index(id)          # probably very slow
×
729
            # return np.nonzero(self.local_cells == id)[0][0] # possibly faster?
730
            # another idea - get global index, use idx-sum(mask_local[:idx])?
731
        else:
732
            return self.id_to_index(id)
1✔
733

734
    def _get_structure(self):
1✔
735
        """The spatial structure of the Population."""
736
        return self._structure
1✔
737

738
    def _set_structure(self, structure):
1✔
739
        assert isinstance(structure, space.BaseStructure)
1✔
740
        if self._structure is None or structure != self._structure:
1✔
741
            self._positions = None  # setting a new structure invalidates previously calculated positions
1✔
742
            self._structure = structure
1✔
743
    structure = property(fget=_get_structure, fset=_set_structure)
1✔
744
    # arguably structure should be read-only, i.e. it is not possible to change it after Population creation
745

746
    def _get_positions(self):
1✔
747
        """
748
        Try to return self._positions. If it does not exist, create it and then
749
        return it.
750
        """
751
        if self._positions is None:
1✔
752
            self._positions = self.structure.generate_positions(self.size)
1✔
753
        assert self._positions.shape == (3, self.size)
1✔
754
        return self._positions
1✔
755

756
    def _set_positions(self, pos_array):
1✔
757
        assert isinstance(pos_array, np.ndarray)
1✔
758
        assert pos_array.shape == (3, self.size), "%s != %s" % (pos_array.shape, (3, self.size))
1✔
759
        self._positions = pos_array.copy()  # take a copy in case pos_array is changed later
1✔
760
        self._structure = None  # explicitly setting positions destroys any previous structure
1✔
761

762
    positions = property(_get_positions, _set_positions,
1✔
763
                         doc="""A 3xN array (where N is the number of neurons in the Population)
764
                         giving the x,y,z coordinates of all the neurons (soma, in the
765
                         case of non-point models).""")
766

767
    def describe(self, template='population_default.txt', engine='default'):
1✔
768
        """
769
        Returns a human-readable description of the population.
770

771
        The output may be customized by specifying a different template
772
        together with an associated template engine (see :mod:`pyNN.descriptions`).
773

774
        If template is None, then a dictionary containing the template context
775
        will be returned.
776
        """
777
        context = {
1✔
778
            "label": self.label,
779
            "celltype": self.celltype.describe(template=None),
780
            "structure": None,
781
            "size": self.size,
782
            "size_local": len(self.local_cells),
783
            "first_id": self.first_id,
784
            "last_id": self.last_id,
785
        }
786
        context.update(self.annotations)
1✔
787
        if len(self.local_cells) > 0:
1✔
788
            first_id = self.local_cells[0]
1✔
789
            context.update({
1✔
790
                "local_first_id": first_id,
791
                "cell_parameters": {}  # first_id.get_parameters(),
792
            })
793
        if self.structure:
1✔
794
            context["structure"] = self.structure.describe(template=None)
1✔
795
        return descriptions.render(engine, template, context)
1✔
796

797

798
class PopulationView(BasePopulation):
1✔
799
    """
800
    A view of a subset of neurons within a Population.
801

802
    In most ways, Populations and PopulationViews have the same behaviour, i.e.
803
    they can be recorded, connected with Projections, etc. It should be noted
804
    that any changes to neurons in a PopulationView will be reflected in the
805
    parent Population and vice versa.
806

807
    It is possible to have views of views.
808

809
    Arguments:
810
        selector:
811
            a slice or numpy mask array. The mask array should either be a
812
            boolean array of the same size as the parent, or an integer array
813
            containing cell indices, i.e. if p.size == 5::
814

815
                PopulationView(p, array([False, False, True, False, True]))
816
                PopulationView(p, array([2,4]))
817
                PopulationView(p, slice(2,5,2))
818

819
            will all create the same view.
820
    """
821

822
    def __init__(self, parent, selector, label=None):
1✔
823
        """
824
        Create a view of a subset of neurons within a parent Population or
825
        PopulationView.
826
        """
827
        if not hasattr(self, "_simulator"):
1✔
828
            errmsg = "`common.PopulationView` should not be instantiated directly. " \
×
829
                     "You should import PopulationView from a PyNN backend module, " \
830
                     "e.g. pyNN.nest or pyNN.neuron"
831
            raise Exception(errmsg)
×
832
        self.parent = parent
1✔
833
        self.mask = selector  # later we can have fancier selectors, for now we just have numpy masks
1✔
834
        # maybe just redefine __getattr__ instead of the following...
835
        self.celltype = self.parent.celltype
1✔
836
        # If the mask is a slice, IDs will be consecutives without duplication.
837
        # If not, then we need to remove duplicated IDs
838
        if not isinstance(self.mask, slice):
1✔
839
            if isinstance(self.mask, list):
1✔
840
                self.mask = np.array(self.mask)
1✔
841
            if self.mask.dtype is np.dtype('bool'):
1✔
842
                if len(self.mask) != len(self.parent):
1✔
843
                    raise Exception("Boolean masks should have the size of Parent Population")
×
844
                self.mask = np.arange(len(self.parent))[self.mask]
1✔
845
            else:
846
                if len(np.unique(self.mask)) != len(self.mask):
1✔
847
                    logging.warning(
×
848
                        "PopulationView can contain only once each ID, duplicated IDs are removed")
849
                    self.mask = np.unique(self.mask)
×
850
                self.mask.sort()  # needed by NEST. Maybe emit a warning or exception if mask is not already ordered?
1✔
851
        self.all_cells = self.parent.all_cells[self.mask]
1✔
852
        idx = np.argsort(self.all_cells)
1✔
853
        self._is_sorted = np.all(idx == np.arange(len(self.all_cells)))
1✔
854
        self.size = len(self.all_cells)
1✔
855
        self.label = label or "view of '%s' with size %s" % (parent.label, self.size)
1✔
856
        self._mask_local = self.parent._mask_local[self.mask]
1✔
857
        self.local_cells = self.all_cells[self._mask_local]
1✔
858
        # only works if we assume all_cells is sorted, otherwise could use min()
859
        self.first_id = np.min(self.all_cells)
1✔
860
        self.last_id = np.max(self.all_cells)
1✔
861
        self.annotations = {}
1✔
862
        self.recorder = self.parent.recorder
1✔
863
        self._record_filter = self.all_cells
1✔
864

865
    def __repr__(self):
1✔
866
        return "PopulationView(parent=%r, selector=%r, label=%r)" % (self.parent, self.mask, self.label)
×
867

868
    @property
1✔
869
    def initial_values(self):
1✔
870
        # this is going to be complex - if we keep initial_values as a dict,
871
        # need to return a dict-like object that takes account of self.mask
872
        raise NotImplementedError
×
873

874
    @property
1✔
875
    def structure(self):
1✔
876
        """The spatial structure of the parent Population."""
877
        return self.parent.structure
1✔
878
    # should we allow setting structure for a PopulationView? Maybe if the
879
    # parent has some kind of CompositeStructure?
880

881
    @property
1✔
882
    def positions(self):
1✔
883
        # make positions N,3 instead of 3,N to avoid all this transposing?
884
        return self.parent.positions.T[self.mask].T
1✔
885

886
    def id_to_index(self, id):
1✔
887
        """
888
        Given the ID(s) of cell(s) in the PopulationView, return its/their
889
        index/indices (order in the PopulationView).
890

891
            >>> assert pv.id_to_index(pv[3]) == 3
892
        """
893
        if not np.iterable(id):
1✔
894
            if self._is_sorted:
1✔
895
                if id not in self.all_cells:
1✔
896
                    raise IndexError("ID %s not present in the View" % id)
1✔
897
                return np.searchsorted(self.all_cells, id)
1✔
898
            else:
899
                result = np.where(self.all_cells == id)[0]
×
900
            if len(result) == 0:
×
901
                raise IndexError("ID %s not present in the View" % id)
×
902
            else:
903
                return result
×
904
        else:
905
            if self._is_sorted:
1✔
906
                return np.searchsorted(self.all_cells, id)
1✔
907
            else:
908
                result = np.array([], dtype=int)
×
909
                for item in id:
×
910
                    data = np.where(self.all_cells == item)[0]
×
911
                    if len(data) == 0:
×
912
                        raise IndexError("ID %s not present in the View" % item)
×
913
                    elif len(data) > 1:
×
914
                        raise Exception("ID %s is duplicated in the View" % item)
×
915
                    else:
916
                        result = np.append(result, data)
×
917
                return result
×
918

919
    @property
1✔
920
    def grandparent(self):
1✔
921
        """
922
        Returns the parent Population at the root of the tree (since the
923
        immediate parent may itself be a PopulationView).
924

925
        The name "grandparent" is of course a little misleading, as it could
926
        be just the parent, or the great, great, great, ..., grandparent.
927
        """
928
        if hasattr(self.parent, "parent"):
1✔
929
            return self.parent.grandparent
1✔
930
        else:
931
            return self.parent
1✔
932

933
    def index_in_grandparent(self, indices):
1✔
934
        """
935
        Given an array of indices, return the indices in the parent population
936
        at the root of the tree.
937
        """
938
        indices_in_parent = np.arange(self.parent.size)[self.mask][indices]
1✔
939
        if hasattr(self.parent, "parent"):
1✔
940
            return self.parent.index_in_grandparent(indices_in_parent)
1✔
941
        else:
942
            return indices_in_parent
1✔
943

944
    def index_from_parent_index(self, indices):
1✔
945
        """
946
        Given an index(indices) in the parent population, return
947
        the index(indices) within this view.
948
        """
949
        # todo: add check that all indices correspond to cells that are in this view
950
        if isinstance(self.mask, slice):
1✔
951
            start = self.mask.start or 0
1✔
952
            step = self.mask.step or 1
1✔
953
            return (indices - start) / step
1✔
954
        else:
955
            if isinstance(indices, int):
1✔
956
                return np.nonzero(self.mask == indices)[0][0]
×
957
            elif isinstance(indices, np.ndarray):
1✔
958
                # Lots of ways to do this. Some profiling is in order.
959
                # - https://stackoverflow.com/questions/16992713/translate-every-element-in-numpy-array-according-to-key
960
                # - https://stackoverflow.com/questions/3403973/fast-replacement-of-values-in-a-numpy-array
961
                # - https://stackoverflow.com/questions/13572448/replace-values-of-a-numpy-index-array-with-values-of-a-list
962
                parent_indices = self.mask  # assert mask is sorted
1✔
963
                view_indices = np.arange(self.size)
1✔
964
                index = np.digitize(indices, parent_indices, right=True)
1✔
965
                return view_indices[index]
1✔
966
            else:
967
                raise ValueError("indices must be an integer or an array of integers")
×
968

969
    def __eq__(self, other):
1✔
970
        """
971
        Determine whether two views are the same.
972
        """
973
        return not self.__ne__(other)
1✔
974

975
    def __ne__(self, other):
1✔
976
        """
977
        Determine whether two views are different.
978
        """
979
        # We can't use the self.mask, as different masks can select the same cells
980
        # (e.g. slices vs arrays), therefore we have to use self.all_cells
981
        if isinstance(other, PopulationView):
1✔
982
            return self.parent != other.parent or not np.array_equal(self.all_cells, other.all_cells)
1✔
983
        elif isinstance(other, Population):
1✔
984
            return self.parent != other or not np.array_equal(self.all_cells, other.all_cells)
1✔
985
        else:
986
            return True
×
987

988
    def describe(self, template='populationview_default.txt', engine='default'):
1✔
989
        """
990
        Returns a human-readable description of the population view.
991

992
        The output may be customized by specifying a different template
993
        togther with an associated template engine (see ``pyNN.descriptions``).
994

995
        If template is None, then a dictionary containing the template context
996
        will be returned.
997
        """
998
        context = {"label": self.label,
1✔
999
                   "parent": self.parent.label,
1000
                   "mask": self.mask,
1001
                   "size": self.size}
1002
        context.update(self.annotations)
1✔
1003
        return descriptions.render(engine, template, context)
1✔
1004

1005

1006
class Assembly(object):
1✔
1007
    """
1008
    A group of neurons, may be heterogeneous, in contrast to a Population where
1009
    all the neurons are of the same type.
1010

1011
    Arguments:
1012
        populations:
1013
            Populations or PopulationViews
1014
        kwargs:
1015
            May contain a keyword argument 'label'
1016
    """
1017
    _count = 0
1✔
1018

1019
    def __init__(self, *populations, **kwargs):
1✔
1020
        """
1021
        Create an Assembly of Populations and/or PopulationViews.
1022
        """
1023
        if not hasattr(self, "_simulator"):
1✔
1024
            errmsg = "`common.Assembly` should not be instantiated directly. " \
×
1025
                     "You should import Assembly from a PyNN backend module, " \
1026
                     "e.g. pyNN.nest or pyNN.neuron"
1027
            raise Exception(errmsg)
×
1028
        if kwargs:
1✔
1029
            assert list(kwargs.keys()) == ['label']
1✔
1030
        self.populations = []
1✔
1031
        for p in populations:
1✔
1032
            self._insert(p)
1✔
1033
        self.label = kwargs.get('label', 'assembly%d' % Assembly._count)
1✔
1034
        assert isinstance(self.label, str), "label must be a string"
1✔
1035
        self.annotations = {}
1✔
1036
        Assembly._count += 1
1✔
1037

1038
    def __repr__(self):
1✔
1039
        return "Assembly(*%r, label=%r)" % (self.populations, self.label)
1✔
1040

1041
    def _insert(self, element):
1✔
1042
        if not isinstance(element, BasePopulation):
1✔
1043
            raise TypeError("argument is a %s, not a Population." % type(element).__name__)
1✔
1044
        if isinstance(element, PopulationView):
1✔
1045
            if not element.parent in self.populations:
1✔
1046
                double = False
1✔
1047
                for p in self.populations:
1✔
1048
                    data = np.concatenate((p.all_cells, element.all_cells))
1✔
1049
                    if len(np.unique(data)) != len(p.all_cells) + len(element.all_cells):
1✔
1050
                        logging.warning(
×
1051
                            'Adding a PopulationView to an Assembly containing elements already present is not posible')
1052
                        double = True  # Should we automatically remove duplicated IDs ?
×
1053
                        break
×
1054
                if not double:
1✔
1055
                    self.populations.append(element)
1✔
1056
            else:
1057
                logging.warning(
1✔
1058
                    'Adding a PopulationView to an Assembly when parent Population is there is not possible')
1059
        elif isinstance(element, BasePopulation):
1✔
1060
            if not element in self.populations:
1✔
1061
                self.populations.append(element)
1✔
1062
            else:
1063
                logging.warning('Adding a Population twice in an Assembly is not possible')
1✔
1064

1065
    @property
1✔
1066
    def local_cells(self):
1✔
1067
        result = self.populations[0].local_cells
1✔
1068
        for p in self.populations[1:]:
1✔
1069
            result = np.concatenate((result, p.local_cells))
1✔
1070
        return result
1✔
1071

1072
    @property
1✔
1073
    def all_cells(self):
1✔
1074
        result = self.populations[0].all_cells
1✔
1075
        for p in self.populations[1:]:
1✔
1076
            result = np.concatenate((result, p.all_cells))
1✔
1077
        return result
1✔
1078

1079
    def all(self):
1✔
1080
        """Iterator over cell ids on all nodes."""
1081
        return iter(self.all_cells)
1✔
1082

1083
    @property
1✔
1084
    def _is_sorted(self):
1✔
1085
        idx = np.argsort(self.all_cells)
1✔
1086
        return np.all(idx == np.arange(len(self.all_cells)))
1✔
1087

1088
    @property
1✔
1089
    def _homogeneous_synapses(self):
1✔
1090
        cb = [p.celltype.conductance_based for p in self.populations]
1✔
1091
        return all(cb) or not any(cb)
1✔
1092

1093
    @property
1✔
1094
    def conductance_based(self):
1✔
1095
        """
1096
        `True` if the post-synaptic response is modelled as a change
1097
        in conductance, `False` if a change in current.
1098
        """
1099
        return all(p.celltype.conductance_based for p in self.populations)
1✔
1100

1101
    @property
1✔
1102
    def receptor_types(self):
1✔
1103
        """
1104
        Return a list of receptor types that are common to all populations
1105
        within the assembly.
1106
        """
1107
        rts = self.populations[0].celltype.receptor_types
1✔
1108
        if len(self.populations) > 1:
1✔
1109
            rts = set(rts)
1✔
1110
            for p in self.populations[1:]:
1✔
1111
                rts = rts.intersection(set(p.celltype.receptor_types))
1✔
1112
        return list(rts)
1✔
1113

1114
    def find_units(self, variable):
1✔
1115
        """
1116
        Returns units of the specified variable or parameter, as a string.
1117
        Works for all the recordable variables and neuron parameters of all standard models.
1118
        """
1119
        units = set(p.find_units(variable) for p in self.populations)
×
1120
        if len(units) > 1:
×
1121
            raise ValueError("Inconsistent units")
×
1122
        return units
×
1123

1124
    @property
1✔
1125
    def _mask_local(self):
1✔
1126
        result = self.populations[0]._mask_local
1✔
1127
        for p in self.populations[1:]:
1✔
1128
            result = np.concatenate((result, p._mask_local))
1✔
1129
        return result
1✔
1130

1131
    @property
1✔
1132
    def first_id(self):
1✔
1133
        return np.min(self.all_cells)
1✔
1134

1135
    @property
1✔
1136
    def last_id(self):
1✔
1137
        return np.max(self.all_cells)
1✔
1138

1139
    def id_to_index(self, id):
1✔
1140
        """
1141
        Given the ID(s) of cell(s) in the Assembly, return its (their) index
1142
        (order in the Assembly)::
1143

1144
            >>> assert p.id_to_index(p[5]) == 5
1145
            >>> assert p.id_to_index(p.index([1, 2, 3])) == [1, 2, 3]
1146
        """
1147
        all_cells = self.all_cells
1✔
1148
        if not np.iterable(id):
1✔
1149
            if self._is_sorted:
1✔
1150
                return np.searchsorted(all_cells, id)
1✔
1151
            else:
1152
                result = np.where(all_cells == id)[0]
1✔
1153
            if len(result) == 0:
1✔
1154
                raise IndexError("ID %s not present in the View" % id)
1✔
1155
            else:
1156
                return result
1✔
1157
        else:
1158
            if self._is_sorted:
1✔
1159
                return np.searchsorted(all_cells, id)
1✔
1160
            else:
1161
                result = np.array([], dtype=int)
1✔
1162
                for item in id:
1✔
1163
                    data = np.where(all_cells == item)[0]
1✔
1164
                    if len(data) == 0:
1✔
1165
                        raise IndexError("ID %s not present in the Assembly" % item)
×
1166
                    elif len(data) > 1:
1✔
1167
                        raise Exception("ID %s is duplicated in the Assembly" % item)
×
1168
                    else:
1169
                        result = np.append(result, data)
1✔
1170
                return result
1✔
1171

1172
    @property
1✔
1173
    def positions(self):
1✔
1174
        result = self.populations[0].positions
1✔
1175
        for p in self.populations[1:]:
1✔
1176
            result = np.hstack((result, p.positions))
1✔
1177
        return result
1✔
1178

1179
    @property
1✔
1180
    def size(self):
1✔
1181
        return sum(p.size for p in self.populations)
1✔
1182

1183
    def __iter__(self):
1✔
1184
        """
1185
        Iterator over cells in all populations within the Assembly, for cells
1186
        on the local MPI node.
1187
        """
1188
        iterators = [iter(p) for p in self.populations]
1✔
1189
        return chain(*iterators)
1✔
1190

1191
    def __len__(self):
1✔
1192
        """Return the total number of cells in the population (all nodes)."""
1193
        return self.size
1✔
1194

1195
    def __getitem__(self, index):
1✔
1196
        """
1197
        Where `index` is an integer, return an ID.
1198
        Where `index` is a slice, tuple, list or numpy array, return a new Assembly
1199
        consisting of appropriate populations and (possibly newly created)
1200
        population views.
1201
        """
1202
        count = 0
1✔
1203
        boundaries = [0]
1✔
1204
        for p in self.populations:
1✔
1205
            count += p.size
1✔
1206
            boundaries.append(count)
1✔
1207
        boundaries = np.array(boundaries, dtype=int)
1✔
1208

1209
        if isinstance(index, (int, np.integer)):  # return an ID
1✔
1210
            pindex = boundaries[1:].searchsorted(index, side='right')
1✔
1211
            return self.populations[pindex][index - boundaries[pindex]]
1✔
1212
        elif isinstance(index, (slice, tuple, list, np.ndarray)):
1✔
1213
            if isinstance(index, slice) or (hasattr(index, "dtype") and index.dtype == bool):
1✔
1214
                indices = np.arange(self.size)[index]
1✔
1215
            else:
1216
                indices = np.array(index)
1✔
1217
            pindices = boundaries[1:].searchsorted(indices, side='right')
1✔
1218
            views = [self.populations[i][indices[pindices == i] - boundaries[i]]
1✔
1219
                     for i in np.unique(pindices)]
1220
            return self.__class__(*views)
1✔
1221
        else:
1222
            raise TypeError("indices must be integers, slices, lists, arrays, not %s" %
×
1223
                            type(index).__name__)
1224

1225
    def __add__(self, other):
1✔
1226
        """
1227
        An Assembly may be added to a Population, PopulationView or Assembly
1228
        with the '+' operator, returning a new Assembly, e.g.::
1229

1230
            a2 = a1 + p
1231
        """
1232
        if isinstance(other, BasePopulation):
1✔
1233
            return self.__class__(*(self.populations + [other]))
1✔
1234
        elif isinstance(other, Assembly):
1✔
1235
            return self.__class__(*(self.populations + other.populations))
1✔
1236
        else:
1237
            raise TypeError("can only add a Population or another Assembly to an Assembly")
1✔
1238

1239
    def __iadd__(self, other):
1✔
1240
        """
1241
        A Population, PopulationView or Assembly may be added to an existing
1242
        Assembly using the '+=' operator, e.g.::
1243

1244
            a += p
1245
        """
1246
        if isinstance(other, BasePopulation):
1✔
1247
            self._insert(other)
1✔
1248
        elif isinstance(other, Assembly):
1✔
1249
            for p in other.populations:
1✔
1250
                self._insert(p)
1✔
1251
        else:
1252
            raise TypeError("can only add a Population or another Assembly to an Assembly")
1✔
1253
        return self
1✔
1254

1255
    def sample(self, n, rng=None):
1✔
1256
        """
1257
        Randomly sample `n` cells from the Assembly, and return a Assembly
1258
        object.
1259
        """
1260
        assert isinstance(n, int)
1✔
1261
        if not rng:
1✔
1262
            rng = random.NumpyRNG()
×
1263
        indices = rng.permutation(np.arange(len(self), dtype=int))[0:n]
1✔
1264
        logger.debug("The %d cells recorded have indices %s" % (n, indices))
1✔
1265
        logger.debug("%s.sample(%s)", self.label, n)
1✔
1266
        return self[indices]
1✔
1267

1268
    def initialize(self, **initial_values):
1✔
1269
        """
1270
        Set the initial values of the state variables of the neurons in
1271
        this assembly.
1272
        """
1273
        for p in self.populations:
1✔
1274
            p.initialize(**initial_values)
1✔
1275

1276
    def get(self, parameter_names, gather=False, simplify=True):
1✔
1277
        """
1278
        Get the values of the given parameters for every local cell in the
1279
        Assembly, or, if gather=True, for all cells in the Assembly.
1280
        """
1281
        if isinstance(parameter_names, str):
1✔
1282
            parameter_names = (parameter_names,)
1✔
1283
            return_list = False
1✔
1284
        else:
1285
            return_list = True
1✔
1286

1287
        parameters = defaultdict(list)
1✔
1288
        for p in self.populations:
1✔
1289
            population_values = p.get(parameter_names, gather, simplify=False)
1✔
1290
            for name, arr in zip(parameter_names, population_values):
1✔
1291
                parameters[name].append(arr)
1✔
1292
        for name, value_list in parameters.items():
1✔
1293
            parameters[name] = np.hstack(value_list)
1✔
1294
            if simplify:
1✔
1295
                parameters[name] = simplify_parameter_array(parameters[name])
1✔
1296
        values = [parameters[name] for name in parameter_names]
1✔
1297
        if return_list:
1✔
1298
            return values
1✔
1299
        else:
1300
            assert len(parameter_names) == 1
1✔
1301
            return values[0]
1✔
1302

1303
    def set(self, **parameters):
1✔
1304
        """
1305
        Set one or more parameters for every cell in the Assembly.
1306

1307
        Values passed to set() may be:
1308
            (1) single values
1309
            (2) RandomDistribution objects
1310
            (3) mapping functions, where a mapping function accepts a single
1311
                argument (the cell index) and returns a single value.
1312

1313
        Here, a "single value" may be either a single number or a list/array of
1314
        numbers (e.g. for spike times).
1315
        """
1316
        for p in self.populations:
×
1317
            p.set(**parameters)
×
1318

1319
    @deprecated("set(parametername=rand_distr)")
1✔
1320
    def rset(self, parametername, rand_distr):
1✔
1321
        self.set(parametername=rand_distr)
×
1322

1323
    def record(self, variables, to_file=None, sampling_interval=None):
1✔
1324
        """
1325
        Record the specified variable or variables for all cells in the Assembly.
1326

1327
        `variables` may be either a single variable name or a list of variable
1328
        names. For a given celltype class, `celltype.recordable` contains a list of
1329
        variables that can be recorded for that celltype.
1330

1331
        If specified, `to_file` should be either a filename or a Neo IO instance and `write_data()`
1332
        will be automatically called when `end()` is called.
1333
        """
1334
        for p in self.populations:
1✔
1335
            p.record(variables, to_file, sampling_interval)
1✔
1336

1337
    @deprecated("record('v')")
1✔
1338
    def record_v(self, to_file=True):
1✔
1339
        """Record the membrane potential from all cells in the Assembly."""
1340
        self.record('v', to_file)
1✔
1341

1342
    @deprecated("record(['gsyn_exc', 'gsyn_inh'])")
1✔
1343
    def record_gsyn(self, to_file=True):
1✔
1344
        """Record synaptic conductances from all cells in the Assembly."""
1345
        self.record(['gsyn_exc', 'gsyn_inh'], to_file)
1✔
1346

1347
    def get_population(self, label):
1✔
1348
        """
1349
        Return the Population/PopulationView from within the Assembly that has
1350
        the given label. If no such Population exists, raise KeyError.
1351
        """
1352
        for p in self.populations:
1✔
1353
            if label == p.label:
1✔
1354
                return p
1✔
1355
        raise KeyError("Assembly does not contain a population with the label %s" % label)
1✔
1356

1357
    def save_positions(self, file):
1✔
1358
        """
1359
        Save positions to file. The output format is id x y z
1360
        """
1361
        if isinstance(file, str):
1✔
1362
            file = files.StandardTextFile(file, mode='w')
×
1363
        cells = self.all_cells
1✔
1364
        result = np.empty((len(cells), 4))
1✔
1365
        result[:, 0] = np.array([self.id_to_index(id) for id in cells])
1✔
1366
        result[:, 1:4] = self.positions.T
1✔
1367
        if self._simulator.state.mpi_rank == 0:
1✔
1368
            file.write(result, {'assembly': self.label})
1✔
1369
            file.close()
1✔
1370

1371
    @property
1✔
1372
    def position_generator(self):
1✔
1373
        def gen(i):
1✔
1374
            return self.positions[:, i]
×
1375
        return gen
1✔
1376

1377
    def get_data(self, variables='all', gather=True, clear=False, annotations=None):
1✔
1378
        """
1379
        Return a Neo `Block` containing the data (spikes, state variables)
1380
        recorded from the Assembly.
1381

1382
        `variables` - either a single variable name or a list of variable names
1383
                      Variables must have been previously recorded, otherwise an
1384
                      Exception will be raised.
1385

1386
        For parallel simulators, if `gather` is True, all data will be gathered
1387
        to all nodes and the Neo `Block` will contain data from all nodes.
1388
        Otherwise, the Neo `Block` will contain only data from the cells
1389
        simulated on the local node.
1390

1391
        If `clear` is True, recorded data will be deleted from the `Assembly`.
1392
        """
1393
        name = self.label
1✔
1394
        description = self.describe()
1✔
1395
        blocks = [p.get_data(variables, gather, clear) for p in self.populations]
1✔
1396
        # adjust channel_ids to match assembly channel indices
1397
        offset = 0
1✔
1398
        for block, p in zip(blocks, self.populations):
1✔
1399
            for segment in block.segments:
1✔
1400
                for signal_array in segment.analogsignals:
1✔
1401
                    signal_array.array_annotations["channel_index"] += offset
1✔
1402
            offset += p.size
1✔
1403
        for i, block in enumerate(blocks):
1✔
1404
            logger.debug("%d: %s", i, block.name)
1✔
1405
            for j, segment in enumerate(block.segments):
1✔
1406
                logger.debug("  %d: %s", j, segment.name)
1✔
1407
                for arr in segment.analogsignals:
1✔
1408
                    logger.debug("    %s %s", arr.shape, arr.name)
1✔
1409
        merged_block = blocks[0]
1✔
1410
        for block in blocks[1:]:
1✔
1411
            merged_block.merge(block)
1✔
1412
        merged_block.name = name
1✔
1413
        merged_block.description = description
1✔
1414
        if annotations:
1✔
1415
            merged_block.annotate(**annotations)
×
1416
        return merged_block
1✔
1417

1418
    @deprecated("get_data('spikes')")
1✔
1419
    def getSpikes(self, gather=True, compatible_output=True):
1✔
1420
        return self.get_data('spikes', gather)
1✔
1421

1422
    @deprecated("get_data('v')")
1✔
1423
    def get_v(self, gather=True, compatible_output=True):
1✔
1424
        return self.get_data('v', gather)
1✔
1425

1426
    @deprecated("get_data(['gsyn_exc', 'gsyn_inh'])")
1✔
1427
    def get_gsyn(self, gather=True, compatible_output=True):
1✔
1428
        return self.get_data(['gsyn_exc', 'gsyn_inh'], gather)
1✔
1429

1430
    def mean_spike_count(self, gather=True):
1✔
1431
        """
1432
        Returns the mean number of spikes per neuron.
1433
        """
1434
        spike_counts = self.get_spike_counts()
1✔
1435
        total_spikes = sum(spike_counts.values())
1✔
1436
        if self._simulator.state.mpi_rank == 0 or not gather:  # should maybe use allgather, and get the numbers on all nodes
1✔
1437
            return float(total_spikes) / len(spike_counts)
1✔
1438
        else:
1439
            return np.nan
×
1440

1441
    def get_spike_counts(self, gather=True):
1✔
1442
        """
1443
        Returns the number of spikes for each neuron.
1444
        """
1445
        try:
1✔
1446
            spike_counts = self.populations[0].recorder.count(
1✔
1447
                'spikes', gather, self.populations[0]._record_filter)
1448
        except errors.NothingToWriteError:
×
1449
            spike_counts = {}
×
1450
        for p in self.populations[1:]:
1✔
1451
            try:
1✔
1452
                spike_counts.update(p.recorder.count('spikes', gather, p._record_filter))
1✔
1453
            except errors.NothingToWriteError:
×
1454
                pass
×
1455
        return spike_counts
1✔
1456

1457
    def write_data(self, io, variables='all', gather=True, clear=False, annotations=None):
1✔
1458
        """
1459
        Write recorded data to file, using one of the file formats supported by
1460
        Neo.
1461

1462
        `io`:
1463
            a Neo IO instance
1464
        `variables`:
1465
            either a single variable name or a list of variable names.
1466
            Variables must have been previously recorded, otherwise an
1467
            Exception will be raised.
1468

1469
        For parallel simulators, if `gather` is True, all data will be gathered
1470
        to the master node and a single output file created there. Otherwise, a
1471
        file will be written on each node, containing only data from the cells
1472
        simulated on that node.
1473

1474
        If `clear` is True, recorded data will be deleted from the `Population`.
1475
        """
1476
        if isinstance(io, str):
×
1477
            io = recording.get_io(io)
×
1478
        if gather is False and self._simulator.state.num_processes > 1:
×
1479
            io.filename += '.%d' % self._simulator.state.mpi_rank
×
1480
        logger.debug("Recorder is writing '%s' to file '%s' with gather=%s" % (
×
1481
            variables, io.filename, gather))
1482
        data = self.get_data(variables, gather, clear, annotations)
×
1483
        if self._simulator.state.mpi_rank == 0 or gather is False:
×
1484
            logger.debug("Writing data to file %s" % io)
×
1485
            io.write(data)
×
1486

1487
    @deprecated("write_data(file, 'spikes')")
1✔
1488
    def printSpikes(self, file, gather=True, compatible_output=True):
1✔
1489
        self.write_data(file, 'spikes', gather)
1✔
1490

1491
    @deprecated("write_data(file, 'v')")
1✔
1492
    def print_v(self, file, gather=True, compatible_output=True):
1✔
1493
        self.write_data(file, 'v', gather)
1✔
1494

1495
    @deprecated("write_data(['gsyn_exc', 'gsyn_inh'])")
1✔
1496
    def print_gsyn(self, file, gather=True, compatible_output=True):
1✔
1497
        self.write_data(file, ['gsyn_exc', 'gsyn_inh'], gather)
1✔
1498

1499
    def inject(self, current_source):
1✔
1500
        """
1501
        Connect a current source to all cells in the Assembly.
1502
        """
1503
        for p in self.populations:
1✔
1504
            current_source.inject_into(p)
1✔
1505

1506
    @property
1✔
1507
    def injectable(self):
1✔
1508
        return all(p.injectable for p in self.populations)
×
1509

1510
    def describe(self, template='assembly_default.txt', engine='default'):
1✔
1511
        """
1512
        Returns a human-readable description of the assembly.
1513

1514
        The output may be customized by specifying a different template
1515
        togther with an associated template engine (see ``pyNN.descriptions``).
1516

1517
        If template is None, then a dictionary containing the template context
1518
        will be returned.
1519
        """
1520
        context = {"label": self.label,
1✔
1521
                   "populations": [p.describe(template=None) for p in self.populations]}
1522
        return descriptions.render(engine, template, context)
1✔
1523

1524
    def get_annotations(self, annotation_keys, simplify=True):
1✔
1525
        """
1526
        Get the values of the given annotations for each population in the Assembly.
1527
        """
1528
        if isinstance(annotation_keys, str):
×
1529
            annotation_keys = (annotation_keys,)
×
1530
        annotations = defaultdict(list)
×
1531

1532
        for key in annotation_keys:
×
1533
            is_array_annotation = False
×
1534
            for p in self.populations:
×
1535
                annotation = p.annotations[key]
×
1536
                annotations[key].append(annotation)
×
1537
                is_array_annotation = isinstance(annotation, np.ndarray)
×
1538
            if is_array_annotation:
×
1539
                annotations[key] = np.hstack(annotations[key])
×
1540
            if simplify:
×
1541
                annotations[key] = simplify_parameter_array(np.array(annotations[key]))
×
1542
        return annotations
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc