Skip to content

Commit 79f406e

Browse files
updated to latest tf2 style
1 parent 2860c1d commit 79f406e

File tree

1 file changed

+28
-28
lines changed

1 file changed

+28
-28
lines changed

model.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,27 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
import tensorflow.compat.v1 as tf
23-
24-
from tensorflow.python.keras.models import Model
25-
from tensorflow.python.keras import layers
26-
from tensorflow.python.keras.layers import Input
27-
from tensorflow.python.keras.layers import Lambda
28-
from tensorflow.python.keras.layers import Activation
29-
from tensorflow.python.keras.layers import Concatenate
30-
from tensorflow.python.keras.layers import Add
31-
from tensorflow.python.keras.layers import Dropout
32-
from tensorflow.python.keras.layers import BatchNormalization
33-
from tensorflow.python.keras.layers import Conv2D
34-
from tensorflow.python.keras.layers import DepthwiseConv2D
35-
from tensorflow.python.keras.layers import ZeroPadding2D
36-
from tensorflow.python.keras.layers import GlobalAveragePooling2D
37-
from tensorflow.python.keras.layers import UpSampling2D
38-
from tensorflow.python.keras.utils.layer_utils import get_source_inputs
39-
from tensorflow.python.keras.utils.data_utils import get_file
40-
from tensorflow.python.keras import backend as K
41-
from tensorflow.python.keras.activations import relu
42-
from tensorflow.python.keras.applications.imagenet_utils import preprocess_input
22+
import tensorflow as tf
23+
from tensorflow.keras.mixed_precision import experimental as mixed_precision
24+
policy = mixed_precision.Policy('mixed_float16')
25+
mixed_precision.set_policy(policy)
26+
from tensorflow.keras.models import Model
27+
from tensorflow.keras import layers
28+
from tensorflow.keras.layers import Input
29+
from tensorflow.keras.layers import Lambda
30+
from tensorflow.keras.layers import Activation
31+
from tensorflow.keras.layers import Concatenate
32+
from tensorflow.keras.layers import Add
33+
from tensorflow.keras.layers import Dropout
34+
from tensorflow.keras.layers import BatchNormalization
35+
from tensorflow.keras.layers import Conv2D
36+
from tensorflow.keras.layers import DepthwiseConv2D
37+
from tensorflow.keras.layers import ZeroPadding2D
38+
from tensorflow.keras.layers import GlobalAveragePooling2D
39+
from tensorflow.keras.utils import get_file
40+
from tensorflow.keras.utils import get_source_inputs
41+
from tensorflow.keras.applications.imagenet_utils import preprocess_input
42+
import tensorflow.keras.backend as K
4343

4444
WEIGHTS_PATH_X = "https://github.com/bonlime/keras-deeplab-v3-plus/releases/download/1.1/deeplabv3_xception_tf_dim_ordering_tf_kernels.h5"
4545
WEIGHTS_PATH_MOBILE = "https://github.com/bonlime/keras-deeplab-v3-plus/releases/download/1.1/deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels.h5"
@@ -169,7 +169,7 @@ def _make_divisible(v, divisor, min_value=None):
169169

170170

171171
def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id, skip_connection, rate=1):
172-
in_channels = inputs.shape[-1].value # inputs._keras_shape[-1]
172+
in_channels = inputs.shape[-1] # inputs._keras_shape[-1]
173173
pointwise_conv_filters = int(filters * alpha)
174174
pointwise_filters = _make_divisible(pointwise_conv_filters, 8)
175175
x = inputs
@@ -364,8 +364,8 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)
364364
# Image Feature branch
365365
b4 = GlobalAveragePooling2D()(x)
366366
# from (b_size, channels)->(b_size, 1, 1, channels)
367-
b4 = Lambda(lambda x: K.expand_dims(x, 1))(b4)
368-
b4 = Lambda(lambda x: K.expand_dims(x, 1))(b4)
367+
b4 = Lambda(lambda x: tf.expand_dims(x, 1))(b4)
368+
b4 = Lambda(lambda x: tf.expand_dims(x, 1))(b4)
369369
b4 = Conv2D(256, (1, 1), padding='same',
370370
use_bias=False, name='image_pooling')(b4)
371371
b4 = BatchNormalization(name='image_pooling_BN', epsilon=1e-5)(b4)
@@ -374,7 +374,7 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)
374374
size_before = K.int_shape(x)
375375

376376
b4 = Lambda(lambda x: tf.image.resize(x, size_before[1:3],
377-
method='bilinear', align_corners=True))(b4)
377+
method='bilinear'))(b4)
378378
# b4 = UpSampling2D(size=(size_before[1],size_before[2]),interpolation='bilinear')(b4)
379379
# simple 1x1
380380
b0 = Conv2D(256, (1, 1), padding='same', use_bias=False, name='aspp0')(x)
@@ -401,8 +401,8 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)
401401
x = Conv2D(256, (1, 1), padding='same',
402402
use_bias=False, name='concat_projection')(x)
403403
x = BatchNormalization(name='concat_projection_BN', epsilon=1e-5)(x)
404-
x = Activation('elu')(x)
405404
x = Dropout(0.1)(x)
405+
x = Activation('elu')(x)
406406
# DeepLab v.3+ decoder
407407

408408
if backbone == 'xception':
@@ -435,7 +435,7 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)
435435
# size_out = K.int_shape(img_input)
436436
# x = UpSampling2D(size=(size_out[1] // size_in[1], size_out[2] // size_in[2]), interpolation='bilinear')(x)
437437
size_before3 = K.int_shape(img_input)
438-
x = Lambda(lambda xx: tf.image.resize(xx, size_before3[1:3], method='bilinear', align_corners=True))(x)
438+
x = Lambda(lambda xx: tf.image.resize(xx, size_before3[1:3], method='bilinear'))(x)
439439
# Ensure that the model takes into account
440440
# any potential predecessors of `input_tensor`.
441441
if input_tensor is not None:
@@ -445,7 +445,7 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)
445445

446446
if activation in {'softmax', 'sigmoid'}:
447447
x = Activation(activation)(x)
448-
448+
x = Activation('linear', dtype='float32')(x)
449449
model = Model(inputs=inputs, outputs=x, name='deeplabv3plus')
450450

451451
# load weights

0 commit comments

Comments
 (0)