# Copyright (c) 2017 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import Counter
import logging
from typing import (Any, Collection, Dict, List, Optional, Sequence, Tuple,
Union, TYPE_CHECKING)
import numpy
from numpy.typing import ArrayLike, NDArray
from typing_extensions import TypeAlias, TypeGuard
from pyNN.space import Grid2D, Grid3D, BaseStructure
from spinn_utilities.log import FormatAdapter
from spinn_utilities.overrides import overrides
from spinn_utilities.config_holder import get_config_int
from spinn_utilities.ranged.abstract_sized import Selector
from pacman.model.graphs.common import Slice
from pacman.model.partitioner_splitters import AbstractSplitterCommon
from pacman.model.resources import AbstractSDRAM
from spinn_front_end_common.utility_models import ReverseIpTagMultiCastSource
from spynnaker.pyNN.data import SpynnakerDataView
from spynnaker.pyNN.models.abstract_models import SupportsStructure
from spynnaker.pyNN.models.common import (
ParameterHolder, PopulationApplicationVertex)
from spynnaker.pyNN.models.common.types import (Names, Spikes)
from spynnaker.pyNN.utilities.buffer_data_type import BufferDataType
from spynnaker.pyNN.utilities.ranged import SpynnakerRangedList
from .spike_source_array_machine_vertex import SpikeSourceArrayMachineVertex
if TYPE_CHECKING:
from .spike_source_array import SpikeSourceArray
logger = FormatAdapter(logging.getLogger(__name__))
# Cut off to warn too many spikes sent at one time
TOO_MANY_SPIKES = 100
_Number: TypeAlias = Union[int, float]
_SingleList: TypeAlias = Union[
Sequence[_Number], NDArray[numpy.integer]]
_DoubleList: TypeAlias = Union[
Sequence[Sequence[_Number]], NDArray[numpy.integer]]
def _is_double_list(value: Spikes) -> TypeGuard[_DoubleList]:
return not isinstance(value, (float, int)) and bool(len(value)) and \
hasattr(value[0], "__len__")
def _is_single_list(value: Spikes) -> TypeGuard[_SingleList]:
# USE _is_double_list first!
return not isinstance(value, (float, int)) and bool(len(value))
def _is_singleton(value: Spikes) -> TypeGuard[_Number]:
return isinstance(value, (float, int))
def _as_numpy_ticks(
times: ArrayLike, time_step: float) -> NDArray[numpy.int64]:
return numpy.ceil(
numpy.floor(numpy.array(times) * 1000.0) / time_step).astype("int64")
def _send_buffer_times(
spike_times: Spikes, time_step: float) -> Union[
NDArray[numpy.int64], List[NDArray[numpy.int64]]]:
# Convert to ticks
if _is_double_list(spike_times):
return [_as_numpy_ticks(times, time_step) for times in spike_times]
elif _is_single_list(spike_times):
return _as_numpy_ticks(spike_times, time_step)
elif _is_singleton(spike_times):
return _as_numpy_ticks([spike_times], time_step)
else:
return []
class SpikeSourceArrayVertex(
ReverseIpTagMultiCastSource, PopulationApplicationVertex,
SupportsStructure):
"""
Model for play back of spikes.
"""
__slots__ = (
"__model_name",
"__model",
"__structure",
"_spike_times",
"__n_colour_bits")
#: ID of the recording region used for recording transmitted spikes.
SPIKE_RECORDING_REGION_ID = 0
def __init__(
self, n_neurons: int, spike_times: Spikes, label: str,
max_atoms_per_core: Union[int, Tuple[int, ...]],
model: SpikeSourceArray,
splitter: Optional[AbstractSplitterCommon],
n_colour_bits: Optional[int]):
# pylint: disable=too-many-arguments
self.__model_name = "SpikeSourceArray"
self.__model = model
self.__structure: Optional[BaseStructure] = None
if spike_times is None:
spike_times = []
self._spike_times = SpynnakerRangedList(
n_neurons, spike_times,
use_list_as_value=not _is_double_list(spike_times))
time_step = SpynnakerDataView.get_simulation_time_step_us()
super().__init__(
n_keys=n_neurons, label=label,
max_atoms_per_core=max_atoms_per_core,
send_buffer_times=_send_buffer_times(spike_times, time_step),
splitter=splitter)
self._check_spike_density(spike_times)
# Do colouring
if n_colour_bits is None:
self.__n_colour_bits = get_config_int(
"Simulation", "n_colour_bits")
else:
self.__n_colour_bits = n_colour_bits
[docs]
@overrides(ReverseIpTagMultiCastSource.create_machine_vertex)
def create_machine_vertex(
self, vertex_slice: Slice, sdram: AbstractSDRAM,
label: Optional[str] = None) -> SpikeSourceArrayMachineVertex:
send_buffer_times = self._filtered_send_buffer_times(vertex_slice)
machine_vertex = SpikeSourceArrayMachineVertex(
label=label, app_vertex=self, vertex_slice=vertex_slice,
eieio_params=self._eieio_params,
send_buffer_times=send_buffer_times)
machine_vertex.enable_recording(self._is_recording)
# Known issue with ReverseIPTagMulticastSourceMachineVertex
if sdram:
assert sdram == machine_vertex.sdram_required
return machine_vertex
def _check_spike_density(self, spike_times: Spikes) -> None:
if _is_double_list(spike_times):
self._check_density_double_list(spike_times)
elif _is_single_list(spike_times):
self._check_density_single_list(spike_times)
elif _is_singleton(spike_times):
pass
else:
logger.warning("SpikeSourceArray has no spike times")
def _check_density_single_list(self, spike_times: _SingleList) -> None:
counter = Counter(spike_times)
top = counter.most_common(1)
val, count = top[0]
if count * self.n_atoms > TOO_MANY_SPIKES:
if self.n_atoms > 1:
logger.warning(
"Danger of SpikeSourceArray sending too many spikes "
"at the same time. "
"This is because ({}) neurons share the same spike list",
self.n_atoms)
else:
logger.warning(
"Danger of SpikeSourceArray sending too many spikes "
"at the same time. "
"For example at time {}, {} spikes will be sent",
val, count * self.n_atoms)
def _check_density_double_list(self, spike_times: _DoubleList) -> None:
counter: Counter = Counter()
for neuron_id in range(0, self.n_atoms):
counter.update(spike_times[neuron_id])
top = counter.most_common(1)
val, count = top[0]
if count > TOO_MANY_SPIKES:
logger.warning(
"Danger of SpikeSourceArray sending too many spikes "
"at the same time. "
"For example at time {}, {} spikes will be sent",
val, count)
[docs]
@overrides(SupportsStructure.set_structure)
def set_structure(self, structure: BaseStructure) -> None:
self.__structure = structure
@property
@overrides(ReverseIpTagMultiCastSource.atoms_shape)
def atoms_shape(self) -> Tuple[int, ...]:
if isinstance(self.__structure, (Grid2D, Grid3D)):
return self.__structure.calculate_size(self.n_atoms)
return super().atoms_shape
def _to_early_spikes_single_list(self, spike_times: _SingleList) -> None:
"""
Checks if there is one or more spike_times before the current time.
Logs a warning for the first one found
:param list(int) spike_times:
"""
current_time = SpynnakerDataView.get_current_run_time_ms()
for spike_time in spike_times:
if spike_time < current_time:
logger.warning(
"SpikeSourceArray {} has spike_times that are lower than "
"the current time {} For example {} - "
"these will be ignored.",
self, current_time, float(spike_time))
return
def _check_spikes_double_list(self, spike_times: _DoubleList) -> None:
"""
Checks if there is one or more spike_times before the current time.
Logs a warning for the first one found
:param iterable(int) spike_times:
"""
current_time = SpynnakerDataView.get_current_run_time_ms()
for neuron_id in range(0, self.n_atoms):
id_times = spike_times[neuron_id]
for id_time in id_times:
if id_time < current_time:
logger.warning(
"SpikeSourceArray {} has spike_times that are lower "
"than the current time {} For example {} - "
"these will be ignored.",
self, current_time, float(id_time))
return
def __set_spike_buffer_times(self, spike_times: Spikes) -> None:
"""
Set the spike source array's buffer spike times.
"""
time_step = SpynnakerDataView.get_simulation_time_step_us()
# warn the user if they are asking for a spike time out of range
if _is_double_list(spike_times):
self._check_spikes_double_list(spike_times)
elif _is_single_list(spike_times):
self._to_early_spikes_single_list(spike_times)
elif _is_singleton(spike_times):
self._to_early_spikes_single_list([spike_times])
else:
# in case of empty list do not check
pass
self.send_buffer_times = _send_buffer_times(spike_times, time_step)
self._check_spike_density(spike_times)
def __read_parameter(self, name: str, selector: Selector) -> Sequence:
# pylint: disable=unused-argument
# This can only be spike times
return self._spike_times.get_values(selector)
[docs]
@overrides(PopulationApplicationVertex.get_parameter_values)
def get_parameter_values(
self, names: Names, selector: Selector = None) -> ParameterHolder:
self._check_parameters(names, {"spike_times"})
return ParameterHolder(names, self.__read_parameter, selector)
[docs]
@overrides(PopulationApplicationVertex.set_parameter_values)
def set_parameter_values(
self, name: str, value: Spikes, selector: Selector = None) -> None:
self._check_parameters(name, {"spike_times"})
self.__set_spike_buffer_times(value)
self._spike_times.set_value_by_selector(
selector, value, use_list_as_value=not _is_double_list(value))
[docs]
@overrides(PopulationApplicationVertex.get_parameters)
def get_parameters(self) -> List[str]:
return ["spike_times"]
[docs]
@overrides(PopulationApplicationVertex.get_units)
def get_units(self, name: str) -> str:
if name == "spikes":
return ""
if name == "spike_times":
return "ms"
raise KeyError(f"Units for {name} unknown")
[docs]
@overrides(PopulationApplicationVertex.get_recordable_variables)
def get_recordable_variables(self) -> List[str]:
return ["spikes"]
[docs]
@overrides(PopulationApplicationVertex.get_buffer_data_type)
def get_buffer_data_type(self, name: str) -> BufferDataType:
if name == "spikes":
return BufferDataType.EIEIO_SPIKES
raise KeyError(f"Cannot record {name}")
[docs]
@overrides(PopulationApplicationVertex.get_neurons_recording)
def get_neurons_recording(
self, name: str, vertex_slice: Slice) -> NDArray[numpy.integer]:
if name != "spikes":
raise KeyError(f"Cannot record {name}")
return vertex_slice.get_raster_ids()
[docs]
@overrides(PopulationApplicationVertex.set_recording)
def set_recording(
self, name: str, sampling_interval: Optional[float] = None,
indices: Optional[Collection[int]] = None) -> None:
if name != "spikes":
raise KeyError(f"Cannot record {name}")
if sampling_interval is not None:
logger.warning("Sampling interval currently not supported for "
"SpikeSourceArray so being ignored")
if indices is not None:
logger.warning("Indices currently not supported for "
"SpikeSourceArray so being ignored")
self.enable_recording(True)
SpynnakerDataView.set_requires_mapping()
[docs]
@overrides(PopulationApplicationVertex.set_not_recording)
def set_not_recording(self, name: str,
indices: Optional[Collection[int]] = None) -> None:
if name != "spikes":
raise KeyError(f"Cannot record {name}")
if indices is not None:
logger.warning("Indices currently not supported for "
"SpikeSourceArray so being ignored")
self.enable_recording(False)
[docs]
@overrides(PopulationApplicationVertex.get_recording_variables)
def get_recording_variables(self) -> List[str]:
if self._is_recording:
return ["spikes"]
return []
[docs]
@overrides(PopulationApplicationVertex.get_sampling_interval_ms)
def get_sampling_interval_ms(self, name: str) -> float:
if name != "spikes":
raise KeyError(f"Cannot record {name}")
return SpynnakerDataView.get_simulation_time_step_us()
[docs]
@overrides(PopulationApplicationVertex.get_recording_region)
def get_recording_region(self, name: str) -> int:
if name != "spikes":
raise KeyError(f"Cannot record {name}")
return self.SPIKE_RECORDING_REGION_ID
[docs]
@overrides(PopulationApplicationVertex.get_data_type)
def get_data_type(self, name: str) -> None:
if name != "spikes":
raise KeyError(f"Cannot record {name}")
return None
[docs]
def describe(
self) -> Dict[str, Union[str, ParameterHolder, Dict[str, Any]]]:
"""
Returns a human-readable description of the cell or synapse type.
The output may be customised by specifying a different template
together with an associated template engine
(see :py:mod:`pyNN.descriptions`).
If template is `None`, then a dictionary containing the template
context will be returned.
"""
return {
"name": self.__model_name,
"default_parameters": self.__model.default_parameters,
"default_initial_values": self.__model.default_parameters,
"parameters": self.get_parameter_values(
self.__model.default_parameters),
}
@property
@overrides(PopulationApplicationVertex.n_colour_bits)
def n_colour_bits(self) -> int:
return self.__n_colour_bits