@@ -549,150 +549,6 @@ def __init__(
549549 self .add_connection (recurrent_conn , source = "Y" , target = "Y" )
550550
551551
552- import snntorch as snn
553-
554- class FFSNN (nn .Module ):
555- # language=rst
556- """
557- A simple feedforward Spiking Neural Network (SNN) using snntorch,
558- designed for use with the ForwardForwardPipeline.
559- It consists of a sequence of Linear layers followed by Leaky Integrate-and-Fire
560- (LIF) spiking neuron layers.
561- """
562-
563- def __init__ (
564- self ,
565- input_size : int , # This should be 794 (image_features + num_classes)
566- hidden_sizes : List [int ], # e.g., [500, 500]
567- output_size : Optional [int ] = None , # If the last FF layer is also the output layer for classification
568- beta : Union [float , torch .Tensor ] = 0.9 , # Decay rate for snn.Leaky neurons
569- threshold : float = 1.0 , # Firing threshold for snn.Leaky neurons
570- reset_mechanism : str = "subtract" , # "subtract", "zero", or "none"
571- # Add other snn.Leaky parameters if needed, e.g., spike_grad
572- ) -> None :
573- # language=rst
574- """
575- Constructor for FFSNN.
576-
577- :param input_size: Number of input features (after encoding and label embedding).
578- :param hidden_sizes: A list of integers, where each integer is the number of
579- neurons in a hidden layer.
580- :param output_size: Optional. Number of neurons in the final layer if it's
581- distinct or specifically for output. If None, the last
582- size in hidden_sizes is considered the final FF layer.
583- :param beta: Membrane potential decay rate for Leaky neurons.
584- :param threshold: Firing threshold for Leaky neurons.
585- :param reset_mechanism: Reset mechanism for Leaky neurons after a spike.
586- """
587- super ().__init__ ()
588-
589- self .input_size = input_size
590- self .hidden_sizes = hidden_sizes
591- self .output_size = output_size
592- self .beta = beta
593- self .threshold = threshold
594- self .reset_mechanism = reset_mechanism
595-
596- self .fc_layers = nn .ModuleList ()
597- self .snn_layers = nn .ModuleList ()
598- self ._ff_layer_pairs_info = []
599-
600- current_dim = self .input_size # Starts at 794
601- for i , hidden_dim in enumerate (self .hidden_sizes ):
602- linear_layer = nn .Linear (current_dim , hidden_dim ) # Layer 1: 794 -> 500
603- # Layer 2: 500 -> 500
604- self .fc_layers .append (linear_layer )
605-
606- snn_layer = snn .Leaky (
607- beta = self .beta ,
608- threshold = self .threshold ,
609- reset_mechanism = self .reset_mechanism ,
610- # output_shape=[hidden_dim] # Optional: snntorch can infer this
611- )
612- self .snn_layers .append (snn_layer )
613- self ._ff_layer_pairs_info .append ((linear_layer , snn_layer ))
614- current_dim = hidden_dim # Update current_dim for the *next* layer's input
615-
616- # If there's an output_size for a final classifier (not typical for pure FF layers)
617- if self .output_size is not None :
618- self .fc_out = nn .Linear (current_dim , self .output_size )
619- # Potentially another SNN layer if output is spiking
620- # self.snn_out = snn.Leaky(...)
621- # self._ff_layer_pairs_info.append((self.fc_out, self.snn_out)) # If FF applies here too
622-
623- def forward (self , x_batch_time : torch .Tensor ) -> Tuple [torch .Tensor , List [torch .Tensor ]]:
624- # language=rst
625- """
626- Defines the forward pass of the SNN over time.
627- This method might be used if the network is called directly with time-series data.
628- However, the ForwardForwardPipeline._run_snn_batch currently iterates
629- through self.network_sequence modules per time step.
630-
631- :param x_batch_time: Input tensor with shape [batch_size, time_steps, num_features].
632- :return: Final layer output spikes and a list of hidden states (membrane potentials)
633- from all spiking layers.
634- """
635- # Initialize hidden states for all spiking layers in the sequence
636- # This assumes they are snn.Leaky and support init_leaky() or similar
637- # Or, more generally, that they initialize if mem is None on first call.
638-
639- spiking_layer_modules = [info ['spiking' ] for info in self ._ff_layer_pairs_info ]
640- # mem_states = [layer.init_leaky() for layer in spiking_layer_modules] # This creates new states
641- # For snntorch, typically pass None for initial state, layer handles it.
642-
643- # The pipeline's _run_snn_batch actually handles the time loop and state passing.
644- # This forward method is more for standalone use or if the pipeline changes.
645- # If you want this model to be directly callable with (B, T, F) input and manage its own time loop:
646-
647- batch_size = x_batch_time .shape [0 ]
648-
649- # Initialize states for each spiking layer for this batch
650- # This is tricky because snn.Leaky.init_hidden() doesn't take batch_size.
651- # State initialization is usually handled by passing None to the layer's forward method
652- # for the first time step, and it initializes based on the input batch size.
653-
654- # Placeholder: The pipeline's _run_snn_batch is the primary runner.
655- # This forward method would need a more elaborate state management if used directly.
656- # For now, let's make it compatible with how _run_snn_batch works if it were to call this.
657-
658- # If this 'forward' is to be used, it should mirror _run_snn_batch's logic:
659- # Initialize all spiking layer states to None
660- spiking_layer_states = {module : None for module in self .network_sequence if isinstance (module , snn .SpikingNeuron )}
661-
662- # Record outputs if needed (e.g., for a final classification layer not part of FF)
663- # final_spk_rec = [] # If you want to record output spikes over time
664-
665- for t in range (x_batch_time .shape [1 ]): # Iterate over time
666- x_t = x_batch_time [:, t , :]
667- layer_input = x_t
668-
669- current_module_idx = 0
670- for module in self .network_sequence :
671- if isinstance (module , snn .SpikingNeuron ):
672- spk_out , new_mem = module (layer_input , spiking_layer_states .get (module ))
673- spiking_layer_states [module ] = new_mem
674- layer_input = spk_out
675- else : # nn.Linear
676- layer_input = module (layer_input )
677- # After passing through all layers for time step t, layer_input is the output of the last layer
678- # final_spk_rec.append(layer_input)
679-
680- # return torch.stack(final_spk_rec, dim=1), [state for state in spiking_layer_states.values()]
681- return layer_input , [spiking_layer_states [info ['spiking' ]] for info in self ._ff_layer_pairs_info ]
682-
683-
684- def get_ff_layer_pairs (self ) -> List [Tuple [nn .Linear , snn .SpikingNeuron ]]:
685- # language=rst
686- """
687- Returns the list of (Linear, SpikingNeuron) pairs for Forward-Forward training.
688- """
689- # If _ff_layer_pairs_info contains tuples, return them directly
690- return self ._ff_layer_pairs_info
691-
692- # Alternative: If you want to be explicit about the structure
693- # return [(pair[0], pair[1]) for pair in self._ff_layer_pairs_info]
694-
695-
696552class FFSNN_BindsNET (Network ):
697553 # language=rst
698554 """
0 commit comments