• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

NeuralEnsemble / PyNN / 970

pending completion
970

push

travis-ci-com

GitHub
Merge pull request #763 from jiegec/fix-plot-spiketrains

Fix quantities error and x-y order in plotting

fixes #765

2 of 2 new or added lines in 1 file covered. (100.0%)

6956 of 9976 relevant lines covered (69.73%)

0.7 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/pyNN/utility/plotting.py
1
"""
2
Simple tools for plotting Neo-format data.
3

4
These tools are intended for quickly producing basic plots with simple
5
formatting. If you need to produce more complex and/or publication-quality
6
figures, it will probably be easier to use matplotlib or another plotting
7
package directly rather than trying to extend this module.
8

9
:copyright: Copyright 2006-2022 by the PyNN team, see AUTHORS.
10
:license: CeCILL, see LICENSE for details.
11

12
"""
13

14
import sys
×
15
from collections import defaultdict
×
16
from numbers import Number
×
17
from itertools import repeat
×
18
from os import path, makedirs
×
19
import matplotlib.pyplot as plt
×
20
import matplotlib.gridspec as gridspec
×
21
import numpy as np
×
22
from quantities import ms
×
23
from neo import AnalogSignal, IrregularlySampledSignal, SpikeTrain
×
24
from neo.core.spiketrainlist import SpikeTrainList
×
25

26

27
DEFAULT_FIG_SETTINGS = {
×
28
    'lines.linewidth': 0.5,
29
    'axes.linewidth': 0.5,
30
    'axes.labelsize': 'small',
31
    'legend.fontsize': 'small',
32
    'font.size': 8,
33
    'savefig.dpi': 150,
34
}
35

36

37
def handle_options(ax, options):
×
38
    if "xticks" not in options or options.pop("xticks") is False:
×
39
        plt.setp(ax.get_xticklabels(), visible=False)
×
40
    if "xlabel" in options:
×
41
        ax.set_xlabel(options.pop("xlabel"))
×
42
    if "yticks" not in options or options.pop("yticks") is False:
×
43
        plt.setp(ax.get_yticklabels(), visible=False)
×
44
    if "ylabel" in options:
×
45
        ax.set_ylabel(options.pop("ylabel"))
×
46
    if "ylim" in options:
×
47
        ax.set_ylim(options.pop("ylim"))
×
48
    if "xlim" in options:
×
49
        ax.set_xlim(options.pop("xlim"))
×
50

51

52
def plot_signal(ax, signal, index=None, label='', **options):
×
53
    """
54
    Plot a single channel from an AnalogSignal.
55
    """
56
    if "ylabel" in options:
×
57
        if options["ylabel"] == "auto":
×
58
            options["ylabel"] = "%s (%s)" % (signal.name,
×
59
                                             signal.units._dimensionality.string)
60
    handle_options(ax, options)
×
61
    if index is None:
×
62
        label = "%s (Neuron %d)" % (label, signal.array_annotations["channel_index"] or 0)
×
63
    else:
64
        label = "%s (Neuron %d)" % (label, signal.array_annotations["channel_index"][index])
×
65
        signal = signal[:, index]
×
66
    ax.plot(signal.times.rescale(ms), signal.magnitude, label=label, **options)
×
67
    ax.legend()
×
68

69

70
def plot_signals(ax, signal_array, label_prefix='', **options):
×
71
    """
72
    Plot all channels in an AnalogSignal in a single panel.
73
    """
74
    if "ylabel" in options:
×
75
        if options["ylabel"] == "auto":
×
76
            options["ylabel"] = "%s (%s)" % (signal_array.name,
×
77
                                             signal_array.units._dimensionality.string)
78
    handle_options(ax, options)
×
79
    offset = options.pop("y_offset", None)
×
80
    show_legend = options.pop("legend", True)
×
81
    if "channel_index" in signal_array.array_annotations:
×
82
        channel_iterator = signal_array.array_annotations["channel_index"].argsort()
×
83
    else:
84
        channel_iterator = range(signal_array.shape[1])
×
85
    for i in channel_iterator:
×
86
        if "channel_index" in signal_array.array_annotations:
×
87
            channel = signal_array.array_annotations["channel_index"][i]
×
88
            if label_prefix:
×
89
                label = "%s (Neuron %d)" % (label_prefix, channel)
×
90
            else:
91
                label = "Neuron %d" % channel
×
92
        elif label_prefix:
×
93
            label = "%s (%d)" % (label_prefix, i)
×
94
        else:
95
            label = str(i)
×
96
        signal = signal_array[:, i]
×
97
        if offset:
×
98
            signal += i * offset
×
99
        ax.plot(signal.times.rescale(ms), signal.magnitude, label=label, **options)
×
100
    if show_legend:
×
101
        ax.legend()
×
102

103

104
def plot_spiketrains(ax, spiketrains, label='', **options):
×
105
    """
106
    Plot all spike trains in a Segment in a raster plot.
107
    """
108
    ax.set_xlim(spiketrains[0].t_start, spiketrains[0].t_stop / ms)
×
109
    handle_options(ax, options)
×
110
    max_index = 0
×
111
    min_index = sys.maxsize
×
112
    for spiketrain in spiketrains:
×
113
        ax.plot(spiketrain,
×
114
                np.ones_like(spiketrain) * spiketrain.annotations['source_index'],
115
                'k.', **options)
116
        max_index = max(max_index, spiketrain.annotations['source_index'])
×
117
        min_index = min(min_index, spiketrain.annotations['source_index'])
×
118
    ax.set_ylabel("Neuron index")
×
119
    ax.set_ylim(-0.5 + min_index, max_index + 0.5)
×
120
    if label:
×
121
        plt.text(0.95, 0.95, label,
×
122
                 transform=ax.transAxes, ha='right', va='top',
123
                 bbox=dict(facecolor='white', alpha=1.0))
124

125

126
def plot_spiketrainlist(ax, spiketrains, label='', **options):
×
127
    """
128
    Plot all spike trains in a Segment in a raster plot.
129
    """
130
    ax.set_xlim(spiketrains.t_start, spiketrains.t_stop)
×
131
    handle_options(ax, options)
×
132
    channel_ids, spike_times = spiketrains.multiplexed
×
133
    max_id = max(spiketrains.all_channel_ids)
×
134
    min_id = min(spiketrains.all_channel_ids)
×
135
    ax.plot(spike_times, channel_ids, 'k.', **options)
×
136
    ax.set_ylabel("Neuron index")
×
137
    ax.set_ylim(-0.5 + min_id, max_id + 0.5)
×
138
    if label:
×
139
        plt.text(0.95, 0.95, label,
×
140
                 transform=ax.transAxes, ha='right', va='top',
141
                 bbox=dict(facecolor='white', alpha=1.0))
142

143

144
def plot_array_as_image(ax, arr, label='', **options):
×
145
    """
146
    Plots a numpy array as an image.
147
    """
148
    handle_options(ax, options)
×
149
    show_legend = options.pop("legend", True)
×
150
    plt.pcolormesh(arr, **options)
×
151
    ax.set_aspect('equal')
×
152
    if label:
×
153
        plt.text(0.95, 0.95, label,
×
154
                 transform=ax.transAxes, ha='right', va='top',
155
                 bbox=dict(facecolor='white', alpha=1.0))
156
    if show_legend:
×
157
        plt.colorbar()
×
158

159

160
def scatterplot(ax, data_table, label='', **options):
×
161
    handle_options(ax, options)
×
162
    if options.pop("show_fit", False):
×
163
        plt.plot(data_table.x, data_table.y_fit, 'k-')
×
164
    plt.scatter(data_table.x, data_table.y, **options)
×
165
    if label:
×
166
        plt.text(0.95, 0.95, label,
×
167
                 transform=ax.transAxes, ha='right', va='top',
168
                 bbox=dict(facecolor='white', alpha=1.0))
169

170

171
def plot_hist(ax, histogram, label='', **options):
×
172
    handle_options(ax, options)
×
173
    for t, n in histogram:
×
174
        ax.bar(t, n, width=histogram.bin_width, color=None)
×
175
    if label:
×
176
        plt.text(0.95, 0.95, label,
×
177
                 transform=ax.transAxes, ha='right', va='top',
178
                 bbox=dict(facecolor='white', alpha=1.0))
179

180

181
def variable_names(segment):
×
182
    """
183
    List the names of all the AnalogSignals (used for the variable name by
184
    PyNN) in the given segment.
185
    """
186
    return set(signal.name for signal in segment.analogsignals)
×
187

188

189
class Figure(object):
×
190
    """
191
    Provide simple, declarative specification of multi-panel figures.
192

193
    Example::
194

195
      Figure(
196
          Panel(segment.filter(name="v")[0], ylabel="Membrane potential (mV)")
197
          Panel(segment.spiketrains, xlabel="Time (ms)"),
198
          title="Network activity",
199
      ).save("figure3.png")
200

201
    Valid options are:
202
        `settings`:
203
            for figure settings, e.g. {'font.size': 9}
204
        `annotations`:
205
            a (multi-line) string to be printed at the bottom of the figure.
206
        `title`:
207
            a string to be printed at the top of the figure.
208
    """
209

210
    def __init__(self, *panels, **options):
×
211
        n_panels = len(panels)
×
212
        if "settings" in options and options["settings"] is not None:
×
213
            settings = options["settings"]
×
214
        else:
215
            settings = DEFAULT_FIG_SETTINGS
×
216
        plt.rcParams.update(settings)
×
217
        width, height = options.get("size", (6, 2 * n_panels + 1.2))
×
218
        self.fig = plt.figure(1, figsize=(width, height))
×
219
        gs = gridspec.GridSpec(n_panels, 1)
×
220
        if "annotations" in options:
×
221
            gs.update(bottom=1.2 / height)  # leave space for annotations
×
222
        gs.update(top=1 - 0.8 / height, hspace=0.25)
×
223
        # print(gs.get_grid_positions(self.fig))
224

225
        for i, panel in enumerate(panels):
×
226
            panel.plot(plt.subplot(gs[i, 0]))
×
227

228
        if "title" in options:
×
229
            self.fig.text(0.5, 1 - 0.5 / height, options["title"],
×
230
                          ha="center", va="top", fontsize="large")
231
        if "annotations" in options:
×
232
            plt.figtext(0.01, 0.01, options["annotations"], fontsize=6, verticalalignment='bottom')
×
233

234
    def save(self, filename):
×
235
        """
236
        Save the figure to file. The format is taken from the file extension.
237
        """
238
        dirname = path.dirname(filename)
×
239
        if dirname and not path.exists(dirname):
×
240
            makedirs(dirname)
×
241
        self.fig.savefig(filename)
×
242

243
    def show(self):
×
244
        plt.show()
×
245

246

247
class Panel(object):
×
248
    """
249
    Represents a single panel in a multi-panel figure.
250

251
    A panel is a Matplotlib Axes or Subplot instance. A data item may be an
252
    AnalogSignal, AnalogSignal, or a list of SpikeTrains. The Panel will
253
    automatically choose an appropriate representation. Multiple data items may
254
    be plotted in the same panel.
255

256
    Valid options are any valid Matplotlib formatting options that should be
257
    applied to the Axes/Subplot, plus in addition:
258

259
        `data_labels`:
260
            a list of strings of the same length as the number of data items.
261
        `line_properties`:
262
            a list of dicts containing Matplotlib formatting options, of the
263
            same length as the number of data items.
264

265
    """
266

267
    def __init__(self, *data, **options):
×
268
        self.data = list(data)
×
269
        self.options = options
×
270
        self.data_labels = options.pop("data_labels", repeat(None))
×
271
        self.line_properties = options.pop("line_properties", repeat({}))
×
272

273
    def plot(self, axes):
×
274
        """
275
        Plot the Panel's data in the provided Axes/Subplot instance.
276
        """
277
        for datum, label, properties in zip(self.data, self.data_labels, self.line_properties):
×
278
            properties.update(self.options)
×
279
            if isinstance(datum, DataTable):
×
280
                scatterplot(axes, datum, label=label, **properties)
×
281
            elif isinstance(datum, Histogram):
×
282
                plot_hist(axes, datum, label=label, **properties)
×
283
            elif isinstance(datum, (AnalogSignal, IrregularlySampledSignal)):
×
284
                plot_signals(axes, datum, label_prefix=label, **properties)
×
285
            elif isinstance(datum, list) and len(datum) > 0 and isinstance(datum[0], SpikeTrain):
×
286
                plot_spiketrains(axes, datum, label=label, **properties)
×
287
            elif isinstance(datum, SpikeTrainList):
×
288
                plot_spiketrainlist(axes, datum, label=label, **properties)
×
289
            elif isinstance(datum, np.ndarray):
×
290
                if datum.ndim == 2:
×
291
                    plot_array_as_image(axes, datum, label=label, **properties)
×
292
                else:
293
                    raise Exception("Can't handle arrays with %s dimensions" % datum.ndim)
×
294
            else:
295
                raise Exception("Can't handle type %s" % type(datum))
×
296

297

298
def comparison_plot(segments, labels, title='', annotations=None,
×
299
                    fig_settings=None, with_spikes=True):
300
    """
301
    Given a list of segments, plot all the data they contain so as to be able
302
    to compare them.
303

304
    Return a Figure instance.
305
    """
306
    variables_to_plot = set.union(*(variable_names(s) for s in segments))
×
307
    print("Plotting the following variables: %s" % ", ".join(variables_to_plot))
×
308

309
    # group signal arrays by name
310
    n_seg = len(segments)
×
311
    by_var_and_channel = defaultdict(lambda: defaultdict(list))
×
312
    line_properties = []
×
313
    units = {}
×
314
    for k, (segment, label) in enumerate(zip(segments, labels)):
×
315
        lw = 2 * (n_seg - k) - 1
×
316
        col = 'bcgmkr'[k % 6]
×
317
        line_properties.append({"linewidth": lw, "color": col})
×
318
        for array in segment.analogsignals:
×
319
            # rescale signals to the same units, for a given variable name
320
            if array.name not in units:
×
321
                units[array.name] = array.units
×
322
            elif array.units != units[array.name]:
×
323
                array = array.rescale(units[array.name])
×
324
            for i in array.array_annotations["channel_index"].argsort():
×
325
                channel = array.array_annotations["channel_index"][i]
×
326
                signal = array[:, i]
×
327
                by_var_and_channel[array.name][channel].append(signal)
×
328
    # each panel plots the signals for a given variable.
329
    panels = []
×
330
    for by_channel in by_var_and_channel.values():
×
331
        for array_list in by_channel.values():
×
332
            ylabel = array_list[0].name
×
333
            if ylabel:
×
334
                ylabel += " ({})".format(array_list[0].dimensionality)
×
335
            panels.append(
×
336
                Panel(*array_list,
337
                      line_properties=line_properties,
338
                      yticks=True,
339
                      ylabel=ylabel,
340
                      data_labels=labels))
341
    if with_spikes and len(segments[0].spiketrains) > 0:
×
342
        panels += [Panel(segment.spiketrains, data_labels=[label])
×
343
                   for segment, label in zip(segments, labels)]
344
    panels[-1].options["xticks"] = True
×
345
    panels[-1].options["xlabel"] = "Time (ms)"
×
346
    fig = Figure(*panels,
×
347
                 title=title,
348
                 settings=fig_settings,
349
                 annotations=annotations)
350
    return fig
×
351

352

353
class DataTable(object):
×
354
    """A lightweight encapsulation of x, y data for scatterplots."""
355

356
    def __init__(self, x, y):
×
357
        self.x = x
×
358
        self.y = y
×
359

360
    def fit_curve(self, f, p0, **fitting_parameters):
×
361
        from scipy.optimize import curve_fit
×
362
        self._f = f
×
363
        self._p0 = p0
×
364
        self._popt, self._pcov = curve_fit(f, self.x, self.y, p0, **fitting_parameters)
×
365
        return self._popt, self._pcov
×
366

367
    @property
×
368
    def y_fit(self):
×
369
        return self._f(self.x, *self._popt)
×
370

371

372
class Histogram(object):
×
373
    """A lightweight encapsulation of histogram data."""
374

375
    def __init__(self, data):
×
376
        self.data = data
×
377
        self.evaluated = False
×
378

379
    def evaluate(self):
×
380
        if not self.evaluated:
×
381
            n_bins = int(np.sqrt(len(self.data)))
×
382
            self.values, self.bins = np.histogram(self.data, bins=n_bins)
×
383
            self.bin_width = self.bins[1] - self.bins[0]
×
384
            self.evaluated = True
×
385

386
    def __iter__(self):
×
387
        """Iterate over the bars of the histogram"""
388
        self.evaluate()
×
389
        for x, y in zip(self.bins[:-1], self.values):
×
390
            yield (x, y)
×
391

392

393
def isi_histogram(segment):
×
394
    all_isis = np.concatenate([np.diff(np.array(st)) for st in segment.spiketrains])
×
395
    return Histogram(all_isis)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc