Skip to content

Commit f19692c

Browse files
committed
Refactor code for improved readability by formatting arguments in constructor calls across multiple classes.
1 parent f5a716d commit f19692c

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

bindsnet/models/models.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,11 @@ def __init__(
208208
source=exc_layer,
209209
target=inh_layer,
210210
device=device,
211-
pipeline=[Weight("weight", w, value_dtype=w_dtype, range=[0, self.exc], sparse=sparse)],
211+
pipeline=[
212+
Weight(
213+
"weight", w, value_dtype=w_dtype, range=[0, self.exc], sparse=sparse
214+
)
215+
],
212216
)
213217
w = -self.inh * (
214218
torch.ones(self.n_neurons, self.n_neurons)
@@ -220,7 +224,15 @@ def __init__(
220224
source=inh_layer,
221225
target=exc_layer,
222226
device=device,
223-
pipeline=[Weight("weight", w, value_dtype=w_dtype, range=[-self.inh, 0], sparse=sparse)],
227+
pipeline=[
228+
Weight(
229+
"weight",
230+
w,
231+
value_dtype=w_dtype,
232+
range=[-self.inh, 0],
233+
sparse=sparse,
234+
)
235+
],
224236
)
225237

226238
# Add to network

bindsnet/network/topology_features.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def __init__(
142142
self, "enforce_polarity", False
143143
), "enforce_polarity isn't supported for sparse tensors"
144144

145-
146145
@staticmethod
147146
def cast_dtype_if_needed(value, value_dtype):
148147
if value.dtype != value_dtype:
@@ -495,7 +494,13 @@ def __init__(
495494
# Send boolean to tensor (priming wont work if it's not a tensor)
496495
value = torch.tensor(value)
497496

498-
super().__init__(name=name, value=value, value_dtype=torch.bool, sparse=sparse, batch_size=batch_size)
497+
super().__init__(
498+
name=name,
499+
value=value,
500+
value_dtype=torch.bool,
501+
sparse=sparse,
502+
batch_size=batch_size,
503+
)
499504
self.name = name
500505
self.value = value
501506

@@ -736,7 +741,12 @@ def __init__(
736741
:param batch_size: Mini-batch size.
737742
"""
738743
super().__init__(
739-
name=name, value=value, value_dtype=value_dtype, range=range, sparse=sparse, batch_size=batch_size
744+
name=name,
745+
value=value,
746+
value_dtype=value_dtype,
747+
range=range,
748+
sparse=sparse,
749+
batch_size=batch_size,
740750
)
741751

742752
def reset_state_variables(self) -> None:
@@ -864,7 +874,13 @@ def forward(self, x):
864874
self.const_update_rate = const_update_rate
865875
self.const_decay = const_decay
866876

867-
super().__init__(name=name, value=value, value_dtype=self.value_dtype, sparse=sparse, batch_size=batch_size)
877+
super().__init__(
878+
name=name,
879+
value=value,
880+
value_dtype=self.value_dtype,
881+
sparse=sparse,
882+
batch_size=batch_size,
883+
)
868884

869885
def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
870886

@@ -968,7 +984,13 @@ def forward(self, x):
968984
self.const_update_rate = const_update_rate
969985
self.const_decay = const_decay
970986

971-
super().__init__(name=name, value=value, value_dtype=self.value_dtype, sparse=sparse, batch_size=batch_size)
987+
super().__init__(
988+
name=name,
989+
value=value,
990+
value_dtype=self.value_dtype,
991+
sparse=sparse,
992+
batch_size=batch_size,
993+
)
972994

973995
def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
974996

0 commit comments

Comments
 (0)