Skip to content

Commit 36f4b18

Browse files
committed
runnable version
Former-commit-id: fc162e1
1 parent 032890b commit 36f4b18

File tree

8 files changed

+115
-88
lines changed

8 files changed

+115
-88
lines changed

examples/tsf/stats.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class Stats():
1010
def __init__(self):
11-
self.clear()
11+
self.reset()
1212

1313
def reset(self):
1414
self._loss, self._g, self._ppl, self._d, self._d0, self._d1 \
@@ -17,7 +17,7 @@ def reset(self):
1717
= 0, 0, 0, 0, 0, 0
1818

1919
def append(self, loss, g, ppl, d, d0, d1,
20-
w_loss=1., w_g=1., w_ppl=1. w_d=1, w_d0=1., w_d1=1.):
20+
w_loss=1., w_g=1., w_ppl=1., w_d=1, w_d0=1., w_d1=1.):
2121
self._loss.append(loss*w_loss)
2222
self._g.append(g*w_g)
2323
self._ppl.append(ppl*w_ppl)

examples/tsf/trainer_base.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from __future__ import division
66
from __future__ import print_function
77

8+
import pdb
9+
810
import tensorflow as tf
911

1012
from txtgen.hyperparams import HParams
@@ -19,17 +21,22 @@
1921
flags.DEFINE_string("config", "", "config to load.")
2022
flags.DEFINE_string("model", "", "model to load.")
2123

22-
class TrainerBase(objec):
24+
class TrainerBase(object):
2325
"""Base class for trainer."""
2426
def __init__(self, hparams=None):
27+
flags_hparams = self.default_hparams()
2528
if FLAGS.data_dir:
26-
self._hparams['data_dir'] = FLAGS.data_dir
29+
flags_hparams["data_dir"] = FLAGS.data_dir
2730
if FLAGS.expt_dir:
28-
self._hparams['expt_dir'] = FLAGS.expt_dir
31+
flags_hparams["expt_dir"] = FLAGS.expt_dir
2932
if FLAGS.log_dir:
30-
self._hparams['log_dir'] = FLAGS.log_dir
33+
flags_hparams["log_dir"] = FLAGS.log_dir
34+
if FLAGS.config:
35+
flags_hparams["config"] = FLAGS.config
36+
if FLAGS.model:
37+
flags_hparams["model"] = FLAGS.model
3138

32-
self._hparams = HParams(hparams, self.default_hparams())
39+
self._hparams = HParams(self._hparams, flags_hparams)
3340

3441
@staticmethod
3542
def default_hparams():

examples/tsf/tsf_trainer.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,39 @@
55
from __future__ import division
66
from __future__ import print_function
77

8-
from txt.models.tsf import TSF
8+
import pdb
9+
10+
import cPickle as pkl
11+
import tensorflow as tf
12+
import json
13+
import os
14+
15+
from txtgen.hyperparams import HParams
16+
from txtgen.models.tsf import TSF
917

1018
from trainer_base import TrainerBase
11-
from utils import log_print, get_batches
19+
from utils import *
20+
from stats import Stats
1221

1322
class TSFTrainer(TrainerBase):
1423
"""TSF trainer."""
1524
def __init__(self, hparams=None):
16-
TrainerBase.__init__(self, hparams)
25+
self._hparams = HParams(hparams, self.default_hparams())
26+
TrainerBase.__init__(self, self._hparams)
1727

1828
@staticmethod
1929
def default_hparams():
2030
return {
21-
"name": "tsf"
31+
"name": "tsf",
2232
"rho": 1.,
2333
"gamma_init": 1,
2434
"gamma_decay": 0.5,
2535
"gamma_min": 0.001,
26-
"disp_interval": 1000,
27-
"batch_size": 128
36+
"disp_interval": 10,
37+
"batch_size": 128,
38+
"vocab_size": 10000,
39+
"max_len": 20,
40+
"max_epoch": 20
2841
}
2942

3043
def load_data(self):
@@ -41,7 +54,7 @@ def load_data(self):
4154
return vocab, train, val, test
4255

4356
def eval_model(self, model, sess, vocab, data0, data1, outupt_path):
44-
batches = utils.get_batches(data0, data1, vocab["word2id"],
57+
batches = get_batches(data0, data1, vocab["word2id"],
4558
self._hparams.batch_size, shuffle=False)
4659
losses = Stats()
4760

@@ -57,26 +70,26 @@ def eval_model(self, model, sess, vocab, data0, data1, outupt_path):
5770
w_loss=batch_size, w_g=batch_size,
5871
w_ppl=word_size, w_d=batch_size,
5972
w_d0=batch_size, w_d1=batch_size)
60-
ori = utils.logits2word(logits_ori, vocab["word2id"])
61-
tsf = utils.logits2word(logits_tsf, vocab["word2id"])
73+
ori = logits2word(logits_ori, vocab["word2id"])
74+
tsf = logits2word(logits_tsf, vocab["word2id"])
6275
half = self._hparams.batch_size/2
6376
data0_ori += tsf[:half]
6477
data1_ori += tsf[half:]
6578

66-
utils.write_sent(data0_ori, output_path + ".0.tsf")
67-
utils.write_sent(data1_ori, output_path + ".1.tsf")
79+
write_sent(data0_ori, output_path + ".0.tsf")
80+
write_sent(data1_ori, output_path + ".1.tsf")
6881
return losses
6982

7083
def train(self):
71-
if FLAGS.config:
72-
with open(FLAGS.config) as f:
84+
if "config" in self._hparams.keys():
85+
with open(self._hparams.config) as f:
7386
self._hparams = HParams(pkl.load(f))
7487

7588
log_print("Start training with hparams:")
76-
log_print(self._hparams)
77-
if not FLAGS.config:
89+
log_print(json.dumps(self._hparams.todict(), indent=2))
90+
if not "config" in self._hparams.keys():
7891
with open(os.path.join(self._hparams.expt_dir, self._hparams.name)
79-
+ ".config") as f:
92+
+ ".config", "w") as f:
8093
pkl.dump(self._hparams, f)
8194

8295
vocab, train, val, test = self.load_data()
@@ -90,45 +103,41 @@ def train(self):
90103
model = TSF(self._hparams)
91104
log_print("finished building model")
92105

93-
if FLAGS.model:
94-
model.saver.restore(ses, FLAGS.model)
106+
if "model" in self._hparams.keys():
107+
model.saver.restore(ses, self._hparams.model)
95108
else:
96-
sess.run(tf.global_variable_initializer())
97-
sess.run(tf.local_variable_initializer())
109+
sess.run(tf.global_variables_initializer())
110+
sess.run(tf.local_variables_initializer())
98111

99112
losses = Stats()
100113
gamma = self._hparams.gamma_init
101114
step = 0
102115
for epoch in range(self._hparams["max_epoch"]):
103-
for batch in utils.get_batches(train[0], train[1], vocab["word2id"],
116+
for batch in get_batches(train[0], train[1], vocab["word2id"],
104117
model._hparams.batch_size, shuffle=True):
105-
loss_d0 = model.train_d0_step(sess, batch, self._hparams.rho,
106-
self._hparams.gamma,
107-
model._hparams.learning_rate )
108-
loss_d1 = model.train_d1_step(sess, batch, self._hparams.rho,
109-
self._hparams.gamma,
110-
model._hparams.learning_rate )
118+
pdb.set_trace()
119+
loss_d0 = model.train_d0_step(sess, batch, self._hparams.rho, gamma)
120+
loss_d1 = model.train_d1_step(sess, batch, self._hparams.rho, gamma)
111121

112122
if loss_d0 < 1.2 and loss_d1 < 1.2:
113123
loss, loss_g, ppl_g, loss_d = model.train_g_step(
114-
sess, batch, self._hparams.rho, self._hparams.gamma,
115-
model._hparams.leanring_rate)
124+
sess, batch, self._hparams.rho, gamma)
116125
else:
117126
loss, loss_g, ppl_g, loss_d = model.train_ae_step(
118-
sess, batch, self._hparams.rho, self._hparams.gamma,
119-
model._hparams.leanring_rate)
127+
sess, batch, self._hparams.rho, gamma)
120128

121-
losses.add(loss, loss_g, ppl_g, loss_d, loss_d0, loss_d1)
129+
losses.append(loss, loss_g, ppl_g, loss_d, loss_d0, loss_d1)
122130

123131
step += 1
124-
if step % self._hparams.disp_interval:
125-
log_print(losses)
132+
if step % self._hparams.disp_interval == 0:
133+
log_print(str(losses))
126134
losses.reset()
127135

128136
# eval on dev
129137
dev_loss = self.eval_model(
130138
model, sess, vocab, val[0], val[1],
131-
os.path.join(FLAGS.expt, "sentiment.dev.epoch%d"%(epoch)))
139+
os.path.join(self._hparams.expt_dir,
140+
"sentiment.dev.epoch%d"%(epoch)))
132141
log_print("dev:" + dev_loss)
133142
if dev_loss.loss < best_dev:
134143
best_dev = dev_loss.loss
@@ -143,9 +152,9 @@ def train(self):
143152
gamma = max(self._hparams.gamma_min, gamma * self._hparams.gamma_decay)
144153

145154

146-
def main():
155+
def main(unused_args):
147156
trainer = TSFTrainer()
148157
trainer.train()
149158

150159
if __name__ == "__main__":
151-
main()
160+
tf.app.run()

examples/tsf/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from __future__ import division
44
from __future__ import print_function
55

6+
import pdb
7+
68
import time
79
import random
810

@@ -47,7 +49,7 @@ def get_batch(x, y, word2id, batch_size, min_len=5):
4749

4850
def get_batches(x0, x1, word2id, batch_size, shuffle=True):
4951
# half as batch size
50-
batch_size = batch_size / 2
52+
batch_size = batch_size // 2
5153
if len(x0) < len(x1):
5254
x0 = makeup(x0, len(x1))
5355
if len(x1) < len(x0):
@@ -66,7 +68,8 @@ def get_batches(x0, x1, word2id, batch_size, shuffle=True):
6668
break
6769
batches.append(get_batch(x0[s:t] + x1[s:t],
6870
[0] * (t-s) + [1]*(t-s),
69-
word2id))
71+
word2id,
72+
batch_size))
7073
s = t
7174

7275
return batches

txtgen/models/tsf/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11

2-
from txtgen.models.text_style_transfer.text_style_transfer import *
3-
from txtgen.models.text_style_transfer.ops import *
2+
from txtgen.models.tsf.tsf import *
3+
from txtgen.models.tsf.ops import *

txtgen/models/tsf/ops.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,14 @@ def get_rnn_cell(hparams=None):
2222
"state_keep_prob": 1.0,
2323
}
2424

25-
if hparams is None or isinstance(hparams, dict):
26-
hparams = HParams(hparams, default_hparams)
25+
hparams = HParams(hparams, default_hparams, allow_new_hparam=True)
2726

2827
cells = []
2928
for i in range(hparams.num_layers):
30-
if cell_type == "BasicLSTMCell":
29+
if hparams.type == "BasicLSTMCell":
3130
cell = rnn.BasicLSTMCell(hparams.size)
32-
else: # cell_type == "GRUCell":
33-
cell = rnn.BasicGRUCell(hparams.size)
31+
else: # hparams.type == "GRU:
32+
cell = rnn.GRUCell(hparams.size)
3433

3534
cell = rnn.DropoutWrapper(
3635
cell = cell,
@@ -75,7 +74,7 @@ def loop_func(output):
7574

7675
return loop_func
7776

78-
def greey_softmax(proj_layer, embedding):
77+
def greedy_softmax(proj_layer, embedding):
7978
def loop_func(output):
8079
logits = proj_layer(output)
8180
word = tf.argmax(logits, axis=1)
@@ -94,14 +93,14 @@ def rnn_decode(h, inp, length, cell, loop_func, scope):
9493
output, h = cell(inp, h)
9594
inp, logits, sample = loop_func(output)
9695
logits_seq.append(tf.expand_dims(logits, 1))
97-
if sample:
96+
if sample is not None:
9897
sample_seq.append(tf.expand_dims(sample, 1))
9998
else:
10099
sample_seq.append(sample)
101100

102101
h_seq = tf.concat(h_seq, 1)
103102
logits_seq = tf.concat(h_seq, 1)
104-
if sample[0]:
103+
if sample[0] is not None:
105104
sample_seq = tf.concat(sample, 1)
106105
else:
107106
sample_seq = None
@@ -120,7 +119,7 @@ def adv_loss(x_real, x_fake, discriminator):
120119
def retrieve_variables(scopes):
121120
var = []
122121
for scope in scopes:
123-
var += tf.get_collections(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
122+
var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
124123
return var
125124

126125
def feed_dict(model, batch, rho, gamma, dropout, learning_rate):

0 commit comments

Comments
 (0)