Skip to content

Commit b67207b

Browse files
Merge pull request #4 from Hananel-Hazan/dan
Adding documentation strings; adding integrate-and-fire (IF) neurons; minor tweaks.
2 parents f2bc993 + 791bd0a commit b67207b

File tree

8 files changed

+175
-45
lines changed

8 files changed

+175
-45
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
To clone this project locally, issue
44

55
```
6-
git clone https://github.com/djsaunde/spiketorch.git #
6+
git clone https://github.com/djsaunde/spiketorch.git # clones spiketorch repository
77
```
88

99
in the directory of your choice. This will place the repository's code in a directory titled `spiketorch`.

bindsnet/analysis/__init__.py

Whitespace-only changes.

bindsnet/datasets/__init__.py

Whitespace-only changes.

bindsnet/evaluation/__init__.py

Whitespace-only changes.

bindsnet/network/__init__.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@ def run(self, inpts, time):
7272
for key in self.layers:
7373
spikes[key] = torch.zeros(self.layers[key].n, timesteps)
7474

75-
for monitor in self.monitors:
76-
self.monitors[monitor].reset()
77-
7875
# Get input to all layers.
7976
inpts.update(self.get_inputs())
8077

@@ -117,4 +114,26 @@ def reset(self):
117114
self.layers[layer].reset()
118115

119116
for connection in self.connections:
120-
self.connections[connection].reset()
117+
self.connections[connection].reset()
118+
119+
for monitor in self.monitors:
120+
self.monitors[monitor].reset()
121+
122+
class Monitor:
123+
'''
124+
Records state variables of interest.
125+
'''
126+
def __init__(self, obj, state_vars):
127+
self.obj = obj
128+
self.state_vars = state_vars
129+
self.recording = {var : torch.Tensor() for var in self.state_vars}
130+
131+
def get(self, var):
132+
return self.recording[var]
133+
134+
def record(self):
135+
for var in self.state_vars:
136+
self.recording[var] = torch.cat([self.recording[var], self.obj.__dict__[var]])
137+
138+
def reset(self):
139+
self.recording = {var : torch.Tensor() for var in self.state_vars}

bindsnet/network/monitors.py

Whitespace-only changes.

bindsnet/network/nodes.py

Lines changed: 122 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ class Input(Nodes):
2929
Layer of nodes with user-specified spiking behavior.
3030
'''
3131
def __init__(self, n, traces=False, trace_tc=5e-2):
32+
'''
33+
Instantiates a layer of input neurons.
34+
35+
Inputs:
36+
n (int): The number of neurons in the layer.
37+
traces (bool): Whether to record decaying spike traces.
38+
trace_tc (float): Time constant of spike trace decay.
39+
'''
3240
super().__init__()
3341

3442
self.n = n # No. of neurons.
@@ -43,6 +51,11 @@ def step(self, inpts, dt):
4351
'''
4452
On each simulation step, set the spikes of the
4553
population equal to the inputs.
54+
55+
Inputs:
56+
inpts (torch.FloatTensor or torch.cuda.FloatTensor): Matrix
57+
of inputs to the layer, with size equal to self.n.
58+
dt (float): Simulation time step.
4659
'''
4760
# Set spike occurrences to input values.
4861
self.s = inpts
@@ -53,6 +66,9 @@ def step(self, inpts, dt):
5366
self.x[self.s] = 1
5467

5568
def reset(self):
69+
'''
70+
Resets relevant state variables.
71+
'''
5672
self.s = torch.zeros_like(torch.Tensor(n)) # Spike occurences.
5773

5874
if self.traces:
@@ -61,7 +77,7 @@ def reset(self):
6177

6278
class McCullochPitts(Nodes):
6379
'''
64-
McCulloch-Pitts neuron.
80+
Layer of McCulloch-Pitts neurons.
6581
'''
6682
def __init__(self, n, traces=False, threshold=1.0, trace_tc=5e-2):
6783
'''
@@ -88,8 +104,8 @@ def step(self, inpts, dt):
88104
Runs a single simulation step.
89105
90106
Inputs:
91-
inpts (torch.FloatTensor or torch.cuda.FloatTensor): Vector of
92-
inputs to the layer, with size equal to self.n.
107+
inpts (torch.FloatTensor or torch.cuda.FloatTensor): Vector
108+
of inputs to the layer, with size equal to self.n.
93109
dt (float): Simulation time step.
94110
'''
95111
self.s = inpts >= self.threshold # Check for spiking neurons.
@@ -100,19 +116,108 @@ def step(self, inpts, dt):
100116
self.x[self.s] = 1
101117

102118
def reset(self):
119+
'''
120+
Resets relevant state variables.
121+
'''
103122
self.s = torch.zeros_like(torch.Tensor(n)) # Spike occurences.
104123

105124
if self.traces:
106125
self.x = torch.zeros_like(torch.Tensor(n)) # Firing traces.
107126

108127

128+
class IFNodes(Nodes):
129+
'''
130+
Layer of integrate-and-fire (IF) neurons.
131+
'''
132+
def __init__(self, n, traces=False, threshold=-52.0, reset=-65.0,
133+
refractory=5, trace_tc=5e-2):
134+
'''
135+
Instantiates a layer of IF neurons.
136+
137+
Inputs:
138+
n (int): The number of neurons in the layer.
139+
traces (bool): Whether to record decaying spike traces.
140+
threshold (float): Value at which to record a spike.
141+
reset (float): Value to which neurons are set to following a spike.
142+
refractory (int): The number of timesteps following
143+
a spike during which a neuron cannot spike again.
144+
trace_tc (float): Time constant of spike trace decay.
145+
'''
146+
147+
super().__init__()
148+
149+
self.n = n # No. of neurons.
150+
self.traces = traces # Whether to record synpatic traces.
151+
self.reset = reset # Post-spike reset voltage.
152+
self.threshold = threshold # Spike threshold voltage.
153+
self.refractory = refractory # Post-spike refractory period.
154+
155+
self.v = self.reset * torch.ones(n) # Neuron voltages.
156+
self.s = torch.zeros(n) # Spike occurences.
157+
158+
if traces:
159+
self.x = torch.zeros(n) # Firing traces.
160+
self.trace_tc = trace_tc # Rate of decay of spike trace time constant.
161+
162+
self.refrac_count = torch.zeros(n) # Refractory period counters.
163+
164+
def step(self, inpts, dt):
165+
'''
166+
Runs a single simulation step.
167+
168+
Inputs:
169+
inpts (torch.FloatTensor or torch.cuda.FloatTensor): Vector
170+
of inputs to the layer, with size equal to self.n.
171+
dt (float): Simulation time step.
172+
'''
173+
# Decrement refractory counters.
174+
self.refrac_count[self.refrac_count != 0] -= dt
175+
176+
# Check for spiking neurons.
177+
self.s = (self.v >= self.threshold) * (self.refrac_count == 0)
178+
self.refrac_count[self.s] = self.refractory
179+
self.v[self.s] = self.reset
180+
181+
# Integrate input and decay voltages.
182+
self.v += inpts
183+
184+
if self.traces:
185+
# Decay and set spike traces.
186+
self.x -= dt * self.trace_tc * self.x
187+
self.x[self.s] = 1.0
188+
189+
def reset(self):
190+
'''
191+
Resets relevant state variables.
192+
'''
193+
self.s = torch.zeros_like(torch.Tensor(n)) # Spike occurences.
194+
self.v = self.reset * torch.ones(n) # Neuron voltages.
195+
self.refrac_count = torch.zeros(n) # Refractory period counters.
196+
197+
if self.traces:
198+
self.x = torch.zeros_like(torch.Tensor(n)) # Firing traces.
199+
200+
109201
class LIFNodes(Nodes):
110202
'''
111-
Group of leaky integrate-and-fire neurons.
203+
Layer of leaky integrate-and-fire (LIF) neurons.
112204
'''
113-
def __init__(self, n, traces=False, rest=-65.0, reset=-65.0, threshold=-52.0,
205+
def __init__(self, n, traces=False, threshold=-52.0, rest=-65.0, reset=-65.0,
114206
refractory=5, voltage_decay=1e-2, trace_tc=5e-2):
207+
'''
208+
Instantiates a layer of LIF neurons.
115209
210+
Inputs:
211+
n (int): The number of neurons in the layer.
212+
traces (bool): Whether to record decaying spike traces.
213+
threshold (float): Value at which to record a spike.
214+
rest (float): Value to which neuron voltages decay.
215+
reset (float): Value to which neurons are set to following a spike.
216+
refractory (int): The number of timesteps following
217+
a spike during which a neuron cannot spike again.
218+
voltage_decay (float): Time constant of neuron voltage decay.
219+
trace_tc (float): Time constant of spike trace decay.
220+
'''
116221
super().__init__()
117222

118223
self.n = n # No. of neurons.
@@ -133,9 +238,14 @@ def __init__(self, n, traces=False, rest=-65.0, reset=-65.0, threshold=-52.0,
133238
self.refrac_count = torch.zeros(n) # Refractory period counters.
134239

135240
def step(self, inpts, dt):
136-
# Decay voltages.
137-
self.v -= dt * self.voltage_decay * (self.v - self.rest)
241+
'''
242+
Runs a single simulation step.
138243
244+
Inputs:
245+
inpts (torch.FloatTensor or torch.cuda.FloatTensor): Vector
246+
of inputs to the layer, with size equal to self.n.
247+
dt (float): Simulation time step.
248+
'''
139249
if self.traces:
140250
# Decay spike traces.
141251
self.x -= dt * self.trace_tc * self.x
@@ -152,10 +262,14 @@ def step(self, inpts, dt):
152262
self.v += inpts
153263

154264
if self.traces:
155-
# Setting synaptic traces.
265+
# Decay and set spike traces.
266+
self.x -= dt * self.trace_tc * self.x
156267
self.x[self.s] = 1.0
157268

158269
def reset(self):
270+
'''
271+
Resets relevant state variables.
272+
'''
159273
self.s = torch.zeros_like(torch.Tensor(n)) # Spike occurences.
160274
self.v = self.rest * torch.ones(n) # Neuron voltages.
161275
self.refrac_count = torch.zeros(n) # Refractory period counters.

examples/Simple test network.ipynb

Lines changed: 29 additions & 32 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)