Skip to content

Commit 08f4bd0

Browse files
committed
Refactor code for improved readability and consistency in weight handling and warnings
1 parent e9dfcbf commit 08f4bd0

File tree

4 files changed

+29
-32
lines changed

4 files changed

+29
-32
lines changed

bindsnet/models/models.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
from bindsnet.learning.MCC_learning import PostPre as MMCPostPre
1010
from bindsnet.network import Network
1111
from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes
12-
from bindsnet.network.topology import Connection, LocalConnection, MulticompartmentConnection
12+
from bindsnet.network.topology import (
13+
Connection,
14+
LocalConnection,
15+
MulticompartmentConnection,
16+
)
1317
from bindsnet.network.topology_features import Weight
1418

1519

@@ -181,30 +185,23 @@ def __init__(
181185
device=device,
182186
pipeline=[
183187
Weight(
184-
'weight',
188+
"weight",
185189
w,
186190
value_dtype=w_dtype,
187191
range=[wmin, wmax],
188192
norm=norm,
189193
reduction=reduction,
190194
nu=nu,
191-
learning_rule=MMCPostPre
195+
learning_rule=MMCPostPre,
192196
)
193-
]
197+
],
194198
)
195199
w = self.exc * torch.diag(torch.ones(self.n_neurons))
196200
exc_inh_conn = MulticompartmentConnection(
197201
source=exc_layer,
198202
target=inh_layer,
199203
device=device,
200-
pipeline=[
201-
Weight(
202-
'weight',
203-
w,
204-
value_dtype=w_dtype,
205-
range=[0, self.exc]
206-
)
207-
]
204+
pipeline=[Weight("weight", w, value_dtype=w_dtype, range=[0, self.exc])],
208205
)
209206
w = -self.inh * (
210207
torch.ones(self.n_neurons, self.n_neurons)
@@ -214,14 +211,7 @@ def __init__(
214211
source=inh_layer,
215212
target=exc_layer,
216213
device=device,
217-
pipeline=[
218-
Weight(
219-
'weight',
220-
w,
221-
value_dtype=w_dtype,
222-
range=[-self.inh, 0]
223-
)
224-
]
214+
pipeline=[Weight("weight", w, value_dtype=w_dtype, range=[-self.inh, 0])],
225215
)
226216

227217
# Add to network

bindsnet/network/topology.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ def reset_state_variables(self) -> None:
144144
@staticmethod
145145
def cast_dtype_if_needed(w, w_dtype):
146146
if w.dtype != w_dtype:
147-
warnings.warn(f"Provided w has data type {w.dtype} but parameter w_dtype is {w_dtype}")
147+
warnings.warn(
148+
f"Provided w has data type {w.dtype} but parameter w_dtype is {w_dtype}"
149+
)
148150
return w.to(dtype=w_dtype)
149151
else:
150152
return w

bindsnet/network/topology_features.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def __init__(
127127
@staticmethod
128128
def cast_dtype_if_needed(value, value_dtype):
129129
if value.dtype != value_dtype:
130-
warnings.warn(f"Provided value has data type {value.dtype} but parameter w_dtype is {value_dtype}")
130+
warnings.warn(
131+
f"Provided value has data type {value.dtype} but parameter w_dtype is {value_dtype}"
132+
)
131133
return value.to(dtype=value_dtype)
132134
else:
133135
return value
@@ -431,11 +433,7 @@ def __init__(
431433
# Send boolean to tensor (priming wont work if it's not a tensor)
432434
value = torch.tensor(value)
433435

434-
super().__init__(
435-
name=name,
436-
value=value,
437-
value_dtype=torch.bool
438-
)
436+
super().__init__(name=name, value=value, value_dtype=torch.bool)
439437

440438
self.name = name
441439
self.value = value
@@ -705,7 +703,12 @@ def __init__(
705703
"""
706704

707705
# Note: parent_feature will override value. See abstract constructor
708-
super().__init__(name=name, value=value, value_dtype=value_dtype, parent_feature=parent_feature)
706+
super().__init__(
707+
name=name,
708+
value=value,
709+
value_dtype=value_dtype,
710+
parent_feature=parent_feature,
711+
)
709712

710713
self.degrade_function = degrade_function
711714

examples/mnist/batch_eth_mnist.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
parser.add_argument(
4444
"--w_dtype",
4545
type=str,
46-
default='float32',
47-
help='Datatype to use for weights. Examples: float32, float16, bfloat16 etc'
46+
default="float32",
47+
help="Datatype to use for weights. Examples: float32, float16, bfloat16 etc",
4848
)
4949
parser.add_argument("--train", dest="train", action="store_true")
5050
parser.add_argument("--test", dest="train", action="store_false")
@@ -109,7 +109,7 @@
109109
theta_plus=theta_plus,
110110
inpt_shape=(1, 28, 28),
111111
device=device,
112-
w_dtype=getattr(torch, args.w_dtype)
112+
w_dtype=getattr(torch, args.w_dtype),
113113
)
114114

115115
# Directs network to GPU
@@ -279,7 +279,9 @@
279279
image = batch["image"][:, 0].view(28, 28)
280280
inpt = inputs["X"][:, 0].view(time, 784).sum(0).view(28, 28)
281281
lable = batch["label"][0]
282-
input_exc_weights = network.connections[("X", "Ae")].feature_index['weight'].value
282+
input_exc_weights = (
283+
network.connections[("X", "Ae")].feature_index["weight"].value
284+
)
283285
square_weights = get_square_weights(
284286
input_exc_weights.view(784, n_neurons), n_sqrt, 28
285287
)

0 commit comments

Comments
 (0)