Skip to content

Commit 2533be9

Browse files
authored
Merge pull request Shawn1993#11 from SeanRosario/master
Fixed cuda tensor bug during prediction
2 parents 2696894 + 220d081 commit 2533be9

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

main.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,18 @@ def mr(text_field, label_field, **kargs):
100100

101101
# train or predict
102102
if args.predict is not None:
103-
label = train.predict(args.predict, cnn, text_field, label_field)
104-
print('\n[Text] {}[Label] {}\n'.format(args.predict, label))
103+
label = train.predict(args.predict, cnn, text_field, label_field, args.cuda)
104+
print('\n[Text] {}\n[Label] {}\n'.format(args.predict, label))
105105
elif args.test :
106106
try:
107107
train.eval(test_iter, cnn, args)
108108
except Exception as e:
109109
print("\nSorry. The test dataset doesn't exist.\n")
110110
else :
111111
print()
112-
train.train(train_iter, dev_iter, cnn, args)
113-
112+
try:
113+
train.train(train_iter, dev_iter, cnn, args)
114+
except KeyboardInterrupt:
115+
print('-' * 89)
116+
print('Exiting from training early')
114117

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,16 @@ def eval(data_iter, model, args):
7474
size))
7575

7676

77-
def predict(text, model, text_field, label_feild):
77+
def predict(text, model, text_field, label_feild, cuda_flag):
7878
assert isinstance(text, str)
7979
model.eval()
8080
# text = text_field.tokenize(text)
8181
text = text_field.preprocess(text)
8282
text = [[text_field.vocab.stoi[x] for x in text]]
8383
x = text_field.tensor_type(text)
8484
x = autograd.Variable(x, volatile=True)
85+
if cuda_flag:
86+
x =x.cuda()
8587
print(x)
8688
output = model(x)
8789
_, predicted = torch.max(output, 1)

0 commit comments

Comments
 (0)