• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

NeuralEnsemble / PyNN / 987

pending completion
987

push

travis-ci-com

web-flow
Merge pull request #772 from apdavison/flake8

fix or ignore flake8 errors and warnings; add flake8 checks to CI

362 of 362 new or added lines in 57 files covered. (100.0%)

7503 of 10585 relevant lines covered (70.88%)

0.71 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/pyNN/utility/plotting.py
1
"""
2
Simple tools for plotting Neo-format data.
3

4
These tools are intended for quickly producing basic plots with simple
5
formatting. If you need to produce more complex and/or publication-quality
6
figures, it will probably be easier to use matplotlib or another plotting
7
package directly rather than trying to extend this module.
8

9
:copyright: Copyright 2006-2022 by the PyNN team, see AUTHORS.
10
:license: CeCILL, see LICENSE for details.
11

12
"""
13

14
import sys
×
15
from collections import defaultdict
×
16
from itertools import repeat
×
17
from os import path, makedirs
×
18
import matplotlib.pyplot as plt
×
19
import matplotlib.gridspec as gridspec
×
20
import numpy as np
×
21
from quantities import ms
×
22
from neo import AnalogSignal, IrregularlySampledSignal, SpikeTrain
×
23
from neo.core.spiketrainlist import SpikeTrainList
×
24

25

26
DEFAULT_FIG_SETTINGS = {
×
27
    'lines.linewidth': 0.5,
28
    'axes.linewidth': 0.5,
29
    'axes.labelsize': 'small',
30
    'legend.fontsize': 'small',
31
    'font.size': 8,
32
    'savefig.dpi': 150,
33
}
34

35

36
def handle_options(ax, options):
×
37
    if "xticks" not in options or options.pop("xticks") is False:
×
38
        plt.setp(ax.get_xticklabels(), visible=False)
×
39
    if "xlabel" in options:
×
40
        ax.set_xlabel(options.pop("xlabel"))
×
41
    if "yticks" not in options or options.pop("yticks") is False:
×
42
        plt.setp(ax.get_yticklabels(), visible=False)
×
43
    if "ylabel" in options:
×
44
        ax.set_ylabel(options.pop("ylabel"))
×
45
    if "ylim" in options:
×
46
        ax.set_ylim(options.pop("ylim"))
×
47
    if "xlim" in options:
×
48
        ax.set_xlim(options.pop("xlim"))
×
49

50

51
def plot_signal(ax, signal, index=None, label='', **options):
×
52
    """
53
    Plot a single channel from an AnalogSignal.
54
    """
55
    if "ylabel" in options:
×
56
        if options["ylabel"] == "auto":
×
57
            options["ylabel"] = "%s (%s)" % (signal.name,
×
58
                                             signal.units._dimensionality.string)
59
    handle_options(ax, options)
×
60
    if index is None:
×
61
        label = "%s (Neuron %d)" % (label, signal.array_annotations["channel_index"] or 0)
×
62
    else:
63
        label = "%s (Neuron %d)" % (label, signal.array_annotations["channel_index"][index])
×
64
        signal = signal[:, index]
×
65
    ax.plot(signal.times.rescale(ms), signal.magnitude, label=label, **options)
×
66
    ax.legend()
×
67

68

69
def plot_signals(ax, signal_array, label_prefix='', **options):
×
70
    """
71
    Plot all channels in an AnalogSignal in a single panel.
72
    """
73
    if "ylabel" in options:
×
74
        if options["ylabel"] == "auto":
×
75
            options["ylabel"] = "%s (%s)" % (signal_array.name,
×
76
                                             signal_array.units._dimensionality.string)
77
    handle_options(ax, options)
×
78
    offset = options.pop("y_offset", None)
×
79
    show_legend = options.pop("legend", True)
×
80
    if "channel_index" in signal_array.array_annotations:
×
81
        channel_iterator = signal_array.array_annotations["channel_index"].argsort()
×
82
    else:
83
        channel_iterator = range(signal_array.shape[1])
×
84
    for i in channel_iterator:
×
85
        if "channel_index" in signal_array.array_annotations:
×
86
            channel = signal_array.array_annotations["channel_index"][i]
×
87
            if label_prefix:
×
88
                label = "%s (Neuron %d)" % (label_prefix, channel)
×
89
            else:
90
                label = "Neuron %d" % channel
×
91
        elif label_prefix:
×
92
            label = "%s (%d)" % (label_prefix, i)
×
93
        else:
94
            label = str(i)
×
95
        signal = signal_array[:, i]
×
96
        if offset:
×
97
            signal += i * offset
×
98
        ax.plot(signal.times.rescale(ms), signal.magnitude, label=label, **options)
×
99
    if show_legend:
×
100
        ax.legend()
×
101

102

103
def plot_spiketrains(ax, spiketrains, label='', **options):
×
104
    """
105
    Plot all spike trains in a Segment in a raster plot.
106
    """
107
    ax.set_xlim(spiketrains[0].t_start, spiketrains[0].t_stop)
×
108
    handle_options(ax, options)
×
109
    max_index = 0
×
110
    min_index = sys.maxsize
×
111
    for spiketrain in spiketrains:
×
112
        ax.plot(spiketrain,
×
113
                np.ones_like(spiketrain) * spiketrain.annotations['source_index'],
114
                'k.', **options)
115
        max_index = max(max_index, spiketrain.annotations['source_index'])
×
116
        min_index = min(min_index, spiketrain.annotations['source_index'])
×
117
    ax.set_ylabel("Neuron index")
×
118
    ax.set_ylim(-0.5 + min_index, max_index + 0.5)
×
119
    if label:
×
120
        plt.text(0.95, 0.95, label,
×
121
                 transform=ax.transAxes, ha='right', va='top',
122
                 bbox=dict(facecolor='white', alpha=1.0))
123

124

125
def plot_spiketrainlist(ax, spiketrains, label='', **options):
×
126
    """
127
    Plot all spike trains in a Segment in a raster plot.
128
    """
129
    ax.set_xlim(spiketrains.t_start, spiketrains.t_stop)
×
130
    handle_options(ax, options)
×
131
    channel_ids, spike_times = spiketrains.multiplexed
×
132
    max_id = max(spiketrains.all_channel_ids)
×
133
    min_id = min(spiketrains.all_channel_ids)
×
134
    ax.plot(spike_times, channel_ids, 'k.', **options)
×
135
    ax.set_ylabel("Neuron index")
×
136
    ax.set_ylim(-0.5 + min_id, max_id + 0.5)
×
137
    if label:
×
138
        plt.text(0.95, 0.95, label,
×
139
                 transform=ax.transAxes, ha='right', va='top',
140
                 bbox=dict(facecolor='white', alpha=1.0))
141

142

143
def plot_array_as_image(ax, arr, label='', **options):
×
144
    """
145
    Plots a numpy array as an image.
146
    """
147
    handle_options(ax, options)
×
148
    show_legend = options.pop("legend", True)
×
149
    plt.pcolormesh(arr, **options)
×
150
    ax.set_aspect('equal')
×
151
    if label:
×
152
        plt.text(0.95, 0.95, label,
×
153
                 transform=ax.transAxes, ha='right', va='top',
154
                 bbox=dict(facecolor='white', alpha=1.0))
155
    if show_legend:
×
156
        plt.colorbar()
×
157

158

159
def scatterplot(ax, data_table, label='', **options):
×
160
    handle_options(ax, options)
×
161
    if options.pop("show_fit", False):
×
162
        plt.plot(data_table.x, data_table.y_fit, 'k-')
×
163
    plt.scatter(data_table.x, data_table.y, **options)
×
164
    if label:
×
165
        plt.text(0.95, 0.95, label,
×
166
                 transform=ax.transAxes, ha='right', va='top',
167
                 bbox=dict(facecolor='white', alpha=1.0))
168

169

170
def plot_hist(ax, histogram, label='', **options):
×
171
    handle_options(ax, options)
×
172
    for t, n in histogram:
×
173
        ax.bar(t, n, width=histogram.bin_width, color=None)
×
174
    if label:
×
175
        plt.text(0.95, 0.95, label,
×
176
                 transform=ax.transAxes, ha='right', va='top',
177
                 bbox=dict(facecolor='white', alpha=1.0))
178

179

180
def variable_names(segment):
×
181
    """
182
    List the names of all the AnalogSignals (used for the variable name by
183
    PyNN) in the given segment.
184
    """
185
    return set(signal.name for signal in segment.analogsignals)
×
186

187

188
class Figure(object):
×
189
    """
190
    Provide simple, declarative specification of multi-panel figures.
191

192
    Example::
193

194
      Figure(
195
          Panel(segment.filter(name="v")[0], ylabel="Membrane potential (mV)")
196
          Panel(segment.spiketrains, xlabel="Time (ms)"),
197
          title="Network activity",
198
      ).save("figure3.png")
199

200
    Valid options are:
201
        `settings`:
202
            for figure settings, e.g. {'font.size': 9}
203
        `annotations`:
204
            a (multi-line) string to be printed at the bottom of the figure.
205
        `title`:
206
            a string to be printed at the top of the figure.
207
    """
208

209
    def __init__(self, *panels, **options):
×
210
        n_panels = len(panels)
×
211
        if "settings" in options and options["settings"] is not None:
×
212
            settings = options["settings"]
×
213
        else:
214
            settings = DEFAULT_FIG_SETTINGS
×
215
        plt.rcParams.update(settings)
×
216
        width, height = options.get("size", (6, 2 * n_panels + 1.2))
×
217
        self.fig = plt.figure(1, figsize=(width, height))
×
218
        gs = gridspec.GridSpec(n_panels, 1)
×
219
        if "annotations" in options:
×
220
            gs.update(bottom=1.2 / height)  # leave space for annotations
×
221
        gs.update(top=1 - 0.8 / height, hspace=0.25)
×
222
        # print(gs.get_grid_positions(self.fig))
223

224
        for i, panel in enumerate(panels):
×
225
            panel.plot(plt.subplot(gs[i, 0]))
×
226

227
        if "title" in options:
×
228
            self.fig.text(0.5, 1 - 0.5 / height, options["title"],
×
229
                          ha="center", va="top", fontsize="large")
230
        if "annotations" in options:
×
231
            plt.figtext(0.01, 0.01, options["annotations"], fontsize=6, verticalalignment='bottom')
×
232

233
    def save(self, filename):
×
234
        """
235
        Save the figure to file. The format is taken from the file extension.
236
        """
237
        dirname = path.dirname(filename)
×
238
        if dirname and not path.exists(dirname):
×
239
            makedirs(dirname)
×
240
        self.fig.savefig(filename)
×
241

242
    def show(self):
×
243
        plt.show()
×
244

245

246
class Panel(object):
×
247
    """
248
    Represents a single panel in a multi-panel figure.
249

250
    A panel is a Matplotlib Axes or Subplot instance. A data item may be an
251
    AnalogSignal, AnalogSignal, or a list of SpikeTrains. The Panel will
252
    automatically choose an appropriate representation. Multiple data items may
253
    be plotted in the same panel.
254

255
    Valid options are any valid Matplotlib formatting options that should be
256
    applied to the Axes/Subplot, plus in addition:
257

258
        `data_labels`:
259
            a list of strings of the same length as the number of data items.
260
        `line_properties`:
261
            a list of dicts containing Matplotlib formatting options, of the
262
            same length as the number of data items.
263

264
    """
265

266
    def __init__(self, *data, **options):
×
267
        self.data = list(data)
×
268
        self.options = options
×
269
        self.data_labels = options.pop("data_labels", repeat(None))
×
270
        self.line_properties = options.pop("line_properties", repeat({}))
×
271

272
    def plot(self, axes):
×
273
        """
274
        Plot the Panel's data in the provided Axes/Subplot instance.
275
        """
276
        for datum, label, properties in zip(self.data, self.data_labels, self.line_properties):
×
277
            properties.update(self.options)
×
278
            if isinstance(datum, DataTable):
×
279
                scatterplot(axes, datum, label=label, **properties)
×
280
            elif isinstance(datum, Histogram):
×
281
                plot_hist(axes, datum, label=label, **properties)
×
282
            elif isinstance(datum, (AnalogSignal, IrregularlySampledSignal)):
×
283
                plot_signals(axes, datum, label_prefix=label, **properties)
×
284
            elif isinstance(datum, list) and len(datum) > 0 and isinstance(datum[0], SpikeTrain):
×
285
                plot_spiketrains(axes, datum, label=label, **properties)
×
286
            elif isinstance(datum, SpikeTrainList):
×
287
                plot_spiketrainlist(axes, datum, label=label, **properties)
×
288
            elif isinstance(datum, np.ndarray):
×
289
                if datum.ndim == 2:
×
290
                    plot_array_as_image(axes, datum, label=label, **properties)
×
291
                else:
292
                    raise Exception("Can't handle arrays with %s dimensions" % datum.ndim)
×
293
            else:
294
                raise Exception("Can't handle type %s" % type(datum))
×
295

296

297
def comparison_plot(segments, labels, title='', annotations=None,
×
298
                    fig_settings=None, with_spikes=True):
299
    """
300
    Given a list of segments, plot all the data they contain so as to be able
301
    to compare them.
302

303
    Return a Figure instance.
304
    """
305
    variables_to_plot = set.union(*(variable_names(s) for s in segments))
×
306
    print("Plotting the following variables: %s" % ", ".join(variables_to_plot))
×
307

308
    # group signal arrays by name
309
    n_seg = len(segments)
×
310
    by_var_and_channel = defaultdict(lambda: defaultdict(list))
×
311
    line_properties = []
×
312
    units = {}
×
313
    for k, (segment, label) in enumerate(zip(segments, labels)):
×
314
        lw = 2 * (n_seg - k) - 1
×
315
        col = 'bcgmkr'[k % 6]
×
316
        line_properties.append({"linewidth": lw, "color": col})
×
317
        for array in segment.analogsignals:
×
318
            # rescale signals to the same units, for a given variable name
319
            if array.name not in units:
×
320
                units[array.name] = array.units
×
321
            elif array.units != units[array.name]:
×
322
                array = array.rescale(units[array.name])
×
323
            for i in array.array_annotations["channel_index"].argsort():
×
324
                channel = array.array_annotations["channel_index"][i]
×
325
                signal = array[:, i]
×
326
                by_var_and_channel[array.name][channel].append(signal)
×
327
    # each panel plots the signals for a given variable.
328
    panels = []
×
329
    for by_channel in by_var_and_channel.values():
×
330
        for array_list in by_channel.values():
×
331
            ylabel = array_list[0].name
×
332
            if ylabel:
×
333
                ylabel += " ({})".format(array_list[0].dimensionality)
×
334
            panels.append(
×
335
                Panel(*array_list,
336
                      line_properties=line_properties,
337
                      yticks=True,
338
                      ylabel=ylabel,
339
                      data_labels=labels))
340
    if with_spikes and len(segments[0].spiketrains) > 0:
×
341
        panels += [Panel(segment.spiketrains, data_labels=[label])
×
342
                   for segment, label in zip(segments, labels)]
343
    panels[-1].options["xticks"] = True
×
344
    panels[-1].options["xlabel"] = "Time (ms)"
×
345
    fig = Figure(*panels,
×
346
                 title=title,
347
                 settings=fig_settings,
348
                 annotations=annotations)
349
    return fig
×
350

351

352
class DataTable(object):
×
353
    """A lightweight encapsulation of x, y data for scatterplots."""
354

355
    def __init__(self, x, y):
×
356
        self.x = x
×
357
        self.y = y
×
358

359
    def fit_curve(self, f, p0, **fitting_parameters):
×
360
        from scipy.optimize import curve_fit
×
361
        self._f = f
×
362
        self._p0 = p0
×
363
        self._popt, self._pcov = curve_fit(f, self.x, self.y, p0, **fitting_parameters)
×
364
        return self._popt, self._pcov
×
365

366
    @property
×
367
    def y_fit(self):
×
368
        return self._f(self.x, *self._popt)
×
369

370

371
class Histogram(object):
×
372
    """A lightweight encapsulation of histogram data."""
373

374
    def __init__(self, data):
×
375
        self.data = data
×
376
        self.evaluated = False
×
377

378
    def evaluate(self):
×
379
        if not self.evaluated:
×
380
            n_bins = int(np.sqrt(len(self.data)))
×
381
            self.values, self.bins = np.histogram(self.data, bins=n_bins)
×
382
            self.bin_width = self.bins[1] - self.bins[0]
×
383
            self.evaluated = True
×
384

385
    def __iter__(self):
×
386
        """Iterate over the bars of the histogram"""
387
        self.evaluate()
×
388
        for x, y in zip(self.bins[:-1], self.values):
×
389
            yield (x, y)
×
390

391

392
def isi_histogram(segment):
×
393
    all_isis = np.concatenate([np.diff(np.array(st)) for st in segment.spiketrains])
×
394
    return Histogram(all_isis)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc