Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,8 +1531,9 @@ def test_ctc_decode_greedy(self):
assert np.alltrue(decode_truth == decode_pred)
assert np.allclose(log_prob_truth, log_prob_pred)

'''tensorflow only, need special handle'''

@pytest.mark.skipif(K.backend() != 'tensorflow',
reason='Beam search is only implemented with '
'the TensorFlow backend.')
def test_ctc_decode_beam_search(self):
"""Test one batch, two beams - hibernating beam search."""

Expand All @@ -1554,10 +1555,10 @@ def test_ctc_decode_beam_search(self):
for t in range(seq_len_0)] + # Pad to max_time_steps = 8
2 * [np.zeros((1, depth), dtype=np.float32)])

inputs = KTF.variable(np.asarray(inputs).transpose((1, 0, 2)))
inputs = K.variable(np.asarray(inputs).transpose((1, 0, 2)))

# batch_size length vector of sequence_lengths
input_length = KTF.variable(np.array([seq_len_0], dtype=np.int32))
input_length = K.variable(np.array([seq_len_0], dtype=np.int32))
# batch_size length vector of negative log probabilities
log_prob_truth = np.array([
0.584855, # output beam 0
Expand All @@ -1569,18 +1570,18 @@ def test_ctc_decode_beam_search(self):
beam_width = 2
top_paths = 2

decode_pred_tf, log_prob_pred_tf = KTF.ctc_decode(inputs,
input_length,
greedy=False,
beam_width=beam_width,
top_paths=top_paths)
decode_pred_tf, log_prob_pred_tf = K.ctc_decode(inputs,
input_length,
greedy=False,
beam_width=beam_width,
top_paths=top_paths)

assert len(decode_pred_tf) == top_paths

log_prob_pred = KTF.eval(log_prob_pred_tf)
log_prob_pred = K.eval(log_prob_pred_tf)

for i in range(top_paths):
assert np.alltrue(decode_truth[i] == KTF.eval(decode_pred_tf[i]))
assert np.alltrue(decode_truth[i] == K.eval(decode_pred_tf[i]))

assert np.allclose(log_prob_truth, log_prob_pred)

Expand All @@ -1593,6 +1594,8 @@ def test_one_hot(self):
koh = K.eval(K.one_hot(K.variable(indices, dtype='int32'), num_classes))
assert np.all(koh == oh)

@pytest.mark.skipif(K.backend() == 'cntk',
reason='Sparse tensors are not supported in cntk.')
def test_sparse_dot(self):
x_d = np.array([0, 7, 2, 3], dtype=np.float32)
x_r = np.array([0, 2, 2, 3], dtype=np.int64)
Expand Down Expand Up @@ -1689,12 +1692,13 @@ def test_foldr(self):
assert p1 < p2
assert 9e-38 < p2 <= 1e-37

@pytest.mark.skipif(K.backend() == 'cntk',
reason='cntk has issues with negative number.')
def test_arange(self):
for test_value in (-20, 0, 1, 10):
a_list = []
dtype_list = []
# cntk has issue with negative number
for k in [KTH, KTF]:
for k in WITH_NP:
t = k.arange(test_value)
a = k.eval(t)
assert np.array_equal(a, np.arange(test_value))
Expand All @@ -1706,19 +1710,19 @@ def test_arange(self):

for start, stop, step in ((0, 5, 1), (-5, 5, 2), (0, 1, 2)):
a_list = []
for k in [KTH, KTF]:
for k in WITH_NP:
a = k.eval(k.arange(start, stop, step))
assert np.array_equal(a, np.arange(start, stop, step))
a_list.append(a)
for i in range(len(a_list) - 1):
assert np.array_equal(a_list[i], a_list[i + 1])

for dtype in ('int32', 'int64', 'float32', 'float64'):
for k in [KTH, KTF]:
for k in WITH_NP:
t = k.arange(10, dtype=dtype)
assert k.dtype(t) == dtype

for k in [KTH, KTF]:
for k in WITH_NP:
start = k.constant(1, dtype='int32')
t = k.arange(start)
assert len(k.eval(t)) == 1
Expand Down
19 changes: 19 additions & 0 deletions tests/keras/backend/reference_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import scipy.signal as signal
from keras.backend import floatx


def normalize_conv(func):
Expand Down Expand Up @@ -347,6 +348,10 @@ def repeat(x, n):
return y


def arange(start, stop=None, step=1, dtype='int32'):
return np.arange(start, stop, step, dtype)


def flatten(x):
return np.reshape(x, (-1,))

Expand All @@ -359,6 +364,20 @@ def eval(x):
return x


def dtype(x):
return x.dtype.name


def constant(value, dtype=None, shape=None, name=None):
if dtype is None:
dtype = floatx()
if shape is None:
shape = ()
np_value = value * np.ones(shape)
np_value.astype(dtype)
return np_value


def print_tensor(x, message=''):
print(x, message)
return x
Expand Down