Skip to content

Commit 49546bf

Browse files
authored
Merge pull request Shawn1993#12 from oneTaken/master
update details for better use
2 parents 2533be9 + 8dce7f2 commit 49546bf

File tree

4 files changed

+51
-37
lines changed

4 files changed

+51
-37
lines changed

main.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
parser.add_argument('-test-interval', type=int, default=100, help='how many steps to wait before testing [default: 100]')
2020
parser.add_argument('-save-interval', type=int, default=500, help='how many steps to wait before saving [default:500]')
2121
parser.add_argument('-save-dir', type=str, default='snapshot', help='where to save the snapshot')
22+
parser.add_argument('-early-stop', type=int, default=1000, help='iteration numbers to stop without performance increasing')
23+
parser.add_argument('-save-best', type=bool, default=True, help='whether to save when get best performance')
2224
# data
23-
parser.add_argument('-shuffle', action='store_true', default=False, help='shuffle the data every epoch' )
25+
parser.add_argument('-shuffle', action='store_true', default=False, help='shuffle the data every epoch')
2426
# model
2527
parser.add_argument('-dropout', type=float, default=0.5, help='the probability for dropout [default: 0.5]')
2628
parser.add_argument('-max-norm', type=float, default=3.0, help='l2 constraint of parameters [default: 3.0]')
@@ -30,7 +32,7 @@
3032
parser.add_argument('-static', action='store_true', default=False, help='fix the embedding')
3133
# device
3234
parser.add_argument('-device', type=int, default=-1, help='device to use for iterate data, -1 mean cpu [default: -1]')
33-
parser.add_argument('-no-cuda', action='store_true', default=False, help='disable the gpu' )
35+
parser.add_argument('-no-cuda', action='store_true', default=False, help='disable the gpu')
3436
# option
3537
parser.add_argument('-snapshot', type=str, default=None, help='filename of model snapshot [default: None]')
3638
parser.add_argument('-predict', type=str, default=None, help='predict the sentence given')
@@ -69,7 +71,7 @@ def mr(text_field, label_field, **kargs):
6971
text_field = data.Field(lower=True)
7072
label_field = data.Field(sequential=False)
7173
train_iter, dev_iter = mr(text_field, label_field, device=-1, repeat=False)
72-
#train_iter, dev_iter, test_iter = sst(text_field, label_field, device=-1, repeat=False)
74+
# train_iter, dev_iter, test_iter = sst(text_field, label_field, device=-1, repeat=False)
7375

7476

7577
# update args and print
@@ -85,33 +87,30 @@ def mr(text_field, label_field, **kargs):
8587

8688

8789
# model
88-
if args.snapshot is None:
89-
cnn = model.CNN_Text(args)
90-
else :
91-
print('\nLoading model from [%s]...' % args.snapshot)
92-
try:
93-
cnn = torch.load(args.snapshot)
94-
except :
95-
print("Sorry, This snapshot doesn't exist."); exit()
90+
cnn = model.CNN_Text(args)
91+
if args.snapshot is not None:
92+
print('\nLoading model from {}...'.format(args.snapshot))
93+
cnn.load_state_dict(torch.load(args.snapshot))
9694

9795
if args.cuda:
96+
torch.cuda.set_device(args.device)
9897
cnn = cnn.cuda()
9998

10099

101100
# train or predict
102101
if args.predict is not None:
103102
label = train.predict(args.predict, cnn, text_field, label_field, args.cuda)
104103
print('\n[Text] {}\n[Label] {}\n'.format(args.predict, label))
105-
elif args.test :
104+
elif args.test:
106105
try:
107106
train.eval(test_iter, cnn, args)
108107
except Exception as e:
109108
print("\nSorry. The test dataset doesn't exist.\n")
110-
else :
109+
else:
111110
print()
112111
try:
113112
train.train(train_iter, dev_iter, cnn, args)
114113
except KeyboardInterrupt:
115-
print('-' * 89)
114+
print('\n' + '-' * 89)
116115
print('Exiting from training early')
117116

model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4+
from torch.autograd import Variable
45

5-
class CNN_Text(nn.Module):
6+
7+
class CNN_Text(nn.Module):
68

79
def __init__(self, args):
8-
super(CNN_Text,self).__init__()
10+
super(CNN_Text, self).__init__()
911
self.args = args
1012

1113
V = args.embed_num
@@ -16,7 +18,7 @@ def __init__(self, args):
1618
Ks = args.kernel_sizes
1719

1820
self.embed = nn.Embedding(V, D)
19-
#self.convs1 = [nn.Conv2d(Ci, Co, (K, D)) for K in Ks]
21+
# self.convs1 = [nn.Conv2d(Ci, Co, (K, D)) for K in Ks]
2022
self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (K, D)) for K in Ks])
2123
'''
2224
self.conv13 = nn.Conv2d(Ci, Co, (3, D))
@@ -27,23 +29,21 @@ def __init__(self, args):
2729
self.fc1 = nn.Linear(len(Ks)*Co, C)
2830

2931
def conv_and_pool(self, x, conv):
30-
x = F.relu(conv(x)).squeeze(3) #(N,Co,W)
32+
x = F.relu(conv(x)).squeeze(3) # (N, Co, W)
3133
x = F.max_pool1d(x, x.size(2)).squeeze(2)
3234
return x
3335

34-
3536
def forward(self, x):
36-
x = self.embed(x) # (N,W,D)
37+
x = self.embed(x) # (N, W, D)
3738

3839
if self.args.static:
3940
x = Variable(x)
4041

41-
x = x.unsqueeze(1) # (N,Ci,W,D)
42-
43-
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
42+
x = x.unsqueeze(1) # (N, Ci, W, D)
4443

44+
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] # [(N, Co, W), ...]*len(Ks)
4545

46-
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks)
46+
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N, Co), ...]*len(Ks)
4747

4848
x = torch.cat(x, 1)
4949

@@ -53,6 +53,6 @@ def forward(self, x):
5353
x3 = self.conv_and_pool(x,self.conv15) #(N,Co)
5454
x = torch.cat((x1, x2, x3), 1) # (N,len(Ks)*Co)
5555
'''
56-
x = self.dropout(x) # (N,len(Ks)*Co)
57-
logit = self.fc1(x) # (N,C)
56+
x = self.dropout(x) # (N, len(Ks)*Co)
57+
logit = self.fc1(x) # (N, C)
5858
return logit

mydatasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import random
44
import tarfile
5-
from six.moves import urllib
5+
import urllib
66
from torchtext import data
77

88

@@ -86,7 +86,7 @@ def clean_str(string):
8686
super(MR, self).__init__(examples, fields, **kwargs)
8787

8888
@classmethod
89-
def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True ,root='.', **kwargs):
89+
def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True, root='.', **kwargs):
9090
"""Create dataset objects for splits of the MR dataset.
9191
9292
Arguments:

train.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ def train(train_iter, dev_iter, model, args):
1212
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
1313

1414
steps = 0
15+
best_acc = 0
16+
last_step = 0
1517
model.train()
1618
for epoch in range(1, args.epochs+1):
1719
for batch in train_iter:
@@ -40,12 +42,17 @@ def train(train_iter, dev_iter, model, args):
4042
corrects,
4143
batch.batch_size))
4244
if steps % args.test_interval == 0:
43-
eval(dev_iter, model, args)
44-
if steps % args.save_interval == 0:
45-
if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir)
46-
save_prefix = os.path.join(args.save_dir, 'snapshot')
47-
save_path = '{}_steps{}.pt'.format(save_prefix, steps)
48-
torch.save(model, save_path)
45+
dev_acc = eval(dev_iter, model, args)
46+
if dev_acc > best_acc:
47+
best_acc = dev_acc
48+
last_step = steps
49+
if args.save_best:
50+
save(model, args.save_dir, 'best', steps)
51+
else:
52+
if steps - last_step >= args.early_stop:
53+
print('early stop by {} steps.'.format(args.early_stop))
54+
elif steps % args.save_interval == 0:
55+
save(model, args.save_dir, 'snapshot', steps)
4956

5057

5158
def eval(data_iter, model, args):
@@ -65,13 +72,13 @@ def eval(data_iter, model, args):
6572
[1].view(target.size()).data == target.data).sum()
6673

6774
size = len(data_iter.dataset)
68-
avg_loss = avg_loss/size
75+
avg_loss /= size
6976
accuracy = 100.0 * corrects/size
70-
model.train()
7177
print('\nEvaluation - loss: {:.6f} acc: {:.4f}%({}/{}) \n'.format(avg_loss,
7278
accuracy,
7379
corrects,
7480
size))
81+
return accuracy
7582

7683

7784
def predict(text, model, text_field, label_feild, cuda_flag):
@@ -83,8 +90,16 @@ def predict(text, model, text_field, label_feild, cuda_flag):
8390
x = text_field.tensor_type(text)
8491
x = autograd.Variable(x, volatile=True)
8592
if cuda_flag:
86-
x =x.cuda()
93+
x = x.cuda()
8794
print(x)
8895
output = model(x)
8996
_, predicted = torch.max(output, 1)
9097
return label_feild.vocab.itos[predicted.data[0][0]+1]
98+
99+
100+
def save(model, save_dir, save_prefix, steps):
101+
if not os.path.isdir(save_dir):
102+
os.makedirs(save_dir)
103+
save_prefix = os.path.join(save_dir, save_prefix)
104+
save_path = '{}_steps_{}.pt'.format(save_prefix, steps)
105+
torch.save(model.state_dict(), save_path)

0 commit comments

Comments
 (0)