Multi-layer Recurrent Neural Networks (LSTM, RNN) for character-level language models in Python using Tensorflow.
Inspired from Andrej Karpathy's char-rnn.
-
If you don’t have pip, then install pip first:
- sudo apt-get install python3-pip
-
If you don’t have virtualenv, then:
- pip install virtualenv
-
If you haven’t set up your virtual environment yet, set up virtualenv (venv) within the char-rnn directory:
- cd to the directory
- virtualenv venv
In your directory,
-
Activate virtualenv with:
- source venv/bin/activate
-
Run with:
- python train.py
-
To train on new text:
- python train.py --data_dir=./data/name-of-new-folder
-
To view output graphs and logs on “TensorBoard”:
- In separate terminal window:
- cd to directory
- tensorboard --logdir=./logs/
- Open a browser to http://localhost:6006 or the correct IP/Port specified.
- In separate terminal window:
-
To generate new text:
- In separate terminal window:
- cd to directory
- python sample.py
- In separate terminal window:
- rnn-size: 128
- learning_rate: 0.008
- dropout: 0.5
- seq = 50
- Check # of parameters vs. data size
- Tune rnn_size and dropout rate accordingly
- Think about acoustic / rhythmic feature engineering
To train with default parameters on the tinyshakespeare corpus, run python train.py. To access all the parameters use python train.py --help.
To sample from a checkpointed model, python sample.py.
You can use any plain text file as input. For example you could download The complete Sherlock Holmes as such:
cd data
mkdir sherlock
cd sherlock
wget https://sherlock-holm.es/stories/plain-text/cnus.txt
mv cnus.txt input.txtThen start train from the top level directory using python train.py --data_dir=./data/sherlock/
A quick tip to concatenate many small disparate .txt files into one large training file: ls *.txt | xargs -L 1 cat >> input.txt
To visualize training progress, model graphs, and internal state histograms: fire up Tensorboard and point it at your log_dir. E.g.:
$ tensorboard --logdir=./logs/Then open a browser to http://localhost:6006 or the correct IP/Port specified.
- Add explanatory comments
- Expose more command-line arguments
- Compare accuracy and performance with char-rnn
- More Tensorboard instrumentation
Please feel free to:
- Leave feedback in the issues
- Open a Pull Request
- Join the gittr chat
- Share your success stories and data sets!