55from __future__ import division
66from __future__ import print_function
77
8- from txt .models .tsf import TSF
8+ import pdb
9+
10+ import cPickle as pkl
11+ import tensorflow as tf
12+ import json
13+ import os
14+
15+ from txtgen .hyperparams import HParams
16+ from txtgen .models .tsf import TSF
917
1018from trainer_base import TrainerBase
11- from utils import log_print , get_batches
19+ from utils import *
20+ from stats import Stats
1221
1322class TSFTrainer (TrainerBase ):
1423 """TSF trainer."""
1524 def __init__ (self , hparams = None ):
16- TrainerBase .__init__ (self , hparams )
25+ self ._hparams = HParams (hparams , self .default_hparams ())
26+ TrainerBase .__init__ (self , self ._hparams )
1727
1828 @staticmethod
1929 def default_hparams ():
2030 return {
21- "name" : "tsf"
31+ "name" : "tsf" ,
2232 "rho" : 1. ,
2333 "gamma_init" : 1 ,
2434 "gamma_decay" : 0.5 ,
2535 "gamma_min" : 0.001 ,
26- "disp_interval" : 1000 ,
27- "batch_size" : 128
36+ "disp_interval" : 10 ,
37+ "batch_size" : 128 ,
38+ "vocab_size" : 10000 ,
39+ "max_len" : 20 ,
40+ "max_epoch" : 20
2841 }
2942
3043 def load_data (self ):
@@ -41,7 +54,7 @@ def load_data(self):
4154 return vocab , train , val , test
4255
4356 def eval_model (self , model , sess , vocab , data0 , data1 , outupt_path ):
44- batches = utils . get_batches (data0 , data1 , vocab ["word2id" ],
57+ batches = get_batches (data0 , data1 , vocab ["word2id" ],
4558 self ._hparams .batch_size , shuffle = False )
4659 losses = Stats ()
4760
@@ -57,26 +70,26 @@ def eval_model(self, model, sess, vocab, data0, data1, outupt_path):
5770 w_loss = batch_size , w_g = batch_size ,
5871 w_ppl = word_size , w_d = batch_size ,
5972 w_d0 = batch_size , w_d1 = batch_size )
60- ori = utils . logits2word (logits_ori , vocab ["word2id" ])
61- tsf = utils . logits2word (logits_tsf , vocab ["word2id" ])
73+ ori = logits2word (logits_ori , vocab ["word2id" ])
74+ tsf = logits2word (logits_tsf , vocab ["word2id" ])
6275 half = self ._hparams .batch_size / 2
6376 data0_ori += tsf [:half ]
6477 data1_ori += tsf [half :]
6578
66- utils . write_sent (data0_ori , output_path + ".0.tsf" )
67- utils . write_sent (data1_ori , output_path + ".1.tsf" )
79+ write_sent (data0_ori , output_path + ".0.tsf" )
80+ write_sent (data1_ori , output_path + ".1.tsf" )
6881 return losses
6982
7083 def train (self ):
71- if FLAGS . config :
72- with open (FLAGS .config ) as f :
84+ if " config" in self . _hparams . keys () :
85+ with open (self . _hparams .config ) as f :
7386 self ._hparams = HParams (pkl .load (f ))
7487
7588 log_print ("Start training with hparams:" )
76- log_print (self ._hparams )
77- if not FLAGS . config :
89+ log_print (json . dumps ( self ._hparams . todict (), indent = 2 ) )
90+ if not " config" in self . _hparams . keys () :
7891 with open (os .path .join (self ._hparams .expt_dir , self ._hparams .name )
79- + ".config" ) as f :
92+ + ".config" , "w" ) as f :
8093 pkl .dump (self ._hparams , f )
8194
8295 vocab , train , val , test = self .load_data ()
@@ -90,45 +103,41 @@ def train(self):
90103 model = TSF (self ._hparams )
91104 log_print ("finished building model" )
92105
93- if FLAGS . model :
94- model .saver .restore (ses , FLAGS .model )
106+ if " model" in self . _hparams . keys () :
107+ model .saver .restore (ses , self . _hparams .model )
95108 else :
96- sess .run (tf .global_variable_initializer ())
97- sess .run (tf .local_variable_initializer ())
109+ sess .run (tf .global_variables_initializer ())
110+ sess .run (tf .local_variables_initializer ())
98111
99112 losses = Stats ()
100113 gamma = self ._hparams .gamma_init
101114 step = 0
102115 for epoch in range (self ._hparams ["max_epoch" ]):
103- for batch in utils . get_batches (train [0 ], train [1 ], vocab ["word2id" ],
116+ for batch in get_batches (train [0 ], train [1 ], vocab ["word2id" ],
104117 model ._hparams .batch_size , shuffle = True ):
105- loss_d0 = model .train_d0_step (sess , batch , self ._hparams .rho ,
106- self ._hparams .gamma ,
107- model ._hparams .learning_rate )
108- loss_d1 = model .train_d1_step (sess , batch , self ._hparams .rho ,
109- self ._hparams .gamma ,
110- model ._hparams .learning_rate )
118+ pdb .set_trace ()
119+ loss_d0 = model .train_d0_step (sess , batch , self ._hparams .rho , gamma )
120+ loss_d1 = model .train_d1_step (sess , batch , self ._hparams .rho , gamma )
111121
112122 if loss_d0 < 1.2 and loss_d1 < 1.2 :
113123 loss , loss_g , ppl_g , loss_d = model .train_g_step (
114- sess , batch , self ._hparams .rho , self ._hparams .gamma ,
115- model ._hparams .leanring_rate )
124+ sess , batch , self ._hparams .rho , gamma )
116125 else :
117126 loss , loss_g , ppl_g , loss_d = model .train_ae_step (
118- sess , batch , self ._hparams .rho , self ._hparams .gamma ,
119- model ._hparams .leanring_rate )
127+ sess , batch , self ._hparams .rho , gamma )
120128
121- losses .add (loss , loss_g , ppl_g , loss_d , loss_d0 , loss_d1 )
129+ losses .append (loss , loss_g , ppl_g , loss_d , loss_d0 , loss_d1 )
122130
123131 step += 1
124- if step % self ._hparams .disp_interval :
125- log_print (losses )
132+ if step % self ._hparams .disp_interval == 0 :
133+ log_print (str ( losses ) )
126134 losses .reset ()
127135
128136 # eval on dev
129137 dev_loss = self .eval_model (
130138 model , sess , vocab , val [0 ], val [1 ],
131- os .path .join (FLAGS .expt , "sentiment.dev.epoch%d" % (epoch )))
139+ os .path .join (self ._hparams .expt_dir ,
140+ "sentiment.dev.epoch%d" % (epoch )))
132141 log_print ("dev:" + dev_loss )
133142 if dev_loss .loss < best_dev :
134143 best_dev = dev_loss .loss
@@ -143,9 +152,9 @@ def train(self):
143152 gamma = max (self ._hparams .gamma_min , gamma * self ._hparams .gamma_decay )
144153
145154
146- def main ():
155+ def main (unused_args ):
147156 trainer = TSFTrainer ()
148157 trainer .train ()
149158
150159if __name__ == "__main__" :
151- main ()
160+ tf . app . run ()
0 commit comments