diff --git a/code/rnnrbm.py b/code/rnnrbm.py index db19c8f8..f7aad5f9 100644 --- a/code/rnnrbm.py +++ b/code/rnnrbm.py @@ -271,9 +271,10 @@ def test_rnnrbm(batch_size=100, num_epochs=200): model = RnnRbm() model.train(glob.glob('../data/Nottingham/train/*.mid'), batch_size=batch_size, num_epochs=num_epochs) + return model if __name__ == '__main__': - test_rnnrbm() + model = test_rnnrbm() model.generate('sample1.mid') model.generate('sample2.mid') pylab.show() diff --git a/doc/rnnrbm.txt b/doc/rnnrbm.txt index 65b28bd1..ff9d718f 100644 --- a/doc/rnnrbm.txt +++ b/doc/rnnrbm.txt @@ -11,7 +11,7 @@ Modeling and generating sequences of polyphonic music with the RNN-RBM and `restricted Boltzmann machines (RBM) `_. .. note:: - The code for this section is available for download here: `rnnrbm.py <../code/rnnrbm.py>`_. + The code for this section is available for download here: `rnnrbm.py `_. You will need the modified `Python MIDI package (GPL license) `_ in your ``$PYTHONPATH`` or in the working directory in order to convert MIDI files to and from piano-rolls. The script also assumes that the content of the `Nottingham Database of folk tunes `_ has been extracted in the ``../data`` directory. @@ -266,7 +266,7 @@ We now have all the necessary ingredients to start training our network on real n_hidden_recurrent) gradient = T.grad(cost, params, consider_constant=[v_sample]) - updates_train.update(dict((p, p - lr * g) for p, g in zip(params, + updates_train.update(((p, p - lr * g) for p, g in zip(params, gradient))) self.train_function = theano.function([v], monitor, updates=updates_train) @@ -288,7 +288,10 @@ We now have all the necessary ingredients to start training our network on real assert len(files) > 0, 'Training set is empty!' \ ' (did you download the data files?)' - dataset = [midiread(f, self.r, self.dt).piano_roll for f in files] + dataset = [midiread(f, self.r, + self.dt).piano_roll.astype(theano.config.floatX) + for f in files] + try: for epoch in xrange(num_epochs): numpy.random.shuffle(dataset)