Skip to content

Commit 24daab1

Browse files
taehoonleefchollet
authored andcommitted
Enable Xception to work on Theano and CNTK (#10024)
* Enable Xception to work on Theano and CNTK * Fix different predictions over all the backends
1 parent 918c599 commit 24daab1

File tree

2 files changed

+5
-12
lines changed

2 files changed

+5
-12
lines changed

keras/applications/xception.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99
and that the input preprocessing function
1010
is also different (same as Inception V3).
1111
12-
Also do note that this model is only available for the TensorFlow backend,
13-
due to its reliance on `SeparableConvolution` layers.
14-
1512
# Reference
1613
1714
- [Xception: Deep Learning with Depthwise Separable Convolutions](https://arxiv.org/abs/1610.02357)
@@ -37,6 +34,7 @@
3734
from ..layers import GlobalMaxPooling2D
3835
from ..engine import get_source_inputs
3936
from ..utils.data_utils import get_file
37+
from ..utils import layer_utils
4038
from .. import backend as K
4139
from . import imagenet_utils
4240
from .imagenet_utils import decode_predictions
@@ -53,10 +51,8 @@ def Xception(include_top=True, weights='imagenet',
5351
classes=1000):
5452
"""Instantiates the Xception architecture.
5553
56-
Optionally loads weights pre-trained
57-
on ImageNet. This model is available for TensorFlow only,
58-
and can only be used with inputs following the TensorFlow
59-
data format `(width, height, channels)`.
54+
Optionally loads weights pre-trained on ImageNet. This model can
55+
only be used with the data format `(width, height, channels)`.
6056
You should set `image_data_format='channels_last'` in your Keras config
6157
located at ~/.keras/keras.json.
6258
@@ -110,9 +106,6 @@ def Xception(include_top=True, weights='imagenet',
110106
raise ValueError('If using `weights` as imagenet with `include_top`'
111107
' as true, `classes` should be 1000')
112108

113-
if K.backend() != 'tensorflow':
114-
raise RuntimeError('The Xception model is only available with '
115-
'the TensorFlow backend.')
116109
if K.image_data_format() != 'channels_last':
117110
warnings.warn('The Xception model is only available for the '
118111
'input data format "channels_last" '
@@ -261,6 +254,8 @@ def Xception(include_top=True, weights='imagenet',
261254
cache_subdir='models',
262255
file_hash='b0042744bf5b25fce3cb969f33bebb97')
263256
model.load_weights(weights_path)
257+
if K.backend() == 'theano':
258+
layer_utils.convert_all_kernels_in_model(model)
264259
elif weights is not None:
265260
model.load_weights(weights)
266261

tests/keras/applications/applications_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,6 @@ def test_vgg():
108108
_test_app_pooling(app, last_dim)
109109

110110

111-
@pytest.mark.skipif((K.backend() != 'tensorflow'),
112-
reason='Requires TensorFlow backend')
113111
def test_xception():
114112
app = applications.Xception
115113
last_dim = 2048

0 commit comments

Comments
 (0)