@@ -725,42 +725,313 @@ def reset_state_variables(self) -> None:
725725 self .eligibility_trace .zero_ ()
726726 return
727727
728+ # Remove the MyBackpropVariant class and replace with:
728729
729-
730- class MyBackpropVariant (MCC_LearningRule ):
731- def __init__ (self , connection , feature_value , ** kwargs ):
732- super ().__init__ (connection = connection , feature_value = feature_value , ** kwargs )
733- # Potentially initialize other parameters specific to your variant
734- self .update = self ._custom_connection_update
735-
736- def _custom_connection_update (self , ** kwargs ) -> None :
737- # Assume 'error_signal' for the target layer is passed in kwargs
738- # Assume 'surrogate_grad_target' for target neuron activations is available or computed
739- # Assume 'source_activity' (e.g., spikes or trace) is from self.source
730+ class ForwardForwardMCCLearning (MCC_LearningRule ):
731+ """
732+ Forward-Forward learning rule for MulticompartmentConnection.
733+
734+ This MCC learning rule wrapper integrates the Forward-Forward algorithm
735+ with the MulticompartmentConnection architecture, enabling layer-wise
736+ learning without backpropagation through time.
737+
738+ The learning rule works by:
739+ 1. Computing goodness scores from target layer activity
740+ 2. Collecting positive and negative sample statistics
741+ 3. Applying contrastive weight updates based on Forward-Forward loss
742+ """
743+
744+ def __init__ (
745+ self ,
746+ alpha_loss : float = 0.6 ,
747+ goodness_fn : str = "mean_squared" ,
748+ nu : float = 0.001 ,
749+ momentum : float = 0.0 ,
750+ weight_decay : float = 0.0 ,
751+ ** kwargs
752+ ):
753+ """
754+ Initialize Forward-Forward MCC learning rule.
740755
741- if "error_signal" not in kwargs :
742- return # Or handle missing error
756+ Args:
757+ alpha_loss: Forward-Forward loss threshold parameter
758+ goodness_fn: Goodness score computation method ("mean_squared", "sum_squared")
759+ nu: Learning rate for weight updates
760+ momentum: Momentum factor for weight updates
761+ weight_decay: Weight decay factor for regularization
762+ **kwargs: Additional arguments passed to parent MCC_LearningRule
763+ """
764+ super ().__init__ (nu = nu , ** kwargs )
743765
744- error_signal = kwargs ["error_signal" ] # This would be specific to target neurons
766+ self .alpha_loss = alpha_loss
767+ self .goodness_fn = goodness_fn
768+ self .momentum = momentum
769+ self .weight_decay = weight_decay
745770
746- # This is highly conceptual and depends on your specific variant:
747- # 1. Get pre-synaptic activity (e.g., self.source.s or self.source.x)
748- # 2. The 'error_signal' would correspond to the error at the post-synaptic (target) neurons
749- # 3. Compute weight updates, e.g., delta_w = learning_rate * error_signal * pre_synaptic_activity
750- # (This is a simplification; SNN backprop is more complex)
771+ # State tracking for Forward-Forward learning
772+ self .positive_goodness = None
773+ self . negative_goodness = None
774+ self . positive_activations = None
775+ self . negative_activations = None
751776
752- # Example: (very abstract, actual SNN backprop is more involved)
753- # Assume error_signal is shaped for target neurons, source_s for source neurons
754- # update_matrix = torch.outer(error_signal, self.source.s.float().mean(dim=0)) # Simplified
755- # self.feature_value += self.nu[0] * update_matrix * self.connection.dt
777+ # Momentum state
778+ self .velocity = None
756779
757- # Actual implementation would depend on the precise math of your variant
758- # (e.g., using surrogate derivatives of target neuron potentials, etc.)
780+ # Sample type tracking
781+ self .current_sample_type = None
782+ self .samples_processed = 0
759783
760- # Call the parent's update for decay, clamping, etc.
761- super ().update ()
784+ def update (
785+ self ,
786+ connection : 'MulticompartmentConnection' ,
787+ source_s : torch .Tensor ,
788+ target_s : torch .Tensor ,
789+ ** kwargs
790+ ) -> None :
791+ """
792+ Perform Forward-Forward learning update.
762793
763- def reset_state_variables (self ) -> None :
764- # Reset any internal states if your rule has them
765- pass
766-
794+ This method is called by MCC during each simulation step. It accumulates
795+ statistics for positive and negative samples, then applies contrastive
796+ updates when both sample types are available.
797+
798+ Args:
799+ connection: Parent MulticompartmentConnection
800+ source_s: Source layer spikes [batch_size, source_neurons]
801+ target_s: Target layer spikes [batch_size, target_neurons]
802+ **kwargs: Additional arguments including 'sample_type'
803+ """
804+ # Check if learning is enabled
805+ if not connection .w .requires_grad :
806+ return
807+
808+ # Get sample type from kwargs
809+ sample_type = kwargs .get ('sample_type' , self .current_sample_type )
810+ if sample_type is None :
811+ # Default to positive for backward compatibility
812+ sample_type = "positive"
813+
814+ # Compute goodness score for current batch
815+ current_goodness = self ._compute_goodness (target_s )
816+
817+ # Store activations and goodness based on sample type
818+ if sample_type == "positive" :
819+ self .positive_goodness = current_goodness .detach ()
820+ self .positive_activations = {
821+ 'source' : source_s .detach (),
822+ 'target' : target_s .detach ()
823+ }
824+
825+ elif sample_type == "negative" :
826+ self .negative_goodness = current_goodness .detach ()
827+ self .negative_activations = {
828+ 'source' : source_s .detach (),
829+ 'target' : target_s .detach ()
830+ }
831+
832+ else :
833+ raise ValueError (f"Invalid sample_type: { sample_type } . Must be 'positive' or 'negative'" )
834+
835+ self .samples_processed += 1
836+
837+ # Apply contrastive update if we have both positive and negative samples
838+ if (self .positive_goodness is not None and
839+ self .negative_goodness is not None and
840+ self .positive_activations is not None and
841+ self .negative_activations is not None ):
842+
843+ self ._apply_forward_forward_update (connection )
844+ self ._reset_accumulated_data ()
845+
846+ def _compute_goodness (self , target_activity : torch .Tensor ) -> torch .Tensor :
847+ """
848+ Compute Forward-Forward goodness score from target layer activity.
849+
850+ Args:
851+ target_activity: Target neuron spikes [batch_size, neurons]
852+
853+ Returns:
854+ Goodness scores [batch_size]
855+ """
856+ if self .goodness_fn == "mean_squared" :
857+ # Mean squared activity across neurons (original FF paper)
858+ goodness = torch .mean (target_activity ** 2 , dim = 1 )
859+
860+ elif self .goodness_fn == "sum_squared" :
861+ # Sum of squared activity across neurons
862+ goodness = torch .sum (target_activity ** 2 , dim = 1 )
863+
864+ else :
865+ raise ValueError (f"Unknown goodness function: { self .goodness_fn } " )
866+
867+ return goodness
868+
869+ def _apply_forward_forward_update (self , connection : 'MulticompartmentConnection' ):
870+ """
871+ Apply Forward-Forward contrastive weight update.
872+
873+ The update follows the Forward-Forward principle:
874+ - Strengthen weights that increase goodness for positive samples
875+ - Weaken weights that increase goodness for negative samples
876+
877+ Args:
878+ connection: Parent MulticompartmentConnection
879+ """
880+ # Get weight tensor
881+ w = connection .w
882+
883+ # Compute Forward-Forward loss (for monitoring)
884+ ff_loss = self ._compute_ff_loss (self .positive_goodness , self .negative_goodness )
885+
886+ # Compute weight update based on activity correlations
887+ pos_source = self .positive_activations ['source' ]
888+ pos_target = self .positive_activations ['target' ]
889+ neg_source = self .negative_activations ['source' ]
890+ neg_target = self .negative_activations ['target' ]
891+
892+ # Positive update: strengthen weights for positive samples
893+ # ΔW_pos = η * s_pos^T * t_pos / batch_size
894+ delta_w_pos = torch .mm (pos_source .t (), pos_target ) / pos_source .shape [0 ]
895+
896+ # Negative update: weaken weights for negative samples
897+ # ΔW_neg = -η * s_neg^T * t_neg / batch_size
898+ delta_w_neg = - torch .mm (neg_source .t (), neg_target ) / neg_source .shape [0 ]
899+
900+ # Combined Forward-Forward update
901+ delta_w = self .nu * (delta_w_pos + delta_w_neg )
902+
903+ # Add weight decay if specified
904+ if self .weight_decay > 0 :
905+ delta_w = delta_w - self .weight_decay * w
906+
907+ # Apply momentum if specified
908+ if self .momentum > 0 :
909+ if self .velocity is None :
910+ self .velocity = torch .zeros_like (w )
911+
912+ self .velocity = self .momentum * self .velocity + delta_w
913+ delta_w = self .velocity
914+
915+ # Apply weight update
916+ with torch .no_grad ():
917+ w .add_ (delta_w )
918+
919+ # Apply weight constraints if they exist
920+ self ._apply_weight_constraints (connection )
921+
922+ def _compute_ff_loss (
923+ self ,
924+ goodness_pos : torch .Tensor ,
925+ goodness_neg : torch .Tensor
926+ ) -> torch .Tensor :
927+ """
928+ Compute Forward-Forward contrastive loss for monitoring.
929+
930+ L = log(1 + exp(-g_pos + α)) + log(1 + exp(g_neg - α))
931+
932+ Args:
933+ goodness_pos: Goodness scores for positive samples
934+ goodness_neg: Goodness scores for negative samples
935+
936+ Returns:
937+ Forward-Forward loss (scalar)
938+ """
939+ # Positive loss: encourage high goodness for positive samples
940+ loss_pos = torch .log (1 + torch .exp (- goodness_pos + self .alpha_loss ))
941+
942+ # Negative loss: encourage low goodness for negative samples
943+ loss_neg = torch .log (1 + torch .exp (goodness_neg - self .alpha_loss ))
944+
945+ # Return mean loss across batch
946+ total_loss = loss_pos + loss_neg
947+ return torch .mean (total_loss )
948+
949+ def _apply_weight_constraints (self , connection : 'MulticompartmentConnection' ):
950+ """
951+ Apply weight constraints (bounds, normalization) if specified.
952+
953+ Args:
954+ connection: Parent connection with constraint parameters
955+ """
956+ w = connection .w
957+
958+ # Apply weight bounds if specified
959+ if hasattr (connection , 'wmin' ) and hasattr (connection , 'wmax' ):
960+ with torch .no_grad ():
961+ w .clamp_ (connection .wmin , connection .wmax )
962+
963+ # Apply normalization if specified
964+ if hasattr (connection , 'norm' ) and connection .norm is not None :
965+ with torch .no_grad ():
966+ if connection .norm == "l2" :
967+ # L2 normalize each output neuron's weights
968+ w .div_ (w .norm (dim = 0 , keepdim = True ) + 1e-8 )
969+ elif connection .norm == "l1" :
970+ # L1 normalize each output neuron's weights
971+ w .div_ (w .abs ().sum (dim = 0 , keepdim = True ) + 1e-8 )
972+
973+ def _reset_accumulated_data (self ):
974+ """Reset accumulated positive and negative sample data."""
975+ self .positive_goodness = None
976+ self .negative_goodness = None
977+ self .positive_activations = None
978+ self .negative_activations = None
979+
980+ def set_sample_type (self , sample_type : str ):
981+ """
982+ Set the current sample type for subsequent updates.
983+
984+ Args:
985+ sample_type: Either "positive" or "negative"
986+ """
987+ if sample_type not in ["positive" , "negative" ]:
988+ raise ValueError (f"Invalid sample_type: { sample_type } " )
989+
990+ self .current_sample_type = sample_type
991+
992+ def get_goodness_scores (self ) -> dict :
993+ """Get current goodness scores for positive and negative samples."""
994+ return {
995+ 'positive_goodness' : self .positive_goodness ,
996+ 'negative_goodness' : self .negative_goodness
997+ }
998+
999+ def get_ff_loss (self ) -> torch .Tensor :
1000+ """Compute and return current Forward-Forward loss if data available."""
1001+ if self .positive_goodness is not None and self .negative_goodness is not None :
1002+ return self ._compute_ff_loss (self .positive_goodness , self .negative_goodness )
1003+ else :
1004+ return torch .tensor (0.0 )
1005+
1006+ def reset_state (self ):
1007+ """Reset all learning rule state."""
1008+ self ._reset_accumulated_data ()
1009+ self .velocity = None
1010+ self .current_sample_type = None
1011+ self .samples_processed = 0
1012+
1013+ def get_learning_stats (self ) -> dict :
1014+ """Get learning rule statistics and configuration."""
1015+ return {
1016+ 'learning_rule_type' : 'ForwardForwardMCCLearning' ,
1017+ 'alpha_loss' : self .alpha_loss ,
1018+ 'goodness_fn' : self .goodness_fn ,
1019+ 'learning_rate' : self .nu ,
1020+ 'momentum' : self .momentum ,
1021+ 'weight_decay' : self .weight_decay ,
1022+ 'samples_processed' : self .samples_processed ,
1023+ 'current_sample_type' : self .current_sample_type ,
1024+ 'has_positive_data' : self .positive_goodness is not None ,
1025+ 'has_negative_data' : self .negative_goodness is not None
1026+ }
1027+
1028+ def __repr__ (self ):
1029+ """String representation of the learning rule."""
1030+ return (
1031+ f"ForwardForwardMCCLearning("
1032+ f"nu={ self .nu } , "
1033+ f"alpha_loss={ self .alpha_loss } , "
1034+ f"goodness_fn='{ self .goodness_fn } ', "
1035+ f"momentum={ self .momentum } , "
1036+ f"weight_decay={ self .weight_decay } )"
1037+ )
0 commit comments