Skip to content

Commit 87618a1

Browse files
author
Kevin Chang
committed
initial
1 parent 96479e6 commit 87618a1

File tree

14 files changed

+1409
-21
lines changed

14 files changed

+1409
-21
lines changed

.DS_Store

8 KB
Binary file not shown.

.pytest_cache/v/cache/lastfailed

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,11 @@
1-
{}
1+
{
2+
"test/analysis/test_analyzers.py": true,
3+
"test/conversion/test_conversion.py": true,
4+
"test/encoding/test_encoding.py": true,
5+
"test/models/test_models.py": true,
6+
"test/network/test_connections.py": true,
7+
"test/network/test_learning.py": true,
8+
"test/network/test_monitors.py": true,
9+
"test/network/test_network.py": true,
10+
"test/network/test_nodes.py": true
11+
}

BINDSNET paper.pdf

5.85 MB
Binary file not shown.

bindsnet/.DS_Store

8 KB
Binary file not shown.
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import torch
2+
import random
3+
from typing import Optional
4+
5+
def generate_positive_sample(
6+
x_input: torch.Tensor,
7+
true_label: int,
8+
num_classes: int,
9+
) -> torch.Tensor:
10+
"""
11+
Generates a positive sample x_pos by embedding the true label into x_input.
12+
13+
The first `num_classes` elements of the output vector are set to 0,
14+
except at the index corresponding to the `true_label`, where it's
15+
set to the maximum value of `x_input`. The remaining elements are copied
16+
from `x_input`.
17+
18+
Args:
19+
x_input: The original flattened input vector (1D Tensor).
20+
true_label: The 0-indexed true class label of x_input.
21+
num_classes: The total number of classes (c).
22+
23+
Returns:
24+
A new tensor x_pos with the true label embedded.
25+
"""
26+
if not isinstance(x_input, torch.Tensor) or x_input.ndim != 1:
27+
raise ValueError("x_input must be a 1D PyTorch Tensor.")
28+
if not (0 <= true_label < num_classes):
29+
raise ValueError(
30+
f"True label {true_label} is out of bounds for {num_classes} classes."
31+
)
32+
if num_classes <= 0:
33+
raise ValueError("num_classes must be positive.")
34+
35+
36+
d = x_input.shape[0]
37+
m = torch.max(x_input) if d > 0 else torch.tensor(0.0, dtype=x_input.dtype) # Handle empty x_input
38+
39+
# Initialize x_pos with zeros, matching dtype and device of x_input
40+
x_pos = torch.zeros_like(x_input)
41+
42+
# Part 1: Embed the true label in the first `num_classes` elements.
43+
# All these elements are 0, except at the index `true_label`.
44+
if true_label < min(num_classes, d): # Ensure true_label is within bounds of the modifiable part
45+
x_pos[true_label] = m
46+
47+
# Part 2: Copy the rest of the original input vector.
48+
# These are elements from index `num_classes` to `d-1`.
49+
if d > num_classes:
50+
x_pos[num_classes:] = x_input[num_classes:]
51+
# If num_classes >= d, only the first d elements are modified, and the above copy is skipped.
52+
53+
return x_pos
54+
55+
56+
def generate_negative_sample(
57+
x_input: torch.Tensor,
58+
true_label: int,
59+
num_classes: int,
60+
false_label_override: Optional[int] = None,
61+
) -> torch.Tensor:
62+
"""
63+
Generates a negative sample x_neg by embedding a false label into x_input.
64+
65+
The first `num_classes` elements of the output vector are set to 0,
66+
except at the index corresponding to the `chosen_false_label`, where it's
67+
set to the maximum value of `x_input`. The remaining elements are copied
68+
from `x_input`.
69+
70+
Args:
71+
x_input: The original flattened input vector (1D Tensor).
72+
true_label: The 0-indexed true class label of x_input.
73+
num_classes: The total number of classes (c).
74+
false_label_override: Optional. A specific 0-indexed false class label to embed.
75+
If None, a false label will be chosen randomly, ensuring
76+
it's different from `true_label`. This parameter can be
77+
used if implementing a "hard labeling" strategy externally.
78+
Returns:
79+
A new tensor x_neg with the false label embedded.
80+
"""
81+
if not isinstance(x_input, torch.Tensor) or x_input.ndim != 1:
82+
raise ValueError("x_input must be a 1D PyTorch Tensor.")
83+
if not (0 <= true_label < num_classes):
84+
raise ValueError(
85+
f"True label {true_label} is out of bounds for {num_classes} classes."
86+
)
87+
if num_classes <= 0:
88+
raise ValueError("num_classes must be positive.")
89+
90+
chosen_false_label: int
91+
if false_label_override is not None:
92+
chosen_false_label = false_label_override
93+
if not (0 <= chosen_false_label < num_classes):
94+
raise ValueError(
95+
f"Provided false_label_override {chosen_false_label} is out of bounds for {num_classes} classes."
96+
)
97+
if chosen_false_label == true_label:
98+
raise ValueError(
99+
f"Provided false_label_override {chosen_false_label} cannot be the same as true_label {true_label}."
100+
)
101+
else:
102+
if num_classes <= 1:
103+
raise ValueError(
104+
"Cannot randomly choose a distinct false label with less than 2 classes."
105+
)
106+
possible_false_labels = [i for i in range(num_classes) if i != true_label]
107+
if not possible_false_labels: # Should be caught by num_classes <= 1
108+
raise ValueError(f"No available false labels to choose from for true_label {true_label} with {num_classes} classes.")
109+
chosen_false_label = random.choice(possible_false_labels)
110+
111+
d = x_input.shape[0]
112+
m = torch.max(x_input) if d > 0 else torch.tensor(0.0, dtype=x_input.dtype) # Handle empty x_input
113+
114+
# Initialize x_neg with zeros, matching dtype and device of x_input
115+
x_neg = torch.zeros_like(x_input)
116+
117+
# Part 1: Embed the false label in the first `num_classes` elements.
118+
if chosen_false_label < min(num_classes, d): # Ensure chosen_false_label is within bounds
119+
x_neg[chosen_false_label] = m
120+
121+
# Part 2: Copy the rest of the original input vector.
122+
if d > num_classes:
123+
x_neg[num_classes:] = x_input[num_classes:]
124+
125+
return x_neg
126+
127+
128+
129+
# --- Example Usage (for demonstration if you run this file directly) ---
130+
if __name__ == "__main__":
131+
# Example: 10 features in original input, 4 classes
132+
original_x = torch.rand(10) # Random data
133+
true_class_label = 1
134+
total_classes = 4
135+
136+
print(f"Original x_input: {original_x}")
137+
print(f"True label: {true_class_label}")
138+
print(f"Num classes: {total_classes}")
139+
print("-" * 30)
140+
141+
x_positive = generate_positive_sample(
142+
original_x, true_class_label, total_classes
143+
)
144+
print(f"x_pos (true label {true_class_label} embedded): {x_positive}")
145+
print("-" * 30)
146+
147+
x_negative_random = generate_negative_sample(
148+
original_x, true_class_label, total_classes
149+
)
150+
print(f"x_neg (random false label embedded): {x_negative_random}")
151+
print("-" * 30)
152+
153+
specific_false = 3
154+
if specific_false == true_class_label: # Ensure it's actually false for the example
155+
specific_false = 0 if true_class_label !=0 else 2
156+
157+
x_negative_specific = generate_negative_sample(
158+
original_x, true_class_label, total_classes, false_label_override=specific_false
159+
)
160+
print(f"x_neg (specific false label {specific_false} embedded): {x_negative_specific}")
161+
print("-" * 30)
162+
163+
# Edge case: num_classes > len(x_input)
164+
short_x = torch.tensor([0.1, 0.9])
165+
true_short_label = 0
166+
classes_short = 3
167+
print(f"Short Original x_input: {short_x}")
168+
x_pos_short = generate_positive_sample(short_x, true_short_label, classes_short)
169+
print(f"x_pos_short (true label {true_short_label}, num_classes {classes_short}): {x_pos_short}")
170+
x_neg_short = generate_negative_sample(short_x, true_short_label, classes_short, false_label_override=1)
171+
print(f"x_neg_short (false label 1, num_classes {classes_short}): {x_neg_short}")

bindsnet/learning/MCC_learning.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,13 @@
1212
)
1313
from ..utils import im2col_indices
1414

15+
#MCC much faster and durable
16+
# Old is practical connection between 2 objects and you can only use it as a weight
17+
# If theres a mask you want to add, you need to add another object, or another weight
18+
# run the experiments
19+
#Object pipe it can be delayed weighted
1520

16-
class MCC_LearningRule(ABC):
21+
class MCC_LearningRule(ABC): #multicompartment connection
1722
# language=rst
1823
"""
1924
Abstract base class for learning rules.
@@ -719,3 +724,43 @@ def reset_state_variables(self) -> None:
719724
self.eligibility.zero_()
720725
self.eligibility_trace.zero_()
721726
return
727+
728+
729+
730+
class MyBackpropVariant(MCC_LearningRule):
731+
def __init__(self, connection, feature_value, **kwargs):
732+
super().__init__(connection=connection, feature_value=feature_value, **kwargs)
733+
# Potentially initialize other parameters specific to your variant
734+
self.update = self._custom_connection_update
735+
736+
def _custom_connection_update(self, **kwargs) -> None:
737+
# Assume 'error_signal' for the target layer is passed in kwargs
738+
# Assume 'surrogate_grad_target' for target neuron activations is available or computed
739+
# Assume 'source_activity' (e.g., spikes or trace) is from self.source
740+
741+
if "error_signal" not in kwargs:
742+
return # Or handle missing error
743+
744+
error_signal = kwargs["error_signal"] # This would be specific to target neurons
745+
746+
# This is highly conceptual and depends on your specific variant:
747+
# 1. Get pre-synaptic activity (e.g., self.source.s or self.source.x)
748+
# 2. The 'error_signal' would correspond to the error at the post-synaptic (target) neurons
749+
# 3. Compute weight updates, e.g., delta_w = learning_rate * error_signal * pre_synaptic_activity
750+
# (This is a simplification; SNN backprop is more complex)
751+
752+
# Example: (very abstract, actual SNN backprop is more involved)
753+
# Assume error_signal is shaped for target neurons, source_s for source neurons
754+
# update_matrix = torch.outer(error_signal, self.source.s.float().mean(dim=0)) # Simplified
755+
# self.feature_value += self.nu[0] * update_matrix * self.connection.dt
756+
757+
# Actual implementation would depend on the precise math of your variant
758+
# (e.g., using surrogate derivatives of target neuron potentials, etc.)
759+
760+
# Call the parent's update for decay, clamping, etc.
761+
super().update()
762+
763+
def reset_state_variables(self) -> None:
764+
# Reset any internal states if your rule has them
765+
pass
766+

0 commit comments

Comments
 (0)