• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SpiNNakerManchester / sPyNNaker / 6942302574

21 Nov 2023 10:15AM UTC coverage: 63.153% (+1.6%) from 61.529%
6942302574

Pull #1342

github

Christian-B
receptor_type
Pull Request #1342: Type Annotations and Checking

1960 of 4482 branches covered (0.0%)

Branch coverage included in aggregate %.

4148 of 5078 new or added lines in 233 files covered. (81.69%)

193 existing lines in 78 files now uncovered.

12725 of 18771 relevant lines covered (67.79%)

0.68 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.29
/spynnaker/pyNN/models/spike_source/spike_source_array_vertex.py
1
# Copyright (c) 2017 The University of Manchester
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     https://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
from __future__ import annotations
1✔
15
from collections import Counter
1✔
16
import logging
1✔
17
import numpy
1✔
18
from numpy.typing import ArrayLike, NDArray
1✔
19
from pyNN.space import Grid2D, Grid3D, BaseStructure
1✔
20
from typing import (
1✔
21
    Collection, List, Optional, Sequence, Tuple, Union, TYPE_CHECKING)
22
from typing_extensions import TypeAlias, TypeGuard
1✔
23
from spinn_utilities.log import FormatAdapter
1✔
24
from spinn_utilities.overrides import overrides
1✔
25
from spinn_utilities.config_holder import get_config_int
1✔
26
from spinn_utilities.ranged.abstract_sized import Selector
1✔
27
from pacman.model.graphs.common import Slice
1✔
28
from pacman.model.resources import AbstractSDRAM
1✔
29
from pacman.model.partitioner_splitters import AbstractSplitterCommon
1✔
30
from spinn_front_end_common.utility_models import ReverseIpTagMultiCastSource
1✔
31
from spynnaker.pyNN.data import SpynnakerDataView
1✔
32
from spynnaker.pyNN.utilities import constants
1✔
33
from spynnaker.pyNN.models.common import PopulationApplicationVertex
1✔
34
from spynnaker.pyNN.models.common.types import Names
1✔
35
from spynnaker.pyNN.models.abstract_models import SupportsStructure
1✔
36
from spynnaker.pyNN.utilities.buffer_data_type import BufferDataType
1✔
37
from spynnaker.pyNN.utilities.ranged import SpynnakerRangedList
1✔
38
from spynnaker.pyNN.models.common import ParameterHolder
1✔
39
from spynnaker.pyNN.models.common.types import Spikes
1✔
40
from .spike_source_array_machine_vertex import SpikeSourceArrayMachineVertex
1✔
41
if TYPE_CHECKING:
42
    from .spike_source_array import SpikeSourceArray
43

44
logger = FormatAdapter(logging.getLogger(__name__))
1✔
45

46
# Cutoff to warn too many spikes sent at one time
47
TOO_MANY_SPIKES = 100
1✔
48

49
_Number: TypeAlias = Union[int, float]
1✔
50

51
_SingleList: TypeAlias = Union[
1✔
52
    Sequence[_Number], NDArray[numpy.integer]]
53
_DoubleList: TypeAlias = Union[
1✔
54
    Sequence[Sequence[_Number]], NDArray[numpy.integer]]
55

56

57
def _is_double_list(value: Spikes) -> TypeGuard[_DoubleList]:
1✔
58
    return not isinstance(value, (float, int)) and bool(len(value)) and \
1✔
59
        hasattr(value[0], "__len__")
60

61

62
def _is_single_list(value: Spikes) -> TypeGuard[_SingleList]:
1✔
63
    # USE _is_double_list first!
64
    return not isinstance(value, (float, int)) and bool(len(value))
1✔
65

66

67
def _is_singleton(value: Spikes) -> TypeGuard[_Number]:
1✔
68
    return isinstance(value, (float, int))
1✔
69

70

71
def _as_numpy_ticks(
1✔
72
        times: ArrayLike, time_step: float) -> NDArray[numpy.int64]:
73
    return numpy.ceil(
1✔
74
        numpy.floor(numpy.array(times) * 1000.0) / time_step).astype("int64")
75

76

77
def _send_buffer_times(
1✔
78
        spike_times: Spikes, time_step: float) -> Union[
79
            NDArray[numpy.int64], List[NDArray[numpy.int64]]]:
80
    # Convert to ticks
81
    if _is_double_list(spike_times):
1✔
82
        return [_as_numpy_ticks(times, time_step) for times in spike_times]
1✔
83
    elif _is_single_list(spike_times):
1✔
84
        return _as_numpy_ticks(spike_times, time_step)
1✔
85
    elif _is_singleton(spike_times):
1!
NEW
86
        return _as_numpy_ticks([spike_times], time_step)
×
87
    else:
88
        return []
1✔
89

90

91
class SpikeSourceArrayVertex(
1✔
92
        ReverseIpTagMultiCastSource, PopulationApplicationVertex,
93
        SupportsStructure):
94
    """
95
    Model for play back of spikes.
96
    """
97
    __slots__ = (
1✔
98
        "__model_name",
99
        "__model",
100
        "__structure",
101
        "_spike_times",
102
        "__n_colour_bits")
103

104
    #: ID of the recording region used for recording transmitted spikes.
105
    SPIKE_RECORDING_REGION_ID = 0
1✔
106

107
    def __init__(
1✔
108
            self, n_neurons: int, spike_times: Spikes, label: str,
109
            max_atoms_per_core: int, model: SpikeSourceArray,
110
            splitter: Optional[AbstractSplitterCommon],
111
            n_colour_bits: Optional[int]):
112
        # pylint: disable=too-many-arguments
113
        self.__model_name = "SpikeSourceArray"
1✔
114
        self.__model = model
1✔
115
        self.__structure: Optional[BaseStructure] = None
1✔
116

117
        if spike_times is None:
1✔
118
            spike_times = []
1✔
119
        self._spike_times = SpynnakerRangedList(
1✔
120
            n_neurons, spike_times,
121
            use_list_as_value=not _is_double_list(spike_times))
122

123
        time_step = SpynnakerDataView.get_simulation_time_step_us()
1✔
124

125
        super().__init__(
1✔
126
            n_keys=n_neurons, label=label,
127
            max_atoms_per_core=max_atoms_per_core,
128
            send_buffer_times=_send_buffer_times(spike_times, time_step),
129
            send_buffer_partition_id=constants.SPIKE_PARTITION_ID,
130
            splitter=splitter)
131

132
        self._check_spike_density(spike_times)
1✔
133
        # Do colouring
134
        if n_colour_bits is None:
1!
135
            self.__n_colour_bits = get_config_int(
1✔
136
                "Simulation", "n_colour_bits")
137
        else:
NEW
138
            self.__n_colour_bits = n_colour_bits
×
139

140
    @overrides(ReverseIpTagMultiCastSource.create_machine_vertex)
1✔
141
    def create_machine_vertex(
1✔
142
            self, vertex_slice: Slice, sdram: AbstractSDRAM,
143
            label: Optional[str] = None) -> SpikeSourceArrayMachineVertex:
144
        send_buffer_times = self._filtered_send_buffer_times(vertex_slice)
1✔
145
        machine_vertex = SpikeSourceArrayMachineVertex(
1✔
146
            label=label, app_vertex=self, vertex_slice=vertex_slice,
147
            eieio_params=self._eieio_params,
148
            send_buffer_times=send_buffer_times)
149
        machine_vertex.enable_recording(self._is_recording)
1✔
150
        # Known issue with ReverseIPTagMulticastSourceMachineVertex
151
        if sdram:
1!
152
            assert sdram == machine_vertex.sdram_required
1✔
153
        return machine_vertex
1✔
154

155
    def _check_spike_density(self, spike_times: Spikes):
1✔
156
        if _is_double_list(spike_times):
1✔
157
            self._check_density_double_list(spike_times)
1✔
158
        elif _is_single_list(spike_times):
1✔
159
            self._check_density_single_list(spike_times)
1✔
160
        elif _is_singleton(spike_times):
1!
NEW
161
            pass
×
162
        else:
163
            logger.warning("SpikeSourceArray has no spike times")
1✔
164

165
    def _check_density_single_list(self, spike_times: _SingleList):
1✔
166
        counter = Counter(spike_times)
1✔
167
        top = counter.most_common(1)
1✔
168
        val, count = top[0]
1✔
169
        if count * self.n_atoms > TOO_MANY_SPIKES:
1✔
170
            if self.n_atoms > 1:
1✔
171
                logger.warning(
1✔
172
                    "Danger of SpikeSourceArray sending too many spikes "
173
                    "at the same time. "
174
                    "This is because ({}) neurons share the same spike list",
175
                    self.n_atoms)
176
            else:
177
                logger.warning(
1✔
178
                    "Danger of SpikeSourceArray sending too many spikes "
179
                    "at the same time. "
180
                    "For example at time {}, {} spikes will be sent",
181
                    val, count * self.n_atoms)
182

183
    def _check_density_double_list(self, spike_times: _DoubleList):
1✔
184
        counter: Counter = Counter()
1✔
185
        for neuron_id in range(0, self.n_atoms):
1✔
186
            counter.update(spike_times[neuron_id])
1✔
187
        top = counter.most_common(1)
1✔
188
        val, count = top[0]
1✔
189
        if count > TOO_MANY_SPIKES:
1✔
190
            logger.warning(
1✔
191
                "Danger of SpikeSourceArray sending too many spikes "
192
                "at the same time. "
193
                "For example at time {}, {} spikes will be sent",
194
                val, count)
195

196
    @overrides(SupportsStructure.set_structure)
1✔
197
    def set_structure(self, structure: BaseStructure):
1✔
198
        self.__structure = structure
1✔
199

200
    @property
1✔
201
    @overrides(ReverseIpTagMultiCastSource.atoms_shape)
1✔
202
    def atoms_shape(self) -> Tuple[int, ...]:
1✔
203
        if isinstance(self.__structure, (Grid2D, Grid3D)):
1!
204
            return self.__structure.calculate_size(self.n_atoms)
×
205
        return super().atoms_shape
1✔
206

207
    def _to_early_spikes_single_list(self, spike_times: _SingleList):
1✔
208
        """
209
        Checks if there is one or more spike_times before the current time.
210

211
        Logs a warning for the first one found
212

213
        :param list(int) spike_times:
214
        """
215
        current_time = SpynnakerDataView.get_current_run_time_ms()
1✔
216
        for i in range(len(spike_times)):
1✔
217
            if spike_times[i] < current_time:
1!
218
                logger.warning(
×
219
                    "SpikeSourceArray {} has spike_times that are lower than "
220
                    "the current time {} For example {} - "
221
                    "these will be ignored.",
222
                    self, current_time, float(spike_times[i]))
223
                return
×
224

225
    def _check_spikes_double_list(self, spike_times: _DoubleList):
1✔
226
        """
227
        Checks if there is one or more spike_times before the current time.
228

229
        Logs a warning for the first one found
230

231
        :param iterable(int) spike_times:
232
        """
233
        current_time = SpynnakerDataView.get_current_run_time_ms()
×
234
        for neuron_id in range(0, self.n_atoms):
×
235
            id_times = spike_times[neuron_id]
×
236
            for i in range(len(id_times)):
×
237
                if id_times[i] < current_time:
×
238
                    logger.warning(
×
239
                        "SpikeSourceArray {} has spike_times that are lower "
240
                        "than the current time {} For example {} - "
241
                        "these will be ignored.",
242
                        self, current_time, float(id_times[i]))
243
                    return
×
244

245
    def __set_spike_buffer_times(self, spike_times: Spikes):
1✔
246
        """
247
        Set the spike source array's buffer spike times.
248
        """
249
        time_step = SpynnakerDataView.get_simulation_time_step_us()
1✔
250
        # warn the user if they are asking for a spike time out of range
251
        if _is_double_list(spike_times):
1!
NEW
252
            self._check_spikes_double_list(spike_times)
×
253
        elif _is_single_list(spike_times):
1✔
254
            self._to_early_spikes_single_list(spike_times)
1✔
255
        elif _is_singleton(spike_times):
1!
NEW
256
            self._to_early_spikes_single_list([spike_times])
×
257
        else:
258
            # in case of empty list do not check
259
            pass
260
        self.send_buffer_times = _send_buffer_times(spike_times, time_step)
1✔
261
        self._check_spike_density(spike_times)
1✔
262

263
    def __read_parameter(self, name: str, selector: Selector):
1✔
264
        # pylint: disable=unused-argument
265
        # This can only be spike times
266
        return self._spike_times.get_values(selector)
1✔
267

268
    @overrides(PopulationApplicationVertex.get_parameter_values)
1✔
269
    def get_parameter_values(
1✔
270
            self, names: Names, selector: Selector = None) -> ParameterHolder:
271
        self._check_parameters(names, {"spike_times"})
1✔
272
        return ParameterHolder(names, self.__read_parameter, selector)
1✔
273

274
    @overrides(PopulationApplicationVertex.set_parameter_values)
1✔
275
    def set_parameter_values(
1✔
276
            self, name: str, value: Spikes, selector: Selector = None):
277
        self._check_parameters(name, {"spike_times"})
1✔
278
        self.__set_spike_buffer_times(value)
1✔
279
        self._spike_times.set_value_by_selector(
1✔
280
            selector, value, use_list_as_value=not _is_double_list(value))
281

282
    @overrides(PopulationApplicationVertex.get_parameters)
1✔
283
    def get_parameters(self) -> List[str]:
1✔
284
        return ["spike_times"]
×
285

286
    @overrides(PopulationApplicationVertex.get_units)
1✔
287
    def get_units(self, name: str) -> str:
1✔
288
        if name == "spikes":
×
289
            return ""
×
290
        if name == "spike_times":
×
291
            return "ms"
×
292
        raise KeyError(f"Units for {name} unknown")
×
293

294
    @overrides(PopulationApplicationVertex.get_recordable_variables)
1✔
295
    def get_recordable_variables(self) -> List[str]:
1✔
296
        return ["spikes"]
1✔
297

298
    @overrides(PopulationApplicationVertex.get_buffer_data_type)
1✔
299
    def get_buffer_data_type(self, name: str) -> BufferDataType:
1✔
300
        if name == "spikes":
×
301
            return BufferDataType.EIEIO_SPIKES
×
302
        raise KeyError(f"Cannot record {name}")
×
303

304
    @overrides(PopulationApplicationVertex.get_neurons_recording)
1✔
305
    def get_neurons_recording(
1✔
306
            self, name: str, vertex_slice: Slice) -> NDArray[numpy.integer]:
307
        if name != "spikes":
×
308
            raise KeyError(f"Cannot record {name}")
×
309
        return vertex_slice.get_raster_ids()
×
310

311
    @overrides(PopulationApplicationVertex.set_recording)
1✔
312
    def set_recording(self, name: str, sampling_interval=None, indices=None):
1✔
313
        if name != "spikes":
1✔
314
            raise KeyError(f"Cannot record {name}")
1✔
315
        if sampling_interval is not None:
1✔
316
            logger.warning("Sampling interval currently not supported for "
1✔
317
                           "SpikeSourceArray so being ignored")
318
        if indices is not None:
1✔
319
            logger.warning("Indices currently not supported for "
1✔
320
                           "SpikeSourceArray so being ignored")
321
        self.enable_recording(True)
1✔
322
        SpynnakerDataView.set_requires_mapping()
1✔
323

324
    @overrides(PopulationApplicationVertex.set_not_recording)
1✔
325
    def set_not_recording(
1✔
326
            self, name: str, indices: Optional[Collection[int]] = None):
327
        if name != "spikes":
×
328
            raise KeyError(f"Cannot record {name}")
×
329
        if indices is not None:
×
330
            logger.warning("Indices currently not supported for "
×
331
                           "SpikeSourceArray so being ignored")
332
        self.enable_recording(False)
×
333

334
    @overrides(PopulationApplicationVertex.get_recording_variables)
1✔
335
    def get_recording_variables(self) -> List[str]:
1✔
336
        if self._is_recording:
1✔
337
            return ["spikes"]
1✔
338
        return []
1✔
339

340
    @overrides(PopulationApplicationVertex.get_sampling_interval_ms)
1✔
341
    def get_sampling_interval_ms(self, name: str) -> float:
1✔
342
        if name != "spikes":
×
343
            raise KeyError(f"Cannot record {name}")
×
344
        return SpynnakerDataView.get_simulation_time_step_us()
×
345

346
    @overrides(PopulationApplicationVertex.get_recording_region)
1✔
347
    def get_recording_region(self, name: str) -> int:
1✔
348
        if name != "spikes":
×
349
            raise KeyError(f"Cannot record {name}")
×
NEW
350
        return self.SPIKE_RECORDING_REGION_ID
×
351

352
    @overrides(PopulationApplicationVertex.get_data_type)
1✔
353
    def get_data_type(self, name: str) -> None:
1✔
354
        if name != "spikes":
×
355
            raise KeyError(f"Cannot record {name}")
×
356
        return None
×
357

358
    def describe(self):
1✔
359
        """
360
        Returns a human-readable description of the cell or synapse type.
361

362
        The output may be customised by specifying a different template
363
        together with an associated template engine
364
        (see :py:mod:`pyNN.descriptions`).
365

366
        If template is `None`, then a dictionary containing the template
367
        context will be returned.
368
        """
NEW
369
        return {
×
370
            "name": self.__model_name,
371
            "default_parameters": self.__model.default_parameters,
372
            "default_initial_values": self.__model.default_parameters,
373
            "parameters": self.get_parameter_values(
374
                self.__model.default_parameters),
375
        }
376

377
    @property
1✔
378
    @overrides(PopulationApplicationVertex.n_colour_bits)
1✔
379
    def n_colour_bits(self) -> int:
1✔
380
        return self.__n_colour_bits
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc