Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions bindsnet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,120 @@
import shutil
import pickle as p
import numpy as np
import scipy.io.wavfile
from scipy.fftpack import dct
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this down to the from imports and correct the spacing.


from struct import unpack
from urllib.request import urlretrieve


class SpokenMNIST:
'''
Data is divided by an 80-20 split into train and test
'''
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More descriptive doc-string + punctuaction.

def __init__(self, path=None):
self.data_dir = '/home/darpan/sem4/free-spoken-digit-dataset/recordings/'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a hard-coded path, and will fail on general machines. Take a look at the MNIST class for a default way of doing this.

# self.data_dir = '/mnt/nfs/work1/rkozma/dsanghavi/bindsnet/data/spokenMNIST/recordings/'
if path:
self.data_dir = path
self.files = [f for f in os.listdir(self.data_dir) if os.path.isfile(os.path.join(self.data_dir, f)) and '.wav' in f]
np.random.shuffle(self.files)
split = int(0.8*len(self.files))
self.train_files = self.files[:split]
self.test_files = self.files[split:]

def get_train(self):
'''
Gets the SpokenMNIST training log filter banks and labels.

Returns:
List of variable length audios: The spoken MNIST training log filter banks. Each element in the list is of shape (T_i, 40)
(torch.Tensor or torch.cuda.Tensor) labels: The MNIST training labels.
'''
audios = []
labels = []
for f in self.train_files:
filter_banks, label = self.pre_process(f)
audios.append(torch.Tensor(filter_banks))
labels.append(label)
return audios, torch.Tensor(labels)

def get_test(self):
'''
Gets the SpokenMNIST testing log filter banks and labels.

Returns:
List of variable length audios: The spoken MNIST testing log filter banks.
(torch.Tensor or torch.cuda.Tensor) labels: The MNIST training labels.
'''
audios = []
labels = []
for f in self.test_files:
filter_banks, label = self.pre_process(f)
audios.append(filter_banks)
labels.append(label)
return audios, torch.Tensor(labels)

def pre_process(self, file):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe preprocess instead of pre_process?

'''
Returns the 40 dim log filter banks
'''
label = int(file[0])

sample_rate, signal = scipy.io.wavfile.read(os.path.join(self.data_dir,file))
pre_emphasis = 0.97
emphasized_signal = np.append(signal[0], signal[1:] - pre_emphasis * signal[:-1])
# Popular settings are 25 ms for the frame size and a 10 ms stride (15 ms overlap)
frame_size = 0.025
frame_stride = 0.01
frame_length, frame_step = frame_size * sample_rate, frame_stride * sample_rate # Convert from seconds to samples
signal_length = len(emphasized_signal)
frame_length = int(round(frame_length))
frame_step = int(round(frame_step))
num_frames = int(np.ceil(float(np.abs(signal_length - frame_length)) / frame_step)) # Make sure that we have at least 1 frame

pad_signal_length = num_frames * frame_step + frame_length
z = np.zeros((pad_signal_length - signal_length))
pad_signal = np.append(emphasized_signal, z) # Pad Signal to make sure that all frames have equal number of samples without truncating any samples from the original signal

indices = np.tile(np.arange(0, frame_length), (num_frames, 1)) + np.tile(np.arange(0, num_frames * frame_step, frame_step), (frame_length, 1)).T
frames = pad_signal[indices.astype(np.int32, copy=False)]

# Hamming Window
frames *= np.hamming(frame_length)

# Fast Fourier Transform and Power Spectrum
NFFT = 512
mag_frames = np.absolute(np.fft.rfft(frames, NFFT)) # Magnitude of the FFT
pow_frames = ((1.0 / NFFT) * ((mag_frames) ** 2)) # Power Spectrum

# Log filter banks
nfilt = 40
low_freq_mel = 0
high_freq_mel = (2595 * np.log10(1 + (sample_rate / 2) / 700)) # Convert Hz to Mel
mel_points = np.linspace(low_freq_mel, high_freq_mel, nfilt + 2) # Equally spaced in Mel scale
hz_points = (700 * (10**(mel_points / 2595) - 1)) # Convert Mel to Hz
bin = np.floor((NFFT + 1) * hz_points / sample_rate)

fbank = np.zeros((nfilt, int(np.floor(NFFT / 2 + 1))))
for m in range(1, nfilt + 1):
f_m_minus = int(bin[m - 1]) # left
f_m = int(bin[m]) # center
f_m_plus = int(bin[m + 1]) # right

for k in range(f_m_minus, f_m):
fbank[m - 1, k] = (k - bin[m - 1]) / (bin[m] - bin[m - 1])
for k in range(f_m, f_m_plus):
fbank[m - 1, k] = (bin[m + 1] - k) / (bin[m + 1] - bin[m])
filter_banks = np.dot(pow_frames, fbank.T)
filter_banks = np.where(filter_banks == 0, np.finfo(float).eps, filter_banks) # Numerical Stability
filter_banks = 20 * np.log10(filter_banks) # dB

return filter_banks, label

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, this method seems really messy / hard-coded. Perhaps some of the parameters should be arguments to the function (e.g., NFFT, frame_size, frame_stride, etc.).




class MNIST:
'''
Handles loading and saving of the MNIST handwritten digits
Expand Down
183 changes: 183 additions & 0 deletions bindsnet/encoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,91 @@
import numpy as np


def get_poisson_mixture(data, time, window):
'''
Generates mixture models of Poisson spike trains based on input intensity.
Each timeframe describes a Poisson spike train, which is aggregated to the actual
spike train from that timestep onwards
Inputs must be non-negative. Spike inter-arrival times are inversely proportional to
input magnitude, so data must be scaled according to desired spike frequency.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But, it's not a "mixture" per se, is it? I feel like the function header is misleading.

Inputs:
data List of (torch.Tensor or torch.cuda.Tensor): Tensors of shape [n_samples, n_1,
..., n_k], with arbitrary sample dimensionality [n_1, ..., n_k].
time (int): Length of Poisson spike train per input variable.

Yields:
(torch.Tensor or torch.cuda.Tensor): Tensors with shape [time, n_1, ..., n_k], with
Poisson-distributed spikes parameterized by the data values.
'''
for audio in data:
# For poisson, add minimum element to all
audio_shifted = audio - torch.min(audio) # Linear shifting for now.. Can try exponential/polynomial
s = get_poisson_mixture_for_example(audio_shifted, time, window)
yield torch.Tensor(s).byte()

def get_poisson_mixture_for_example(data, time, window):
'''
Generates mixture models of Poisson spike trains based on input intensity.
Each timeframe describes a Poisson spike train, which is aggregated to the actual
spike train from that timestep onwards
Inputs must be non-negative. Spike inter-arrival times are inversely proportional to
input magnitude, so data must be scaled according to desired spike frequency.

Inputs:
data (torch.Tensor or torch.cuda.Tensor): Tensors of shape [T, n_1,
..., n_k], with arbitrary sample dimensionality [n_1, ..., n_k]
T = #frames in utterance
time (int): Length of Poisson spike train per input variable.

Returns:
(torch.Tensor or torch.cuda.Tensor): Tensors with shape [time, n_1, ..., n_k], with
Poisson-distributed spikes parameterized by the data values.
'''
if data.shape[0]>time:
print("Warning: more frames than timesteps. Extra frames will be skipped. Frames = ", data.shape[0]," Timesteps = ", time)

spikes = np.zeros([time, data.shape[1]]) # (time,40) for spokenMNIST with 40 dim log filter banks
for i,frame in enumerate(data):
# TODO is this needed?
frame = np.copy(frame)
if i>time-window:
break
# s = get_poisson_for_frame(frame, time-i)
s = get_poisson_for_frame(frame, window) # For every frame, generate a small window of Poisson distributions
spikes[i:i+window] += s # add them (like deconv) beginning from their temporal location
spikes[spikes>1] = 1

return spikes


def get_poisson_for_frame(datum, time):
# Get i-th datum.
shape, size = datum.shape, datum.size
datum = datum.ravel()

# Invert inputs (input intensity inversely
# proportional to spike inter-arrival time).
datum[datum != 0] = 1 / datum[datum != 0]

# Make spike data from Poisson sampling.
s_times = np.random.poisson(datum, [time, size])
s_times = np.cumsum(s_times, axis=0)
s_times[s_times >= time] = 0

# Create spike trains from spike times.
s = np.zeros([time, size])
for idx in range(time):
s[s_times[idx], np.arange(size)] = 1

s[0, :] = 0
s = s.reshape([time, *shape])

return s
# # Yield Poisson-distributed spike trains.
# yield torch.Tensor(s).byte()


def get_poisson(data, time):
'''
Generates Poisson spike trains based on input intensity. Inputs must be
Expand Down Expand Up @@ -47,6 +132,104 @@ def get_poisson(data, time):
yield torch.Tensor(s).byte()


def get_tfs(data, time):
'''
Generates spike trains based on the Time to First Spike scheme. First Spike times are inversely proportional to
input magnitude, so data must be scaled according to desired spike frequency.

Inputs:
data (torch.Tensor or torch.cuda.Tensor): Tensor of shape [n_samples, n_1,
..., n_k], with arbitrary sample dimensionality [n_1, ..., n_k].
time (int): Length of Poisson spike train per input variable.

Yields:
(torch.Tensor or torch.cuda.Tensor): Tensors with shape [time, n_1, ..., n_k], with
Poisson-distributed spikes parameterized by the data values.
'''
# n_samples = data.size(0) # Number of samples
# data = np.copy(data)
#
# for i in range(n_samples):
# # Get i-th datum.
# datum = data[i]
# shape, size = datum.shape, datum.size
# datum = datum.ravel()
#
# # Invert inputs (input intensity inversely
# # proportional to spike inter-arrival time).
# datum[datum != 0] = 1 / datum[datum != 0]
#
# # Make spike data from Poisson sampling.
# s_times = np.random.poisson(datum, [time, size])
# s_times = np.cumsum(s_times, axis=0)
# s_times[s_times >= time] = 0
#
# # Create spike trains from spike times.
# s = np.zeros([time, size])
# for idx in range(time):
# s[s_times[idx], np.arange(size)] = 1
#
# s[0, :] = 0
# s = s.reshape([time, *shape])
#
# # Yield Poisson-distributed spike trains.
# yield torch.Tensor(s).byte()
#

def get_bernoulli_mixture(data, time, window=1):
for audio in data:
# For poisson, add minimum element to all
audio_shifted = audio - torch.min(audio) # Linear shifting for now.. Can try exponential/polynomial
s = get_bernoulli_for_example(audio_shifted, time, window=window)
yield torch.Tensor(s).byte()


def get_bernoulli_for_example(data, time, window=1):
if data.shape[0] > time:
print("Warning: more frames than timesteps. Extra frames will be skipped. Frames = ", data.shape[0],
" Timesteps = ", time)

spikes = np.zeros([time, data.shape[1]]) # (time,40) for spokenMNIST with 40 dim log filter banks
for i, frame in enumerate(data):
# TODO is this needed?
frame = np.copy(frame)
if i > time-window:
break
# s = get_poisson_for_frame(frame, time-i)
s = get_poisson_for_frame(frame, window) # For every frame, generate a small window of Poisson distributions
spikes[i:i+window] += s # add them (like deconv) beginning from their temporal location
spikes[spikes > 1] = 1 # unnecessary for Bernoulli trials with 0/1

return spikes

def get_bernoulli_for_frame(datum, time, max_prob=1.0):
'''
Generates Bernoulli-distributed spike trains based on input intensity. Inputs must
be non-negative. Spikes correspond to successful Bernoulli trials, with success
probability equal to (normalized in [0, 1]) input value.

Inputs:
data (torch.Tensor or torch.cuda.Tensor): Tensor of shape [n_samples,
n_1, ..., n_k], with arbitrary sample dimensionality [n_1, ..., n_k].
time (int): Length of Bernoulli spike train per input variable.
max_prob (float): Maximum probability of spike per Bernoulli trial.
'''
shape, size = datum.shape, datum.size
datum = datum.ravel()

# Normalize inputs and rescale (spike probability
# proportional to normalized intensity).
datum /= datum.max()
datum *= max_prob

# Make spike data from Bernoulli sampling.
s = np.random.binomial(1, datum, [time, size])
s = s.reshape([time, *shape])

# Yield Bernoulli-distributed spike trains.
return s

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simply add functionality to get_bernoulli in which the time dimension is implicit; e.g., we could pass time=None, and based on this, discard the explicit time dimension.


def get_bernoulli(data, time, max_prob=1.0):
'''
Generates Bernoulli-distributed spike trains based on input intensity. Inputs must
Expand Down
Loading