Skip to content

Commit 37fd456

Browse files
committed
Lambda layer improvements
1 parent 0c1af09 commit 37fd456

File tree

1 file changed

+44
-44
lines changed

1 file changed

+44
-44
lines changed

keras/layers/core.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
from .. import activations, initializations, regularizers, constraints
1313
from ..regularizers import ActivityRegularizer
1414

15-
import marshal
16-
import types
17-
import sys
15+
import inspect
1816

1917

2018
class Layer(object):
@@ -1501,25 +1499,20 @@ class Lambda(Layer):
15011499
Takes one argument: the output of previous layer
15021500
output_shape: Expected output shape from function.
15031501
Could be a tuple or a function of the shape of the input
1502+
arguments: optional dictionary of keyword arguments to be passed
1503+
to the function.
15041504
'''
1505-
def __init__(self, function, output_shape=None, **kwargs):
1505+
def __init__(self, function, output_shape=None, arguments={}, **kwargs):
15061506
super(Lambda, self).__init__(**kwargs)
1507-
py3 = sys.version_info[0] == 3
1508-
if py3:
1509-
self.function = marshal.dumps(function.__code__)
1510-
else:
1511-
assert hasattr(function, 'func_code'), ('The Lambda layer "function"'
1512-
' argument must be a Python function.')
1513-
self.function = marshal.dumps(function.func_code)
1507+
self.function = function
1508+
self.arguments = arguments
15141509
if output_shape is None:
15151510
self._output_shape = None
15161511
elif type(output_shape) in {tuple, list}:
15171512
self._output_shape = tuple(output_shape)
15181513
else:
1519-
if py3:
1520-
self._output_shape = marshal.dumps(output_shape.__code__)
1521-
else:
1522-
self._output_shape = marshal.dumps(output_shape.func_code)
1514+
assert hasattr(output_shape, '__call__'), 'In Lambda, `output_shape` must be a list, a tuple, or a function.'
1515+
self._output_shape = output_shape
15231516
super(Lambda, self).__init__()
15241517

15251518
@property
@@ -1528,26 +1521,34 @@ def output_shape(self):
15281521
# if TensorFlow, we can infer the output shape directly:
15291522
if K._BACKEND == 'tensorflow':
15301523
# we assume output shape is not dependent on train/test mode
1531-
x = self.get_input()
1524+
x = self.get_output()
15321525
return K.int_shape(x)
15331526
# otherwise, we default to the input shape
15341527
return self.input_shape
1535-
elif type(self._output_shape) == tuple:
1528+
elif type(self._output_shape) in {tuple, list}:
15361529
nb_samples = self.input_shape[0] if self.input_shape else None
1537-
return (nb_samples, ) + self._output_shape
1530+
return (nb_samples,) + tuple(self._output_shape)
15381531
else:
1539-
output_shape_func = marshal.loads(self._output_shape)
1540-
output_shape_func = types.FunctionType(output_shape_func, globals())
1541-
shape = output_shape_func(self.input_shape)
1532+
shape = self._output_shape(self.input_shape)
15421533
if type(shape) not in {list, tuple}:
15431534
raise Exception('output_shape function must return a tuple')
15441535
return tuple(shape)
15451536

15461537
def get_output(self, train=False):
15471538
X = self.get_input(train)
1548-
func = marshal.loads(self.function)
1549-
func = types.FunctionType(func, globals())
1550-
return func(X)
1539+
arguments = self.arguments
1540+
arg_spec = inspect.getargspec(self.function)
1541+
if 'train' in arg_spec.args:
1542+
arguments['train'] = train
1543+
return self.function(X, **arguments)
1544+
1545+
def get_config(self):
1546+
# note: not serializable at the moment.
1547+
config = {'function': self.function,
1548+
'output_shape': self._output_shape,
1549+
'arguments': self.arguments}
1550+
base_config = super(Lambda, self).get_config()
1551+
return dict(list(base_config.items()) + list(config.items()))
15511552

15521553

15531554
class MaskedLambda(MaskedLayer, Lambda):
@@ -1567,8 +1568,10 @@ class LambdaMerge(Lambda):
15671568
list of outputs from input layers
15681569
output_shape - Expected output shape from function.
15691570
Could be a tuple or a function of list of input shapes
1571+
arguments: optional dictionary of keyword arguments to be passed
1572+
to the function.
15701573
'''
1571-
def __init__(self, layers, function, output_shape=None):
1574+
def __init__(self, layers, function, output_shape=None, arguments={}):
15721575
if len(layers) < 2:
15731576
raise Exception('Please specify two or more input layers '
15741577
'(or containers) to merge.')
@@ -1577,6 +1580,7 @@ def __init__(self, layers, function, output_shape=None):
15771580
self.regularizers = []
15781581
self.constraints = []
15791582
self.updates = []
1583+
self.arguments = arguments
15801584
for l in self.layers:
15811585
params, regs, consts, updates = l.get_params()
15821586
self.regularizers += regs
@@ -1586,45 +1590,39 @@ def __init__(self, layers, function, output_shape=None):
15861590
if p not in self.trainable_weights:
15871591
self.trainable_weights.append(p)
15881592
self.constraints.append(c)
1589-
py3 = sys.version_info[0] == 3
1590-
if py3:
1591-
self.function = marshal.dumps(function.__code__)
1592-
else:
1593-
self.function = marshal.dumps(function.func_code)
1593+
self.function = function
15941594
if output_shape is None:
15951595
self._output_shape = None
15961596
elif type(output_shape) in {tuple, list}:
15971597
self._output_shape = tuple(output_shape)
15981598
else:
1599-
if py3:
1600-
self._output_shape = marshal.dumps(output_shape.__code__)
1601-
else:
1602-
self._output_shape = marshal.dumps(output_shape.func_code)
1599+
assert hasattr(output_shape, '__call__'), 'In LambdaMerge, `output_shape` must be a list, a tuple, or a function.'
1600+
self._output_shape = output_shape
16031601
super(Lambda, self).__init__()
16041602

16051603
@property
16061604
def output_shape(self):
16071605
input_shapes = [layer.output_shape for layer in self.layers]
16081606
if self._output_shape is None:
16091607
return input_shapes[0]
1610-
elif type(self._output_shape) == tuple:
1611-
return (input_shapes[0][0], ) + self._output_shape
1608+
elif type(self._output_shape) in {tuple, list}:
1609+
return (input_shapes[0][0],) + self._output_shape
16121610
else:
1613-
output_shape_func = marshal.loads(self._output_shape)
1614-
output_shape_func = types.FunctionType(output_shape_func, globals())
1615-
shape = output_shape_func(input_shapes)
1611+
shape = self._output_shape(input_shapes)
16161612
if type(shape) not in {list, tuple}:
1617-
raise Exception('output_shape function must return a tuple.')
1613+
raise Exception('In LambdaMerge, the `output_shape` function must return a tuple.')
16181614
return tuple(shape)
16191615

16201616
def get_params(self):
16211617
return self.trainable_weights, self.regularizers, self.constraints, self.updates
16221618

16231619
def get_output(self, train=False):
1624-
func = marshal.loads(self.function)
1625-
func = types.FunctionType(func, globals())
16261620
inputs = [layer.get_output(train) for layer in self.layers]
1627-
return func(inputs)
1621+
arguments = self.arguments
1622+
arg_spec = inspect.getargspec(self.function)
1623+
if 'train' in arg_spec.args:
1624+
arguments['train'] = train
1625+
return self.function(inputs, **arguments)
16281626

16291627
def get_input(self, train=False):
16301628
res = []
@@ -1660,10 +1658,12 @@ def set_weights(self, weights):
16601658
weights = weights[nb_param:]
16611659

16621660
def get_config(self):
1661+
# note: not serializable at the moment.
16631662
config = {'name': self.__class__.__name__,
16641663
'layers': [l.get_config() for l in self.layers],
16651664
'function': self.function,
1666-
'output_shape': self._output_shape}
1665+
'output_shape': self._output_shape,
1666+
'arguments': self.arguments}
16671667
base_config = super(LambdaMerge, self).get_config()
16681668
return dict(list(base_config.items()) + list(config.items()))
16691669

0 commit comments

Comments
 (0)