Skip to content

Commit 67b73c3

Browse files
committed
Fixed bug in predict for cuda tensor
1 parent 2696894 commit 67b73c3

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def mr(text_field, label_field, **kargs):
100100

101101
# train or predict
102102
if args.predict is not None:
103-
label = train.predict(args.predict, cnn, text_field, label_field)
103+
label = train.predict(args.predict, cnn, text_field, label_field, args.cuda)
104104
print('\n[Text] {}[Label] {}\n'.format(args.predict, label))
105105
elif args.test :
106106
try:

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,16 @@ def eval(data_iter, model, args):
7474
size))
7575

7676

77-
def predict(text, model, text_field, label_feild):
77+
def predict(text, model, text_field, label_feild, cuda_flag):
7878
assert isinstance(text, str)
7979
model.eval()
8080
# text = text_field.tokenize(text)
8181
text = text_field.preprocess(text)
8282
text = [[text_field.vocab.stoi[x] for x in text]]
8383
x = text_field.tensor_type(text)
8484
x = autograd.Variable(x, volatile=True)
85+
if cuda_flag:
86+
x =x.cuda()
8587
print(x)
8688
output = model(x)
8789
_, predicted = torch.max(output, 1)

0 commit comments

Comments
 (0)