Skip to content

Commit 8e2602c

Browse files
author
Ubuntu
committed
Save embeddings
1 parent 0ad2946 commit 8e2602c

File tree

3 files changed

+21
-11
lines changed

3 files changed

+21
-11
lines changed

main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@ def sarcasm(text_field, label_field, train_filepath, test_filepath, options, hea
125125
label = train.predict(args.predict, cnn, text_field, label_field, args.cuda)
126126
print('\n[Text] {}\n[Label] {}\n'.format(args.predict, label))
127127
elif args.test:
128-
try:
129-
train.eval(test_iter, cnn, args)
130-
except Exception as e:
131-
print("\nSorry. The test dataset doesn't exist.\n")
128+
# try:
129+
train.eval(test_iter, cnn, args)
130+
# except Exception as e:
131+
# print("\nSorry. The test dataset doesn't exist.\n")
132132
else:
133133
print()
134134
try:

model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,6 @@ def forward(self, x):
5353
x3 = self.conv_and_pool(x,self.conv15) #(N,Co)
5454
x = torch.cat((x1, x2, x3), 1) # (N,len(Ks)*Co)
5555
'''
56-
x = self.dropout(x) # (N, len(Ks)*Co)
57-
logit = self.fc1(x) # (N, C)
58-
return logit
56+
x_drop = self.dropout(x) # (N, len(Ks)*Co)
57+
logit = self.fc1(x_drop) # (N, C)
58+
return logit, x

train.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torch.autograd as autograd
55
import torch.nn.functional as F
6+
import numpy as np
67

78

89
def train(train_iter, dev_iter, model, args):
@@ -23,7 +24,7 @@ def train(train_iter, dev_iter, model, args):
2324
feature, target = feature.cuda(), target.cuda()
2425

2526
optimizer.zero_grad()
26-
logit = model(feature)
27+
logit, embedding = model(feature)
2728

2829
#print('logit vector', logit.size())
2930
#print('target vector', target.size())
@@ -58,26 +59,35 @@ def train(train_iter, dev_iter, model, args):
5859
def eval(data_iter, model, args):
5960
model.eval()
6061
corrects, avg_loss = 0, 0
62+
embeddings = []
6163
for batch in data_iter:
6264
feature, target = batch.text, batch.label
6365
feature.t_(), target.data.sub_(1) # batch first, index align
6466
if args.cuda:
6567
feature, target = feature.cuda(), target.cuda()
6668

67-
logit = model(feature)
69+
logit, embedding = model(feature)
6870
loss = F.cross_entropy(logit, target, size_average=False)
6971

7072
avg_loss += loss
7173
corrects += (torch.max(logit, 1)
7274
[1].view(target.size()).data == target.data).sum()
73-
75+
embeddings.extend(embedding)
76+
7477
size = len(data_iter.dataset)
7578
avg_loss /= size
7679
accuracy = 100.0 * corrects/size
7780
print('\nEvaluation - loss: {:.6f} acc: {:.4f}%({}/{}) \n'.format(avg_loss,
7881
accuracy,
7982
corrects,
8083
size))
84+
if args.test:
85+
new_embeddings = []
86+
for idx, embed in enumerate(embeddings):
87+
print(embed.type())
88+
new_embeddings.append(embed.detach().cpu().numpy())
89+
print(len(embeddings), len(embeddings[0]))
90+
np.save('./embeddings.npy', np.array(new_embeddings))
8191
return accuracy
8292

8393

@@ -92,7 +102,7 @@ def predict(text, model, text_field, label_feild, cuda_flag):
92102
if cuda_flag:
93103
x = x.cuda()
94104
print(x)
95-
output = model(x)
105+
output, embedding = model(x)
96106
_, predicted = torch.max(output, 1)
97107
#return label_feild.vocab.itos[predicted.data[0][0]+1]
98108
return label_feild.vocab.itos[predicted.data[0]+1]

0 commit comments

Comments
 (0)