Skip to content

Commit d788ebc

Browse files
committed
Add basic LSTM tutorial
1 parent 30f10a9 commit d788ebc

File tree

4 files changed

+674
-0
lines changed

4 files changed

+674
-0
lines changed

code/imdb.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import cPickle
2+
import gzip
3+
import os
4+
import sys
5+
import time
6+
7+
import numpy
8+
9+
import theano
10+
import theano.tensor as T
11+
12+
def prepare_data(seqs, labels, maxlen=None):
13+
# x: a list of sentences
14+
lengths = [len(s) for s in seqs]
15+
16+
if maxlen != None:
17+
new_seqs = []
18+
new_labels = []
19+
new_lengths = []
20+
for l, s, y in zip(lengths, seqs, labels):
21+
if l < maxlen:
22+
new_seqs.append(s)
23+
new_labels.append(y)
24+
new_lengths.append(l)
25+
lengths = new_lengths
26+
labels = new_labels
27+
seqs = new_seqs
28+
29+
if len(lengths) < 1:
30+
return None, None, None
31+
32+
n_samples = len(seqs)
33+
maxlen = numpy.max(lengths)
34+
35+
x = numpy.zeros((maxlen, n_samples)).astype('int64')
36+
x_mask = numpy.zeros((maxlen, n_samples)).astype('float32')
37+
for idx, s in enumerate(seqs):
38+
x[:lengths[idx],idx] = s
39+
x_mask[:lengths[idx],idx] = 1.
40+
41+
return x, x_mask, labels
42+
43+
def load_data(path="/data/lisatmp3/chokyun/tweets_sa/imdb/aclImdb/imdb.pkl", n_words=100000, valid_portion=0.1):
44+
''' Loads the dataset
45+
46+
:type dataset: string
47+
:param dataset: the path to the dataset (here IMDB)
48+
'''
49+
50+
#############
51+
# LOAD DATA #
52+
#############
53+
54+
print '... loading data'
55+
56+
# Load the dataset
57+
f = open(path, 'rb')
58+
train_set = cPickle.load(f)
59+
test_set = cPickle.load(f)
60+
f.close()
61+
62+
# split training set into validation set
63+
train_set_x, train_set_y = train_set
64+
n_samples = len(train_set_x)
65+
sidx = numpy.random.permutation(n_samples)
66+
n_train = int(numpy.round(n_samples * (1. - valid_portion)))
67+
valid_set_x = [train_set_x[s] for s in sidx[n_train:]]
68+
valid_set_y = [train_set_y[s] for s in sidx[n_train:]]
69+
train_set_x = [train_set_x[s] for s in sidx[:n_train]]
70+
train_set_y = [train_set_y[s] for s in sidx[:n_train]]
71+
72+
train_set = (train_set_x, train_set_y)
73+
valid_set = (valid_set_x, valid_set_y)
74+
75+
def remove_unk(x):
76+
return [[1 if w >= n_words else w for w in sen] for sen in x]
77+
78+
test_set_x, test_set_y = test_set
79+
valid_set_x, valid_set_y = valid_set
80+
train_set_x, train_set_y = train_set
81+
82+
train_set_x = remove_unk(train_set_x)
83+
valid_set_x = remove_unk(valid_set_x)
84+
test_set_x = remove_unk(test_set_x)
85+
86+
train = (train_set_x, train_set_y)
87+
valid = (valid_set_x, valid_set_y)
88+
test = (test_set_x, test_set_y)
89+
90+
return train, valid, test
91+
92+

0 commit comments

Comments
 (0)