diff --git a/code/rnnrbm.py b/code/rnnrbm.py index 900ffdc6..43bda691 100644 --- a/code/rnnrbm.py +++ b/code/rnnrbm.py @@ -288,7 +288,8 @@ def generate(self, filename, show=True): def test_rnnrbm(batch_size=100, num_epochs=200): model = RnnRbm() - re = os.path.join(os.path.split(os.path.dirname(__file__))[0], + cwd = os.path.dirname(os.path.abspath(__file__)) + re = os.path.join(os.path.split(cwd)[0], 'data', 'Nottingham', 'train', '*.mid') model.train(glob.glob(re), batch_size=batch_size, num_epochs=num_epochs)