Skip to content

Commit 3fe0484

Browse files
committed
Add the new rnnrbm tutorial in the normal test
1 parent 74e1d35 commit 3fe0484

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

code/rnnrbm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,13 @@ def generate(self, filename, show=True):
258258
pylab.title('generated piano-roll')
259259

260260

261-
if __name__ == '__main__':
261+
def test_rnnrbm(batch_size=100, num_epochs=200):
262262
model = RnnRbm()
263-
model.train(glob.glob('../data/Nottingham/train/*.mid'))
263+
model.train(glob.glob('../data/Nottingham/train/*.mid'),
264+
batch_size=batch_size, num_epochs=num_epochs)
265+
266+
if __name__ == '__main__':
267+
test_rnnrbm()
264268
model.generate('sample1.mid')
265269
model.generate('sample2.mid')
266270
pylab.show()

code/test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logistic_sgd
1111
import mlp
1212
import rbm
13+
import rnnrbm
1314
import SdA
1415

1516

@@ -49,6 +50,8 @@ def test_rbm():
4950
rbm.test_rbm(training_epochs=1, batch_size=300, n_chains=1, n_samples=1,
5051
n_hidden=20, output_folder='tmp_rbm_plots')
5152

53+
def test_rnnrbm():
54+
rnnrbm.test_rnnrbm(num_epochs=1)
5255

5356
def speed():
5457
"""

0 commit comments

Comments
 (0)