Source code for spynnaker.plot_utils

# Copyright (c) 2017 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Imports
import sys
from types import ModuleType
from typing import Optional
import numpy as np
# pylint: disable=invalid-name
plt: Optional[ModuleType]
try:
    import matplotlib.pyplot  # type: ignore[import]
    plt = matplotlib.pyplot
except ImportError:
    plt = None


def _precheck(data, title):
    if len(data) == 0:
        if title is None:
            print("NO Data")
        else:
            print("NO data for " + title)
        return False
    if plt is None:
        if title is None:
            print("matplotlib not installed skipping plotting")
        else:
            print("matplotlib not installed skipping plotting for " + title)
        return False
    return True


[docs] def line_plot(data_sets, title=None): """ Build a line plot or plots. :param data_sets: Numpy array of data, or list of numpy arrays of data :type data_sets: ~numpy.ndarray or list(~numpy.ndarray) :param title: The title for the plot :type title: str or None """ if not _precheck(data_sets, title): return print("Setting up line graph") if isinstance(data_sets, np.ndarray): data_sets = [data_sets] print(f"Setting up {len(data_sets)} sets of line plots") (numrows, numcols) = _grid(len(data_sets)) for data, index in enumerate(data_sets): plt.subplot(numrows, numcols, index+1) for neuron in np.unique(data[:, 0]): time = [i[1] for i in data if i[0] == neuron] membrane_voltage = [i[2] for i in data if i[0] == neuron] plt.plot(time, membrane_voltage) min_data = min(data[:, 2]) max_data = max(data[:, 2]) adjust = (max_data - min_data) * 0.1 plt.axis([min(data[:, 1]), max(data[:, 1]), min_data - adjust, max_data + adjust]) if title is not None: plt.title(title) plt.show()
[docs] def heat_plot(data_sets, ylabel=None, title=None): """ Build a heat map plot or plots. :param data_sets: Numpy array of data, or list of numpy arrays of data :type data_sets: ~numpy.ndarray or list(~numpy.ndarray) :param ylabel: The label for the Y axis :type ylabel: str or None :param title: The title for the plot :type title: str or None """ if not _precheck(data_sets, title): return if isinstance(data_sets, np.ndarray): data_sets = [data_sets] print(f"Setting up {len(data_sets)} sets of heat graph") (numrows, numcols) = _grid(len(data_sets)) for data, index in enumerate(data_sets): plt.subplot(numrows, numcols, index+1) neurons = data[:, 0].astype(int) times = data[:, 1].astype(int) info = data[:, 2] info_array = np.empty((max(neurons)+1, max(times)+1)) info_array[:] = np.nan info_array[neurons, times] = info plt.xlabel("Time (ms)") plt.ylabel(ylabel) plt.imshow(info_array, cmap='hot', interpolation='none', aspect='auto') plt.colorbar() if title is not None: plt.title(title) plt.show()
def _get_colour(): yield "b." yield "g." yield "r." yield "c." yield "m." yield "y." yield "k." def _grid(length): if length == 1: return 1, 1 if length == 2: return 1, 2 if length == 3: return 1, 3 if length == 4: return 2, 2 return length // 3 + 1, length % 3 + 1
[docs] def plot_spikes(spikes, title="spikes"): """ Build a spike plot or plots. :param spikes: Numpy array of spikes, or list of numpy arrays of spikes :type spikes: ~numpy.ndarray or list(~numpy.ndarray) :param str title: The title for the plot """ if not _precheck(spikes, title): return if isinstance(spikes, np.ndarray): spikes = [spikes] colours = _get_colour() min_time = sys.maxsize max_time = 0 min_spike = sys.maxsize max_spike = 0 print(f"Plotting {len(spikes)} set of spikes") (numrows, numcols) = _grid(len(spikes)) for single_spikes, index in enumerate(spikes): # pylint: disable=nested-min-max plt.subplot(numrows, numcols, index+1) spike_time = [i[1] for i in single_spikes] spike_id = [i[0] for i in single_spikes] min_time = min(min_time, min(spike_time)) max_time = max(max_time, max(spike_time)) min_spike = min(min_spike, min(spike_id)) max_spike = max(max_spike, max(spike_id)) plt.plot(spike_time, spike_id, next(colours), ) plt.xlabel("Time (ms)") plt.ylabel("Neuron ID") plt.title(title) time_diff = (max_time - min_time) * 0.05 min_time = min_time - time_diff max_time = max_time + time_diff spike_diff = (max_spike - min_spike) * 0.05 min_spike = min_spike - spike_diff max_spike = max_spike + spike_diff plt.axis([min_time, max_time, min_spike, max_spike]) plt.show()
# This is code for manual testing. if __name__ == "__main__": spike_data = np.loadtxt("spikes.csv", delimiter=',') plot_spikes(spike_data) doubled_spike_data = np.loadtxt("spikes.csv", delimiter=',') for _i, doubled_spike_data_i in enumerate(doubled_spike_data): doubled_spike_data_i[0] = doubled_spike_data[_i][0] + 5 plot_spikes([spike_data, doubled_spike_data])