| 
2 | 2 | #  | 
3 | 3 | # SPDX-License-Identifier: MIT  | 
4 | 4 | 
 
  | 
 | 5 | +# Run with 'python tools/extract_pyi.py shared-bindings/ path/to/stub/dir  | 
 | 6 | +# You can also test a specific library in shared-bindings by putting the path  | 
 | 7 | +# to that directory instead  | 
 | 8 | + | 
 | 9 | +import ast  | 
5 | 10 | import os  | 
 | 11 | +import re  | 
6 | 12 | import sys  | 
7 |  | -import astroid  | 
8 | 13 | import traceback  | 
9 | 14 | 
 
  | 
10 |  | -top_level = sys.argv[1].strip("/")  | 
11 |  | -stub_directory = sys.argv[2]  | 
 | 15 | +import isort  | 
 | 16 | +import black  | 
 | 17 | + | 
 | 18 | + | 
 | 19 | +IMPORTS_IGNORE = frozenset({'int', 'float', 'bool', 'str', 'bytes', 'tuple', 'list', 'set', 'dict', 'bytearray', 'slice', 'file', 'buffer', 'range', 'array', 'struct_time'})  | 
 | 20 | +IMPORTS_TYPING = frozenset({'Any', 'Optional', 'Union', 'Tuple', 'List', 'Sequence', 'NamedTuple', 'Iterable', 'Iterator', 'Callable', 'AnyStr', 'overload', 'Type'})  | 
 | 21 | +IMPORTS_TYPES = frozenset({'TracebackType'})  | 
 | 22 | +CPY_TYPING = frozenset({'ReadableBuffer', 'WriteableBuffer', 'AudioSample', 'FrameBuffer'})  | 
 | 23 | + | 
 | 24 | + | 
 | 25 | +def is_typed(node, allow_any=False):  | 
 | 26 | +    if node is None:  | 
 | 27 | +        return False  | 
 | 28 | +    if allow_any:  | 
 | 29 | +        return True  | 
 | 30 | +    elif isinstance(node, ast.Name) and node.id == "Any":  | 
 | 31 | +        return False  | 
 | 32 | +    elif isinstance(node, ast.Attribute) and type(node.value) == ast.Name \  | 
 | 33 | +            and node.value.id == "typing" and node.attr == "Any":  | 
 | 34 | +        return False  | 
 | 35 | +    return True  | 
 | 36 | + | 
 | 37 | + | 
 | 38 | +def find_stub_issues(tree):  | 
 | 39 | +    for node in ast.walk(tree):  | 
 | 40 | +        if isinstance(node, ast.AnnAssign):  | 
 | 41 | +            if not is_typed(node.annotation):  | 
 | 42 | +                yield ("WARN", f"Missing attribute type on line {node.lineno}")  | 
 | 43 | +            if isinstance(node.value, ast.Constant) and node.value.value == Ellipsis:  | 
 | 44 | +                yield ("WARN", f"Unnecessary Ellipsis assignment (= ...) on line {node.lineno}.")  | 
 | 45 | +        elif isinstance(node, ast.Assign):  | 
 | 46 | +            if isinstance(node.value, ast.Constant) and node.value.value == Ellipsis:  | 
 | 47 | +                yield ("WARN", f"Unnecessary Ellipsis assignment (= ...) on line {node.lineno}.")  | 
 | 48 | +        elif isinstance(node, ast.arguments):  | 
 | 49 | +            allargs = list(node.args + node.kwonlyargs)  | 
 | 50 | +            if sys.version_info >= (3, 8):  | 
 | 51 | +                allargs.extend(node.posonlyargs)  | 
 | 52 | +            for arg_node in allargs:  | 
 | 53 | +                if not is_typed(arg_node.annotation) and (arg_node.arg != "self" and arg_node.arg != "cls"):  | 
 | 54 | +                    yield ("WARN", f"Missing argument type: {arg_node.arg} on line {arg_node.lineno}")  | 
 | 55 | +            if node.vararg and not is_typed(node.vararg.annotation, allow_any=True):  | 
 | 56 | +                yield ("WARN", f"Missing argument type: *{node.vararg.arg} on line {node.vararg.lineno}")  | 
 | 57 | +            if node.kwarg and not is_typed(node.kwarg.annotation, allow_any=True):  | 
 | 58 | +                yield ("WARN", f"Missing argument type: **{node.kwarg.arg} on line {node.kwarg.lineno}")  | 
 | 59 | +        elif isinstance(node, ast.FunctionDef):  | 
 | 60 | +            if not is_typed(node.returns):  | 
 | 61 | +                yield ("WARN", f"Missing return type: {node.name} on line {node.lineno}")  | 
 | 62 | + | 
 | 63 | + | 
 | 64 | +def extract_imports(tree):  | 
 | 65 | +    modules = set()  | 
 | 66 | +    typing = set()  | 
 | 67 | +    types = set()  | 
 | 68 | +    cpy_typing = set()  | 
 | 69 | + | 
 | 70 | +    def collect_annotations(anno_tree):  | 
 | 71 | +        if anno_tree is None:  | 
 | 72 | +            return  | 
 | 73 | +        for node in ast.walk(anno_tree):  | 
 | 74 | +            if isinstance(node, ast.Name):  | 
 | 75 | +                if node.id in IMPORTS_IGNORE:  | 
 | 76 | +                    continue  | 
 | 77 | +                elif node.id in IMPORTS_TYPING:  | 
 | 78 | +                    typing.add(node.id)  | 
 | 79 | +                elif node.id in IMPORTS_TYPES:  | 
 | 80 | +                    types.add(node.id)  | 
 | 81 | +                elif node.id in CPY_TYPING:  | 
 | 82 | +                    cpy_typing.add(node.id)  | 
 | 83 | +            elif isinstance(node, ast.Attribute):  | 
 | 84 | +                if isinstance(node.value, ast.Name):  | 
 | 85 | +                    modules.add(node.value.id)  | 
 | 86 | + | 
 | 87 | +    for node in ast.walk(tree):  | 
 | 88 | +        if isinstance(node, (ast.AnnAssign, ast.arg)):  | 
 | 89 | +            collect_annotations(node.annotation)  | 
 | 90 | +        elif isinstance(node, ast.Assign):  | 
 | 91 | +            collect_annotations(node.value)  | 
 | 92 | +        elif isinstance(node, ast.FunctionDef):  | 
 | 93 | +            collect_annotations(node.returns)  | 
 | 94 | +            for deco in node.decorator_list:  | 
 | 95 | +                if isinstance(deco, ast.Name) and (deco.id in IMPORTS_TYPING):  | 
 | 96 | +                    typing.add(deco.id)  | 
 | 97 | + | 
 | 98 | +    return {  | 
 | 99 | +        "modules": sorted(modules),  | 
 | 100 | +        "typing": sorted(typing),  | 
 | 101 | +        "types": sorted(types),  | 
 | 102 | +        "cpy_typing": sorted(cpy_typing),  | 
 | 103 | +    }  | 
 | 104 | + | 
 | 105 | + | 
 | 106 | +def find_references(tree):  | 
 | 107 | +    for node in ast.walk(tree):  | 
 | 108 | +        if isinstance(node, ast.arguments):  | 
 | 109 | +            for node in ast.walk(node):  | 
 | 110 | +                if isinstance(node, ast.Attribute):  | 
 | 111 | +                    if isinstance(node.value, ast.Name) and node.value.id[0].isupper():  | 
 | 112 | +                        yield node.value.id  | 
 | 113 | + | 
12 | 114 | 
 
  | 
13 | 115 | def convert_folder(top_level, stub_directory):  | 
14 | 116 |     ok = 0  | 
15 | 117 |     total = 0  | 
16 | 118 |     filenames = sorted(os.listdir(top_level))  | 
17 |  | -    pyi_lines = []  | 
 | 119 | +    stub_fragments = []  | 
 | 120 | +    references = set()  | 
 | 121 | + | 
18 | 122 |     for filename in filenames:  | 
19 | 123 |         full_path = os.path.join(top_level, filename)  | 
20 | 124 |         file_lines = []  | 
21 | 125 |         if os.path.isdir(full_path):  | 
22 |  | -            mok, mtotal = convert_folder(full_path, os.path.join(stub_directory, filename))  | 
 | 126 | +            (mok, mtotal) = convert_folder(full_path, os.path.join(stub_directory, filename))  | 
23 | 127 |             ok += mok  | 
24 | 128 |             total += mtotal  | 
25 | 129 |         elif filename.endswith(".c"):  | 
26 |  | -            with open(full_path, "r") as f:  | 
 | 130 | +            with open(full_path, "r", encoding="utf-8") as f:  | 
27 | 131 |                 for line in f:  | 
 | 132 | +                    line = line.rstrip()  | 
28 | 133 |                     if line.startswith("//|"):  | 
29 |  | -                        if line[3] == " ":  | 
 | 134 | +                        if len(line) == 3:  | 
 | 135 | +                            line = ""  | 
 | 136 | +                        elif line[3] == " ":  | 
30 | 137 |                             line = line[4:]  | 
31 |  | -                        elif line[3] == "\n":  | 
32 |  | -                            line = line[3:]  | 
33 | 138 |                         else:  | 
34 |  | -                            continue  | 
 | 139 | +                            line = line[3:]  | 
 | 140 | +                            print("[WARN] There must be at least one space after '//|'")  | 
35 | 141 |                         file_lines.append(line)  | 
36 | 142 |         elif filename.endswith(".pyi"):  | 
37 | 143 |             with open(full_path, "r") as f:  | 
38 |  | -                file_lines.extend(f.readlines())  | 
 | 144 | +                file_lines.extend(line.rstrip() for line in f)  | 
 | 145 | + | 
 | 146 | +        fragment = "\n".join(file_lines).strip()  | 
 | 147 | +        try:  | 
 | 148 | +            tree = ast.parse(fragment)  | 
 | 149 | +        except SyntaxError as e:  | 
 | 150 | +            print(f"[ERROR] Failed to parse a Python stub from {full_path}")  | 
 | 151 | +            traceback.print_exception(type(e), e, e.__traceback__)  | 
 | 152 | +            return (ok, total + 1)  | 
 | 153 | +        references.update(find_references(tree))  | 
39 | 154 | 
 
  | 
40 |  | -        # Always put the contents from an __init__ first.  | 
41 |  | -        if filename.startswith("__init__."):  | 
42 |  | -            pyi_lines = file_lines + pyi_lines  | 
43 |  | -        else:  | 
44 |  | -            pyi_lines.extend(file_lines)  | 
 | 155 | +        if fragment:  | 
 | 156 | +            name = os.path.splitext(os.path.basename(filename))[0]  | 
 | 157 | +            if name == "__init__" or (name in references):  | 
 | 158 | +                stub_fragments.insert(0, fragment)  | 
 | 159 | +            else:  | 
 | 160 | +                stub_fragments.append(fragment)  | 
45 | 161 | 
 
  | 
46 |  | -    if not pyi_lines:  | 
47 |  | -        return ok, total  | 
 | 162 | +    if not stub_fragments:  | 
 | 163 | +        return (ok, total)  | 
48 | 164 | 
 
  | 
49 | 165 |     stub_filename = os.path.join(stub_directory, "__init__.pyi")  | 
50 | 166 |     print(stub_filename)  | 
51 |  | -    stub_contents = "".join(pyi_lines)  | 
 | 167 | +    stub_contents = "\n\n".join(stub_fragments)  | 
 | 168 | + | 
 | 169 | +    # Validate the stub code.  | 
 | 170 | +    try:  | 
 | 171 | +        tree = ast.parse(stub_contents)  | 
 | 172 | +    except SyntaxError as e:  | 
 | 173 | +        traceback.print_exception(type(e), e, e.__traceback__)  | 
 | 174 | +        return (ok, total)  | 
 | 175 | + | 
 | 176 | +    error = False  | 
 | 177 | +    for (level, msg) in find_stub_issues(tree):  | 
 | 178 | +        if level == "ERROR":  | 
 | 179 | +            error = True  | 
 | 180 | +        print(f"[{level}] {msg}")  | 
 | 181 | + | 
 | 182 | +    total += 1  | 
 | 183 | +    if not error:  | 
 | 184 | +        ok += 1  | 
 | 185 | + | 
 | 186 | +    # Add import statements  | 
 | 187 | +    imports = extract_imports(tree)  | 
 | 188 | +    import_lines = ["from __future__ import annotations"]  | 
 | 189 | +    if imports["types"]:  | 
 | 190 | +        import_lines.append("from types import " + ", ".join(imports["types"]))  | 
 | 191 | +    if imports["typing"]:  | 
 | 192 | +        import_lines.append("from typing import " + ", ".join(imports["typing"]))  | 
 | 193 | +    if imports["cpy_typing"]:  | 
 | 194 | +        import_lines.append("from _typing import " + ", ".join(imports["cpy_typing"]))  | 
 | 195 | +    import_lines.extend(f"import {m}" for m in imports["modules"])  | 
 | 196 | +    import_body = "\n".join(import_lines)  | 
 | 197 | +    m = re.match(r'(\s*""".*?""")', stub_contents, flags=re.DOTALL)  | 
 | 198 | +    if m:  | 
 | 199 | +        stub_contents = m.group(1) + "\n\n" + import_body + "\n\n" + stub_contents[m.end():]  | 
 | 200 | +    else:  | 
 | 201 | +        stub_contents = import_body + "\n\n" + stub_contents  | 
 | 202 | + | 
 | 203 | +    # Code formatting  | 
 | 204 | +    stub_contents = isort.code(stub_contents)  | 
 | 205 | +    stub_contents = black.format_str(stub_contents, mode=black.FileMode(is_pyi=True))  | 
 | 206 | + | 
52 | 207 |     os.makedirs(stub_directory, exist_ok=True)  | 
53 | 208 |     with open(stub_filename, "w") as f:  | 
54 | 209 |         f.write(stub_contents)  | 
55 | 210 | 
 
  | 
56 |  | -    # Validate that the module is a parseable stub.  | 
57 |  | -    total += 1  | 
58 |  | -    try:  | 
59 |  | -        tree = astroid.parse(stub_contents)  | 
60 |  | -        for i in tree.body:  | 
61 |  | -            if 'name' in i.__dict__:  | 
62 |  | -                print(i.__dict__['name'])  | 
63 |  | -                for j in i.body:  | 
64 |  | -                    if isinstance(j, astroid.scoped_nodes.FunctionDef):  | 
65 |  | -                        if None in j.args.__dict__['annotations']:  | 
66 |  | -                            print(f"Missing parameter type: {j.__dict__['name']} on line {j.__dict__['lineno']}\n")  | 
67 |  | -                        if j.returns:  | 
68 |  | -                            if 'Any' in j.returns.__dict__.values():  | 
69 |  | -                                print(f"Missing return type: {j.__dict__['name']} on line {j.__dict__['lineno']}")  | 
70 |  | -                    elif isinstance(j, astroid.node_classes.AnnAssign):  | 
71 |  | -                        if 'name' in j.__dict__['annotation'].__dict__:  | 
72 |  | -                            if j.__dict__['annotation'].__dict__['name'] == 'Any':  | 
73 |  | -                                print(f"missing attribute type on line {j.__dict__['lineno']}")  | 
 | 211 | +    return (ok, total)  | 
74 | 212 | 
 
  | 
75 |  | -        ok += 1  | 
76 |  | -    except astroid.exceptions.AstroidSyntaxError as e:  | 
77 |  | -        e = e.__cause__  | 
78 |  | -        traceback.print_exception(type(e), e, e.__traceback__)  | 
79 |  | -    print()  | 
80 |  | -    return ok, total  | 
81 | 213 | 
 
  | 
82 |  | -ok, total = convert_folder(top_level, stub_directory)  | 
 | 214 | +if __name__ == "__main__":  | 
 | 215 | +    top_level = sys.argv[1].strip("/")  | 
 | 216 | +    stub_directory = sys.argv[2]  | 
 | 217 | + | 
 | 218 | +    (ok, total) = convert_folder(top_level, stub_directory)  | 
83 | 219 | 
 
  | 
84 |  | -print(f"{ok} ok out of {total}")  | 
 | 220 | +    print(f"Parsing .pyi files: {total - ok} failed, {ok} passed")  | 
85 | 221 | 
 
  | 
86 |  | -if ok != total:  | 
87 |  | -    sys.exit(total - ok)  | 
 | 222 | +    if ok != total:  | 
 | 223 | +        sys.exit(total - ok)  | 
0 commit comments