@@ -142,7 +142,6 @@ def __init__(
142142 self , "enforce_polarity" , False
143143 ), "enforce_polarity isn't supported for sparse tensors"
144144
145-
146145 @staticmethod
147146 def cast_dtype_if_needed (value , value_dtype ):
148147 if value .dtype != value_dtype :
@@ -495,7 +494,13 @@ def __init__(
495494 # Send boolean to tensor (priming wont work if it's not a tensor)
496495 value = torch .tensor (value )
497496
498- super ().__init__ (name = name , value = value , value_dtype = torch .bool , sparse = sparse , batch_size = batch_size )
497+ super ().__init__ (
498+ name = name ,
499+ value = value ,
500+ value_dtype = torch .bool ,
501+ sparse = sparse ,
502+ batch_size = batch_size ,
503+ )
499504 self .name = name
500505 self .value = value
501506
@@ -736,7 +741,12 @@ def __init__(
736741 :param batch_size: Mini-batch size.
737742 """
738743 super ().__init__ (
739- name = name , value = value , value_dtype = value_dtype , range = range , sparse = sparse , batch_size = batch_size
744+ name = name ,
745+ value = value ,
746+ value_dtype = value_dtype ,
747+ range = range ,
748+ sparse = sparse ,
749+ batch_size = batch_size ,
740750 )
741751
742752 def reset_state_variables (self ) -> None :
@@ -864,7 +874,13 @@ def forward(self, x):
864874 self .const_update_rate = const_update_rate
865875 self .const_decay = const_decay
866876
867- super ().__init__ (name = name , value = value , value_dtype = self .value_dtype , sparse = sparse , batch_size = batch_size )
877+ super ().__init__ (
878+ name = name ,
879+ value = value ,
880+ value_dtype = self .value_dtype ,
881+ sparse = sparse ,
882+ batch_size = batch_size ,
883+ )
868884
869885 def compute (self , conn_spikes ) -> Union [torch .Tensor , float , int ]:
870886
@@ -968,7 +984,13 @@ def forward(self, x):
968984 self .const_update_rate = const_update_rate
969985 self .const_decay = const_decay
970986
971- super ().__init__ (name = name , value = value , value_dtype = self .value_dtype , sparse = sparse , batch_size = batch_size )
987+ super ().__init__ (
988+ name = name ,
989+ value = value ,
990+ value_dtype = self .value_dtype ,
991+ sparse = sparse ,
992+ batch_size = batch_size ,
993+ )
972994
973995 def compute (self , conn_spikes ) -> Union [torch .Tensor , float , int ]:
974996
0 commit comments