44import dis
55import importlib
66import inspect
7+ import linecache
78import os
89import re
910import sys
1415from numbers import Number
1516from pathlib import Path
1617
18+ from astor import to_source
1719
1820from sqlmesh .core import constants as c
1921from sqlmesh .utils import format_exception , unique
2022from sqlmesh .utils .errors import SQLMeshError
2123from sqlmesh .utils .pydantic import PydanticModel
2224
23- IGNORED_DECORATORS = {"macro" , "model" , "signal" }
24- IGNORED_DECORATOR_CALL_REGEX = re .compile (
25- rf'(?s)@({ "|" .join (IGNORED_DECORATORS )} )\s*\(.*?\)(\s|#.*?\n)*def'
26- )
27-
25+ IGNORE_DECORATORS = {"macro" , "model" , "signal" }
2826SERIALIZABLE_CALLABLES = (type , types .FunctionType )
2927LITERALS = (Number , str , bytes , tuple , list , dict , set , bool )
3028
@@ -102,6 +100,40 @@ def func_globals(func: t.Callable) -> t.Dict[str, t.Any]:
102100 return variables
103101
104102
103+ class ClassFoundException (Exception ):
104+ pass
105+
106+
107+ class _ClassFinder (ast .NodeVisitor ):
108+ def __init__ (self , qualname : str ) -> None :
109+ self .stack : t .List [str ] = []
110+ self .qualname = qualname
111+
112+ def visit_FunctionDef (self , node : ast .FunctionDef ) -> None :
113+ self .stack .append (node .name )
114+ self .stack .append ("<locals>" )
115+ self .generic_visit (node )
116+ self .stack .pop ()
117+ self .stack .pop ()
118+
119+ visit_AsyncFunctionDef = visit_FunctionDef # type: ignore
120+
121+ def visit_ClassDef (self , node : ast .ClassDef ) -> None :
122+ self .stack .append (node .name )
123+ if self .qualname == "." .join (self .stack ):
124+ # Return the decorator for the class if present
125+ if node .decorator_list :
126+ line_number = node .decorator_list [0 ].lineno
127+ else :
128+ line_number = node .lineno
129+
130+ # decrement by one since lines starts with indexing by zero
131+ line_number -= 1
132+ raise ClassFoundException (line_number )
133+ self .generic_visit (node )
134+ self .stack .pop ()
135+
136+
105137class _DecoratorDependencyFinder (ast .NodeVisitor ):
106138 def __init__ (self ) -> None :
107139 self .dependencies : t .List [str ] = []
@@ -117,7 +149,7 @@ def _extract_dependencies(self, node: ast.ClassDef | ast.FunctionDef) -> None:
117149 else :
118150 continue
119151
120- if dep in IGNORED_DECORATORS :
152+ if dep in IGNORE_DECORATORS :
121153 dependencies = []
122154 break
123155
@@ -134,9 +166,53 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
134166 visit_AsyncFunctionDef = visit_FunctionDef # type: ignore
135167
136168
169+ def getsource (obj : t .Any ) -> str :
170+ """Get the source of a function or class.
171+
172+ inspect.getsource doesn't find decorators in python < 3.9
173+ https://github.com/python/cpython/commit/696136b993e11b37c4f34d729a0375e5ad544ade
174+ """
175+ path = inspect .getsourcefile (obj )
176+ if path :
177+ module = inspect .getmodule (obj , path )
178+
179+ if module :
180+ lines = linecache .getlines (path , module .__dict__ )
181+ else :
182+ lines = linecache .getlines (path )
183+
184+ def join_source (lnum : int ) -> str :
185+ return "" .join (inspect .getblock (lines [lnum :]))
186+
187+ if inspect .isclass (obj ):
188+ qualname = obj .__qualname__
189+ source = "" .join (lines )
190+ tree = ast .parse (source )
191+ class_finder = _ClassFinder (qualname )
192+ try :
193+ class_finder .visit (tree )
194+ except ClassFoundException as e :
195+ return join_source (e .args [0 ])
196+ elif inspect .isfunction (obj ):
197+ obj = obj .__code__
198+ if hasattr (obj , "co_firstlineno" ):
199+ lnum = obj .co_firstlineno - 1
200+ pat = re .compile (r"^(\s*def\s)|(\s*async\s+def\s)|(.*(?<!\w)lambda(:|\s))|^(\s*@)" )
201+ while lnum > 0 :
202+ try :
203+ line = lines [lnum ]
204+ except IndexError :
205+ raise OSError ("lineno is out of bounds" )
206+ if pat .match (line ):
207+ break
208+ lnum = lnum - 1
209+ return join_source (lnum )
210+ raise SQLMeshError (f"Cannot find source for { obj } " )
211+
212+
137213def parse_source (func : t .Callable ) -> ast .Module :
138214 """Parse a function and returns an ast node."""
139- return ast .parse (textwrap .dedent (inspect . getsource (func )))
215+ return ast .parse (textwrap .dedent (getsource (func )))
140216
141217
142218def _decorator_name (decorator : ast .expr ) -> str :
@@ -158,14 +234,33 @@ def decorator_vars(func: t.Callable, root_node: t.Optional[ast.Module] = None) -
158234 return unique (finder .dependencies )
159235
160236
161- def remove_ignored_decorators (source : str ) -> str :
162- """
163- Removes decorator calls like @model(...) from the Python source code.
237+ def normalize_source (obj : t .Any ) -> str :
238+ """Rewrites an object's source with formatting and doc strings removed by using Python ast.
164239
165- We do this because we don't need to serialize the decorator or any value within its argument
166- list when hydrating the python environment; we only need the function definition itself.
240+ Args:
241+ obj: The object to fetch source from and convert to a string.
242+
243+ Returns:
244+ A string representation of the normalized function.
167245 """
168- return IGNORED_DECORATOR_CALL_REGEX .sub ("def" , source )
246+ root_node = parse_source (obj )
247+
248+ for node in ast .walk (root_node ):
249+ if isinstance (node , (ast .FunctionDef , ast .ClassDef )):
250+ for decorator in node .decorator_list :
251+ if _decorator_name (decorator ) in IGNORE_DECORATORS :
252+ node .decorator_list .remove (decorator )
253+
254+ # remove docstrings
255+ body = node .body
256+ if body and isinstance (body [0 ], ast .Expr ) and isinstance (body [0 ].value , ast .Str ):
257+ node .body = body [1 :]
258+
259+ # remove function return type annotation
260+ if isinstance (node , ast .FunctionDef ):
261+ node .returns = None
262+
263+ return to_source (root_node ).strip ()
169264
170265
171266def build_env (
@@ -352,7 +447,7 @@ def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable
352447 if relative_obj_file_path :
353448 serialized [k ] = Executable (
354449 name = name ,
355- payload = remove_ignored_decorators ( inspect . getsource ( v ). strip () ),
450+ payload = normalize_source ( v ),
356451 kind = ExecutableKind .DEFINITION ,
357452 # Do `as_posix` to serialize windows path back to POSIX
358453 path = t .cast (Path , file_path ).relative_to (path .absolute ()).as_posix (),
0 commit comments