Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion code/rnnrbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,10 @@ def test_rnnrbm(batch_size=100, num_epochs=200):
model = RnnRbm()
model.train(glob.glob('../data/Nottingham/train/*.mid'),
batch_size=batch_size, num_epochs=num_epochs)
return model

if __name__ == '__main__':
test_rnnrbm()
model = test_rnnrbm()
model.generate('sample1.mid')
model.generate('sample2.mid')
pylab.show()
9 changes: 6 additions & 3 deletions doc/rnnrbm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Modeling and generating sequences of polyphonic music with the RNN-RBM
and `restricted Boltzmann machines (RBM) <rbm.html>`_.

.. note::
The code for this section is available for download here: `rnnrbm.py <../code/rnnrbm.py>`_.
The code for this section is available for download here: `rnnrbm.py <code/rnnrbm.py>`_.

You will need the modified `Python MIDI package (GPL license) <http://www.iro.umontreal.ca/~lisa/deep/midi.zip>`_ in your ``$PYTHONPATH`` or in the working directory in order to convert MIDI files to and from piano-rolls.
The script also assumes that the content of the `Nottingham Database of folk tunes <http://www.iro.umontreal.ca/~lisa/deep/data/Nottingham.zip>`_ has been extracted in the ``../data`` directory.
Expand Down Expand Up @@ -266,7 +266,7 @@ We now have all the necessary ingredients to start training our network on real
n_hidden_recurrent)

gradient = T.grad(cost, params, consider_constant=[v_sample])
updates_train.update(dict((p, p - lr * g) for p, g in zip(params,
updates_train.update(((p, p - lr * g) for p, g in zip(params,
gradient)))
self.train_function = theano.function([v], monitor,
updates=updates_train)
Expand All @@ -288,7 +288,10 @@ We now have all the necessary ingredients to start training our network on real

assert len(files) > 0, 'Training set is empty!' \
' (did you download the data files?)'
dataset = [midiread(f, self.r, self.dt).piano_roll for f in files]
dataset = [midiread(f, self.r,
self.dt).piano_roll.astype(theano.config.floatX)
for f in files]

try:
for epoch in xrange(num_epochs):
numpy.random.shuffle(dataset)
Expand Down