Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 154 additions & 97 deletions bindsnet/analysis/plotting.py
Original file line number Diff line number Diff line change
@@ -1,115 +1,184 @@
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import numpy as np

from mpl_toolkits.axes_grid1 import make_axes_locatable


plt.ion()

def plot_input(image, inpt, ims=None, figsize=(12, 6)):
def plot_input(image, inpt, ims=None, figsize=(10, 5)):
'''
Plots a two-dimensional image and its corresponding spike-train representation.

Inputs:
image (torch.Tensor or torch.cuda.Tensor): A two-dimensional
array of floating point values depicting an input image.
inpt (torch.Tensor or torch.cuda.Tensor): A two-dimensional array of
floating point values depicting an image's spike-train encoding.
ims (list(matplotlib.image.AxesImage)): Used for re-drawing the input plots.
figsize (tuple(int)): Horizontal, vertical figure size in inches.

Returns:
(list(matplotlib.image.AxesImage)): Used for re-drawing the input plots.
'''
if not ims:
f, axes = plt.subplots(1, 2, figsize=figsize)
fig, axes = plt.subplots(1, 2, figsize=figsize)
ims = axes[0].imshow(image, cmap='binary'), axes[1].imshow(inpt, cmap='binary')

axes[0].set_title('Current image')
axes[1].set_title('Poisson spiking representation')
f.tight_layout()
axes[1].set_xlabel('Simulation time'); axes[1].set_ylabel('Neuron index')
axes[1].set_aspect('auto')

for ax in axes:
ax.set_xticks(()); ax.set_yticks(())

fig.tight_layout()
else:
ims[0].set_data(image)
ims[1].set_data(inpt)

return ims

def plot_spikes(data, ims=None, time=None, figsize=(12, 7)):
"""
Plot spikes for any group of neuron

Inputs:
data (dict): Contains spiking data for groups of neurons of interest.

ims (matplotlib.figure.Figure): Figure to plot on. Otherwise, a new
figure is created.

time (tuple): Plot spiking activity of neurons between the given range
of time.

Default is the entire simulation time.

Ex: time = (40, 80) will plot spiking activity of
neurons from 40 ms to 80 ms. Plotting ticks are multiples
of 5.

figsize (tuple): Figure size.

Returns:
Nothing
"""

def plot_spikes(spikes, ims=None, axes=None, time=None, figsize=(12, 7)):
'''
Plot spikes for any group of neurons.

Inputs:
spikes (dict(torch.Tensor or torch.cuda.Tensor)): Contains
spiking data for groups of neurons of interest.
ims (list(matplotlib.image.AxesImage)): Used for re-drawing the spike plots.
axes (list(matplotlib.axes.Axes)): Used for re-drawing the spike plots.
time (tuple(int)): Plot spiking activity of neurons between the given range
of time. Default is the entire simulation time. For example, time =
(40, 80) will plot spiking activity of neurons from 40 ms to 80 ms.
figsize (tuple(int)): Horizontal, vertical figure size in inches.

Returns:
(list(matplotlib.image.AxesImage)): Used for re-drawing the spike plots.
(list(matplotlib.axes.Axes)): Used for re-drawing the spike plots.
'''
n_subplots = len(spikes.keys())

n_subplots = len(data.keys())
# Confirm that only 2 values for time were given
if time is not None:
assert(len(time) == 2)
assert(time[0] < time[1])

else: # Set it for entire duration
for key in data.keys():
time = (0, data[key].shape[1])
n = data[key].shape[0] # plot for a certain set of neurons?
break

if not ims:
locs, ticks = [t for t in range(0, time[1]-time[0]+5, 5)], [t for t in range(time[0], time[1]+5, 5)]
if n_subplots == 1: # Plotting only one image
plt.figure(figsize=figsize)
for key in data.keys():
ims = plt.imshow(data[key][:, time[0]:time[1]], cmap='binary')
plt.title('%s spikes from t = %1.2f ms to %1.2f ms'%(key, time[0], time[1]))
plt.xlabel('Time (ms)'); plt.ylabel('Neuron index')

plt.xticks(locs,ticks)

else: # Multiple subplots
f, axes = plt.subplots(n_subplots, 1, figsize=figsize)
plt.setp(axes, xticks=locs, xticklabels=ticks)

# Plot each layer at a time
for plot_ind, layer_data in enumerate(data.items()):
ims = axes[plot_ind].imshow(layer_data[1][:, time[0]:time[1]], cmap='binary')
axes[plot_ind].set_title('%s spikes from t = %1.2f ms to %1.2f ms'%(layer_data[0], time[0], time[1]))
# axes[plot_ind].axis('off')

f.tight_layout()
# Confirm only 2 values for time were given
if time is not None:
assert(len(time) == 2)
assert(time[0] < time[1])

else: # Set it for entire duration
for key in spikes.keys():
time = (0, spikes[key].shape[1])
break

if not ims:
fig, axes = plt.subplots(n_subplots, 1, figsize=figsize)
ims = []

if n_subplots == 1: # Plotting only one image
for key in spikes.keys():
ims.append(axes.imshow(spikes[key][:, time[0]:time[1]], cmap='binary'))
plt.title('%s spikes from t = %1.2f ms to %1.2f ms' % (key, time[0], time[1]))
plt.xlabel('Time (ms)'); plt.ylabel('Neuron index')

else: # Plot each layer at a time
for i, datum in enumerate(spikes.items()):
ims.append(axes[i].imshow(datum[1][:, time[0]:time[1]], cmap='binary'))
axes[i].set_title('%s spikes from t = %1.2f ms to %1.2f ms' % (datum[0], time[0], time[1]))

plt.setp(axes, xticks=[], yticks=[], xlabel='Simulation time', ylabel='Neuron index')

for ax in axes:
ax.set_aspect('auto')

plt.tight_layout()

else: #plotting figure given
assert(len(ims) == n_subplots)
for plot_ind, layer_data in enumerate(data.items()):
if time is None:
ims[plot_ind].set_data(layer_data[1])
ims[plot_ind].set_title('%s spikes from t = %1.2f ms to %1.2f ms'%(layer_data[0], time[0], time[1]))
else:#plot for given time
ims[plot_ind].set_data(layer_data[1][time[0], time[1]])
ims[plot_ind].set_title('%s spikes from t = %1.2f ms to %1.2f ms'%(layer_data[0], time[0], time[1]))
else: # Plotting figure given
assert(len(ims) == n_subplots)
for i, datum in enumerate(spikes.items()):
if time is None:
ims[i].set_data(datum[1])
axes[i].set_title('%s spikes from t = %1.2f ms to %1.2f ms' % (datum[0], time[0], time[1]))
else: # Plot for given time
ims[i].set_data(datum[1][time[0]:time[1]])
axes[i].set_title('%s spikes from t = %1.2f ms to %1.2f ms' % (datum[0], time[0], time[1]))

return ims, axes

def plot_weights(weights, assignments, wmax=1, ims=None, figsize=(10, 6)):
if not ims:
f, axes = plt.subplots(1, 2, figsize=figsize)

def plot_weights(weights, wmin=0.0, wmax=1.0, im=None, figsize=(6, 6)):
'''
Plot a (possibly reshaped) connection weight matrix.

Inputs:
weights (torch.Tensor or torch.cuda.Tensor): Weight matrix of Connection object.
wmin (float): Minimum allowed weight value.
wmax (float): Maximum allowed weight value.
im (matplotlib.image.AxesImage): Used for re-drawing the weights plot.
figsize (tuple(int)): Horizontal, vertical figure size in inches.

Returns:
(matplotlib.image.AxesImage): Used for re-drawing the weights plot.
'''
if not im:
fig, ax = plt.subplots(figsize=figsize)

im = ax.imshow(weights, cmap='hot_r', vmin=wmin, vmax=wmax)
div = make_axes_locatable(ax)
cax = div.append_axes("right", size="5%", pad=0.05)

ax.set_xticks(()); ax.set_yticks(())

plt.colorbar(im, cax=cax)
fig.tight_layout()
else:
im.set_data(weights)

return im


def plot_assignments(assignments, im=None, figsize=(6, 6)):
'''
Plot the two-dimensional neuron assignments.

Inputs:
assignments (torch.Tensor or torch.cuda.Tensor): Matrix of neuron label assignments.
im (matplotlib.image.AxesImage): Used for re-drawing the assignments plot.
figsize (tuple(int)): Horizontal, vertical figure size in inches.

Returns:
(matplotlib.image.AxesImage): Used for re-drawing the assigments plot.
'''
if not im:
fig, ax = plt.subplots(figsize=figsize)

color = plt.get_cmap('RdBu', 11)
ims = axes[0].imshow(weights, cmap='hot_r', vmin=0, vmax=wmax), axes[1].matshow(assignments, cmap=color, vmin=-1.5, vmax=9.5)
divs = make_axes_locatable(axes[0]), make_axes_locatable(axes[1])
caxs = divs[0].append_axes("right", size="5%", pad=0.05), divs[1].append_axes("right", size="5%", pad=0.05)
im = ax.matshow(assignments, cmap=color, vmin=-1.5, vmax=9.5)
div = make_axes_locatable(ax)
cax = div.append_axes("right", size="5%", pad=0.05)

plt.colorbar(ims[0], cax=caxs[0])
plt.colorbar(ims[0], cax=caxs[1], ticks=np.arange(-1, 10))
f.tight_layout()
plt.colorbar(im, cax=cax, ticks=np.arange(-1, 10))
fig.tight_layout()
else:
ims[0].set_data(weights)
ims[1].set_data(assignments)
im.set_data(assignments)

return ims
return im


def plot_performance(performances, ax=None, figsize=(6, 6)):
'''
Plot training accuracy curves.

Inputs:
performances (dict(list(float))): Lists of training accuracy estimates per voting scheme.
ax (matplotlib.axes.Axes): Used for re-drawing the performance plot.
figsize (tuple(int)): Horizontal, vertical figure size in inches.

Returns:
(matplotlib.axes.Axes): Used for re-drawing the performance plot.
'''
if not ax:
_, ax = plt.subplots(figsize=figsize)
else:
Expand All @@ -124,16 +193,4 @@ def plot_performance(performances, ax=None, figsize=(6, 6)):
ax.set_ylabel('Accuracy')
ax.legend()

return ax


def plot_voltages(exc, inh, axes=None, figsize=(8, 8)):
if axes is None:
_, axes = plt.subplots(2, 1, figsize=figsize)
axes[0].set_title('Excitatory voltages')
axes[1].set_title('Inhibitory voltages')

axes[0].clear(); axes[1].clear()
axes[0].plot(exc), axes[1].plot(inh)

return axes
return ax
8 changes: 4 additions & 4 deletions bindsnet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from struct import unpack
from urllib.request import urlretrieve


class MNIST:
'''
Handles loading and saving of the MNIST handwritten digits
Expand Down Expand Up @@ -74,7 +74,7 @@ def get_train(self):
p.dump(labels, open(os.path.join(self.path, MNIST.train_labels_pickle), 'wb'))
else:
# Load label data from disk if it has already been processed.
print('Loading labels from serialized object file.\n')
print('Loading training labels from serialized object file.\n')
labels = p.load(open(os.path.join(self.path, MNIST.train_labels_pickle), 'rb'))

return torch.Tensor(images), torch.Tensor(labels)
Expand All @@ -97,7 +97,7 @@ def get_test(self):
p.dump(images, open(os.path.join(self.path, MNIST.test_images_pickle), 'wb'))
else:
# Load image data from disk if it has already been processed.
print('Loading images from serialized object file.\n')
print('Loading test images from serialized object file.\n')
images = p.load(open(os.path.join(self.path, MNIST.test_images_pickle), 'rb'))

if not os.path.isfile(os.path.join(self.path, MNIST.test_labels_pickle)):
Expand All @@ -110,7 +110,7 @@ def get_test(self):
p.dump(labels, open(os.path.join(self.path, MNIST.test_labels_pickle), 'wb'))
else:
# Load label data from disk if it has already been processed.
print('Loading labels from serialized object file.\n')
print('Loading test labels from serialized object file.\n')
labels = p.load(open(os.path.join(self.path, MNIST.test_labels_pickle), 'rb'))

return torch.Tensor(images), torch.Tensor(labels)
Expand Down
2 changes: 1 addition & 1 deletion bindsnet/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def run(self, inpts, time):

return spikes

def reset(self):
def _reset(self):
'''
Reset state variables of objects in network.
'''
Expand Down
2 changes: 1 addition & 1 deletion bindsnet/network/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def step(self, inpts, dt):
self.refrac_count[self.refrac_count != 0] -= dt

# Check for spiking neurons.
self.s = (self.v >= self.threshold) * (self.refrac_count == 0)
self.s = (self.v >= self.threshold + self.theta) * (self.refrac_count == 0)
self.refrac_count[self.s] = self.refractory
self.v[self.s] = self.reset
self.theta[self.s] += self.theta_plus
Expand Down
Loading