diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index cb2883f80..c87ec1758 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -42,6 +42,7 @@ """ from __future__ import print_function +import ctypes import dis from functools import partial import imp @@ -69,6 +70,20 @@ from io import BytesIO as StringIO PY3 = True +pythonapi = None +try: + from ctypes import pythonapi +except ImportError: + pass + +PyCell_Set = None +if pythonapi: + try: + PyCell_Set = ctypes.PYFUNCTYPE(ctypes.c_int, ctypes.py_object, ctypes.py_object)( + ('PyCell_Set', ctypes.pythonapi), ((1, 'cell'), (1, 'value'))) + except AttributeError: + pass + #relevant opcodes STORE_GLOBAL = opcode.opmap['STORE_GLOBAL'] DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL'] @@ -129,6 +144,10 @@ def _walk_global_ops(code): yield op, instr.arg +def _make_cell(value): + return (lambda: value).__closure__[0] + + class CloudPickler(Pickler): dispatch = Pickler.dispatch.copy() @@ -333,9 +352,13 @@ def save_function_tuple(self, func): self._save_subimports(code, set(f_globals.values()) | set(closure)) + closure_cells = closure + if closure is not None and PyCell_Set is not None: + closure_cells = list(map(lambda _: _make_cell(None), closure)) + # create a skeleton function object and memoize it save(_make_skel_func) - save((code, closure, base_globals)) + save((code, closure_cells, base_globals)) write(pickle.REDUCE) self.memoize(func) @@ -343,9 +366,20 @@ def save_function_tuple(self, func): save(f_globals) save(defaults) save(dct) + save(closure) write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple + def save_cell(self, obj): + save = self.save + write = self.write + + save(_make_cell) + save((obj.cell_contents,)) + write(pickle.REDUCE) + + dispatch[_make_cell(None).__class__] = save_cell + _extract_code_globals_cache = ( weakref.WeakKeyDictionary() if sys.version_info >= (2, 7) and not hasattr(sys, "pypy_version_info") @@ -799,21 +833,20 @@ def _gen_ellipsis(): def _gen_not_implemented(): return NotImplemented -def _fill_function(func, globals, defaults, dict): +def _fill_function(func, globals, defaults, dict, closure=None): """ Fills in the rest of function data into the skeleton function object that were created via _make_skel_func(). """ func.__globals__.update(globals) func.__defaults__ = defaults func.__dict__ = dict + if closure is not None and PyCell_Set: + for i, v in enumerate(closure): + PyCell_Set(func.__closure__[i], v) return func -def _make_cell(value): - return (lambda: value).__closure__[0] - - def _reconstruct_closure(values): return tuple([_make_cell(v) for v in values]) diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index 29deb8cdf..afea3077b 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -12,6 +12,11 @@ import base64 import subprocess +try: + from ctypes import pythonapi +except ImportError: + pythonapi = None + try: # try importing numpy and scipy. These are not hard dependencies and # tests should be skipped if these modules are not available @@ -133,6 +138,20 @@ def test_nested_lambdas(self): f2 = lambda x: f1(x) // b self.assertEqual(pickle_depickle(f2)(1), 1) + @pytest.mark.skipif(not pythonapi or not hasattr(pythonapi, 'PyCell_Set'), + reason="missing required Python C API functionality") + def test_recursive_nested_function(self): + def f1(): + def g(): return g + return g + def f2(base): + def g(n): return base if n <= 1 else n * g(n - 1) + return g + g1 = pickle_depickle(f1()) + self.assertEqual(g1(), g1) + g2 = pickle_depickle(f2(2)) + self.assertEqual(g2(5), 240) + @pytest.mark.skipif(sys.version_info >= (3, 4) and sys.version_info < (3, 4, 3), reason="subprocess has a bug in 3.4.0 to 3.4.2")