From 36a4f86a1ff495c5c0129fbb3f824c08a6296c75 Mon Sep 17 00:00:00 2001 From: tobymao Date: Wed, 7 Dec 2022 23:05:12 -0800 Subject: [PATCH] add custom source parser because it's broken in < 3.9 --- sqlmesh/utils/metaprogramming.py | 96 ++++++++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 6 deletions(-) diff --git a/sqlmesh/utils/metaprogramming.py b/sqlmesh/utils/metaprogramming.py index c5c044debf..e7e4ac919a 100644 --- a/sqlmesh/utils/metaprogramming.py +++ b/sqlmesh/utils/metaprogramming.py @@ -1,6 +1,7 @@ import ast import dis import inspect +import linecache import os import re import sys @@ -52,8 +53,95 @@ def func_globals(func: t.Callable) -> t.Dict[str, t.Any]: return variables +class ClassFoundException(Exception): + pass + + +class _ClassFinder(ast.NodeVisitor): + def __init__(self, qualname): + self.stack = [] + self.qualname = qualname + + def visit_FunctionDef(self, node): + self.stack.append(node.name) + self.stack.append("") + self.generic_visit(node) + self.stack.pop() + self.stack.pop() + + visit_AsyncFunctionDef = visit_FunctionDef + + def visit_ClassDef(self, node): + self.stack.append(node.name) + if self.qualname == ".".join(self.stack): + # Return the decorator for the class if present + if node.decorator_list: + line_number = node.decorator_list[0].lineno + else: + line_number = node.lineno + + # decrement by one since lines starts with indexing by zero + line_number -= 1 + raise ClassFoundException(line_number) + self.generic_visit(node) + self.stack.pop() + + +def getsource(obj: t.Any) -> str: + """Get the source of a function or class. + + inspect.getsource doesn't find decorators in python < 3.9 + https://github.com/python/cpython/commit/696136b993e11b37c4f34d729a0375e5ad544ade + """ + path = inspect.getsourcefile(obj) + if path: + module = inspect.getmodule(obj, path) + + if module: + lines = linecache.getlines(path, module.__dict__) + else: + lines = linecache.getlines(path) + + def join_source(lnum: int) -> str: + return "".join(inspect.getblock(lines[lnum:])) + + if inspect.isclass(obj): + qualname = obj.__qualname__ + source = "".join(lines) + tree = ast.parse(source) + class_finder = _ClassFinder(qualname) + try: + class_finder.visit(tree) + except ClassFoundException as e: + return join_source(e.args[0]) + elif inspect.isfunction(obj): + obj = obj.__code__ + if hasattr(obj, "co_firstlineno"): + lnum = obj.co_firstlineno - 1 + pat = re.compile( + r"^(\s*def\s)|(\s*async\s+def\s)|(.*(? 0: + try: + line = lines[lnum] + except IndexError: + raise OSError("lineno is out of bounds") + if pat.match(line): + break + return join_source(lnum) + raise SQLMeshError(f"Cannot find source for {obj}") + + +def unparse(node: ast.Module) -> str: + if sys.version_info < (3, 9): + import astor + + return astor.to_source(node).strip() + return ast.unparse(node).strip() + + def _parse_source(func: t.Callable) -> ast.Module: - return ast.parse(textwrap.dedent(inspect.getsource(func))) + return ast.parse(textwrap.dedent(getsource(func))) def _decorator_name(decorator: ast.expr) -> str: @@ -110,11 +198,7 @@ def normalize_source(obj: t.Any) -> str: elif isinstance(node, ast.arg): node.annotation = None - if sys.version_info < (3, 9): - import astor - - return astor.to_source(root_node).strip() - return ast.unparse(root_node) + return unparse(root_node) def build_env(obj: t.Any, *, env: t.Dict[str, t.Any], name: str, module: str) -> None: