Skip to content

Commit d545515

Browse files
committed
Touch-ups to addition RNN example
1 parent 588ce7a commit d545515

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

examples/addition_rnn.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,21 @@
2525
http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf
2626
Theoretically it introduces shorter term dependencies between source and target.
2727
28+
2829
Two digits inverted:
29-
+ One layer JZS1 (128 HN) with 55 iterations = 99% train/test accuracy
30+
+ One layer JZS1 (128 HN), 5k training examples = 99% train/test accuracy in 55 epochs
31+
3032
Three digits inverted:
31-
+ One layer JZS1 (128 HN) with 19 iterations = 99% train/test accuracy
33+
+ One layer JZS1 (128 HN), 50k training examples = 99% train/test accuracy in 100 epochs
34+
35+
3236
Four digits inverted:
33-
+ One layer JZS1 (128 HN) with 20 iterations = 99% train/test accuracy
37+
+ One layer JZS1 (128 HN), 400k training examples = 99% train/test accuracy in 20 epochs
38+
39+
3440
Five digits inverted:
35-
+ One layer JZS1 (128 HN) with 28 iterations = 99% train/test accuracy
41+
+ One layer JZS1 (128 HN), 550k training examples = 99% train/test accuracy in 30 epochs
42+
3643
"""
3744

3845

@@ -61,9 +68,14 @@ def decode(self, X, calc_argmax=True):
6168
X = X.argmax(axis=-1)
6269
return ''.join(self.indices_char[x] for x in X)
6370

71+
72+
class colors:
73+
ok = '\033[92m'
74+
fail = '\033[91m'
75+
close = '\033[0m'
76+
6477
# Parameters for the model and dataset
65-
# Note: Training size is number of queries to generate, not final number of unique queries
66-
TRAINING_SIZE = 800000
78+
TRAINING_SIZE = 50000
6779
DIGITS = 3
6880
INVERT = True
6981
# Try replacing JZS1 with LSTM, GRU, or SimpleRNN
@@ -80,7 +92,7 @@ def decode(self, X, calc_argmax=True):
8092
expected = []
8193
seen = set()
8294
print('Generating data...')
83-
for i in xrange(TRAINING_SIZE):
95+
while len(questions) < TRAINING_SIZE:
8496
f = lambda: int(''.join(np.random.choice(list('0123456789')) for i in xrange(np.random.randint(1, DIGITS + 1))))
8597
a, b = f(), f()
8698
# Skip any addition questions we've already seen
@@ -132,7 +144,7 @@ def decode(self, X, calc_argmax=True):
132144
model.compile(loss='categorical_crossentropy', optimizer='adam')
133145

134146
# Train the model each generation and show predictions against the validation dataset
135-
for iteration in range(1, 60):
147+
for iteration in range(1, 200):
136148
print()
137149
print('-' * 50)
138150
print('Iteration', iteration)
@@ -148,5 +160,5 @@ def decode(self, X, calc_argmax=True):
148160
guess = ctable.decode(preds[0], calc_argmax=False)
149161
print('Q', q[::-1] if INVERT else q)
150162
print('T', correct)
151-
print('☑' if correct == guess else '☒', guess)
163+
print(colors.ok + '☑' + colors.close if correct == guess else colors.fail + '☒' + colors.close, guess)
152164
print('---')

0 commit comments

Comments
 (0)