@@ -136,23 +136,30 @@ def test_TensorBoard():
136136 nb_class = nb_class )
137137 y_test = np_utils .to_categorical (y_test )
138138 y_train = np_utils .to_categorical (y_train )
139- # case 1 Sequential wo accuracy
140- with tf .Graph ().as_default ():
141- session = tf .Session ('' )
142- KTF ._set_session (session )
143- model = Sequential ()
144- model .add (Dense (nb_hidden , input_dim = input_dim , activation = 'relu' ))
145- model .add (Dense (nb_class , activation = 'softmax' ))
146- model .compile (loss = 'categorical_crossentropy' , optimizer = 'sgd' )
147139
148- tsb = callbacks .TensorBoard (log_dir = filepath , histogram_freq = 1 )
149- cbks = [tsb ]
150- model .fit (X_train , y_train , batch_size = batch_size , show_accuracy = True ,
151- validation_data = (X_test , y_test ), callbacks = cbks , nb_epoch = 2 )
152- assert os .path .exists (filepath )
153- shutil .rmtree (filepath )
140+ def data_generator (train ):
141+ if train :
142+ max_batch_index = len (X_train ) // batch_size
143+ else :
144+ max_batch_index = len (X_test ) // batch_size
145+ i = 0
146+ while 1 :
147+ if train :
148+ yield (X_train [i * batch_size : (i + 1 ) * batch_size ], y_train [i * batch_size : (i + 1 ) * batch_size ])
149+ else :
150+ yield (X_test [i * batch_size : (i + 1 ) * batch_size ], y_test [i * batch_size : (i + 1 ) * batch_size ])
151+ i += 1
152+ i = i % max_batch_index
153+
154+ def data_generator_graph (train ):
155+ while 1 :
156+ if train :
157+ yield {'X_vars' : X_train , 'output' : y_train }
158+ else :
159+ yield {'X_vars' : X_test , 'output' : y_test }
160+
161+ # case 1 Sequential
154162
155- # case 2 Sequential w accuracy
156163 with tf .Graph ().as_default ():
157164 session = tf .Session ('' )
158165 KTF ._set_session (session )
@@ -163,12 +170,42 @@ def test_TensorBoard():
163170
164171 tsb = callbacks .TensorBoard (log_dir = filepath , histogram_freq = 1 )
165172 cbks = [tsb ]
173+
174+ # fit with validation data
175+ model .fit (X_train , y_train , batch_size = batch_size , show_accuracy = False ,
176+ validation_data = (X_test , y_test ), callbacks = cbks , nb_epoch = 2 )
177+
178+ # fit with validation data and accuracy
166179 model .fit (X_train , y_train , batch_size = batch_size , show_accuracy = True ,
167180 validation_data = (X_test , y_test ), callbacks = cbks , nb_epoch = 2 )
181+
182+ # fit generator with validation data
183+ model .fit_generator (data_generator (True ), len (X_train ), nb_epoch = 2 ,
184+ show_accuracy = False ,
185+ validation_data = (X_test , y_test ),
186+ callbacks = cbks )
187+
188+ # fit generator without validation data
189+ model .fit_generator (data_generator (True ), len (X_train ), nb_epoch = 2 ,
190+ show_accuracy = False ,
191+ callbacks = cbks )
192+
193+ # fit generator with validation data and accuracy
194+ model .fit_generator (data_generator (True ), len (X_train ), nb_epoch = 2 ,
195+ show_accuracy = True ,
196+ validation_data = (X_test , y_test ),
197+ callbacks = cbks )
198+
199+ # fit generator without validation data and accuracy
200+ model .fit_generator (data_generator (True ), len (X_train ), nb_epoch = 2 ,
201+ show_accuracy = True ,
202+ callbacks = cbks )
203+
168204 assert os .path .exists (filepath )
169205 shutil .rmtree (filepath )
170206
171- # case 3 Graph
207+ # case 2 Graph
208+
172209 with tf .Graph ().as_default ():
173210 session = tf .Session ('' )
174211 KTF ._set_session (session )
@@ -185,10 +222,27 @@ def test_TensorBoard():
185222
186223 tsb = callbacks .TensorBoard (log_dir = filepath , histogram_freq = 1 )
187224 cbks = [tsb ]
225+
226+ # fit with validation
188227 model .fit ({'X_vars' : X_train , 'output' : y_train },
189228 batch_size = batch_size ,
190229 validation_data = {'X_vars' : X_test , 'output' : y_test },
191230 callbacks = cbks , nb_epoch = 2 )
231+
232+ # fit wo validation
233+ model .fit ({'X_vars' : X_train , 'output' : y_train },
234+ batch_size = batch_size ,
235+ callbacks = cbks , nb_epoch = 2 )
236+
237+ # fit generator with validation
238+ model .fit_generator (data_generator_graph (True ), 1000 , nb_epoch = 2 ,
239+ validation_data = {'X_vars' : X_test , 'output' : y_test },
240+ callbacks = cbks )
241+
242+ # fit generator wo validation
243+ model .fit_generator (data_generator_graph (True ), 1000 , nb_epoch = 2 ,
244+ callbacks = cbks )
245+
192246 assert os .path .exists (filepath )
193247 shutil .rmtree (filepath )
194248
0 commit comments