# Copyright (c) 2022 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import (
Any, cast, Dict, Iterable, List, Optional, overload, Sequence, Union,
TYPE_CHECKING)
import numpy
from numpy import floating
from numpy.typing import NDArray
from pyNN.descriptions import TemplateEngine
import neo # type: ignore[import]
from spinn_utilities.ranged.abstract_sized import AbstractSized, Selector
from spinn_utilities.log import FormatAdapter
from spinn_utilities.overrides import overrides
from spynnaker.pyNN.models.populations import Population
from spynnaker.pyNN.types import ViewIndices
from spynnaker.pyNN.utilities.neo_buffer_database import NeoBufferDatabase
from spynnaker.pyNN.utilities.utility_calls import get_neo_io
from spynnaker.pyNN.models.common.types import Names
if TYPE_CHECKING:
from .neo_buffer_database import Annotations
logger = FormatAdapter(logging.getLogger(__file__))
_SELECTIVE_RECORDED_MSG = (
"Getting data on a whole population when selective recording was "
"active will result in only the recorded neurons being returned "
"in numerical order and without repeats.")
[docs]
class DataPopulation(object):
"""
A wrapper of a sqlite3 database to provide the Population data methods
"""
__slots__ = (
"__database_file",
"__label",
"_indexes",
"_size")
def __init__(self, database_file: str, label: str,
indexes: ViewIndices = None):
self.__label = label
self.__database_file = database_file
# getting size right away also check the inputs or fails fast
with NeoBufferDatabase(self.__database_file) as db:
size = db.get_population_metadata(label)[0]
self._size = size
self._indexes = indexes
[docs]
@overrides(Population.write_data)
def write_data(self, io: Union[str, neo.baseio.BaseIO],
variables: Names = 'all', gather: bool = True,
clear: bool = False,
annotations: Annotations = None) -> None:
# pylint: disable=missing-function-docstring,protected-access
Population._check_params(gather, annotations)
if clear:
logger.warning("Ignoring clear as supported in this mode")
if isinstance(io, str):
io = get_neo_io(io)
data = self.get_data(variables)
# write the neo block to the file
io.write(bl=data)
[docs]
@overrides(Population.describe, adds_typing=True)
def describe(self, template: Optional[str] = None,
engine: Optional[Union[str, TemplateEngine]] = None
) -> Union[str, Dict[str, Any]]:
# pylint: disable=missing-function-docstring
if template is not None:
logger.warning("Ignoring template as not supported in this mode")
if engine is not None:
logger.warning("Ignoring engine as not supported in this mode")
with NeoBufferDatabase(self.__database_file) as db:
_, _, description = db.get_population_metadata(self.label)
return description
[docs]
@overrides(Population.get_data)
def get_data(
self, variables: Names = 'all',
gather: bool = True, clear: bool = False, *,
annotations: Optional[Dict[str, Any]] = None) -> neo.Block:
# pylint: disable=missing-function-docstring,protected-access
Population._check_params(gather, annotations)
if clear:
logger.warning("Ignoring clear as supported in this mode")
with NeoBufferDatabase(self.__database_file) as db:
return db.get_full_block(
self.__label, variables, self._indexes, annotations)
[docs]
@overrides(Population.spinnaker_get_data)
def spinnaker_get_data(
self, variable: str, as_matrix: bool = False,
view_indexes: Optional[Sequence[int]] = None) -> NDArray[floating]:
# pylint: disable=missing-function-docstring
if view_indexes:
return self[view_indexes].spinnaker_get_data(variable, as_matrix)
with NeoBufferDatabase(self.__database_file) as db:
return db.spinnaker_get_data(
self.__label, variable, as_matrix, self._indexes)
[docs]
@overrides(Population.get_spike_counts)
def get_spike_counts(self, gather: bool = True) -> Dict[int, int]:
# pylint: disable=missing-function-docstring
Population._check_params(gather) # pylint: disable=protected-access
with NeoBufferDatabase(self.__database_file) as db:
return db.get_spike_counts(self.__label, self._indexes)
[docs]
@overrides(Population.find_units)
def find_units(self, variable: str) -> Optional[str]:
# pylint: disable=missing-function-docstring
with NeoBufferDatabase(self.__database_file) as db:
return db.find_units(self.__label, variable)
def __len__(self) -> int:
return self._size
@property
@overrides(Population.label)
def label(self) -> str:
# pylint: disable=missing-function-docstring
return self.__label
@property
@overrides(Population.local_size)
def local_size(self) -> int:
# pylint: disable=missing-function-docstring
return self._size
@property
@overrides(Population.size)
def size(self) -> int:
# pylint: disable=missing-function-docstring
return self._size
@overload
def id_to_index(self, id: int) -> int: # @ReservedAssignment
# pylint: disable=redefined-builtin
...
@overload
def id_to_index(
self, id: Iterable[int]) -> List[int]: # @ReservedAssignment
# pylint: disable=redefined-builtin
...
[docs]
@overrides(Population.id_to_index)
def id_to_index(self, id: Union[int, Iterable[int]]
) -> Union[int, List[int]]: # @ReservedAssignment
# pylint: disable=missing-function-docstring,redefined-builtin
# assuming not called often so not caching first id
with NeoBufferDatabase(self.__database_file) as db:
_, first_id, _ = db.get_population_metadata(self.__label)
last_id = self._size + first_id
if not numpy.iterable(id):
id = cast(int, id)
if not first_id <= id <= last_id:
raise ValueError(
f"id should be in the range [{first_id},{last_id}], "
f"actually {id}")
return int(id - first_id) # assume IDs are consecutive
return [_id - first_id for _id in id]
@overload
def index_to_id(self, index: int) -> int:
...
@overload
def index_to_id(self, index: Iterable[int]) -> List[int]:
...
[docs]
@overrides(Population.index_to_id)
def index_to_id(self, index: Union[int, Iterable[int]]
) -> Union[int, List[int]]:
# pylint: disable=missing-function-docstring
# assuming not called often so not caching first id
with NeoBufferDatabase(self.__database_file) as db:
_, first_id, _ = db.get_population_metadata(self.__label)
if not numpy.iterable(index):
index = cast(int, index)
if index >= self._size:
raise ValueError(
f"indexes should be in the range [0,{self._size}],"
f" actually {index}")
return int(index + first_id)
# this assumes IDs are consecutive
return [_index + first_id for _index in index]
def __getitem__(self, index_or_slice: Selector) -> DataPopulation:
"""
:param selector: a slice or numpy mask array.
The mask array should either be a Boolean array (ideally) of the
same size as the parent,
or an integer array containing cell indices,
i.e. if `p.size == 5` then:
::
PopulationView(p, array([False, False, True, False, True]))
PopulationView(p, array([2, 4]))
PopulationView(p, slice(2, 5, 2))
will all create the same view.
:type selector: None or slice or int or list(bool) or list(int) or
~numpy.ndarray(bool) or ~numpy.ndarray(int)
:param index_or_slice:
:return:
"""
sized = AbstractSized(self._size)
ids = sized.selector_to_ids(index_or_slice, warn=True)
if self._indexes:
indexes = [self._indexes[index] for index in ids]
else:
indexes = [range(self._size)[index] for index in ids]
return DataPopulation(self.__database_file, self.__label, indexes)
[docs]
@overrides(Population.mean_spike_count)
def mean_spike_count(self, gather: bool = True) -> float:
# pylint: disable=missing-function-docstring
Population._check_params(gather) # pylint: disable=protected-access
counts = self.get_spike_counts()
return sum(counts.values()) / len(counts)