Skip to content

Commit b889a4d

Browse files
committed
Style+consistency fix for target placeholders test
1 parent 9a24cc6 commit b889a4d

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

tests/keras/engine/test_training.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -807,14 +807,6 @@ def test_target_tensors():
807807
model.train_on_batch(input_val, None,
808808
sample_weight={'dense_a': np.random.random((10,))})
809809

810-
if K.backend() == 'tensorflow':
811-
import tensorflow as tf
812-
# test with custom TF placeholder as target
813-
pl_target_a = tf.placeholder('float32', shape=(None, 4))
814-
model.compile(optimizer='rmsprop', loss='mse',
815-
target_tensors=[pl_target_a, target_b])
816-
model.train_on_batch(input_val, target_val_a)
817-
818810

819811
@keras_test
820812
def test_model_custom_target_tensors():
@@ -847,18 +839,35 @@ def test_model_custom_target_tensors():
847839
output_b_np = np.random.random((10, 3))
848840

849841
out = model.train_on_batch([input_a_np, input_b_np],
850-
[output_a_np, output_b_np], {y: np.random.random((10, 4)), y1: np.random.random((10, 3))})
851-
842+
[output_a_np, output_b_np],
843+
{y: np.random.random((10, 4)),
844+
y1: np.random.random((10, 3))})
852845
# test dictionary of target_tensors
853846
with pytest.raises(ValueError):
854-
model.compile(optimizer, loss, metrics=[], loss_weights=loss_weights,
855-
sample_weight_mode=None, target_tensors={'does_not_exist': y2})
847+
model.compile(optimizer, loss,
848+
metrics=[],
849+
loss_weights=loss_weights,
850+
sample_weight_mode=None,
851+
target_tensors={'does_not_exist': y2})
856852
# test dictionary of target_tensors
857-
model.compile(optimizer, loss, metrics=[], loss_weights=loss_weights,
858-
sample_weight_mode=None, target_tensors={'dense_1': y, 'dropout': y1})
859-
853+
model.compile(optimizer, loss,
854+
metrics=[],
855+
loss_weights=loss_weights,
856+
sample_weight_mode=None,
857+
target_tensors={'dense_1': y, 'dropout': y1})
860858
out = model.train_on_batch([input_a_np, input_b_np],
861-
[output_a_np, output_b_np], {y: np.random.random((10, 4)), y1: np.random.random((10, 3))})
859+
[output_a_np, output_b_np],
860+
{y: np.random.random((10, 4)),
861+
y1: np.random.random((10, 3))})
862+
863+
if K.backend() == 'tensorflow':
864+
import tensorflow as tf
865+
# test with custom TF placeholder as target
866+
pl_target_a = tf.placeholder('float32', shape=(None, 4))
867+
model.compile(optimizer='rmsprop', loss='mse',
868+
target_tensors={'dense_1': pl_target_a})
869+
model.train_on_batch([input_a_np, input_b_np],
870+
[output_a_np, output_b_np])
862871

863872

864873
if __name__ == '__main__':

0 commit comments

Comments
 (0)