Skip to content

Commit f2bc993

Browse files
Merge pull request #3 from Hananel-Hazan/dan
Nodes and connections, network tweaks, etc.
2 parents 1753879 + 76ef7f0 commit f2bc993

File tree

4 files changed

+590
-29
lines changed

4 files changed

+590
-29
lines changed

bindsnet/network/__init__.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
sys.path.append(os.path.abspath(os.path.join('..', 'bindsnet', 'network')))
77

8+
from nodes import Input
9+
810

911
def load_network(fname):
1012
try:
@@ -26,8 +28,8 @@ def __init__(self, dt=1):
2628
def add_layer(self, layer, name):
2729
self.layers[name] = layer
2830

29-
def add_connections(self, connections, source, target):
30-
self.connections[(source, target)] = connections
31+
def add_connection(self, connection, source, target):
32+
self.connections[(source, target)] = connection
3133

3234
def add_monitor(self, monitor, name):
3335
self.monitors[name] = monitor
@@ -53,61 +55,66 @@ def get_inputs(self):
5355
target = self.connections[key].target
5456

5557
if not key[1] in inpts:
56-
inpts[key[1]] = {}
58+
inpts[key[1]] = torch.zeros_like(torch.Tensor(target.n))
5759

58-
inpts[key[1]][key[0]] = source.s.float() @ weights
60+
inpts[key[1]] += source.s.float() @ weights
5961

6062
return inpts
6163

6264
def run(self, inpts, time):
6365
'''
6466
Run network for a single iteration.
6567
'''
68+
timesteps = int(time / self.dt)
69+
6670
# Record spikes from each population over the iteration.
6771
spikes = {}
68-
for key in self.nodes:
69-
spikes[key] = torch.zeros(int(time / self.dt), self.nodes[key].n)
72+
for key in self.layers:
73+
spikes[key] = torch.zeros(self.layers[key].n, timesteps)
7074

7175
for monitor in self.monitors:
7276
self.monitors[monitor].reset()
7377

74-
# Get inputs to all neuron nodes from their parent neuron nodes.
78+
# Get input to all layers.
7579
inpts.update(self.get_inputs())
7680

77-
# Simulate neuron and synapse activity for `time` timesteps.
78-
for timestep in range(int(time / self.dt)):
79-
# Update each layer of nodes in turn.
80-
for key in self.nodes:
81-
self.nodes[key].step(inpts[key], self.mode, self.dt)
81+
# Simulate network activity for `time` timesteps.
82+
for timestep in range(timesteps):
83+
# Update each layer of nodes.
84+
for key in self.layers:
85+
if type(self.layers[key]) is Input:
86+
self.layers[key].step(inpts[key][:, timestep], self.dt)
87+
else:
88+
self.layers[key].step(inpts[key], self.dt)
89+
90+
# Record spikes.
91+
spikes[key][:, timestep] = self.layers[key].s
8292

83-
# Record spikes from this population at this timestep.
84-
spikes[key][timestep, :] = self.nodes[key].s
85-
86-
# Update synapse weights if we're in training mode.
8793
if self.train:
94+
# Update synapse weights.
8895
for synapse in self.connections:
89-
if type(self.connections[synapse]) == connections.STDPconnections:
90-
self.connections[synapse].update()
96+
self.connections[synapse].update()
9197

92-
# Get inputs to all neuron nodes from their parent neuron nodes.
98+
# Get input to all layers.
9399
inpts.update(self.get_inputs())
94100

95101
for monitor in self.monitors:
96102
self.monitors[monitor].record()
97103

98-
# Normalize synapse weights if we're in training mode.
99-
if self.train:
100-
for synapse in self.connections:
101-
if type(self.connections[synapse]) == connections.STDPconnections:
102-
self.connections[synapse].normalize()
104+
# if self.train:
105+
# # Normalize synapse weights.
106+
# for synapse in self.connections:
107+
# if type(self.connections[synapse]) == connections.STDPconnections:
108+
# self.connections[synapse].normalize()
103109

104110
return spikes
105111

106-
def reset(self, attrs):
112+
def reset(self):
107113
'''
108-
Reset state variables.
114+
Reset state variables of objects in network.
109115
'''
110116
for layer in self.layers:
111-
for attr in attrs:
112-
if hasattr(self.nodes[layer], attr):
113-
self.nodes[layer].reset(attr)
117+
self.layers[layer].reset()
118+
119+
for connection in self.connections:
120+
self.connections[connection].reset()

bindsnet/network/connections.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import torch
2+
3+
4+
def ETH_STDP(conn, nu_pre=1e-4, nu_post=1e-2):
5+
# Post-synaptic.
6+
conn.w += nu_post * (conn.source.x.view(conn.source.n,
7+
1) * conn.target.s.float().view(1, conn.target.n))
8+
# Pre-synaptic.
9+
conn.w -= nu_pre * (conn.source.s.float().view(conn.source.n,
10+
1) * conn.target.x.view(1, conn.target.n))
11+
12+
# Ensure that weights are within [0, self.wmax].
13+
conn.w = torch.clamp(conn.w, 0, conn.wmax)
14+
15+
16+
class Connection:
17+
'''
18+
Specifies constant synapses between two populations of neurons.
19+
'''
20+
def __init__(self, source, target, update_rule=None, w=None, wmin=-1.0, wmax=1.0):
21+
'''
22+
Instantiates a Connections object, used to connect two layers of nodes.
23+
24+
Inputs:
25+
source (nodes.Nodes): A layer of nodes from which the connection originates.
26+
target (nodes.Nodes): A layer of nodes to which the connection connects.
27+
update_rule (function): Modifies connection parameters according to some rule.
28+
w (torch.FloatTensor or torch.cuda.FloatTensor): Effective strengths of synaptics.
29+
wmin (float): The minimum value on the connection weights.
30+
wmax (float): The maximum value on the connection weights.
31+
'''
32+
self.source = source
33+
self.target = target
34+
self.wmin = wmin
35+
self.wmax = wmax
36+
37+
if update_rule is None:
38+
self.update_rule = lambda : None
39+
else:
40+
self.update_rule = update_rule
41+
42+
if w is None:
43+
self.w = torch.rand(source.n, target.n)
44+
else:
45+
self.w = w
46+
47+
torch.clamp(self.w, self.wmin, self.wmax)
48+
49+
def get_weights(self):
50+
return self.w
51+
52+
def set_weights(self, w):
53+
self.w = w
54+
55+
def update(self):
56+
'''
57+
Run connection's given update rule, and clamp
58+
weights between `self.wmin` and `self.wmax`.
59+
'''
60+
self.update_rule() # Run update rule.
61+
torch.clamp(self.w, self.wmin, self.wmax) # Bound weights.
62+
63+
def reset(self):
64+
pass

bindsnet/network/nodes.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import torch
2+
3+
from abc import ABC, abstractmethod
4+
5+
6+
class Nodes(ABC):
7+
'''
8+
Abstract base class for groups of neurons.
9+
'''
10+
def __init__(self):
11+
super().__init__()
12+
13+
@abstractmethod
14+
def step(self, inpts, dt):
15+
pass
16+
17+
def get_spikes(self):
18+
return self.s
19+
20+
def get_voltages(self):
21+
return self.v
22+
23+
def get_traces(self):
24+
return self.x
25+
26+
27+
class Input(Nodes):
28+
'''
29+
Layer of nodes with user-specified spiking behavior.
30+
'''
31+
def __init__(self, n, traces=False, trace_tc=5e-2):
32+
super().__init__()
33+
34+
self.n = n # No. of neurons.
35+
self.traces = traces # Whether to record synpatic traces.
36+
self.s = torch.zeros_like(torch.Tensor(n)) # Spike occurences.
37+
38+
if self.traces:
39+
self.x = torch.zeros_like(torch.Tensor(n)) # Firing traces.
40+
self.trace_tc = trace_tc # Rate of decay of spike trace time constant.
41+
42+
def step(self, inpts, dt):
43+
'''
44+
On each simulation step, set the spikes of the
45+
population equal to the inputs.
46+
'''
47+
# Set spike occurrences to input values.
48+
self.s = inpts
49+
50+
if self.traces:
51+
# Decay and set spike traces.
52+
self.x -= dt * self.trace_tc * self.x
53+
self.x[self.s] = 1
54+
55+
def reset(self):
56+
self.s = torch.zeros_like(torch.Tensor(n)) # Spike occurences.
57+
58+
if self.traces:
59+
self.x = torch.zeros_like(torch.Tensor(n)) # Firing traces.
60+
61+
62+
class McCullochPitts(Nodes):
63+
'''
64+
McCulloch-Pitts neuron.
65+
'''
66+
def __init__(self, n, traces=False, threshold=1.0, trace_tc=5e-2):
67+
'''
68+
Instantiates a McCulloch-Pitts layer of neurons.
69+
70+
Inputs:
71+
n (int): The number of neurons in the layer.
72+
traces (bool): Whether to record decaying spike traces.
73+
threshold (float): Value at which to record a spike.
74+
'''
75+
super().__init__()
76+
77+
self.n = n # No. of neurons.
78+
self.traces = traces # Whether to record synpatic traces.
79+
self.threshold = threshold # Spike threshold voltage.
80+
self.s = torch.zeros_like(torch.Tensor(n)) # Spike occurences.
81+
82+
if self.traces:
83+
self.x = torch.zeros_like(torch.Tensor(n)) # Firing traces.
84+
self.trace_tc = trace_tc # Rate of decay of spike trace time constant.
85+
86+
def step(self, inpts, dt):
87+
'''
88+
Runs a single simulation step.
89+
90+
Inputs:
91+
inpts (torch.FloatTensor or torch.cuda.FloatTensor): Vector of
92+
inputs to the layer, with size equal to self.n.
93+
dt (float): Simulation time step.
94+
'''
95+
self.s = inpts >= self.threshold # Check for spiking neurons.
96+
97+
if self.traces:
98+
# Decay and set spike traces.
99+
self.x -= dt * self.trace_tc * self.x
100+
self.x[self.s] = 1
101+
102+
def reset(self):
103+
self.s = torch.zeros_like(torch.Tensor(n)) # Spike occurences.
104+
105+
if self.traces:
106+
self.x = torch.zeros_like(torch.Tensor(n)) # Firing traces.
107+
108+
109+
class LIFNodes(Nodes):
110+
'''
111+
Group of leaky integrate-and-fire neurons.
112+
'''
113+
def __init__(self, n, traces=False, rest=-65.0, reset=-65.0, threshold=-52.0,
114+
refractory=5, voltage_decay=1e-2, trace_tc=5e-2):
115+
116+
super().__init__()
117+
118+
self.n = n # No. of neurons.
119+
self.traces = traces # Whether to record synpatic traces.
120+
self.rest = rest # Rest voltage.
121+
self.reset = reset # Post-spike reset voltage.
122+
self.threshold = threshold # Spike threshold voltage.
123+
self.refractory = refractory # Post-spike refractory period.
124+
self.voltage_decay = voltage_decay # Rate of decay of neuron voltage.
125+
126+
self.v = self.rest * torch.ones(n) # Neuron voltages.
127+
self.s = torch.zeros(n) # Spike occurences.
128+
129+
if traces:
130+
self.x = torch.zeros(n) # Firing traces.
131+
self.trace_tc = trace_tc # Rate of decay of spike trace time constant.
132+
133+
self.refrac_count = torch.zeros(n) # Refractory period counters.
134+
135+
def step(self, inpts, dt):
136+
# Decay voltages.
137+
self.v -= dt * self.voltage_decay * (self.v - self.rest)
138+
139+
if self.traces:
140+
# Decay spike traces.
141+
self.x -= dt * self.trace_tc * self.x
142+
143+
# Decrement refractory counters.
144+
self.refrac_count[self.refrac_count != 0] -= dt
145+
146+
# Check for spiking neurons.
147+
self.s = (self.v >= self.threshold) * (self.refrac_count == 0)
148+
self.refrac_count[self.s] = self.refractory
149+
self.v[self.s] = self.reset
150+
151+
# Integrate input and decay voltages.
152+
self.v += inpts
153+
154+
if self.traces:
155+
# Setting synaptic traces.
156+
self.x[self.s] = 1.0
157+
158+
def reset(self):
159+
self.s = torch.zeros_like(torch.Tensor(n)) # Spike occurences.
160+
self.v = self.rest * torch.ones(n) # Neuron voltages.
161+
self.refrac_count = torch.zeros(n) # Refractory period counters.
162+
163+
if self.traces:
164+
self.x = torch.zeros_like(torch.Tensor(n)) # Firing traces.

0 commit comments

Comments
 (0)