1+ import torch
2+
3+ from abc import ABC , abstractmethod
4+
5+
6+ class Nodes (ABC ):
7+ '''
8+ Abstract base class for groups of neurons.
9+ '''
10+ def __init__ (self ):
11+ super ().__init__ ()
12+
13+ @abstractmethod
14+ def step (self , inpts , dt ):
15+ pass
16+
17+ def get_spikes (self ):
18+ return self .s
19+
20+ def get_voltages (self ):
21+ return self .v
22+
23+ def get_traces (self ):
24+ return self .x
25+
26+
27+ class Input (Nodes ):
28+ '''
29+ Layer of nodes with user-specified spiking behavior.
30+ '''
31+ def __init__ (self , n , traces = False , trace_tc = 5e-2 ):
32+ super ().__init__ ()
33+
34+ self .n = n # No. of neurons.
35+ self .traces = traces # Whether to record synpatic traces.
36+ self .s = torch .zeros_like (torch .Tensor (n )) # Spike occurences.
37+
38+ if self .traces :
39+ self .x = torch .zeros_like (torch .Tensor (n )) # Firing traces.
40+ self .trace_tc = trace_tc # Rate of decay of spike trace time constant.
41+
42+ def step (self , inpts , dt ):
43+ '''
44+ On each simulation step, set the spikes of the
45+ population equal to the inputs.
46+ '''
47+ # Set spike occurrences to input values.
48+ self .s = inpts
49+
50+ if self .traces :
51+ # Decay and set spike traces.
52+ self .x -= dt * self .trace_tc * self .x
53+ self .x [self .s ] = 1
54+
55+ def reset (self ):
56+ self .s = torch .zeros_like (torch .Tensor (n )) # Spike occurences.
57+
58+ if self .traces :
59+ self .x = torch .zeros_like (torch .Tensor (n )) # Firing traces.
60+
61+
62+ class McCullochPitts (Nodes ):
63+ '''
64+ McCulloch-Pitts neuron.
65+ '''
66+ def __init__ (self , n , traces = False , threshold = 1.0 , trace_tc = 5e-2 ):
67+ '''
68+ Instantiates a McCulloch-Pitts layer of neurons.
69+
70+ Inputs:
71+ n (int): The number of neurons in the layer.
72+ traces (bool): Whether to record decaying spike traces.
73+ threshold (float): Value at which to record a spike.
74+ '''
75+ super ().__init__ ()
76+
77+ self .n = n # No. of neurons.
78+ self .traces = traces # Whether to record synpatic traces.
79+ self .threshold = threshold # Spike threshold voltage.
80+ self .s = torch .zeros_like (torch .Tensor (n )) # Spike occurences.
81+
82+ if self .traces :
83+ self .x = torch .zeros_like (torch .Tensor (n )) # Firing traces.
84+ self .trace_tc = trace_tc # Rate of decay of spike trace time constant.
85+
86+ def step (self , inpts , dt ):
87+ '''
88+ Runs a single simulation step.
89+
90+ Inputs:
91+ inpts (torch.FloatTensor or torch.cuda.FloatTensor): Vector of
92+ inputs to the layer, with size equal to self.n.
93+ dt (float): Simulation time step.
94+ '''
95+ self .s = inpts >= self .threshold # Check for spiking neurons.
96+
97+ if self .traces :
98+ # Decay and set spike traces.
99+ self .x -= dt * self .trace_tc * self .x
100+ self .x [self .s ] = 1
101+
102+ def reset (self ):
103+ self .s = torch .zeros_like (torch .Tensor (n )) # Spike occurences.
104+
105+ if self .traces :
106+ self .x = torch .zeros_like (torch .Tensor (n )) # Firing traces.
107+
108+
109+ class LIFNodes (Nodes ):
110+ '''
111+ Group of leaky integrate-and-fire neurons.
112+ '''
113+ def __init__ (self , n , traces = False , rest = - 65.0 , reset = - 65.0 , threshold = - 52.0 ,
114+ refractory = 5 , voltage_decay = 1e-2 , trace_tc = 5e-2 ):
115+
116+ super ().__init__ ()
117+
118+ self .n = n # No. of neurons.
119+ self .traces = traces # Whether to record synpatic traces.
120+ self .rest = rest # Rest voltage.
121+ self .reset = reset # Post-spike reset voltage.
122+ self .threshold = threshold # Spike threshold voltage.
123+ self .refractory = refractory # Post-spike refractory period.
124+ self .voltage_decay = voltage_decay # Rate of decay of neuron voltage.
125+
126+ self .v = self .rest * torch .ones (n ) # Neuron voltages.
127+ self .s = torch .zeros (n ) # Spike occurences.
128+
129+ if traces :
130+ self .x = torch .zeros (n ) # Firing traces.
131+ self .trace_tc = trace_tc # Rate of decay of spike trace time constant.
132+
133+ self .refrac_count = torch .zeros (n ) # Refractory period counters.
134+
135+ def step (self , inpts , dt ):
136+ # Decay voltages.
137+ self .v -= dt * self .voltage_decay * (self .v - self .rest )
138+
139+ if self .traces :
140+ # Decay spike traces.
141+ self .x -= dt * self .trace_tc * self .x
142+
143+ # Decrement refractory counters.
144+ self .refrac_count [self .refrac_count != 0 ] -= dt
145+
146+ # Check for spiking neurons.
147+ self .s = (self .v >= self .threshold ) * (self .refrac_count == 0 )
148+ self .refrac_count [self .s ] = self .refractory
149+ self .v [self .s ] = self .reset
150+
151+ # Integrate input and decay voltages.
152+ self .v += inpts
153+
154+ if self .traces :
155+ # Setting synaptic traces.
156+ self .x [self .s ] = 1.0
157+
158+ def reset (self ):
159+ self .s = torch .zeros_like (torch .Tensor (n )) # Spike occurences.
160+ self .v = self .rest * torch .ones (n ) # Neuron voltages.
161+ self .refrac_count = torch .zeros (n ) # Refractory period counters.
162+
163+ if self .traces :
164+ self .x = torch .zeros_like (torch .Tensor (n )) # Firing traces.
0 commit comments