Skip to content

Commit fd1d392

Browse files
committed
Merge pull request lisa-lab#55 from carriepl/lstm_tutorial
Add basic LSTM tutorial
2 parents 30f10a9 + 4bb23cf commit fd1d392

File tree

5 files changed

+790
-0
lines changed

5 files changed

+790
-0
lines changed

code/imdb.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import cPickle
2+
import gzip
3+
import os
4+
import sys
5+
import time
6+
7+
import numpy
8+
9+
import theano
10+
import theano.tensor as T
11+
12+
13+
def prepare_data(seqs, labels, maxlen=None):
14+
# x: a list of sentences
15+
lengths = [len(s) for s in seqs]
16+
17+
if maxlen is not None:
18+
new_seqs = []
19+
new_labels = []
20+
new_lengths = []
21+
for l, s, y in zip(lengths, seqs, labels):
22+
if l < maxlen:
23+
new_seqs.append(s)
24+
new_labels.append(y)
25+
new_lengths.append(l)
26+
lengths = new_lengths
27+
labels = new_labels
28+
seqs = new_seqs
29+
30+
if len(lengths) < 1:
31+
return None, None, None
32+
33+
n_samples = len(seqs)
34+
maxlen = numpy.max(lengths)
35+
36+
x = numpy.zeros((maxlen, n_samples)).astype('int64')
37+
x_mask = numpy.zeros((maxlen, n_samples)).astype('float32')
38+
for idx, s in enumerate(seqs):
39+
x[:lengths[idx], idx] = s
40+
x_mask[:lengths[idx], idx] = 1.
41+
42+
return x, x_mask, labels
43+
44+
45+
def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1):
46+
''' Loads the dataset
47+
48+
:type dataset: string
49+
:param dataset: the path to the dataset (here IMDB)
50+
'''
51+
52+
#############
53+
# LOAD DATA #
54+
#############
55+
56+
print '... loading data'
57+
58+
# Load the dataset
59+
f = open(path, 'rb')
60+
train_set = cPickle.load(f)
61+
test_set = cPickle.load(f)
62+
f.close()
63+
64+
# split training set into validation set
65+
train_set_x, train_set_y = train_set
66+
n_samples = len(train_set_x)
67+
sidx = numpy.random.permutation(n_samples)
68+
n_train = int(numpy.round(n_samples * (1. - valid_portion)))
69+
valid_set_x = [train_set_x[s] for s in sidx[n_train:]]
70+
valid_set_y = [train_set_y[s] for s in sidx[n_train:]]
71+
train_set_x = [train_set_x[s] for s in sidx[:n_train]]
72+
train_set_y = [train_set_y[s] for s in sidx[:n_train]]
73+
74+
train_set = (train_set_x, train_set_y)
75+
valid_set = (valid_set_x, valid_set_y)
76+
77+
def remove_unk(x):
78+
return [[1 if w >= n_words else w for w in sen] for sen in x]
79+
80+
test_set_x, test_set_y = test_set
81+
valid_set_x, valid_set_y = valid_set
82+
train_set_x, train_set_y = train_set
83+
84+
train_set_x = remove_unk(train_set_x)
85+
valid_set_x = remove_unk(valid_set_x)
86+
test_set_x = remove_unk(test_set_x)
87+
88+
train = (train_set_x, train_set_y)
89+
valid = (valid_set_x, valid_set_y)
90+
test = (test_set_x, test_set_y)
91+
92+
return train, valid, test

0 commit comments

Comments
 (0)