Skip to content

Commit 5f1aba1

Browse files
committed
Compatability updates
See Shawn1993@d1f176c
1 parent 75cee56 commit 5f1aba1

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

cnn_text_classification.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from sklearn.model_selection import train_test_split as split
1010
from sklearn.utils.class_weight import compute_sample_weight
1111
from time import time
12-
from torch.autograd import Variable
1312
from torchtext.data import Dataset, Example, Field, Iterator, Pipeline
1413

1514

@@ -68,7 +67,7 @@ def __eval(self, data_iter):
6867
softmax = nn.Softmax(dim=1) if self.scoring == "roc_auc" else None
6968

7069
for batch in data_iter:
71-
feature, target = batch.text.data.t(), batch.label.data.sub(1)
70+
feature, target = batch.text.t_(), batch.label.sub_(1)
7271

7372
if self.cuda and torch.cuda.is_available():
7473
feature, target = feature.cuda(), target.cuda()
@@ -80,10 +79,10 @@ def __eval(self, data_iter):
8079
if self.scoring == "roc_auc":
8180
pred = [[float(p) for p in dist] for dist in softmax(logit)]
8281
else:
83-
pred = torch.max(logit, 1)[1].view(target.size()).data.tolist()
82+
pred = torch.max(logit, 1)[1].view(target.size()).tolist()
8483

8584
preds += pred
86-
targets += target.data.tolist()
85+
targets += target.tolist()
8786

8887
targets = [self.__label_field.vocab.itos[targ + 1] for targ in targets]
8988

@@ -131,7 +130,7 @@ def fit(self, X, y, sample_weight=None):
131130

132131
for epoch in range(self.epochs):
133132
for batch in train_iter:
134-
feature, target = batch.text.data.t(), batch.label.data.sub(1)
133+
feature, target = batch.text.t_(), batch.label.sub_(1)
135134

136135
if self.cuda and torch.cuda.is_available():
137136
feature, target = feature.cuda(), target.cuda()
@@ -167,7 +166,6 @@ def fit(self, X, y, sample_weight=None):
167166
if self.verbose > 0:
168167
self.__print_elapsed_time(time() - start)
169168

170-
torch.cuda.empty_cache()
171169
return self
172170

173171
def __predict(self, X):
@@ -182,17 +180,16 @@ def __predict(self, X):
182180
text = self.__text_field.preprocess(text)
183181
text = self.__pad(text, max_kernel_size, True)
184182
text = [[self.__text_field.vocab.stoi[x] for x in text]]
185-
x = Variable(torch.tensor(text))
183+
x = torch.tensor(text)
186184
x = x.cuda() if self.cuda and torch.cuda.is_available() else x
187185

188186
y_output.append(self.__model(x))
189187

190-
torch.cuda.empty_cache()
191188
return y_output
192189

193190
def predict(self, X):
194191
y_pred = [torch.argmax(yi, 1) for yi in self.__predict(X)]
195-
return [self.__label_field.vocab.itos[yi.data[0] + 1] for yi in y_pred]
192+
return [self.__label_field.vocab.itos[yi.item() + 1] for yi in y_pred]
196193

197194
def predict_proba(self, X):
198195
softmax = nn.Softmax(dim=1)
@@ -253,7 +250,7 @@ def __preprocess_text(self, text):
253250
if self.preprocessor is None:
254251
return self.__clean_str(text)
255252

256-
return self.__clean_str(self.preprocessor(text))
253+
return self.preprocessor(text)
257254

258255
def __print_elapsed_time(self, seconds):
259256
sc = round(seconds)
@@ -289,10 +286,10 @@ def __init__(self, embed_num, embed_dim, class_num, kernel_num,
289286

290287
Ks = kernel_sizes
291288
module_list = [nn.Conv2d(1, kernel_num, (K, embed_dim)) for K in Ks]
292-
self.__convs1 = nn.ModuleList(module_list)
289+
self.__convs = nn.ModuleList(module_list)
293290
self.__dropout = nn.Dropout(dropout)
294-
self.__fc1 = nn.Linear(len(Ks) * kernel_num, class_num)
295-
self.__static = static
291+
self.__fc = nn.Linear(len(Ks) * kernel_num, class_num)
292+
self.__embed.weight.requires_grad = not static
296293

297294
if activation_func == "relu":
298295
self.__f = F.relu
@@ -302,11 +299,11 @@ def __init__(self, embed_num, embed_dim, class_num, kernel_num,
302299
self.__f = lambda x: x
303300

304301
def forward(self, x):
305-
x = Variable(self.__embed(x)) if self.__static else self.__embed(x)
306-
x = x.unsqueeze(1)
307-
x = [self.__f(conv(x), inplace=True).squeeze(3) for conv in self.__convs1]
302+
x = self.__embed(x).unsqueeze(1)
303+
x = [self.__f(cnv(x), inplace=True).squeeze(3) for cnv in self.__convs]
308304
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
309-
return self.__fc1(self.__dropout(torch.cat(x, 1)))
305+
return self.__fc(self.__dropout(torch.cat(x, 1)))
306+
310307

311308
class _Eval():
312309
def __init__(self, preds):

0 commit comments

Comments
 (0)