1919parser .add_argument ('-test-interval' , type = int , default = 100 , help = 'how many steps to wait before testing [default: 100]' )
2020parser .add_argument ('-save-interval' , type = int , default = 500 , help = 'how many steps to wait before saving [default:500]' )
2121parser .add_argument ('-save-dir' , type = str , default = 'snapshot' , help = 'where to save the snapshot' )
22+ parser .add_argument ('-early-stop' , type = int , default = 1000 , help = 'iteration numbers to stop without performance increasing' )
23+ parser .add_argument ('-save-best' , type = bool , default = True , help = 'whether to save when get best performance' )
2224# data
23- parser .add_argument ('-shuffle' , action = 'store_true' , default = False , help = 'shuffle the data every epoch' )
25+ parser .add_argument ('-shuffle' , action = 'store_true' , default = False , help = 'shuffle the data every epoch' )
2426# model
2527parser .add_argument ('-dropout' , type = float , default = 0.5 , help = 'the probability for dropout [default: 0.5]' )
2628parser .add_argument ('-max-norm' , type = float , default = 3.0 , help = 'l2 constraint of parameters [default: 3.0]' )
3032parser .add_argument ('-static' , action = 'store_true' , default = False , help = 'fix the embedding' )
3133# device
3234parser .add_argument ('-device' , type = int , default = - 1 , help = 'device to use for iterate data, -1 mean cpu [default: -1]' )
33- parser .add_argument ('-no-cuda' , action = 'store_true' , default = False , help = 'disable the gpu' )
35+ parser .add_argument ('-no-cuda' , action = 'store_true' , default = False , help = 'disable the gpu' )
3436# option
3537parser .add_argument ('-snapshot' , type = str , default = None , help = 'filename of model snapshot [default: None]' )
3638parser .add_argument ('-predict' , type = str , default = None , help = 'predict the sentence given' )
@@ -69,7 +71,7 @@ def mr(text_field, label_field, **kargs):
6971text_field = data .Field (lower = True )
7072label_field = data .Field (sequential = False )
7173train_iter , dev_iter = mr (text_field , label_field , device = - 1 , repeat = False )
72- #train_iter, dev_iter, test_iter = sst(text_field, label_field, device=-1, repeat=False)
74+ # train_iter, dev_iter, test_iter = sst(text_field, label_field, device=-1, repeat=False)
7375
7476
7577# update args and print
@@ -85,33 +87,30 @@ def mr(text_field, label_field, **kargs):
8587
8688
8789# model
88- if args .snapshot is None :
89- cnn = model .CNN_Text (args )
90- else :
91- print ('\n Loading model from [%s]...' % args .snapshot )
92- try :
93- cnn = torch .load (args .snapshot )
94- except :
95- print ("Sorry, This snapshot doesn't exist." ); exit ()
90+ cnn = model .CNN_Text (args )
91+ if args .snapshot is not None :
92+ print ('\n Loading model from {}...' .format (args .snapshot ))
93+ cnn .load_state_dict (torch .load (args .snapshot ))
9694
9795if args .cuda :
96+ torch .cuda .set_device (args .device )
9897 cnn = cnn .cuda ()
9998
10099
101100# train or predict
102101if args .predict is not None :
103102 label = train .predict (args .predict , cnn , text_field , label_field , args .cuda )
104103 print ('\n [Text] {}\n [Label] {}\n ' .format (args .predict , label ))
105- elif args .test :
104+ elif args .test :
106105 try :
107106 train .eval (test_iter , cnn , args )
108107 except Exception as e :
109108 print ("\n Sorry. The test dataset doesn't exist.\n " )
110- else :
109+ else :
111110 print ()
112111 try :
113112 train .train (train_iter , dev_iter , cnn , args )
114113 except KeyboardInterrupt :
115- print ('-' * 89 )
114+ print ('\n ' + ' -' * 89 )
116115 print ('Exiting from training early' )
117116
0 commit comments