1212from .. import activations , initializations , regularizers , constraints
1313from ..regularizers import ActivityRegularizer
1414
15- import marshal
16- import types
17- import sys
15+ import inspect
1816
1917
2018class Layer (object ):
@@ -1501,25 +1499,20 @@ class Lambda(Layer):
15011499 Takes one argument: the output of previous layer
15021500 output_shape: Expected output shape from function.
15031501 Could be a tuple or a function of the shape of the input
1502+ arguments: optional dictionary of keyword arguments to be passed
1503+ to the function.
15041504 '''
1505- def __init__ (self , function , output_shape = None , ** kwargs ):
1505+ def __init__ (self , function , output_shape = None , arguments = {}, ** kwargs ):
15061506 super (Lambda , self ).__init__ (** kwargs )
1507- py3 = sys .version_info [0 ] == 3
1508- if py3 :
1509- self .function = marshal .dumps (function .__code__ )
1510- else :
1511- assert hasattr (function , 'func_code' ), ('The Lambda layer "function"'
1512- ' argument must be a Python function.' )
1513- self .function = marshal .dumps (function .func_code )
1507+ self .function = function
1508+ self .arguments = arguments
15141509 if output_shape is None :
15151510 self ._output_shape = None
15161511 elif type (output_shape ) in {tuple , list }:
15171512 self ._output_shape = tuple (output_shape )
15181513 else :
1519- if py3 :
1520- self ._output_shape = marshal .dumps (output_shape .__code__ )
1521- else :
1522- self ._output_shape = marshal .dumps (output_shape .func_code )
1514+ assert hasattr (output_shape , '__call__' ), 'In Lambda, `output_shape` must be a list, a tuple, or a function.'
1515+ self ._output_shape = output_shape
15231516 super (Lambda , self ).__init__ ()
15241517
15251518 @property
@@ -1528,26 +1521,34 @@ def output_shape(self):
15281521 # if TensorFlow, we can infer the output shape directly:
15291522 if K ._BACKEND == 'tensorflow' :
15301523 # we assume output shape is not dependent on train/test mode
1531- x = self .get_input ()
1524+ x = self .get_output ()
15321525 return K .int_shape (x )
15331526 # otherwise, we default to the input shape
15341527 return self .input_shape
1535- elif type (self ._output_shape ) == tuple :
1528+ elif type (self ._output_shape ) in { tuple , list } :
15361529 nb_samples = self .input_shape [0 ] if self .input_shape else None
1537- return (nb_samples , ) + self ._output_shape
1530+ return (nb_samples ,) + tuple ( self ._output_shape )
15381531 else :
1539- output_shape_func = marshal .loads (self ._output_shape )
1540- output_shape_func = types .FunctionType (output_shape_func , globals ())
1541- shape = output_shape_func (self .input_shape )
1532+ shape = self ._output_shape (self .input_shape )
15421533 if type (shape ) not in {list , tuple }:
15431534 raise Exception ('output_shape function must return a tuple' )
15441535 return tuple (shape )
15451536
15461537 def get_output (self , train = False ):
15471538 X = self .get_input (train )
1548- func = marshal .loads (self .function )
1549- func = types .FunctionType (func , globals ())
1550- return func (X )
1539+ arguments = self .arguments
1540+ arg_spec = inspect .getargspec (self .function )
1541+ if 'train' in arg_spec .args :
1542+ arguments ['train' ] = train
1543+ return self .function (X , ** arguments )
1544+
1545+ def get_config (self ):
1546+ # note: not serializable at the moment.
1547+ config = {'function' : self .function ,
1548+ 'output_shape' : self ._output_shape ,
1549+ 'arguments' : self .arguments }
1550+ base_config = super (Lambda , self ).get_config ()
1551+ return dict (list (base_config .items ()) + list (config .items ()))
15511552
15521553
15531554class MaskedLambda (MaskedLayer , Lambda ):
@@ -1567,8 +1568,10 @@ class LambdaMerge(Lambda):
15671568 list of outputs from input layers
15681569 output_shape - Expected output shape from function.
15691570 Could be a tuple or a function of list of input shapes
1571+ arguments: optional dictionary of keyword arguments to be passed
1572+ to the function.
15701573 '''
1571- def __init__ (self , layers , function , output_shape = None ):
1574+ def __init__ (self , layers , function , output_shape = None , arguments = {} ):
15721575 if len (layers ) < 2 :
15731576 raise Exception ('Please specify two or more input layers '
15741577 '(or containers) to merge.' )
@@ -1577,6 +1580,7 @@ def __init__(self, layers, function, output_shape=None):
15771580 self .regularizers = []
15781581 self .constraints = []
15791582 self .updates = []
1583+ self .arguments = arguments
15801584 for l in self .layers :
15811585 params , regs , consts , updates = l .get_params ()
15821586 self .regularizers += regs
@@ -1586,45 +1590,39 @@ def __init__(self, layers, function, output_shape=None):
15861590 if p not in self .trainable_weights :
15871591 self .trainable_weights .append (p )
15881592 self .constraints .append (c )
1589- py3 = sys .version_info [0 ] == 3
1590- if py3 :
1591- self .function = marshal .dumps (function .__code__ )
1592- else :
1593- self .function = marshal .dumps (function .func_code )
1593+ self .function = function
15941594 if output_shape is None :
15951595 self ._output_shape = None
15961596 elif type (output_shape ) in {tuple , list }:
15971597 self ._output_shape = tuple (output_shape )
15981598 else :
1599- if py3 :
1600- self ._output_shape = marshal .dumps (output_shape .__code__ )
1601- else :
1602- self ._output_shape = marshal .dumps (output_shape .func_code )
1599+ assert hasattr (output_shape , '__call__' ), 'In LambdaMerge, `output_shape` must be a list, a tuple, or a function.'
1600+ self ._output_shape = output_shape
16031601 super (Lambda , self ).__init__ ()
16041602
16051603 @property
16061604 def output_shape (self ):
16071605 input_shapes = [layer .output_shape for layer in self .layers ]
16081606 if self ._output_shape is None :
16091607 return input_shapes [0 ]
1610- elif type (self ._output_shape ) == tuple :
1611- return (input_shapes [0 ][0 ], ) + self ._output_shape
1608+ elif type (self ._output_shape ) in { tuple , list } :
1609+ return (input_shapes [0 ][0 ],) + self ._output_shape
16121610 else :
1613- output_shape_func = marshal .loads (self ._output_shape )
1614- output_shape_func = types .FunctionType (output_shape_func , globals ())
1615- shape = output_shape_func (input_shapes )
1611+ shape = self ._output_shape (input_shapes )
16161612 if type (shape ) not in {list , tuple }:
1617- raise Exception ('output_shape function must return a tuple.' )
1613+ raise Exception ('In LambdaMerge, the ` output_shape` function must return a tuple.' )
16181614 return tuple (shape )
16191615
16201616 def get_params (self ):
16211617 return self .trainable_weights , self .regularizers , self .constraints , self .updates
16221618
16231619 def get_output (self , train = False ):
1624- func = marshal .loads (self .function )
1625- func = types .FunctionType (func , globals ())
16261620 inputs = [layer .get_output (train ) for layer in self .layers ]
1627- return func (inputs )
1621+ arguments = self .arguments
1622+ arg_spec = inspect .getargspec (self .function )
1623+ if 'train' in arg_spec .args :
1624+ arguments ['train' ] = train
1625+ return self .function (inputs , ** arguments )
16281626
16291627 def get_input (self , train = False ):
16301628 res = []
@@ -1660,10 +1658,12 @@ def set_weights(self, weights):
16601658 weights = weights [nb_param :]
16611659
16621660 def get_config (self ):
1661+ # note: not serializable at the moment.
16631662 config = {'name' : self .__class__ .__name__ ,
16641663 'layers' : [l .get_config () for l in self .layers ],
16651664 'function' : self .function ,
1666- 'output_shape' : self ._output_shape }
1665+ 'output_shape' : self ._output_shape ,
1666+ 'arguments' : self .arguments }
16671667 base_config = super (LambdaMerge , self ).get_config ()
16681668 return dict (list (base_config .items ()) + list (config .items ()))
16691669
0 commit comments