Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Tf1 -> Tf2: layer_norm to tf.keras.layers, LuongAttention
  • Loading branch information
woj-i committed Aug 28, 2020
commit 1032bc1085f8322de4ab7bca5649b972937ba595
2 changes: 1 addition & 1 deletion btgym/algorithms/nn/ae.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import tensorflow as tf
from tensorflow.contrib.layers import flatten as batch_flatten
from tensorflow.contrib.layers import layer_norm as norm_layer

from btgym.algorithms.nn.layers import normalized_columns_initializer, linear, conv2d

Expand Down Expand Up @@ -35,6 +34,7 @@ def conv2d_encoder(x,
layer_shapes = [x.get_shape()]
layer_outputs = []
for i, layer_spec in enumerate(layer_config, 1):
norm_layer = tf.keras.layers.LayerNormalization()
x = tf.nn.elu(
norm_layer(
conv2d(
Expand Down
2 changes: 1 addition & 1 deletion btgym/algorithms/nn/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import tensorflow as tf
import tensorflow.contrib.rnn as rnn
from tensorflow.contrib.layers import layer_norm as norm_layer
from tensorflow.python.util.nest import flatten as flatten_nested

from btgym.algorithms.nn.layers import normalized_columns_initializer, categorical_sample
Expand Down Expand Up @@ -38,6 +37,7 @@ def conv_2d_network(x,
"""
with tf.compat.v1.variable_scope(name, reuse=reuse):
for i, num_filters in enumerate(conv_2d_num_filters):
norm_layer = tf.keras.layers.LayerNormalization()
x = tf.nn.elu(
norm_layer(
conv_2d_layer_ref(
Expand Down
9 changes: 5 additions & 4 deletions btgym/research/casual_conv/networks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import tensorflow as tf
from tensorflow.contrib.layers import layer_norm as norm_layer
import tensorflow_addons as tfa

import numpy as np
import math
Expand Down Expand Up @@ -78,7 +78,8 @@ def conv_1d_casual_encoder(
# b2t:
y = tf.reshape(y, [-1, num_time_batches, conv_1d_num_filters], name='layer_{}_output'.format(i))

y = norm_layer(y)
normalization_layer = tf.keras.layers.LayerNormalization()
y = normalization_layer(y)
if conv_1d_activation is not None:
y = conv_1d_activation(y)

Expand Down Expand Up @@ -137,7 +138,7 @@ def conv_1d_casual_encoder(
return encoded


def attention_layer(inputs, attention_ref=tf.contrib.seq2seq.LuongAttention, name='attention_layer', **kwargs):
def attention_layer(inputs, attention_ref=tfa.seq2seq.LuongAttention, name='attention_layer', **kwargs):
"""
Temporal attention layer.
Computes attention context based on last(left) value in time dim.
Expand Down Expand Up @@ -201,7 +202,7 @@ def conv_1d_casual_attention_encoder(
conv_1d_num_filters=32,
conv_1d_filter_size=2,
conv_1d_activation=tf.nn.elu,
conv_1d_attention_ref=tf.contrib.seq2seq.LuongAttention,
conv_1d_attention_ref=tfa.seq2seq.LuongAttention,
name='casual_encoder',
keep_prob=None,
conv_1d_gated=False,
Expand Down
2 changes: 1 addition & 1 deletion btgym/research/encoder_test/networks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import tensorflow as tf
import tensorflow.contrib.rnn as rnn
from tensorflow.contrib.layers import layer_norm as norm_layer
from tensorflow.python.util.nest import flatten as flatten_nested

from btgym.algorithms.nn.layers import normalized_columns_initializer, categorical_sample
Expand Down Expand Up @@ -35,6 +34,7 @@ def conv_2d_network_skip(x,
layers = []
with tf.compat.v1.variable_scope(name, reuse=reuse):
for i, num_filters in enumerate(conv_2d_num_filters):
norm_layer = tf.keras.layers.LayerNormalization()
x = tf.nn.elu(
norm_layer(
conv_2d_layer_ref(
Expand Down