Skip to content

Commit 3c2e44a

Browse files
authored
Revert "Fix!: use inspect.getsource in favor of custom parsing, stop relying on astor" (#3909)
1 parent a14eb36 commit 3c2e44a

8 files changed

Lines changed: 174 additions & 78 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
"types": [python]
2323
files: *files
2424
require_serial: true
25-
exclude: ^(tests/fixtures|tests/core/metaprogramming_test_helper\.py)
25+
exclude: ^(tests/fixtures)
2626
- repo: https://github.com/pre-commit/mirrors-prettier
2727
rev: "fc26039"
2828
hooks:

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ ignore_missing_imports = True
1818
[mypy-tests.*]
1919
disallow_untyped_defs = False
2020

21+
[mypy-astor.*]
22+
ignore_missing_imports = True
23+
2124
[mypy-IPython.*]
2225
ignore_missing_imports = True
2326

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
},
3333
setup_requires=["setuptools_scm"],
3434
install_requires=[
35+
"astor",
3536
"click",
3637
"croniter",
3738
"duckdb!=0.10.3",

sqlmesh/core/model/common.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import typing as t
55
from pathlib import Path
66

7+
from astor import to_source
78
from sqlglot import exp
89
from sqlglot.helper import ensure_list
910

@@ -140,7 +141,7 @@ def parse_dependencies(
140141

141142
def get_first_arg(keyword_arg_name: str) -> t.Any:
142143
if node.args:
143-
first_arg: t.Optional[ast.AST] = node.args[0]
144+
first_arg: t.Optional[ast.expr] = node.args[0]
144145
else:
145146
first_arg = next(
146147
(
@@ -151,20 +152,14 @@ def get_first_arg(keyword_arg_name: str) -> t.Any:
151152
None,
152153
)
153154

154-
if first_arg is None:
155-
raise_config_error(
156-
f"Missing {keyword_arg_name} argument in '{func.attr}' call.",
157-
executable.path,
158-
)
159-
160155
try:
161-
expression = ast.unparse(t.cast(ast.AST, first_arg))
156+
expression = to_source(first_arg)
162157
return eval(expression, env)
163158
except Exception:
164159
if strict_resolution:
165-
raise_config_error(
166-
f"Argument '{expression.strip()}' must be resolvable at parse time",
167-
executable.path,
160+
raise ConfigError(
161+
f"Error resolving dependencies for '{executable.path}'. "
162+
f"Argument '{expression.strip()}' must be resolvable at parse time."
168163
)
169164

170165
if func.value.id == "context" and func.attr in ("table", "resolve_table"):

sqlmesh/utils/metaprogramming.py

Lines changed: 109 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import dis
55
import importlib
66
import inspect
7+
import linecache
78
import os
89
import re
910
import sys
@@ -14,17 +15,14 @@
1415
from numbers import Number
1516
from pathlib import Path
1617

18+
from astor import to_source
1719

1820
from sqlmesh.core import constants as c
1921
from sqlmesh.utils import format_exception, unique
2022
from sqlmesh.utils.errors import SQLMeshError
2123
from sqlmesh.utils.pydantic import PydanticModel
2224

23-
IGNORED_DECORATORS = {"macro", "model", "signal"}
24-
IGNORED_DECORATOR_CALL_REGEX = re.compile(
25-
rf'(?s)@({"|".join(IGNORED_DECORATORS)})\s*\(.*?\)(\s|#.*?\n)*def'
26-
)
27-
25+
IGNORE_DECORATORS = {"macro", "model", "signal"}
2826
SERIALIZABLE_CALLABLES = (type, types.FunctionType)
2927
LITERALS = (Number, str, bytes, tuple, list, dict, set, bool)
3028

@@ -102,6 +100,40 @@ def func_globals(func: t.Callable) -> t.Dict[str, t.Any]:
102100
return variables
103101

104102

103+
class ClassFoundException(Exception):
104+
pass
105+
106+
107+
class _ClassFinder(ast.NodeVisitor):
108+
def __init__(self, qualname: str) -> None:
109+
self.stack: t.List[str] = []
110+
self.qualname = qualname
111+
112+
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
113+
self.stack.append(node.name)
114+
self.stack.append("<locals>")
115+
self.generic_visit(node)
116+
self.stack.pop()
117+
self.stack.pop()
118+
119+
visit_AsyncFunctionDef = visit_FunctionDef # type: ignore
120+
121+
def visit_ClassDef(self, node: ast.ClassDef) -> None:
122+
self.stack.append(node.name)
123+
if self.qualname == ".".join(self.stack):
124+
# Return the decorator for the class if present
125+
if node.decorator_list:
126+
line_number = node.decorator_list[0].lineno
127+
else:
128+
line_number = node.lineno
129+
130+
# decrement by one since lines starts with indexing by zero
131+
line_number -= 1
132+
raise ClassFoundException(line_number)
133+
self.generic_visit(node)
134+
self.stack.pop()
135+
136+
105137
class _DecoratorDependencyFinder(ast.NodeVisitor):
106138
def __init__(self) -> None:
107139
self.dependencies: t.List[str] = []
@@ -117,7 +149,7 @@ def _extract_dependencies(self, node: ast.ClassDef | ast.FunctionDef) -> None:
117149
else:
118150
continue
119151

120-
if dep in IGNORED_DECORATORS:
152+
if dep in IGNORE_DECORATORS:
121153
dependencies = []
122154
break
123155

@@ -134,9 +166,53 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
134166
visit_AsyncFunctionDef = visit_FunctionDef # type: ignore
135167

136168

169+
def getsource(obj: t.Any) -> str:
170+
"""Get the source of a function or class.
171+
172+
inspect.getsource doesn't find decorators in python < 3.9
173+
https://github.com/python/cpython/commit/696136b993e11b37c4f34d729a0375e5ad544ade
174+
"""
175+
path = inspect.getsourcefile(obj)
176+
if path:
177+
module = inspect.getmodule(obj, path)
178+
179+
if module:
180+
lines = linecache.getlines(path, module.__dict__)
181+
else:
182+
lines = linecache.getlines(path)
183+
184+
def join_source(lnum: int) -> str:
185+
return "".join(inspect.getblock(lines[lnum:]))
186+
187+
if inspect.isclass(obj):
188+
qualname = obj.__qualname__
189+
source = "".join(lines)
190+
tree = ast.parse(source)
191+
class_finder = _ClassFinder(qualname)
192+
try:
193+
class_finder.visit(tree)
194+
except ClassFoundException as e:
195+
return join_source(e.args[0])
196+
elif inspect.isfunction(obj):
197+
obj = obj.__code__
198+
if hasattr(obj, "co_firstlineno"):
199+
lnum = obj.co_firstlineno - 1
200+
pat = re.compile(r"^(\s*def\s)|(\s*async\s+def\s)|(.*(?<!\w)lambda(:|\s))|^(\s*@)")
201+
while lnum > 0:
202+
try:
203+
line = lines[lnum]
204+
except IndexError:
205+
raise OSError("lineno is out of bounds")
206+
if pat.match(line):
207+
break
208+
lnum = lnum - 1
209+
return join_source(lnum)
210+
raise SQLMeshError(f"Cannot find source for {obj}")
211+
212+
137213
def parse_source(func: t.Callable) -> ast.Module:
138214
"""Parse a function and returns an ast node."""
139-
return ast.parse(textwrap.dedent(inspect.getsource(func)))
215+
return ast.parse(textwrap.dedent(getsource(func)))
140216

141217

142218
def _decorator_name(decorator: ast.expr) -> str:
@@ -158,14 +234,33 @@ def decorator_vars(func: t.Callable, root_node: t.Optional[ast.Module] = None) -
158234
return unique(finder.dependencies)
159235

160236

161-
def remove_ignored_decorators(source: str) -> str:
162-
"""
163-
Removes decorator calls like @model(...) from the Python source code.
237+
def normalize_source(obj: t.Any) -> str:
238+
"""Rewrites an object's source with formatting and doc strings removed by using Python ast.
164239
165-
We do this because we don't need to serialize the decorator or any value within its argument
166-
list when hydrating the python environment; we only need the function definition itself.
240+
Args:
241+
obj: The object to fetch source from and convert to a string.
242+
243+
Returns:
244+
A string representation of the normalized function.
167245
"""
168-
return IGNORED_DECORATOR_CALL_REGEX.sub("def", source)
246+
root_node = parse_source(obj)
247+
248+
for node in ast.walk(root_node):
249+
if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
250+
for decorator in node.decorator_list:
251+
if _decorator_name(decorator) in IGNORE_DECORATORS:
252+
node.decorator_list.remove(decorator)
253+
254+
# remove docstrings
255+
body = node.body
256+
if body and isinstance(body[0], ast.Expr) and isinstance(body[0].value, ast.Str):
257+
node.body = body[1:]
258+
259+
# remove function return type annotation
260+
if isinstance(node, ast.FunctionDef):
261+
node.returns = None
262+
263+
return to_source(root_node).strip()
169264

170265

171266
def build_env(
@@ -352,7 +447,7 @@ def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable
352447
if relative_obj_file_path:
353448
serialized[k] = Executable(
354449
name=name,
355-
payload=remove_ignored_decorators(inspect.getsource(v).strip()),
450+
payload=normalize_source(v),
356451
kind=ExecutableKind.DEFINITION,
357452
# Do `as_posix` to serialize windows path back to POSIX
358453
path=t.cast(Path, file_path).relative_to(path.absolute()).as_posix(),

tests/core/metaprogramming_test_helper.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

tests/core/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4576,7 +4576,7 @@ def test_default_catalog_sql(assert_exp_eq):
45764576

45774577

45784578
def test_default_catalog_python():
4579-
HASH_WITH_CATALOG = "1357281693"
4579+
HASH_WITH_CATALOG = "770057346"
45804580

45814581
@model(name="db.table", kind="full", columns={'"COL"': "int"})
45824582
def my_model(context, **kwargs):

0 commit comments

Comments
 (0)