Skip to content

Commit 4927a27

Browse files
authored
Fix!: use inspect.getsource in favor of custom parsing, stop relying on astor (#3857)
1 parent bb7cb1a commit 4927a27

8 files changed

Lines changed: 78 additions & 174 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
"types": [python]
2323
files: *files
2424
require_serial: true
25-
exclude: ^(tests/fixtures)
25+
exclude: ^(tests/fixtures|tests/core/metaprogramming_test_helper\.py)
2626
- repo: https://github.com/pre-commit/mirrors-prettier
2727
rev: "fc26039"
2828
hooks:

setup.cfg

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ ignore_missing_imports = True
1818
[mypy-tests.*]
1919
disallow_untyped_defs = False
2020

21-
[mypy-astor.*]
22-
ignore_missing_imports = True
23-
2421
[mypy-IPython.*]
2522
ignore_missing_imports = True
2623

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
},
3333
setup_requires=["setuptools_scm"],
3434
install_requires=[
35-
"astor",
3635
"click",
3736
"croniter",
3837
"duckdb!=0.10.3",

sqlmesh/core/model/common.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import typing as t
55
from pathlib import Path
66

7-
from astor import to_source
87
from sqlglot import exp
98
from sqlglot.helper import ensure_list
109

@@ -141,7 +140,7 @@ def parse_dependencies(
141140

142141
def get_first_arg(keyword_arg_name: str) -> t.Any:
143142
if node.args:
144-
first_arg: t.Optional[ast.expr] = node.args[0]
143+
first_arg: t.Optional[ast.AST] = node.args[0]
145144
else:
146145
first_arg = next(
147146
(
@@ -152,14 +151,20 @@ def get_first_arg(keyword_arg_name: str) -> t.Any:
152151
None,
153152
)
154153

154+
if first_arg is None:
155+
raise_config_error(
156+
f"Missing {keyword_arg_name} argument in '{func.attr}' call.",
157+
executable.path,
158+
)
159+
155160
try:
156-
expression = to_source(first_arg)
161+
expression = ast.unparse(t.cast(ast.AST, first_arg))
157162
return eval(expression, env)
158163
except Exception:
159164
if strict_resolution:
160-
raise ConfigError(
161-
f"Error resolving dependencies for '{executable.path}'. "
162-
f"Argument '{expression.strip()}' must be resolvable at parse time."
165+
raise_config_error(
166+
f"Argument '{expression.strip()}' must be resolvable at parse time",
167+
executable.path,
163168
)
164169

165170
if func.value.id == "context" and func.attr in ("table", "resolve_table"):

sqlmesh/utils/metaprogramming.py

Lines changed: 14 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import dis
55
import importlib
66
import inspect
7-
import linecache
87
import os
98
import re
109
import sys
@@ -15,14 +14,17 @@
1514
from numbers import Number
1615
from pathlib import Path
1716

18-
from astor import to_source
1917

2018
from sqlmesh.core import constants as c
2119
from sqlmesh.utils import format_exception, unique
2220
from sqlmesh.utils.errors import SQLMeshError
2321
from sqlmesh.utils.pydantic import PydanticModel
2422

25-
IGNORE_DECORATORS = {"macro", "model", "signal"}
23+
IGNORED_DECORATORS = {"macro", "model", "signal"}
24+
IGNORED_DECORATOR_CALL_REGEX = re.compile(
25+
rf'(?s)@({"|".join(IGNORED_DECORATORS)})\s*\(.*?\)(\s|#.*?\n)*def'
26+
)
27+
2628
SERIALIZABLE_CALLABLES = (type, types.FunctionType)
2729
LITERALS = (Number, str, bytes, tuple, list, dict, set, bool)
2830

@@ -100,40 +102,6 @@ def func_globals(func: t.Callable) -> t.Dict[str, t.Any]:
100102
return variables
101103

102104

103-
class ClassFoundException(Exception):
104-
pass
105-
106-
107-
class _ClassFinder(ast.NodeVisitor):
108-
def __init__(self, qualname: str) -> None:
109-
self.stack: t.List[str] = []
110-
self.qualname = qualname
111-
112-
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
113-
self.stack.append(node.name)
114-
self.stack.append("<locals>")
115-
self.generic_visit(node)
116-
self.stack.pop()
117-
self.stack.pop()
118-
119-
visit_AsyncFunctionDef = visit_FunctionDef # type: ignore
120-
121-
def visit_ClassDef(self, node: ast.ClassDef) -> None:
122-
self.stack.append(node.name)
123-
if self.qualname == ".".join(self.stack):
124-
# Return the decorator for the class if present
125-
if node.decorator_list:
126-
line_number = node.decorator_list[0].lineno
127-
else:
128-
line_number = node.lineno
129-
130-
# decrement by one since lines starts with indexing by zero
131-
line_number -= 1
132-
raise ClassFoundException(line_number)
133-
self.generic_visit(node)
134-
self.stack.pop()
135-
136-
137105
class _DecoratorDependencyFinder(ast.NodeVisitor):
138106
def __init__(self) -> None:
139107
self.dependencies: t.List[str] = []
@@ -149,7 +117,7 @@ def _extract_dependencies(self, node: ast.ClassDef | ast.FunctionDef) -> None:
149117
else:
150118
continue
151119

152-
if dep in IGNORE_DECORATORS:
120+
if dep in IGNORED_DECORATORS:
153121
dependencies = []
154122
break
155123

@@ -166,53 +134,9 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
166134
visit_AsyncFunctionDef = visit_FunctionDef # type: ignore
167135

168136

169-
def getsource(obj: t.Any) -> str:
170-
"""Get the source of a function or class.
171-
172-
inspect.getsource doesn't find decorators in python < 3.9
173-
https://github.com/python/cpython/commit/696136b993e11b37c4f34d729a0375e5ad544ade
174-
"""
175-
path = inspect.getsourcefile(obj)
176-
if path:
177-
module = inspect.getmodule(obj, path)
178-
179-
if module:
180-
lines = linecache.getlines(path, module.__dict__)
181-
else:
182-
lines = linecache.getlines(path)
183-
184-
def join_source(lnum: int) -> str:
185-
return "".join(inspect.getblock(lines[lnum:]))
186-
187-
if inspect.isclass(obj):
188-
qualname = obj.__qualname__
189-
source = "".join(lines)
190-
tree = ast.parse(source)
191-
class_finder = _ClassFinder(qualname)
192-
try:
193-
class_finder.visit(tree)
194-
except ClassFoundException as e:
195-
return join_source(e.args[0])
196-
elif inspect.isfunction(obj):
197-
obj = obj.__code__
198-
if hasattr(obj, "co_firstlineno"):
199-
lnum = obj.co_firstlineno - 1
200-
pat = re.compile(r"^(\s*def\s)|(\s*async\s+def\s)|(.*(?<!\w)lambda(:|\s))|^(\s*@)")
201-
while lnum > 0:
202-
try:
203-
line = lines[lnum]
204-
except IndexError:
205-
raise OSError("lineno is out of bounds")
206-
if pat.match(line):
207-
break
208-
lnum = lnum - 1
209-
return join_source(lnum)
210-
raise SQLMeshError(f"Cannot find source for {obj}")
211-
212-
213137
def parse_source(func: t.Callable) -> ast.Module:
214138
"""Parse a function and returns an ast node."""
215-
return ast.parse(textwrap.dedent(getsource(func)))
139+
return ast.parse(textwrap.dedent(inspect.getsource(func)))
216140

217141

218142
def _decorator_name(decorator: ast.expr) -> str:
@@ -234,33 +158,14 @@ def decorator_vars(func: t.Callable, root_node: t.Optional[ast.Module] = None) -
234158
return unique(finder.dependencies)
235159

236160

237-
def normalize_source(obj: t.Any) -> str:
238-
"""Rewrites an object's source with formatting and doc strings removed by using Python ast.
239-
240-
Args:
241-
obj: The object to fetch source from and convert to a string.
242-
243-
Returns:
244-
A string representation of the normalized function.
161+
def remove_ignored_decorators(source: str) -> str:
245162
"""
246-
root_node = parse_source(obj)
247-
248-
for node in ast.walk(root_node):
249-
if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
250-
for decorator in node.decorator_list:
251-
if _decorator_name(decorator) in IGNORE_DECORATORS:
252-
node.decorator_list.remove(decorator)
163+
Removes decorator calls like @model(...) from the Python source code.
253164
254-
# remove docstrings
255-
body = node.body
256-
if body and isinstance(body[0], ast.Expr) and isinstance(body[0].value, ast.Str):
257-
node.body = body[1:]
258-
259-
# remove function return type annotation
260-
if isinstance(node, ast.FunctionDef):
261-
node.returns = None
262-
263-
return to_source(root_node).strip()
165+
We do this because we don't need to serialize the decorator or any value within its argument
166+
list when hydrating the python environment; we only need the function definition itself.
167+
"""
168+
return IGNORED_DECORATOR_CALL_REGEX.sub("def", source)
264169

265170

266171
def build_env(
@@ -447,7 +352,7 @@ def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable
447352
if relative_obj_file_path:
448353
serialized[k] = Executable(
449354
name=name,
450-
payload=normalize_source(v),
355+
payload=remove_ignored_decorators(inspect.getsource(v).strip()),
451356
kind=ExecutableKind.DEFINITION,
452357
# Do `as_posix` to serialize windows path back to POSIX
453358
path=t.cast(Path, file_path).relative_to(path.absolute()).as_posix(),
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# This module is conditionally imported to demonstrate that we can
2+
# serialize `match` statements that are available only for > v3.9.
3+
# If we included this code as-is in test_metaprogramming.py, the
4+
# test harness would crash when using Python 3.9 in CI
5+
6+
7+
def match_expression():
8+
match 5:
9+
case 5:
10+
return 1
11+
return 0

tests/core/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4576,7 +4576,7 @@ def test_default_catalog_sql(assert_exp_eq):
45764576

45774577

45784578
def test_default_catalog_python():
4579-
HASH_WITH_CATALOG = "770057346"
4579+
HASH_WITH_CATALOG = "1357281693"
45804580

45814581
@model(name="db.table", kind="full", columns={'"COL"': "int"})
45824582
def my_model(context, **kwargs):

0 commit comments

Comments
 (0)