@@ -29,6 +29,14 @@ class Input(Nodes):
2929 Layer of nodes with user-specified spiking behavior.
3030 '''
3131 def __init__ (self , n , traces = False , trace_tc = 5e-2 ):
32+ '''
33+ Instantiates a layer of input neurons.
34+
35+ Inputs:
36+ n (int): The number of neurons in the layer.
37+ traces (bool): Whether to record decaying spike traces.
38+ trace_tc (float): Time constant of spike trace decay.
39+ '''
3240 super ().__init__ ()
3341
3442 self .n = n # No. of neurons.
@@ -43,6 +51,11 @@ def step(self, inpts, dt):
4351 '''
4452 On each simulation step, set the spikes of the
4553 population equal to the inputs.
54+
55+ Inputs:
56+ inpts (torch.FloatTensor or torch.cuda.FloatTensor): Matrix
57+ of inputs to the layer, with size equal to self.n.
58+ dt (float): Simulation time step.
4659 '''
4760 # Set spike occurrences to input values.
4861 self .s = inpts
@@ -53,6 +66,9 @@ def step(self, inpts, dt):
5366 self .x [self .s ] = 1
5467
5568 def reset (self ):
69+ '''
70+ Resets relevant state variables.
71+ '''
5672 self .s = torch .zeros_like (torch .Tensor (n )) # Spike occurences.
5773
5874 if self .traces :
@@ -61,7 +77,7 @@ def reset(self):
6177
6278class McCullochPitts (Nodes ):
6379 '''
64- McCulloch-Pitts neuron .
80+ Layer of McCulloch-Pitts neurons .
6581 '''
6682 def __init__ (self , n , traces = False , threshold = 1.0 , trace_tc = 5e-2 ):
6783 '''
@@ -88,8 +104,8 @@ def step(self, inpts, dt):
88104 Runs a single simulation step.
89105
90106 Inputs:
91- inpts (torch.FloatTensor or torch.cuda.FloatTensor): Vector of
92- inputs to the layer, with size equal to self.n.
107+ inpts (torch.FloatTensor or torch.cuda.FloatTensor): Vector
108+ of inputs to the layer, with size equal to self.n.
93109 dt (float): Simulation time step.
94110 '''
95111 self .s = inpts >= self .threshold # Check for spiking neurons.
@@ -100,19 +116,108 @@ def step(self, inpts, dt):
100116 self .x [self .s ] = 1
101117
102118 def reset (self ):
119+ '''
120+ Resets relevant state variables.
121+ '''
103122 self .s = torch .zeros_like (torch .Tensor (n )) # Spike occurences.
104123
105124 if self .traces :
106125 self .x = torch .zeros_like (torch .Tensor (n )) # Firing traces.
107126
108127
128+ class IFNodes (Nodes ):
129+ '''
130+ Layer of integrate-and-fire (IF) neurons.
131+ '''
132+ def __init__ (self , n , traces = False , threshold = - 52.0 , reset = - 65.0 ,
133+ refractory = 5 , trace_tc = 5e-2 ):
134+ '''
135+ Instantiates a layer of IF neurons.
136+
137+ Inputs:
138+ n (int): The number of neurons in the layer.
139+ traces (bool): Whether to record decaying spike traces.
140+ threshold (float): Value at which to record a spike.
141+ reset (float): Value to which neurons are set to following a spike.
142+ refractory (int): The number of timesteps following
143+ a spike during which a neuron cannot spike again.
144+ trace_tc (float): Time constant of spike trace decay.
145+ '''
146+
147+ super ().__init__ ()
148+
149+ self .n = n # No. of neurons.
150+ self .traces = traces # Whether to record synpatic traces.
151+ self .reset = reset # Post-spike reset voltage.
152+ self .threshold = threshold # Spike threshold voltage.
153+ self .refractory = refractory # Post-spike refractory period.
154+
155+ self .v = self .reset * torch .ones (n ) # Neuron voltages.
156+ self .s = torch .zeros (n ) # Spike occurences.
157+
158+ if traces :
159+ self .x = torch .zeros (n ) # Firing traces.
160+ self .trace_tc = trace_tc # Rate of decay of spike trace time constant.
161+
162+ self .refrac_count = torch .zeros (n ) # Refractory period counters.
163+
164+ def step (self , inpts , dt ):
165+ '''
166+ Runs a single simulation step.
167+
168+ Inputs:
169+ inpts (torch.FloatTensor or torch.cuda.FloatTensor): Vector
170+ of inputs to the layer, with size equal to self.n.
171+ dt (float): Simulation time step.
172+ '''
173+ # Decrement refractory counters.
174+ self .refrac_count [self .refrac_count != 0 ] -= dt
175+
176+ # Check for spiking neurons.
177+ self .s = (self .v >= self .threshold ) * (self .refrac_count == 0 )
178+ self .refrac_count [self .s ] = self .refractory
179+ self .v [self .s ] = self .reset
180+
181+ # Integrate input and decay voltages.
182+ self .v += inpts
183+
184+ if self .traces :
185+ # Decay and set spike traces.
186+ self .x -= dt * self .trace_tc * self .x
187+ self .x [self .s ] = 1.0
188+
189+ def reset (self ):
190+ '''
191+ Resets relevant state variables.
192+ '''
193+ self .s = torch .zeros_like (torch .Tensor (n )) # Spike occurences.
194+ self .v = self .reset * torch .ones (n ) # Neuron voltages.
195+ self .refrac_count = torch .zeros (n ) # Refractory period counters.
196+
197+ if self .traces :
198+ self .x = torch .zeros_like (torch .Tensor (n )) # Firing traces.
199+
200+
109201class LIFNodes (Nodes ):
110202 '''
111- Group of leaky integrate-and-fire neurons.
203+ Layer of leaky integrate-and-fire (LIF) neurons.
112204 '''
113- def __init__ (self , n , traces = False , rest = - 65 .0 , reset = - 65.0 , threshold = - 52 .0 ,
205+ def __init__ (self , n , traces = False , threshold = - 52 .0 , rest = - 65.0 , reset = - 65 .0 ,
114206 refractory = 5 , voltage_decay = 1e-2 , trace_tc = 5e-2 ):
207+ '''
208+ Instantiates a layer of LIF neurons.
115209
210+ Inputs:
211+ n (int): The number of neurons in the layer.
212+ traces (bool): Whether to record decaying spike traces.
213+ threshold (float): Value at which to record a spike.
214+ rest (float): Value to which neuron voltages decay.
215+ reset (float): Value to which neurons are set to following a spike.
216+ refractory (int): The number of timesteps following
217+ a spike during which a neuron cannot spike again.
218+ voltage_decay (float): Time constant of neuron voltage decay.
219+ trace_tc (float): Time constant of spike trace decay.
220+ '''
116221 super ().__init__ ()
117222
118223 self .n = n # No. of neurons.
@@ -133,9 +238,14 @@ def __init__(self, n, traces=False, rest=-65.0, reset=-65.0, threshold=-52.0,
133238 self .refrac_count = torch .zeros (n ) # Refractory period counters.
134239
135240 def step (self , inpts , dt ):
136- # Decay voltages.
137- self . v -= dt * self . voltage_decay * ( self . v - self . rest )
241+ '''
242+ Runs a single simulation step.
138243
244+ Inputs:
245+ inpts (torch.FloatTensor or torch.cuda.FloatTensor): Vector
246+ of inputs to the layer, with size equal to self.n.
247+ dt (float): Simulation time step.
248+ '''
139249 if self .traces :
140250 # Decay spike traces.
141251 self .x -= dt * self .trace_tc * self .x
@@ -152,10 +262,14 @@ def step(self, inpts, dt):
152262 self .v += inpts
153263
154264 if self .traces :
155- # Setting synaptic traces.
265+ # Decay and set spike traces.
266+ self .x -= dt * self .trace_tc * self .x
156267 self .x [self .s ] = 1.0
157268
158269 def reset (self ):
270+ '''
271+ Resets relevant state variables.
272+ '''
159273 self .s = torch .zeros_like (torch .Tensor (n )) # Spike occurences.
160274 self .v = self .rest * torch .ones (n ) # Neuron voltages.
161275 self .refrac_count = torch .zeros (n ) # Refractory period counters.
0 commit comments