Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 90 additions & 6 deletions sqlmesh/utils/metaprogramming.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
import dis
import inspect
import linecache
import os
import re
import sys
Expand Down Expand Up @@ -52,8 +53,95 @@ def func_globals(func: t.Callable) -> t.Dict[str, t.Any]:
return variables


class ClassFoundException(Exception):
pass


class _ClassFinder(ast.NodeVisitor):
def __init__(self, qualname):
self.stack = []
self.qualname = qualname

def visit_FunctionDef(self, node):
self.stack.append(node.name)
self.stack.append("<locals>")
self.generic_visit(node)
self.stack.pop()
self.stack.pop()

visit_AsyncFunctionDef = visit_FunctionDef

def visit_ClassDef(self, node):
self.stack.append(node.name)
if self.qualname == ".".join(self.stack):
# Return the decorator for the class if present
if node.decorator_list:
line_number = node.decorator_list[0].lineno
else:
line_number = node.lineno

# decrement by one since lines starts with indexing by zero
line_number -= 1
raise ClassFoundException(line_number)
self.generic_visit(node)
self.stack.pop()


def getsource(obj: t.Any) -> str:
"""Get the source of a function or class.

inspect.getsource doesn't find decorators in python < 3.9
https://github.com/python/cpython/commit/696136b993e11b37c4f34d729a0375e5ad544ade
"""
path = inspect.getsourcefile(obj)
if path:
module = inspect.getmodule(obj, path)

if module:
lines = linecache.getlines(path, module.__dict__)
else:
lines = linecache.getlines(path)

def join_source(lnum: int) -> str:
return "".join(inspect.getblock(lines[lnum:]))

if inspect.isclass(obj):
qualname = obj.__qualname__
source = "".join(lines)
tree = ast.parse(source)
class_finder = _ClassFinder(qualname)
try:
class_finder.visit(tree)
except ClassFoundException as e:
return join_source(e.args[0])
elif inspect.isfunction(obj):
obj = obj.__code__
if hasattr(obj, "co_firstlineno"):
lnum = obj.co_firstlineno - 1
pat = re.compile(
r"^(\s*def\s)|(\s*async\s+def\s)|(.*(?<!\w)lambda(:|\s))|^(\s*@)"
)
while lnum > 0:
try:
line = lines[lnum]
except IndexError:
raise OSError("lineno is out of bounds")
if pat.match(line):
break
return join_source(lnum)
raise SQLMeshError(f"Cannot find source for {obj}")


def unparse(node: ast.Module) -> str:
if sys.version_info < (3, 9):
import astor

return astor.to_source(node).strip()
return ast.unparse(node).strip()


def _parse_source(func: t.Callable) -> ast.Module:
return ast.parse(textwrap.dedent(inspect.getsource(func)))
return ast.parse(textwrap.dedent(getsource(func)))


def _decorator_name(decorator: ast.expr) -> str:
Expand Down Expand Up @@ -110,11 +198,7 @@ def normalize_source(obj: t.Any) -> str:
elif isinstance(node, ast.arg):
node.annotation = None

if sys.version_info < (3, 9):
import astor

return astor.to_source(root_node).strip()
return ast.unparse(root_node)
return unparse(root_node)


def build_env(obj: t.Any, *, env: t.Dict[str, t.Any], name: str, module: str) -> None:
Expand Down