Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions deptry/imports/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dataclasses import dataclass
from pathlib import Path

import chardet


@dataclass
class ImportExtractor(ABC):
Expand All @@ -25,3 +27,8 @@ def _extract_imports_from_ast(tree: ast.AST) -> set[str]:
imported_modules.add(node.module.split(".")[0])

return imported_modules

@staticmethod
def _get_file_encoding(file: Path) -> str:
with open(file, "rb") as f:
return chardet.detect(f.read())["encoding"]
22 changes: 16 additions & 6 deletions deptry/imports/extractors/notebook_import_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
import itertools
import json
import logging
import re
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -17,17 +18,26 @@ class NotebookImportExtractor(ImportExtractor):

def extract_imports(self) -> set[str]:
notebook = self._read_ipynb_file(self.file)
if not notebook:
return set()

cells = self._keep_code_cells(notebook)
import_statements = [self._extract_import_statements_from_cell(cell) for cell in cells]

tree = ast.parse("\n".join(itertools.chain.from_iterable(import_statements)), str(self.file))

return self._extract_imports_from_ast(tree)

@staticmethod
def _read_ipynb_file(path_to_ipynb: Path) -> dict[str, Any]:
with open(path_to_ipynb) as f:
notebook: dict[str, Any] = json.load(f)
@classmethod
def _read_ipynb_file(cls, path_to_ipynb: Path) -> dict[str, Any] | None:
try:
with open(path_to_ipynb) as ipynb_file:
notebook: dict[str, Any] = json.load(ipynb_file)
except UnicodeDecodeError:
try:
with open(path_to_ipynb, encoding=cls._get_file_encoding(path_to_ipynb)) as ipynb_file:
notebook = json.load(ipynb_file, strict=False)
except UnicodeDecodeError:
logging.warning(f"Warning: File {path_to_ipynb} could not be decoded. Skipping...")
return None
return notebook

@staticmethod
Expand Down
8 changes: 0 additions & 8 deletions deptry/imports/extractors/python_import_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import ast
import logging
from dataclasses import dataclass
from pathlib import Path

import chardet

from deptry.imports.extractors.base import ImportExtractor

Expand All @@ -27,8 +24,3 @@ def extract_imports(self) -> set[str]:
return set()

return self._extract_imports_from_ast(tree)

@staticmethod
def _get_file_encoding(file: Path) -> str:
with open(file, "rb") as f:
return chardet.detect(f.read())["encoding"]
42 changes: 42 additions & 0 deletions tests/imports/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,48 @@ def test_import_parser_file_encodings(file_content: str, encoding: str | None, t
assert get_imported_modules_from_file(Path(random_file_name)) == {"foo"}


@pytest.mark.parametrize(
("code_cell_content", "encoding"),
[
(
["import foo", "print('嘉大')"],
"utf-8",
),
(
["import foo", "print('Æ Ç')"],
"iso-8859-15",
),
(
["import foo", "print('嘉大')"],
"utf-16",
),
(
["my_string = '🐺'", "import foo"],
None,
),
],
)
def test_import_parser_file_encodings_ipynb(code_cell_content: list[str], encoding: str | None, tmp_path: Path) -> None:
random_file_name = f"file_{uuid.uuid4()}.ipynb"

with run_within_dir(tmp_path):
with open(random_file_name, "w", encoding=encoding) as f:
file_content = f"""{{
"cells": [
{{
"cell_type": "code",
"metadata": {{}},
"source": [
{", ".join([ f'"{code_line}"' for code_line in code_cell_content])}
]
}}
]}}"""
f.write(file_content)
print(file_content)

assert get_imported_modules_from_file(Path(random_file_name)) == {"foo"}


def test_import_parser_file_encodings_warning(tmp_path: Path, caplog: LogCaptureFixture) -> None:
with run_within_dir(tmp_path):
with open("file1.py", "w", encoding="utf-8") as f:
Expand Down