Skip to content

Commit 71a6019

Browse files
Merge pull request #311 from ELanning/fix-random-network-baseline
Update Examples from deprecated Pipeline to EnvironmentPipeline
2 parents 6c0bde0 + f849f53 commit 71a6019

File tree

6 files changed

+51
-53
lines changed

6 files changed

+51
-53
lines changed

bindsnet/pipeline/action.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def select_multinomial(pipeline: EnvironmentPipeline, **kwargs) -> int:
2626

2727
assert (
2828
output.n % action_space.n == 0
29-
), "Output layer size not divisible by size of action space."
29+
), f"Output layer size of {output.n} is not divisible by action space size of {action_space.n}."
3030

3131
pop_size = int(output.n / action_space.n)
3232
spikes = output.s
@@ -66,11 +66,11 @@ def select_softmax(pipeline: EnvironmentPipeline, **kwargs) -> int:
6666

6767
assert (
6868
pipeline.network.layers[output].n == pipeline.env.action_space.n
69-
), "Output layer size not equal to size of action space."
69+
), "Output layer size is not equal to the size of the action space."
7070

7171
assert hasattr(
7272
pipeline, "spike_record"
73-
), "EnvironmentPipeline has not attribute named: spike_record."
73+
), "EnvironmentPipeline is missing the attribute: spike_record."
7474

7575
# Sum of previous iterations' spikes (Not yet implemented)
7676
spikes = torch.sum(pipeline.spike_record[output], dim=1)

examples/breakout/breakout.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
import torch
2-
31
from bindsnet.network import Network
4-
from bindsnet.pipeline import Pipeline
2+
from bindsnet.pipeline import EnvironmentPipeline
53
from bindsnet.encoding import bernoulli
64
from bindsnet.network.topology import Connection
75
from bindsnet.environment import GymEnvironment
@@ -27,12 +25,12 @@
2725
network.add_connection(inpt_middle, source="Input Layer", target="Hidden Layer")
2826
network.add_connection(middle_out, source="Hidden Layer", target="Output Layer")
2927

30-
# Load SpaceInvaders environment.
28+
# Load the Breakout environment.
3129
environment = GymEnvironment("BreakoutDeterministic-v4")
3230
environment.reset()
3331

3432
# Build pipeline from specified components.
35-
pipeline = Pipeline(
33+
pipeline = EnvironmentPipeline(
3634
network,
3735
environment,
3836
encoding=bernoulli,
@@ -47,12 +45,16 @@
4745

4846
# Run environment simulation for 100 episodes.
4947
for i in range(100):
50-
# initialize episode reward
51-
reward = 0
48+
total_reward = 0
5249
pipeline.reset_()
53-
while True:
54-
pipeline.step()
55-
reward += pipeline.reward
56-
if pipeline.done:
57-
break
58-
print("Episode " + str(i) + " reward:", reward)
50+
is_done = False
51+
while not is_done:
52+
# Broken until select_softmax is implemented.
53+
result = pipeline.env_step()
54+
pipeline.step(result)
55+
56+
reward = result[1]
57+
total_reward += reward
58+
59+
is_done = result[2]
60+
print(f"Episode {i} total reward:{total_reward}")

examples/breakout/breakout_stdp.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
import torch
2-
31
from bindsnet.network import Network
4-
from bindsnet.pipeline import Pipeline
2+
from bindsnet.pipeline import EnvironmentPipeline
53
from bindsnet.learning import MSTDP
64
from bindsnet.encoding import bernoulli
75
from bindsnet.network.topology import Connection
@@ -36,12 +34,12 @@
3634
network.add_connection(inpt_middle, source="Input Layer", target="Hidden Layer")
3735
network.add_connection(middle_out, source="Hidden Layer", target="Output Layer")
3836

39-
# Load SpaceInvaders environment.
37+
# Load the Breakout environment.
4038
environment = GymEnvironment("BreakoutDeterministic-v4")
4139
environment.reset()
4240

4341
# Build pipeline from specified components.
44-
pipeline = Pipeline(
42+
environment_pipeline = EnvironmentPipeline(
4543
network,
4644
environment,
4745
encoding=bernoulli,
@@ -55,30 +53,28 @@
5553
)
5654

5755

58-
# Train agent for 100 episodes.
56+
def run_pipeline(pipeline, episode_count):
57+
for i in range(episode_count):
58+
total_reward = 0
59+
pipeline.reset_()
60+
is_done = False
61+
while not is_done:
62+
# Broken until select_softmax is implemented.
63+
result = pipeline.env_step()
64+
pipeline.step(result)
65+
66+
reward = result[1]
67+
total_reward += reward
68+
69+
is_done = result[2]
70+
print(f"Episode {i} total reward:{total_reward}")
71+
72+
5973
print("Training: ")
60-
for i in range(100):
61-
pipeline.reset_()
62-
# initialize episode reward
63-
reward = 0
64-
while True:
65-
pipeline.step()
66-
reward += pipeline.reward
67-
if pipeline.done:
68-
break
69-
print("Episode " + str(i) + " reward:", reward)
74+
run_pipeline(environment_pipeline, episode_count=100)
7075

7176
# stop MSTDP
72-
pipeline.network.learning = False
77+
environment_pipeline.network.learning = False
7378

7479
print("Testing: ")
75-
for i in range(100):
76-
pipeline.reset_()
77-
# initialize episode reward
78-
reward = 0
79-
while True:
80-
pipeline.step()
81-
reward += pipeline.reward
82-
if pipeline.done:
83-
break
84-
print("Episode " + str(i) + " reward:", reward)
80+
run_pipeline(environment_pipeline, episode_count=100)

examples/asteroids/random_network_baseline.py renamed to examples/breakout/random_network_baseline.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from bindsnet.network import Network
55
from bindsnet.learning import Hebbian
6-
from bindsnet.pipeline import Pipeline
6+
from bindsnet.pipeline import EnvironmentPipeline
77
from bindsnet.encoding import bernoulli
88
from bindsnet.network.monitors import Monitor
99
from bindsnet.environment import GymEnvironment
@@ -43,9 +43,9 @@
4343
network = Network(dt=dt)
4444

4545
# Layers of neurons.
46-
inpt = Input(shape=(110, 84), traces=True) # Input layer
46+
inpt = Input(shape=(80, 80), traces=True) # Input layer
4747
exc = LIFNodes(n=n_neurons, refrac=0, traces=True) # Excitatory layer
48-
readout = LIFNodes(n=14, refrac=0, traces=True) # Readout layer
48+
readout = LIFNodes(n=16, refrac=0, traces=True) # Readout layer
4949
layers = {"X": inpt, "E": exc, "R": readout}
5050

5151
# Connections between layers.
@@ -93,11 +93,11 @@
9393
if layer in voltages:
9494
network.add_monitor(voltages[layer], name="%s_voltages" % layer)
9595

96-
# Load SpaceInvaders environment.
97-
environment = GymEnvironment("Asteroids-v0")
96+
# Load the Breakout environment.
97+
environment = GymEnvironment("BreakoutDeterministic-v4")
9898
environment.reset()
9999

100-
pipeline = Pipeline(
100+
pipeline = EnvironmentPipeline(
101101
network,
102102
environment,
103103
encoding=bernoulli,
@@ -120,9 +120,11 @@
120120
i = 0
121121
try:
122122
while i < n:
123-
pipeline.step()
123+
result = pipeline.env_step()
124+
pipeline.step(result)
124125

125-
if pipeline.done:
126+
is_done = result[2]
127+
if is_done:
126128
pipeline.reset_()
127129

128130
i += 1

examples/mnist/minimal_reservoir.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from bindsnet.network import Network
1111
from bindsnet.pipeline import DataLoaderPipeline
12-
from bindsnet.network.monitors import Monitor
1312
from bindsnet.network.topology import Connection
1413
from bindsnet.network.nodes import LIFNodes, Input
1514

examples/space_invaders/et_space_invaders.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from bindsnet.network.topology import Connection
1212
from bindsnet.pipeline import EnvironmentPipeline
1313
from bindsnet.pipeline.action import select_multinomial
14-
from bindsnet.analysis.plotting import plot_weights
1514

1615

1716
parser = argparse.ArgumentParser()

0 commit comments

Comments
 (0)