Skip to content

Commit 6949c37

Browse files
committed
first version of RNN-RBM tutorial
1 parent 90a1bed commit 6949c37

File tree

7 files changed

+2348
-2
lines changed

7 files changed

+2348
-2
lines changed

code/rnnrbm.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Author: Nicolas Boulanger-Lewandowski
2+
# University of Montreal (2012)
3+
# RNN-RBM deep learning tutorial
4+
# More information at http://deeplearning.net/tutorial/rnnrbm.html
5+
6+
import numpy
7+
import pylab
8+
import glob
9+
import sys
10+
11+
from midi.utils import midiread, midiwrite
12+
import theano
13+
import theano.tensor as T
14+
from theano.tensor.shared_randomstreams import RandomStreams
15+
16+
numpy.random.seed(0xdeadbeef)
17+
rng = RandomStreams(seed=numpy.random.randint(1 << 30))
18+
theano.config.warn.subtensor_merge_bug = False
19+
20+
21+
def build_rbm(v, W, bv, bh, k):
22+
'''Construct a k-step Gibbs chain starting at v with RBM parameters W, bv, bh.
23+
24+
v : Theano vector or matrix
25+
If a matrix, multiple chains will be run in parallel (batch).
26+
W : Theano matrix
27+
Weight matrix of the RBM.
28+
bv : Theano vector
29+
Visible bias vector of the RBM.
30+
bh : Theano vector
31+
Hidden bias vector of the RBM.
32+
k : scalar or Theano scalar
33+
Length of the Gibbs chain.
34+
35+
Return a (v_sample, cost, monitor, updates) tuple:
36+
37+
v_sample : Theano vector or matrix with the same shape as `v`
38+
Corresponds to the generated sample(s).
39+
cost : Theano scalar
40+
Expression whose gradient with respect to W, bv, bh is the CD-k approximation
41+
to the log-likelihood of `v` (training example) under the RBM.
42+
The cost is averaged in the batch case.
43+
monitor: Theano scalar
44+
Pseudo log-likelihood (also averaged in the batch case).
45+
updates: dictionary of Theano variable -> Theano variable
46+
The `updates` object returned by scan.'''
47+
48+
def gibbs_step(v):
49+
mean_h = T.nnet.sigmoid(T.dot(v, W) + bh)
50+
h = rng.binomial(size=mean_h.shape, n=1, p=mean_h, dtype=theano.config.floatX)
51+
mean_v = T.nnet.sigmoid(T.dot(h, W.T) + bv)
52+
v = rng.binomial(size=mean_v.shape, n=1, p=mean_v, dtype=theano.config.floatX)
53+
return mean_v, v
54+
55+
chain, updates = theano.scan(lambda v: gibbs_step(v)[1], outputs_info=[v], n_steps=k)
56+
v_sample = chain[-1]
57+
58+
mean_v = gibbs_step(v_sample)[0]
59+
monitor = T.xlogx.xlogy0(v, mean_v) + T.xlogx.xlogy0(1-v, 1-mean_v)
60+
monitor = monitor.sum() / v.shape[0]
61+
62+
free_energy = lambda v: -(v * bv).sum() - T.log(1 + T.exp(T.dot(v, W) + bh)).sum()
63+
cost = (free_energy(v) - free_energy(v_sample)) / v.shape[0]
64+
65+
return v_sample, cost, monitor, updates
66+
67+
68+
def shared_normal(num_rows, num_cols, scale=1):
69+
'''Initialize a matrix shared variable with normally distributed elements.'''
70+
return theano.shared(numpy.random.normal(scale=scale, size=(num_rows, num_cols)).astype(theano.config.floatX))
71+
72+
73+
def shared_zeros(*shape):
74+
'''Initialize a vector shared variable with zero elements.'''
75+
return theano.shared(numpy.zeros(shape, dtype=theano.config.floatX))
76+
77+
78+
def build_rnnrbm(n_visible, n_hidden, n_hidden_recurrent):
79+
'''Construct a symbolic RNN-RBM, including initialized parameters in shared variables and
80+
symbolic variables for the training cost and sequence generation.
81+
82+
n_visible : integer
83+
Number of visible units.
84+
n_hidden : integer
85+
Number of hidden units of the conditional RBMs.
86+
n_hidden_recurrent : integer
87+
Number of hidden units of the RNN.
88+
89+
Return a (v, v_sample, cost, monitor, params, updates_train, v_t, updates_generate) tuple:
90+
91+
v : Theano matrix
92+
Symbolic variable holding an input sequence (used during training)
93+
v_sample : Theano matrix
94+
Symbolic variable holding the negative particles for CD log-likelihood gradient estimation
95+
(used during training)
96+
cost : Theano scalar
97+
Expression whose gradient (considering v_sample constant) corresponds to the LL gradient of the RNN-RBM.
98+
(used during training)
99+
monitor : Theano scalar
100+
Frame-level pseudo-likelihood (useful for monitoring during training)
101+
params : tuple of Theano shared variables
102+
The parameters of the model to be optimized during training.
103+
updates_train : dictionary of Theano variable -> Theano variable
104+
Update object that should be passed to theano.function when compiling the training function.
105+
v_t : Theano matrix
106+
Symbolic variable holding a generate sequence (used during sampling)
107+
updates_generate : dictionary of Theano variable -> Theano variable
108+
Update object that should be passed to theano.function when compiling the generation function.'''
109+
110+
W = shared_normal(n_visible, n_hidden, 0.01)
111+
bv = shared_zeros(n_visible)
112+
bh = shared_zeros(n_hidden)
113+
Wuh = shared_normal(n_hidden_recurrent, n_hidden, 0.0001)
114+
Wuv = shared_normal(n_hidden_recurrent, n_visible, 0.0001)
115+
Wvu = shared_normal(n_visible, n_hidden_recurrent, 0.0001)
116+
Wuu = shared_normal(n_hidden_recurrent, n_hidden_recurrent, 0.0001)
117+
bu = shared_zeros(n_hidden_recurrent)
118+
119+
params = W, bv, bh, Wuh, Wuv, Wvu, Wuu, bu # learned parameters as shared variables
120+
121+
v = T.matrix() # a training sequence
122+
u0 = T.zeros((n_hidden_recurrent,)) # initial value for the RNN hidden units
123+
124+
# if `v_t` is given, deterministic recurrence to compute the variable biases bv_t, bh_t at each time step
125+
# if `v_t` is None, same recurrence but with a separate Gibbs chain at each time step to sample (generate) from the RNN-RBM
126+
# the resulting sample v_t is returned in order to be passed down to the sequence history
127+
def recurrence(v_t, u_tm1):
128+
bv_t = bv + T.dot(u_tm1, Wuv)
129+
bh_t = bh + T.dot(u_tm1, Wuh)
130+
generate = v_t is None
131+
if generate:
132+
v_t, _, _, updates = build_rbm(T.zeros((n_visible,)), W, bv_t, bh_t, k=25)
133+
u_t = T.tanh(bu + T.dot(v_t, Wvu) + T.dot(u_tm1, Wuu))
134+
return ([v_t, u_t], updates) if generate else [u_t, bv_t, bh_t]
135+
136+
# for training, the deterministic recurrence is used to compute all the {bv_t, bh_t, 1 <= t <= T} given v
137+
# conditional RBMs can then be trained in batches using those parameters
138+
(u_t, bv_t, bh_t), updates_train = theano.scan(lambda v_t, u_tm1, *_: recurrence(v_t, u_tm1), sequences=v, outputs_info=[u0, None, None], non_sequences=params)
139+
v_sample, cost, monitor, updates_rbm = build_rbm(v, W, bv_t[:], bh_t[:], k=15)
140+
updates_train.update(updates_rbm)
141+
142+
# symbolic loop for sequence generation
143+
(v_t, u_t), updates_generate = theano.scan(lambda u_tm1, *_: recurrence(None, u_tm1), outputs_info=[None, u0], non_sequences=params, n_steps=200)
144+
145+
return v, v_sample, cost, monitor, params, updates_train, v_t, updates_generate
146+
147+
148+
class RnnRbm:
149+
'''Simple class to build and train an RNN-RBM from MIDI files and to generate sample sequences.'''
150+
151+
def __init__(self, n_hidden=150, n_hidden_recurrent=100, lr=0.001, r=(21, 109), dt=0.3):
152+
'''Constructs and compiles Theano functions for training and sequence generation.
153+
154+
n_hidden : integer
155+
Number of hidden units of the conditional RBMs.
156+
n_hidden_recurrent : integer
157+
Number of hidden units of the RNN.
158+
lr : float
159+
Learning rate
160+
r : (integer, integer) tuple
161+
Specifies the pitch range of the piano-roll in MIDI note numbers, including r[0] but not r[1],
162+
such that r[1]-r[0] is the number of visible units of the RBM at a given time step.
163+
The default (21, 109) corresponds to the full range of piano (88 notes).
164+
dt : float
165+
Sampling period when converting the MIDI files into piano-rolls, or equivalently the time difference
166+
between consecutive time steps.'''
167+
168+
self.r = r
169+
self.dt = dt
170+
v, v_sample, cost, monitor, params, updates_train, v_t, updates_generate = build_rnnrbm(r[1]-r[0], n_hidden, n_hidden_recurrent)
171+
172+
gradient = T.grad(cost, params, consider_constant=[v_sample])
173+
updates_train.update(dict((p, p - lr*g) for p, g in zip(params, gradient)))
174+
self.train_function = theano.function([v], monitor, updates=updates_train)
175+
self.generate_function = theano.function([], v_t, updates=updates_generate)
176+
177+
178+
def train(self, files, batch_size=100, num_epochs=150):
179+
'''Train the RNN-RBM via stochastic gradient descent (SGD) using MIDI files converted to piano-rolls.
180+
181+
files : list of strings
182+
List of MIDI files that will be loaded as piano-rolls for training.
183+
batch_size : integer
184+
Training sequences will be split into subsequences of at most this size
185+
before applying the SGD updates.
186+
num_epochs : integer
187+
Number of epochs (pass over the training set) performed. The user can
188+
safely interrupt training with Ctrl+C at any time.'''
189+
190+
dataset = [midiread(f, self.r, self.dt).piano_roll for f in files]
191+
try:
192+
for epoch in xrange(num_epochs):
193+
numpy.random.shuffle(dataset)
194+
costs = []
195+
196+
for s, sequence in enumerate(dataset):
197+
for i in xrange(0, len(sequence), batch_size):
198+
cost = self.train_function(sequence[i:i+batch_size])
199+
costs.append(cost)
200+
201+
print 'Epoch %i/%i' % (epoch + 1, num_epochs), numpy.mean(costs)
202+
sys.stdout.flush()
203+
204+
except KeyboardInterrupt:
205+
print 'Interrupted by user.'
206+
207+
208+
def generate(self, filename, show=True):
209+
'''Generate a sample sequence, plot the resulting piano-roll and save it as a MIDI file.
210+
211+
filename : string
212+
A MIDI file will be created at this location.
213+
show : boolean
214+
If True, a piano-roll of the generated sequence will be shown.'''
215+
216+
piano_roll = self.generate_function()
217+
midiwrite(filename, piano_roll, self.r, self.dt)
218+
if show:
219+
extent = (0, self.dt * len(piano_roll)) + self.r
220+
pylab.imshow(piano_roll.T, origin='lower', aspect='auto', interpolation='nearest', cmap=pylab.cm.gray_r, extent=extent)
221+
pylab.xlabel('time (s)')
222+
pylab.ylabel('MIDI note number')
223+
pylab.title('generated piano-roll')
224+
pylab.show()
225+
226+
227+
if __name__ == '__main__':
228+
model = RnnRbm()
229+
model.train(glob.glob('Nottingham/train/*.mid'))
230+
model.generate('sample1.mid')
231+
model.generate('sample2.mid')
232+

doc/contents.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ Contents
1919
rbm
2020
DBN
2121
hmc
22+
rnnrbm
2223
utilities
2324
references

doc/images/rnnrbm.png

20.9 KB
Loading

0 commit comments

Comments
 (0)