Skip to content

Commit 7f41135

Browse files
tensorflower-gardenerAmit Patankar
authored andcommitted
Support tensors as dropout rates again, by removing the min(max(..)..) clipping
that throws TypeError for tensors. PiperOrigin-RevId: 157000504
1 parent 12f033d commit 7f41135

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

tensorflow/python/layers/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def __init__(self, rate=0.5,
244244
name=None,
245245
**kwargs):
246246
super(Dropout, self).__init__(name=name, **kwargs)
247-
self.rate = min(1., max(0., rate))
247+
self.rate = rate
248248
self.noise_shape = noise_shape
249249
self.seed = seed
250250

tensorflow/python/layers/core_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,18 @@ def testFunctionalDropout(self):
337337
np_output = sess.run(dropped, feed_dict={training: False})
338338
self.assertAllClose(np.ones((5, 5)), np_output)
339339

340+
def testDynamicRate(self):
341+
with self.test_session() as sess:
342+
rate = array_ops.placeholder(dtype='float32', name='rate')
343+
dp = core_layers.Dropout(rate, name='dropout')
344+
inputs = array_ops.ones((5, 5))
345+
dropped = dp.apply(inputs, training=True)
346+
sess.run(variables.global_variables_initializer())
347+
np_output = sess.run(dropped, feed_dict={rate: 0.5})
348+
self.assertAlmostEqual(0., np_output.min())
349+
np_output = sess.run(dropped, feed_dict={rate: 0.0})
350+
self.assertAllClose(np.ones((5, 5)), np_output)
351+
340352

341353
if __name__ == '__main__':
342354
test.main()

0 commit comments

Comments
 (0)