Skip to content

Commit 8b7c187

Browse files
committed
make rnnrbm work with float32.
1 parent 3fe0484 commit 8b7c187

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

code/rnnrbm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,10 @@ def train(self, files, batch_size=100, num_epochs=200):
218218

219219
assert len(files) > 0, 'Training set is empty!' \
220220
' (did you download the data files?)'
221-
dataset = [midiread(f, self.r, self.dt).piano_roll for f in files]
221+
dataset = [midiread(f, self.r,
222+
self.dt).piano_roll.astype(theano.config.floatX)
223+
for f in files]
224+
222225
try:
223226
for epoch in xrange(num_epochs):
224227
numpy.random.shuffle(dataset)

0 commit comments

Comments
 (0)