Skip to content

Commit f47bb92

Browse files
committed
show preliminary results
1 parent 6949c37 commit f47bb92

File tree

2 files changed

+43
-26
lines changed

2 files changed

+43
-26
lines changed

code/rnnrbm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def __init__(self, n_hidden=150, n_hidden_recurrent=100, lr=0.001, r=(21, 109),
175175
self.generate_function = theano.function([], v_t, updates=updates_generate)
176176

177177

178-
def train(self, files, batch_size=100, num_epochs=150):
178+
def train(self, files, batch_size=100, num_epochs=200):
179179
'''Train the RNN-RBM via stochastic gradient descent (SGD) using MIDI files converted to piano-rolls.
180180
181181
files : list of strings
@@ -217,16 +217,17 @@ def generate(self, filename, show=True):
217217
midiwrite(filename, piano_roll, self.r, self.dt)
218218
if show:
219219
extent = (0, self.dt * len(piano_roll)) + self.r
220+
pylab.figure()
220221
pylab.imshow(piano_roll.T, origin='lower', aspect='auto', interpolation='nearest', cmap=pylab.cm.gray_r, extent=extent)
221222
pylab.xlabel('time (s)')
222223
pylab.ylabel('MIDI note number')
223224
pylab.title('generated piano-roll')
224-
pylab.show()
225225

226226

227227
if __name__ == '__main__':
228228
model = RnnRbm()
229229
model.train(glob.glob('Nottingham/train/*.mid'))
230230
model.generate('sample1.mid')
231231
model.generate('sample2.mid')
232+
pylab.show()
232233

doc/rnnrbm.txt

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ Modeling and generating sequences of polyphonic music with the RNN-RBM
1717
The script also assumes that the content of the `Nottingham Database of folk tunes <http://www-etud.iro.umontreal.ca/~boulanni/Nottingham.zip>`_ has been extracted in the working directory.
1818
Alternative MIDI datasets are available `here <http://www-etud.iro.umontreal.ca/~boulanni/icml2012>`_.
1919

20+
.. caution::
21+
Depending on your locally installed Theano version, you may have problems running this script.
22+
If this is the case, please use the 'bleeding-edge' developer version from github and checkout the 0.6rc1 release:
23+
``git checkout e7a68bd9b39d722fef933425e0593cdb8998e141``.
24+
2025

2126
The RNN-RBM
2227
+++++++++++++++++++++++++
@@ -241,7 +246,7 @@ We now have all the necessary ingredients to start training our network on real
241246
self.generate_function = theano.function([], v_t, updates=updates_generate)
242247

243248

244-
def train(self, files, batch_size=100, num_epochs=150):
249+
def train(self, files, batch_size=100, num_epochs=200):
245250
'''Train the RNN-RBM via stochastic gradient descent (SGD) using MIDI files converted to piano-rolls.
246251

247252
files : list of strings
@@ -283,57 +288,68 @@ We now have all the necessary ingredients to start training our network on real
283288
midiwrite(filename, piano_roll, self.r, self.dt)
284289
if show:
285290
extent = (0, self.dt * len(piano_roll)) + self.r
291+
pylab.figure()
286292
pylab.imshow(piano_roll.T, origin='lower', aspect='auto', interpolation='nearest', cmap=pylab.cm.gray_r, extent=extent)
287293
pylab.xlabel('time (s)')
288294
pylab.ylabel('MIDI note number')
289295
pylab.title('generated piano-roll')
290-
pylab.show()
291296

292297

293298
if __name__ == '__main__':
294299
model = RnnRbm()
295300
model.train(glob.glob('Nottingham/train/*.mid'))
296301
model.generate('sample1.mid')
297302
model.generate('sample2.mid')
303+
pylab.show()
298304

299305

300306
Results
301307
++++++++
302308

303-
We ran the code on the Nottingham database for 150 epochs; training took approximately 48 hours.
309+
We ran the code on the Nottingham database for 200 epochs; training took approximately 24 hours.
304310

305311
The output was the following:
306312

307313
.. code-block:: text
308314

309-
Epoch 1/150 -15.0308940028
310-
Epoch 2/150 -10.4892606673
311-
Epoch 3/150 -10.2394696138
312-
Epoch 4/150 -10.1431669994
313-
Epoch 5/150 -9.7005382843
314-
Epoch 6/150 -8.5985647524
315-
Epoch 7/150 -8.35115428534
316-
Epoch 8/150 -8.26453580552
317-
Epoch 9/150 -8.21208991542
318-
Epoch 10/150 -8.16847274143
319-
Epoch 11/150 -8.03408036522
320-
Epoch 12/150 -7.72139097234
321-
Epoch 13/150 -7.56626080635
315+
Epoch 1/150 -15.0154373583
316+
Epoch 2/150 -10.4948703701
317+
Epoch 3/150 -10.2507567848
318+
Epoch 4/150 -10.1417621708
319+
Epoch 5/150 -9.69403756276
320+
Epoch 6/150 -8.6036962785
321+
Epoch 7/150 -8.35180803953
322+
Epoch 8/150 -8.26202621624
323+
Epoch 9/150 -8.21526214665
324+
Epoch 10/150 -8.16552397791
325+
326+
... truncated for brevity ...
322327

323-
... truncated for brevity ...
328+
Epoch 140/150 -5.09668220315
329+
Epoch 141/150 -5.08657006002
330+
Epoch 142/150 -5.09776776338
331+
Epoch 143/150 -5.10151042486
332+
Epoch 144/150 -5.07677377181
333+
Epoch 145/150 -5.07374453388
334+
Epoch 146/150 -inf
335+
Epoch 147/150 -5.06393939067
336+
Epoch 148/150 -5.07493685431
337+
Epoch 149/150 -5.06504525246
338+
Epoch 150/150 -5.04567771601
324339

325-
Epoch ..
326340

327341

328-
The figures below show the piano-roll of two sample sequences and the corresponding MIDI files are provided:
342+
The figures below show the piano-rolls of two sample sequences and we provide the corresponding MIDI files:
329343

330-
.. figure:: images/rnnrbm.png
344+
.. figure:: images/sample1.png
345+
:scale: 60%
331346

332-
`sample1.mid <sample1.mid>`_
347+
Listen to `sample1.mid <http://www-etud.iro.umontreal.ca/~boulanni/sample1.mid>`_
333348

334-
.. figure:: images/rnnrbm.png
349+
.. figure:: images/sample1.png
350+
:scale: 60%
335351

336-
`sample2.mid <sample2.mid>`_
352+
Listen to `sample2.mid <http://www-etud.iro.umontreal.ca/~boulanni/sample2.mid>`_
337353

338354

339355
How to improve this code
@@ -343,7 +359,7 @@ The code shown in this tutorial is a stripped-down version that can be improved
343359

344360
* Pretraining techniques: initialize the :math:`W,b_v,b_h` parameters with independent RBMs with fully shuffled frames (i.e. :math:`W_{uh}=W_{uv}=W_{uu}=W_{vu}=0`); initialize the :math:`W_{uv},W_{uu},W_{vu},b_u` parameters of the RNN with the auxiliary cross-entropy objective via either SGD or, preferably, Hessian-free optimization [BoulangerLewandowski12]_.
345361
* Optimization techniques: gradient clipping, Nesterov momentum and the use of NADE for conditional density estimation.
346-
* Preprocessing: transposing the sequences in common tonality (e.g. C major / minor) and normalizing the tempo in beats (quarternotes) per minute can yield important improvement in the generative quality of the model.
362+
* Preprocessing: transposing the sequences in a common tonality (e.g. C major / minor) and normalizing the tempo in beats (quarternotes) per minute can yield substantial improvement in the generative quality of the model.
347363
* Hyperparameter search: learning rate (separately for the RBM and RNN parts), learning rate schedules, batch size, number of hidden units (recurrent and RBM), momentum coefficient, momentum schedule, Gibbs chain length :math:`k` and early stopping.
348364
* Learn the initial condition :math:`u^{(0)}` as a model parameter.
349365

0 commit comments

Comments
 (0)