# Copyright (c) 2017 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Plotting tools to be used together with
https://github.com/NeuralEnsemble/PyNN/blob/master/pyNN/utility/plotting.py
"""
from types import ModuleType
from neo import SpikeTrain, Block, Segment, AnalogSignal
import numpy as np
import quantities
plt: ModuleType
try:
from pyNN.utility.plotting import repeat
import matplotlib.pyplot # type: ignore[import]
plt = matplotlib.pyplot
_matplotlib_missing = False
except ImportError:
_matplotlib_missing = True
def _handle_options(axes, options):
"""
Handles options that can not be passed to `axes.plot`.
Removes the ones it has handled
axes.plot will throw an exception if it gets unwanted options
:param ~matplotlib.axes.Axes axes: An Axes in a matplotlib figure
:param dict options: All options the plotter can be configured with
"""
if "xticks" not in options or options.pop("xticks") is False:
plt.setp(axes.get_xticklabels(), visible=False)
if "xlabel" in options:
axes.set_xlabel(options.pop("xlabel"))
else:
axes.set_xlabel("Time (ms)")
if "yticks" not in options or options.pop("yticks") is False:
plt.setp(axes.get_yticklabels(), visible=False)
if "ylabel" in options:
axes.set_ylabel(options.pop("ylabel"))
else:
axes.set_ylabel("Neuron index")
if "ylim" in options:
axes.set_ylim(options.pop("ylim"))
if "xlim" in options:
axes.set_xlim(options.pop("xlim"))
def _plot_spikes(axes, spike_times, neurons, label='', **options):
"""
Plots the spikes based on two lists.
:param ~matplotlib.axes.Axes axes: An Axes in a matplotlib figure
:param list(~neo.core.SpikeTrain) spike_times: List of spike times
:param neurons: List of Neuron IDs
:param str label: Label for the graph
:param options: plotting options
"""
if len(neurons):
max_index = max(neurons)
min_index = min(neurons)
axes.plot(spike_times, neurons, 'b.', **options)
axes.set_ylim(-0.5 + min_index, max_index + 0.5)
if label:
plt.text(0.95, 0.95, label,
transform=axes.transAxes, ha='right', va='top',
bbox=dict(facecolor='white', alpha=1.0))
[docs]
def plot_spiketrains(axes, spiketrains, label='', **options):
"""
Plot all spike trains in a Segment in a raster plot.
:param ~matplotlib.axes.Axes axes: An Axes in a matplotlib figure
:param list(~neo.core.SpikeTrain) spiketrains: List of spike times
:param str label: Label for the graph
:param options: plotting options
"""
# pylint: disable=c-extension-no-member
axes.set_xlim(0, spiketrains[0].t_stop / quantities.ms)
_handle_options(axes, options)
neurons = np.concatenate(
[np.repeat(x.annotations['source_index'], len(x))
for x in spiketrains])
spike_times = np.concatenate(spiketrains, axis=0)
_plot_spikes(axes, spike_times, neurons, label=label, **options)
[docs]
def plot_spikes_numpy(axes, spikes, label='', **options):
"""
Plot all spikes.
:param ~matplotlib.axes.Axes axes: An Axes in a matplotlib figure
:param ~numpy.ndarray spikes: sPyNNaker7 format numpy array of spikes
:param str label: Label for the graph
:param options: plotting options
"""
_handle_options(axes, options)
neurons = spikes[:, 0]
spike_times = spikes[:, 1]
_plot_spikes(axes, spike_times, neurons, label=label, **options)
def _heat_plot(axes, neurons, times, values, label='', **options):
"""
Plots three lists of neurons, times and values into a heat map.
:param ~matplotlib.axes.Axes axes: An Axes in a matplotlib figure
:param neurons: List of neuron IDs
:param times: List of times
:param values: List of values to plot
:param str label: Label for the graph
:param options: plotting options
"""
_handle_options(axes, options)
info_array = np.empty((max(neurons)+1, max(times)+1))
info_array[:] = np.nan
info_array[neurons, times] = values
heat_map = axes.imshow(info_array, cmap='hot', interpolation='none',
origin='lower', aspect='auto')
axes.figure.colorbar(heat_map)
if label:
plt.text(0.95, 0.95, label,
transform=axes.transAxes, ha='right', va='top',
bbox=dict(facecolor='white', alpha=1.0))
[docs]
def heat_plot_numpy(axes, data, label='', **options):
"""
Plots neurons, times and values into a heat map.
:param ~matplotlib.axes.Axes axes: An Axes in a matplotlib figure
:param ~numpy.ndarray data: numpy array of values in spynnaker7 format
:param str label: Label for the graph
:param options: plotting options
"""
neurons = data[:, 0].astype(int)
times = data[:, 1].astype(int)
values = data[:, 2]
_heat_plot(axes, neurons, times, values, label=label, **options)
[docs]
def heat_plot_neo(axes, signal_array, label='', **options):
"""
Plots neurons, times and values into a heat map.
:param ~matplotlib.axes.Axes axes: An Axes in a matplotlib figure
:param ~neo.core.AnalogSignal signal_array: Neo Signal array object
:param str label: Label for the graph
:param options: plotting options
"""
if label is None:
label = signal_array.name
n_neurons = signal_array.shape[-1]
xs = list(range(n_neurons))
times = signal_array.times / signal_array.sampling_period
times = np.rint(times.magnitude).astype(int)
all_times = np.tile(times, n_neurons)
neurons = np.repeat(xs, len(times))
magnitude = signal_array.magnitude
values = np.concatenate([magnitude[:, x] for x in xs])
_heat_plot(axes, neurons, all_times, values, label=label, **options)
[docs]
def plot_segment(axes, segment, label='', **options):
"""
Plots a segment into a plot of spikes or a heat map.
If there is more than ode type of Data in the segment options must
include the name of the data to plot
.. note::
Method signature defined by PyNN plotting.
This allows mixing of this plotting tool and PyNN's
:param ~matplotlib.axes.Axes axes: An Axes in a matplotlib figure
:param ~neo.core.Segment segment: Data for one run to plot
:param str label: Label for the graph
:param options: plotting options
"""
analogsignals = segment.analogsignals
if "name" in options:
name = options.pop("name")
if name == 'spikes':
plot_spiketrains(axes, segment.spiketrains, label=label, **options)
else:
heat_plot_neo(
axes, segment.filter(name=name)[0], label=label, **options)
elif segment.spiketrains:
if len(analogsignals) > 1:
raise ValueError("please specify data to plot using name=")
plot_spiketrains(axes, segment.spiketrains, label=label, **options)
elif len(analogsignals) == 1:
heat_plot_neo(axes, analogsignals[0], label=label, **options)
elif len(analogsignals) > 1:
raise ValueError("please specify data to plot using name=")
else:
raise ValueError("Block does not appear to hold any data")
[docs]
class SpynnakerPanel(object):
"""
Represents a single panel in a multi-panel figure.
Compatible with :py:class:`pyNN.utility.plotting.Frame` and
can be mixed with :py:class:`pyNN.utility.plotting.Panel`
Unlike :py:class:`pyNN.utility.plotting.Panel`,
Spikes are plotted faster,
other data is plotted as a heat map.
A panel is a Matplotlib Axes or Subplot instance. A data item may be an
:py:class:`~neo.core.AnalogSignal`, or a list of
:py:class:`~neo.core.SpikeTrain`\\ s. The Panel will
automatically choose an appropriate representation. Multiple data items
may be plotted in the same panel.
Valid options are any valid Matplotlib formatting options that should be
applied to the Axes/Subplot, plus in addition:
`data_labels`:
a list of strings of the same length as the number of data items.
`line_properties`:
a list of dictionaries containing Matplotlib formatting options,\
of the same length as the number of data items.
Whole Neo Objects can be passed in as long as they
contain a single Segment/run
and only contain one type of data.
Whole Segments can be passed in only if they only contain one type of data.
"""
def __init__(self, *data, **options):
"""
:param data: One or more data series to be plotted.
:type data: list(~neo.core.SpikeTrain) or ~neo.core.AnalogSignal
or ~numpy.ndarray or ~neo.core.Block or ~neo.core.Segment
:param options: Any additional information.
"""
if _matplotlib_missing:
raise ImportError("No matplotlib module found")
self.data = list(data)
self.options = options
self.data_labels = options.pop("data_labels", repeat(None))
self.line_properties = options.pop("line_properties", repeat({}))
[docs]
def plot(self, axes):
"""
Plot the Panel's data in the provided Axes/Subplot instance.
:param ~matplotlib.axes.Axes axes: An Axes in a matplotlib figure
"""
for datum, label, properties in zip(self.data, self.data_labels,
self.line_properties):
properties.update(self.options)
# Support lists length one
# for example result of segments[0].filter(name='v')
if isinstance(datum, list):
if not datum:
raise ValueError("Can't handle empty list")
if len(datum) == 1 and not isinstance(datum[0], SpikeTrain):
datum = datum[0]
if isinstance(datum, list):
self.__plot_list(axes, datum, label, properties)
# AnalogSignal is also a ndarray, but data format different!
# We import them as a single name here
elif isinstance(datum, AnalogSignal):
heat_plot_neo(axes, datum, label=label, **properties)
elif isinstance(datum, np.ndarray):
self.__plot_array(axes, datum, label, properties)
elif isinstance(datum, Block):
self.__plot_block(axes, datum, label, properties)
elif isinstance(datum, Segment):
plot_segment(axes, datum, label=label, **properties)
else:
raise ValueError(f"Can't handle type {type(datum)}; "
f"consider using pyNN.utility.plotting")
@staticmethod
def __plot_list(axes, datum, label, properties):
if not isinstance(datum[0], SpikeTrain):
raise ValueError(f"Can't handle lists of type {type(datum)}")
plot_spiketrains(axes, datum, label=label, **properties)
@staticmethod
def __plot_array(axes, datum, label, properties):
if len(datum[0]) == 2:
plot_spikes_numpy(axes, datum, label=label, **properties)
elif len(datum[0]) == 3:
heat_plot_numpy(axes, datum, label=label, **properties)
else:
raise ValueError(
f"Can't handle ndarray with {len(datum[0])} columns")
@staticmethod
def __plot_block(axes, datum, label, properties):
if "run" in properties:
run = int(properties.pop("run"))
if len(datum.segments) <= run:
raise ValueError(
f"Block only has {len(datum.segments)} segments")
segment = datum.segments[run]
elif len(datum.segments) != 1:
raise ValueError(f"Block has {len(datum.segments)} segments "
"please specify one to plot using run=")
else:
segment = datum.segments[0]
plot_segment(axes, segment, label=label, **properties)