33import torch
44import torch .autograd as autograd
55import torch .nn .functional as F
6+ import numpy as np
67
78
89def train (train_iter , dev_iter , model , args ):
@@ -23,7 +24,7 @@ def train(train_iter, dev_iter, model, args):
2324 feature , target = feature .cuda (), target .cuda ()
2425
2526 optimizer .zero_grad ()
26- logit = model (feature )
27+ logit , embedding = model (feature )
2728
2829 #print('logit vector', logit.size())
2930 #print('target vector', target.size())
@@ -58,26 +59,35 @@ def train(train_iter, dev_iter, model, args):
5859def eval (data_iter , model , args ):
5960 model .eval ()
6061 corrects , avg_loss = 0 , 0
62+ embeddings = []
6163 for batch in data_iter :
6264 feature , target = batch .text , batch .label
6365 feature .t_ (), target .data .sub_ (1 ) # batch first, index align
6466 if args .cuda :
6567 feature , target = feature .cuda (), target .cuda ()
6668
67- logit = model (feature )
69+ logit , embedding = model (feature )
6870 loss = F .cross_entropy (logit , target , size_average = False )
6971
7072 avg_loss += loss
7173 corrects += (torch .max (logit , 1 )
7274 [1 ].view (target .size ()).data == target .data ).sum ()
73-
75+ embeddings .extend (embedding )
76+
7477 size = len (data_iter .dataset )
7578 avg_loss /= size
7679 accuracy = 100.0 * corrects / size
7780 print ('\n Evaluation - loss: {:.6f} acc: {:.4f}%({}/{}) \n ' .format (avg_loss ,
7881 accuracy ,
7982 corrects ,
8083 size ))
84+ if args .test :
85+ new_embeddings = []
86+ for idx , embed in enumerate (embeddings ):
87+ print (embed .type ())
88+ new_embeddings .append (embed .detach ().cpu ().numpy ())
89+ print (len (embeddings ), len (embeddings [0 ]))
90+ np .save ('./embeddings.npy' , np .array (new_embeddings ))
8191 return accuracy
8292
8393
@@ -92,7 +102,7 @@ def predict(text, model, text_field, label_feild, cuda_flag):
92102 if cuda_flag :
93103 x = x .cuda ()
94104 print (x )
95- output = model (x )
105+ output , embedding = model (x )
96106 _ , predicted = torch .max (output , 1 )
97107 #return label_feild.vocab.itos[predicted.data[0][0]+1]
98108 return label_feild .vocab .itos [predicted .data [0 ]+ 1 ]
0 commit comments