Skip to content

Commit defdc35

Browse files
taehoonleeVijayabhaskar96
authored andcommitted
Immigrate reference operations to a separate module (keras-team#9948)
1 parent 331ec0b commit defdc35

File tree

2 files changed

+173
-164
lines changed

2 files changed

+173
-164
lines changed

tests/keras/backend/backend_test.py

Lines changed: 9 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import pytest
22
from numpy.testing import assert_allclose
33
import numpy as np
4-
import scipy.signal as signal
54
import scipy.sparse as sparse
65
import warnings
76
from keras.utils.test_utils import keras_test
87

98
from keras import backend as K
109
from keras.backend import floatx, set_floatx, variable
1110
from keras.utils.conv_utils import convert_kernel
11+
import reference_operations
12+
1213

1314
BACKENDS = [] # Holds a list of all available back-ends
1415

@@ -192,162 +193,6 @@ def check_composed_tensor_operations(first_function_name, first_function_args,
192193
assert_list_pairwise(z_list)
193194

194195

195-
def normalize_ref_conv(func):
196-
def wrapper(*args):
197-
x = args[0]
198-
w = args[1]
199-
if x.ndim == 3:
200-
w = np.flipud(w)
201-
w = np.transpose(w, (1, 2, 0))
202-
if args[3] == 'channels_last':
203-
x = np.transpose(x, (0, 2, 1))
204-
elif x.ndim == 4:
205-
w = np.fliplr(np.flipud(w))
206-
w = np.transpose(w, (2, 3, 0, 1))
207-
if args[3] == 'channels_last':
208-
x = np.transpose(x, (0, 3, 1, 2))
209-
else:
210-
w = np.flip(np.fliplr(np.flipud(w)), axis=2)
211-
w = np.transpose(w, (3, 4, 0, 1, 2))
212-
if args[3] == 'channels_last':
213-
x = np.transpose(x, (0, 4, 1, 2, 3))
214-
215-
y = func(x, w, args[2], args[3])
216-
217-
if args[3] == 'channels_last':
218-
if y.ndim == 3:
219-
y = np.transpose(y, (0, 2, 1))
220-
elif y.ndim == 4:
221-
y = np.transpose(y, (0, 2, 3, 1))
222-
else:
223-
y = np.transpose(y, (0, 2, 3, 4, 1))
224-
225-
return y
226-
227-
return wrapper
228-
229-
230-
@normalize_ref_conv
231-
def ref_conv(x, w, padding, data_format):
232-
y = []
233-
for i in range(x.shape[0]):
234-
_y = []
235-
for j in range(w.shape[1]):
236-
__y = []
237-
for k in range(w.shape[0]):
238-
__y.append(signal.convolve(x[i, k], w[k, j], mode=padding))
239-
_y.append(np.sum(np.stack(__y, axis=-1), axis=-1))
240-
y.append(_y)
241-
y = np.array(y)
242-
return y
243-
244-
245-
@normalize_ref_conv
246-
def ref_depthwise_conv(x, w, padding, data_format):
247-
y = []
248-
for i in range(x.shape[0]):
249-
_y = []
250-
for j in range(w.shape[0]):
251-
__y = []
252-
for k in range(w.shape[1]):
253-
__y.append(signal.convolve(x[i, j], w[j, k], mode=padding))
254-
_y.append(np.stack(__y, axis=0))
255-
y.append(np.concatenate(_y, axis=0))
256-
y = np.array(y)
257-
return y
258-
259-
260-
def ref_separable_conv(x, w1, w2, padding, data_format):
261-
x2 = ref_depthwise_conv(x, w1, padding, data_format)
262-
return ref_conv(x2, w2, padding, data_format)
263-
264-
265-
def ref_pool(x, pool_size, strides, padding, data_format, pool_mode):
266-
if data_format == 'channels_last':
267-
if x.ndim == 3:
268-
x = np.transpose(x, (0, 2, 1))
269-
elif x.ndim == 4:
270-
x = np.transpose(x, (0, 3, 1, 2))
271-
else:
272-
x = np.transpose(x, (0, 4, 1, 2, 3))
273-
274-
if padding == 'same':
275-
pad = [(0, 0), (0, 0)] + [(s // 2, s // 2) for s in pool_size]
276-
x = np.pad(x, pad, 'constant', constant_values=-np.inf)
277-
278-
# indexing trick
279-
x = np.pad(x, [(0, 0), (0, 0)] + [(0, 1) for _ in pool_size],
280-
'constant', constant_values=0)
281-
282-
if x.ndim == 3:
283-
y = [x[:, :, k:k1:strides[0]]
284-
for (k, k1) in zip(range(pool_size[0]), range(-pool_size[0], 0))]
285-
elif x.ndim == 4:
286-
y = []
287-
for (k, k1) in zip(range(pool_size[0]), range(-pool_size[0], 0)):
288-
for (l, l1) in zip(range(pool_size[1]), range(-pool_size[1], 0)):
289-
y.append(x[:, :, k:k1:strides[0], l:l1:strides[1]])
290-
else:
291-
y = []
292-
for (k, k1) in zip(range(pool_size[0]), range(-pool_size[0], 0)):
293-
for (l, l1) in zip(range(pool_size[1]), range(-pool_size[1], 0)):
294-
for (m, m1) in zip(range(pool_size[2]), range(-pool_size[2], 0)):
295-
y.append(x[:, :, k:k1:strides[0], l:l1:strides[1], m:m1:strides[2]])
296-
y = np.stack(y, axis=-1)
297-
if pool_mode == 'avg':
298-
y = np.mean(np.ma.masked_invalid(y), axis=-1).data
299-
elif pool_mode == 'max':
300-
y = np.max(y, axis=-1)
301-
302-
if data_format == 'channels_last':
303-
if y.ndim == 3:
304-
y = np.transpose(y, (0, 2, 1))
305-
elif y.ndim == 4:
306-
y = np.transpose(y, (0, 2, 3, 1))
307-
else:
308-
y = np.transpose(y, (0, 2, 3, 4, 1))
309-
310-
return y
311-
312-
313-
def ref_rnn(x, w, init, go_backwards=False, mask=None, unroll=False, input_length=None):
314-
w_i, w_h, w_o = w
315-
h = []
316-
o = []
317-
318-
if go_backwards:
319-
t_list = range(x.shape[1] - 1, -1, -1)
320-
else:
321-
t_list = range(x.shape[1])
322-
323-
if mask is not None:
324-
np_mask = K.eval(mask)
325-
else:
326-
np_mask = None
327-
328-
for (i, t) in enumerate(t_list):
329-
h_t = np.dot(x[:, t], w_i)
330-
331-
if w_h is not None:
332-
prev = h[i - 1] if i > 0 else init
333-
h_t1 = np.dot(prev, w_h)
334-
if np_mask is not None:
335-
h_t1[np_mask[:, t] == 0] = prev[np_mask[:, t] == 0]
336-
else:
337-
h_t1 = 0
338-
339-
o_t = h_t + h_t1
340-
if w_o is not None:
341-
o_t = np.dot(o_t, w_o)
342-
o.append(o_t)
343-
344-
if np_mask is not None:
345-
h_t = h_t * np_mask[:, t].reshape(-1, 1)
346-
h.append(h_t + h_t1)
347-
348-
return o[-1], np.stack(o, axis=1), np.stack(h, axis=1)
349-
350-
351196
class TestBackend(object):
352197

353198
def test_is_keras_tensor(self):
@@ -724,7 +569,7 @@ def rnn_fn(x_k, h_k):
724569
]
725570

726571
for (i, kwargs) in enumerate(kwargs_list):
727-
last_y1, y1, h1 = ref_rnn(x, [wi, wh, None], h0, **kwargs)
572+
last_y1, y1, h1 = reference_operations.rnn(x, [wi, wh, None], h0, **kwargs)
728573
last_y2, y2, h2 = K.rnn(rnn_fn, x_k, h0_k, **kwargs)
729574

730575
assert len(h2) == 1
@@ -771,8 +616,8 @@ def rnn_fn(x_k, h_k):
771616
y_k = K.dot(x_k, wi_k)
772617
return y_k, []
773618

774-
last_y1, y1, h1 = ref_rnn(x, [wi, None, None], None,
775-
go_backwards=False, mask=None)
619+
last_y1, y1, h1 = reference_operations.rnn(x, [wi, None, None], None,
620+
go_backwards=False, mask=None)
776621
last_y2, y2, h2 = K.rnn(rnn_fn, x_k, [],
777622
go_backwards=False, mask=None)
778623

@@ -1071,7 +916,7 @@ def test_conv(self, op, input_shape, kernel_shape, padding, data_format):
1071916
k = K.backend()
1072917
_, x = parse_shape_or_val(input_shape)
1073918
_, w = parse_shape_or_val(kernel_shape)
1074-
y1 = ref_conv(x, w, padding, data_format)
919+
y1 = reference_operations.conv(x, w, padding, data_format)
1075920
y2 = check_two_tensor_operation(
1076921
op, x, w, [KTH if k == 'theano' else KC if k == 'cntk' else KTF],
1077922
padding=padding, data_format=data_format,
@@ -1088,7 +933,7 @@ def test_depthwise_conv(self, op, input_shape, kernel_shape, padding, data_forma
1088933
k = K.backend()
1089934
_, x = parse_shape_or_val(input_shape)
1090935
_, w = parse_shape_or_val(kernel_shape)
1091-
y1 = ref_depthwise_conv(x, w, padding, data_format)
936+
y1 = reference_operations.depthwise_conv(x, w, padding, data_format)
1092937
y2 = check_two_tensor_operation(
1093938
op, x, w, [KTH if k == 'theano' else KC if k == 'cntk' else KTF],
1094939
padding=padding, data_format=data_format,
@@ -1108,7 +953,7 @@ def test_depthwise_conv(self, op, input_shape, kernel_shape, padding, data_forma
1108953
def test_pool(self, op, input_shape, pool_size, strides, padding, data_format, pool_mode):
1109954
k = K.backend()
1110955
_, x = parse_shape_or_val(input_shape)
1111-
y1 = ref_pool(x, pool_size, strides, padding, data_format, pool_mode)
956+
y1 = reference_operations.pool(x, pool_size, strides, padding, data_format, pool_mode)
1112957
y2 = check_single_tensor_operation(
1113958
op, x, [KTH if k == 'theano' else KC if k == 'cntk' else KTF],
1114959
pool_size=pool_size, strides=strides,
@@ -1174,7 +1019,7 @@ def test_separable_conv2d(self, op, input_shape, kernel_shape, depth_multiplier,
11741019
_, x = parse_shape_or_val(input_shape)
11751020
_, depthwise = parse_shape_or_val(kernel_shape + (input_depth, depth_multiplier))
11761021
_, pointwise = parse_shape_or_val((1, 1) + (input_depth * depth_multiplier, 7))
1177-
y1 = ref_separable_conv(x, depthwise, pointwise, padding, data_format)
1022+
y1 = reference_operations.separable_conv(x, depthwise, pointwise, padding, data_format)
11781023
if K.backend() == 'cntk':
11791024
y2 = cntk_func_three_tensor(
11801025
op, input_shape,

0 commit comments

Comments
 (0)