Skip to content

Commit 555a0d9

Browse files
committed
Update
1 parent 440b1cd commit 555a0d9

File tree

13 files changed

+12018
-69
lines changed

13 files changed

+12018
-69
lines changed

.gitignore

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,12 @@ ENV/
9393

9494
# model cache
9595
*.pt
96+
97+
# dataset
98+
rt-polaritydata
99+
trees
100+
*.tar
101+
*.zip
102+
103+
# pycharm
104+
.idea

.idea/vcs.xml

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
## Introduction
2-
This is the implementation of kim's [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882) paper in PyTorch.
2+
This is the implementation of Kim's [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882) paper in PyTorch.
33

4-
Kim's implementation of the model in Theano:
4+
1. Kim's implementation of the model in Theano:
55
[https://github.com/yoonkim/CNN_sentence](https://github.com/yoonkim/CNN_sentence)
6-
7-
Denny Britz has an implementation in Tensorflow:
6+
2. Denny Britz has an implementation in Tensorflow:
87
[https://github.com/dennybritz/cnn-text-classification-tf](https://github.com/dennybritz/cnn-text-classification-tf)
98

109
## Requirement
@@ -13,6 +12,16 @@ Denny Britz has an implementation in Tensorflow:
1312
* torchtext > 0.1
1413
* numpy
1514

15+
## Result
16+
I just tried two dataset, MR and SST.
17+
18+
|Dataset|Class Size|Best Result|Kim's Paper Result|
19+
|---|---|---|---|
20+
|MR|2|77.5%(CNN-rand-static)|76.1%(CNN-rand-nostatic)|
21+
|SST|5|37.2%(CNN-rand-static)|45.0%(CNN-rand-nostatic)|
22+
23+
I haven't adjusted the hyper-parameters for SST seriously.
24+
1625
## Usage
1726
```
1827
./main.py -h
@@ -71,11 +80,20 @@ Batch[100] - loss: 0.655424 acc: 59.3750%
7180
Evaluation - loss: 0.672396 acc: 57.6923%(615/1066)
7281
```
7382

83+
## Test
84+
If you has construct you test set, you make testing like:
85+
86+
```
87+
/main.py -test -snapshot="./snapshot/2017-02-11-15-50/snapshot_steps1500.pt
88+
```
89+
The snapshot option means where your model load from. If you don't assign it, the model will start from scratch.
90+
7491
## Predict
7592
* **Example1**
7693

7794
```
78-
./main.py -predict="Hello my dear , I love you so much ." -snapshot="./snapshot/2017-02-11-15-50/snapshot_steps1500.pt"
95+
./main.py -predict="Hello my dear , I love you so much ." \
96+
-snapshot="./snapshot/2017-02-11-15-50/snapshot_steps1500.pt"
7997
```
8098
You will get:
8199

@@ -88,7 +106,8 @@ Evaluation - loss: 0.672396 acc: 57.6923%(615/1066)
88106
* **Example2**
89107

90108
```
91-
./main.py -predict="You just make me so sad and I have to leave you ." -snapshot="./snapshot/2017-02-11-15-50/snapshot_steps1500.pt"
109+
./main.py -predict="You just make me so sad and I have to leave you ."\
110+
-snapshot="./snapshot/2017-02-11-15-50/snapshot_steps1500.pt"
92111
```
93112
You will get:
94113

@@ -99,7 +118,7 @@ Evaluation - loss: 0.672396 acc: 57.6923%(615/1066)
99118
[Label] negative
100119
```
101120

102-
121+
Your text must be separated by space, even punctuation.
103122

104123
## Reference
105124
* [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882)

main.py

Lines changed: 75 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
#! /usr/bin/env python
2-
3-
42
import os
53
import argparse
64
import datetime
@@ -13,76 +11,101 @@
1311

1412

1513
parser = argparse.ArgumentParser(description='CNN text classificer')
16-
parser.add_argument('-batch-size', type=int, default=64, metavar='N', help='batch size for training [default: 50]')
17-
parser.add_argument('-lr', type=float, default=0.001, metavar='LR', help='initial learning rate [default: 0.01]')
18-
parser.add_argument('-epochs', type=int, default=200, metavar='N', help='number of epochs for train [default: 10]')
19-
parser.add_argument('-dropout', type=float, default=0.5, metavar='', help='the probability for dropout [default: 0.5]')
20-
parser.add_argument('-max_norm', type=float, default=3.0, help='l2 constraint of parameters')
21-
parser.add_argument('-cpu', action='store_true', default=False, help='disable the gpu' )
22-
parser.add_argument('-device', type=int, default=-1, help='device to use for iterate data')
23-
# model
24-
parser.add_argument('-embed-dim', type=int, default=128)
25-
parser.add_argument('-static', action='store_true', default=False, help='fix the embedding')
26-
parser.add_argument('-kernel-sizes', type=str, default='3,4,5', help='Comma-separated kernel size to use for convolution')
27-
parser.add_argument('-kernel-num', type=int, default=100, help='number of each kind of kernel')
28-
parser.add_argument('-class-num', type=int, default=2, help='number of class')
14+
# learning
15+
parser.add_argument('-lr', type=float, default=0.001, help='initial learning rate [default: 0.001]')
16+
parser.add_argument('-epochs', type=int, default=256, help='number of epochs for train [default: 256]')
17+
parser.add_argument('-batch-size', type=int, default=64, help='batch size for training [default: 64]')
18+
parser.add_argument('-log-interval', type=int, default=1, help='how many steps to wait before logging training status [default: 1]')
19+
parser.add_argument('-test-interval', type=int, default=100, help='how many steps to wait before testing [default: 100]')
20+
parser.add_argument('-save-interval', type=int, default=500, help='how many steps to wait before saving [default:500]')
21+
parser.add_argument('-save-dir', type=str, default='snapshot', help='where to save the checkpoint')
2922
# data
3023
parser.add_argument('-shuffle', action='store_true', default=False, help='shuffle the data every epoch' )
31-
parser.add_argument('-num-workers', type=int, default=0, help='how many subprocesses to use for data loading [default: 0]')
32-
# log
33-
parser.add_argument('-log-interval', type=int, default=1, help='how many batches to wait before logging training status')
34-
parser.add_argument('-test-interval', type=int, default=100, help='how many epochs to wait before testing')
35-
parser.add_argument('-save-interval', type=int, default=100, help='how many epochs to wait before saving')
36-
parser.add_argument('-predict', type=str, default=None, help='predict the sentence given')
24+
# model
25+
parser.add_argument('-dropout', type=float, default=0.5, help='the probability for dropout [default: 0.5]')
26+
parser.add_argument('-max-norm', type=float, default=3.0, help='l2 constraint of parameters [default: 3.0]')
27+
parser.add_argument('-embed-dim', type=int, default=128, help='number of embedding dimension [default: 128]')
28+
parser.add_argument('-kernel-num', type=int, default=100, help='number of each kind of kernel')
29+
parser.add_argument('-kernel-sizes', type=str, default='3,4,5', help='comma-separated kernel size to use for convolution')
30+
parser.add_argument('-static', action='store_true', default=False, help='fix the embedding')
31+
# device
32+
parser.add_argument('-device', type=int, default=-1, help='device to use for iterate data, -1 mean cpu [default: -1]')
33+
parser.add_argument('-no-cuda', action='store_true', default=False, help='disable the gpu' )
34+
# option
3735
parser.add_argument('-snapshot', type=str, default=None, help='filename of model snapshot [default: None]')
38-
parser.add_argument('-save-dir', type=str, default='snapshot', help='where to save the checkpoint')
36+
parser.add_argument('-predict', type=str, default=None, help='predict the sentence given')
37+
parser.add_argument('-test', action='store_true', default=False, help='train or test')
3938
args = parser.parse_args()
40-
args.cuda = not args.cpu and torch.cuda.is_available()
41-
args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
42-
args.save_dir = os.path.join(args.save_dir, datetime.datetime.now().strftime('%Y-%m-%d-%H-%M'))
39+
4340

4441
# load SST dataset
45-
'''
46-
print("Loading data...")
47-
text_field = data.Field(lower=True)
48-
label_field = data.Field(sequential=False)
49-
train_data, dev_data, test_data = datasets.SST.splits(text_field, label_field, fine_grained=True)
50-
text_field.build_vocab(train_data, dev_data, test_data)
51-
label_field.build_vocab(train_data)
52-
train_iter, dev_iter, test_iter = data.BucketIterator.splits(
53-
(train_data, dev_data, test_data),
54-
batch_sizes=(args.batch_size,
55-
len(dev_data),
56-
len(test_data)),
57-
device=-1, repeat=False)
58-
'''
42+
def sst(text_field, label_field, **kargs):
43+
train_data, dev_data, test_data = datasets.SST.splits(text_field, label_field, fine_grained=True)
44+
text_field.build_vocab(train_data, dev_data, test_data)
45+
label_field.build_vocab(train_data, dev_data, test_data)
46+
train_iter, dev_iter, test_iter = data.BucketIterator.splits(
47+
(train_data, dev_data, test_data),
48+
batch_sizes=(args.batch_size,
49+
len(dev_data),
50+
len(test_data)),
51+
**kargs)
52+
return train_iter, dev_iter, test_iter
53+
5954

6055
# load MR dataset
61-
print("Loading data...")
56+
def mr(text_field, label_field, **kargs):
57+
train_data, dev_data = mydatasets.MR.splits(text_field, label_field)
58+
text_field.build_vocab(train_data, dev_data)
59+
label_field.build_vocab(train_data, dev_data)
60+
train_iter, dev_iter = data.Iterator.splits(
61+
(train_data, dev_data),
62+
batch_sizes=(args.batch_size, len(dev_data)),
63+
**kargs)
64+
return train_iter, dev_iter
65+
66+
67+
# load data
68+
print("\nLoading data...")
6269
text_field = data.Field(lower=True)
6370
label_field = data.Field(sequential=False)
64-
train_data, dev_data = mydatasets.MR.splits(text_field, label_field)
65-
text_field.build_vocab(train_data, dev_data)
66-
label_field.build_vocab(train_data)
67-
train_iter, dev_iter = data.Iterator.splits((train_data, dev_data),
68-
batch_sizes=(args.batch_size, len(dev_data)),
69-
device=-1, repeat=False)
71+
#train_iter, dev_iter = mr(text_field, label_field, device=-1, repeat=False)
72+
train_iter, dev_iter, test_iter = sst(text_field, label_field, device=-1, repeat=False)
7073

71-
# args from vacab
74+
75+
# update args and print
7276
args.embed_num = len(text_field.vocab)
77+
args.class_num = len(label_field.vocab) - 1
78+
args.cuda = args.no_cuda and torch.cuda.is_available(); del args.no_cuda
79+
args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
80+
args.save_dir = os.path.join(args.save_dir, datetime.datetime.now().strftime('%Y-%m-%d-%H-%M'))
81+
82+
print("\nParameters:")
83+
for attr, value in sorted(args.__dict__.items()):
84+
print("\t{}={}".format(attr.upper(), value))
85+
7386

7487
# model
7588
if args.snapshot is None:
7689
cnn = model.CNN_Text(args)
7790
else :
78-
print('Loading model from [%s]...' % args.snapshot)
79-
cnn = torch.load(args.snapshot)
91+
print('\nLoading model from [%s]...' % args.snapshot)
92+
try:
93+
cnn = torch.load(args.snapshot)
94+
except :
95+
print("Sorry, This snapshot doesn't exist."); exit()
96+
8097

8198
# train or predict
8299
if args.predict is not None:
83100
label = train.predict(args.predict, cnn, text_field, label_field)
84-
print('\n[Text] %s'% args.predict)
85-
print('[Label] %s\n'% label)
86-
else:
101+
print('\n[Text] {}[Label] {}\n'.format(args.predict, label))
102+
elif args.test :
103+
try:
104+
train.eval(test_iter, cnn, args)
105+
except Exception as e:
106+
print("\nSorry. The test dataset doesn't exist.\n")
107+
else :
108+
print()
87109
train.train(train_iter, dev_iter, cnn, args)
110+
88111

model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ def __init__(self, args):
1616
Ks = args.kernel_sizes
1717

1818
self.embed = nn.Embedding(V, D)
19-
#self.convs1 = [nn.Conv2d(Ci, Co, (K, D)) for K in Ks]
19+
self.convs1 = [nn.Conv2d(Ci, Co, (K, D)) for K in Ks]
20+
'''
2021
self.conv13 = nn.Conv2d(Ci, Co, (3, D))
2122
self.conv14 = nn.Conv2d(Ci, Co, (4, D))
2223
self.conv15 = nn.Conv2d(Ci, Co, (5, D))
24+
'''
2325
self.dropout = nn.Dropout(args.dropout)
2426
self.fc1 = nn.Linear(len(Ks)*Co, C)
2527

@@ -36,12 +38,15 @@ def forward(self, x):
3638
x = Variable(x)
3739

3840
x = x.unsqueeze(1) # (N,Ci,W,D)
39-
#x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
40-
#x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks)
41+
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
42+
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks)
43+
x = torch.cat(x, 1)
44+
'''
4145
x1 = self.conv_and_pool(x,self.conv13) #(N,Co)
4246
x2 = self.conv_and_pool(x,self.conv14) #(N,Co)
4347
x3 = self.conv_and_pool(x,self.conv15) #(N,Co)
4448
x = torch.cat((x1, x2, x3), 1) # (N,len(Ks)*Co)
49+
'''
4550
x = self.dropout(x) # (N,len(Ks)*Co)
4651
logit = self.fc1(x) # (N,C)
4752
return logit

mydatasets.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,46 @@
11
import re
22
import os
33
import random
4+
import tarfile
5+
from six.moves import urllib
46
from torchtext import data
57

68

7-
class MR(data.ZipDataset):
9+
class TarDataset(data.Dataset):
10+
"""Defines a Dataset loaded from a downloadable tar archive.
11+
12+
Attributes:
13+
url: URL where the tar archive can be downloaded.
14+
filename: Filename of the downloaded tar archive.
15+
dirname: Name of the top-level directory within the zip archive that
16+
contains the data files.
17+
"""
18+
19+
@classmethod
20+
def download_or_unzip(cls, root):
21+
path = os.path.join(root, cls.dirname)
22+
if not os.path.isdir(path):
23+
tpath = os.path.join(root, cls.filename)
24+
if not os.path.isfile(tpath):
25+
print('downloading')
26+
urllib.request.urlretrieve(cls.url, tpath)
27+
with tarfile.open(tpath, 'r') as tfile:
28+
print('extracting')
29+
tfile.extractall(root)
30+
return os.path.join(path, '')
31+
32+
33+
class MR(TarDataset):
834

935
url = 'https://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz'
1036
filename = 'rt-polaritydata.tar'
11-
dirname = os.path.join('rt-polaritydata', 'rt-polaritydata')
37+
dirname = 'rt-polaritydata'
1238

1339
@staticmethod
1440
def sort_key(ex):
1541
return len(ex.text)
1642

17-
def __init__(self, text_field, label_field, path='.', examples=None, **kwargs):
43+
def __init__(self, text_field, label_field, path=None, examples=None, **kwargs):
1844
"""Create an MR dataset instance given a path and fields.
1945
2046
Arguments:
@@ -49,7 +75,7 @@ def clean_str(string):
4975
fields = [('text', text_field), ('label', label_field)]
5076

5177
if examples is None:
52-
path = os.path.expanduser(path)
78+
path = self.dirname if path is None else path
5379
examples = []
5480
with open(os.path.join(path, 'rt-polarity.neg')) as f:
5581
examples += [
@@ -60,13 +86,14 @@ def clean_str(string):
6086
super(MR, self).__init__(examples, fields, **kwargs)
6187

6288
@classmethod
63-
def splits(cls, text_field, label_field, dev_ratio=.1, root='.', **kwargs):
89+
def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True ,root='.', **kwargs):
6490
"""Create dataset objects for splits of the MR dataset.
6591
6692
Arguments:
6793
text_field: The field that will be used for the sentence.
6894
label_field: The field that will be used for label data.
6995
dev_ratio: The ratio that will be used to get split validation dataset.
96+
shuffle: Whether to shuffle the data before split.
7097
root: The root directory that the dataset's zip archive will be
7198
expanded into; therefore the directory in whose trees
7299
subdirectory the data files will be stored.
@@ -76,7 +103,7 @@ def splits(cls, text_field, label_field, dev_ratio=.1, root='.', **kwargs):
76103
"""
77104
path = cls.download_or_unzip(root)
78105
examples = cls(text_field, label_field, path=path, **kwargs).examples
79-
random.shuffle(examples)
106+
if shuffle: random.shuffle(examples)
80107
dev_index = -1 * int(dev_ratio*len(examples))
81108

82109
return (cls(text_field, label_field, examples=examples[:dev_index]),
File renamed without changes.
File renamed without changes.

train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ def train(train_iter, dev_iter, model, args):
3131
corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
3232
accuracy = corrects/batch.batch_size * 100.0
3333
sys.stdout.write(
34-
'\rBatch[{}] - loss: {:.6f} acc: {:.4f}%'.format(steps, loss.data[0], accuracy))
34+
'\rBatch[{}] - loss: {:.6f} acc: {:.4f}%({}/{})'.format(steps,
35+
loss.data[0],
36+
accuracy,
37+
corrects,
38+
batch.batch_size))
3539
if steps % args.test_interval == 0:
3640
eval(dev_iter, model, args)
3741
if steps % args.save_interval == 0:
@@ -75,6 +79,7 @@ def predict(text, model, text_field, label_feild):
7579
text = [[text_field.vocab.stoi[x] for x in text]]
7680
x = text_field.tensor_type(text)
7781
x = autograd.Variable(x, volatile=True)
82+
print(x)
7883
output = model(x)
7984
_, predicted = torch.max(output, 1)
8085
return label_feild.vocab.itos[predicted.data[0][0]+1]

trainDevTestTrees_PTB.zip

771 KB
Binary file not shown.

0 commit comments

Comments
 (0)