Skip to content
Prev Previous commit
Next Next commit
Inject sub-modules when importing with importlib
  • Loading branch information
nicoddemus committed Apr 2, 2021
commit d67bb82e2b595711f57723d35d3b0e0210ba973b
22 changes: 22 additions & 0 deletions src/_pytest/pathlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from posixpath import sep as posix_sep
from types import ModuleType
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import Optional
Expand Down Expand Up @@ -508,6 +509,7 @@ def import_path(
mod = importlib.util.module_from_spec(spec)
sys.modules[module_name] = mod
spec.loader.exec_module(mod) # type: ignore[union-attr]
insert_missing_modules(sys.modules, module_name)
return mod

pkg_path = resolve_package_path(path)
Expand Down Expand Up @@ -593,6 +595,26 @@ def module_name_from_path(path: Path, root: Path) -> str:
return ".".join(path_parts)


def insert_missing_modules(modules: Dict[str, ModuleType], module_name: str) -> None:
"""
Used by ``import_path`` to create intermediate modules when using mode=importlib.

When we want to import a module as "src.tests.test_foo" for example, we need
to create empty modules "src" and "src.tests" after inserting "src.tests.test_foo",
otherwise "src.tests.test_foo" is not importable by ``__import__``.
"""
module_parts = module_name.split(".")
while module_name:
if module_name not in modules:
module = ModuleType(
module_name,
doc="Empty module created by pytest's importmode=importlib.",
)
modules[module_name] = module
module_parts.pop(-1)
module_name = ".".join(module_parts)


def resolve_package_path(path: Path) -> Optional[Path]:
"""Return the Python package path by looking for the last
directory upwards which still contains an __init__.py.
Expand Down
193 changes: 106 additions & 87 deletions testing/test_pathlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from _pytest.pathlib import get_lock_path
from _pytest.pathlib import import_path
from _pytest.pathlib import ImportPathMismatchError
from _pytest.pathlib import insert_missing_modules
from _pytest.pathlib import maybe_delete_a_numbered_dir
from _pytest.pathlib import module_name_from_path
from _pytest.pathlib import resolve_package_path
Expand Down Expand Up @@ -285,6 +286,10 @@ def test_importmode_importlib(self, simple_module: Path, tmp_path: Path) -> None
module = import_path(simple_module, mode="importlib", root=tmp_path)
assert module.foo(2) == 42 # type: ignore[attr-defined]
assert str(simple_module.parent) not in sys.path
assert module.__name__ in sys.modules
assert module.__name__ == "_src.tests.mymod"
assert "_src" in sys.modules
assert "_src.tests" in sys.modules

def test_importmode_twice_is_different_module(self, simple_module: Path, tmp_path: Path) -> None:
"""`importlib` mode always returns a new module."""
Expand Down Expand Up @@ -447,112 +452,126 @@ def test_samefile_false_negatives(tmp_path: Path, monkeypatch: MonkeyPatch) -> N
assert getattr(module, "foo")() == 42


@pytest.mark.skipif(sys.version_info < (3, 7), reason="Dataclasses in Python3.7+")
def test_importmode_importlib_with_dataclass(tmp_path: Path) -> None:
"""Ensure that importlib mode works with a module containing dataclasses (#7856)."""
fn = tmp_path.joinpath("src/tests/test_dataclass.py")
fn.parent.mkdir(parents=True)
fn.write_text(
dedent(
"""
from dataclasses import dataclass
class TestImportLibMode:
@pytest.mark.skipif(sys.version_info < (3, 7), reason="Dataclasses in Python3.7+")
def test_importmode_importlib_with_dataclass(self, tmp_path: Path) -> None:
"""Ensure that importlib mode works with a module containing dataclasses (#7856)."""
fn = tmp_path.joinpath("_src/tests/test_dataclass.py")
fn.parent.mkdir(parents=True)
fn.write_text(
dedent(
"""
from dataclasses import dataclass

@dataclass
class Data:
value: str
"""
@dataclass
class Data:
value: str
"""
)
)
)

module = import_path(fn, mode="importlib", root=tmp_path)
Data: Any = getattr(module, "Data")
data = Data(value="foo")
assert data.value == "foo"
assert data.__module__ == "src.tests.test_dataclass"
module = import_path(fn, mode="importlib", root=tmp_path)
Data: Any = getattr(module, "Data")
data = Data(value="foo")
assert data.value == "foo"
assert data.__module__ == "_src.tests.test_dataclass"

def test_importmode_importlib_with_pickle(self, tmp_path: Path) -> None:
"""Ensure that importlib mode works with pickle (#7859)."""
fn = tmp_path.joinpath("_src/tests/test_pickle.py")
fn.parent.mkdir(parents=True)
fn.write_text(
dedent(
"""
import pickle

def test_importmode_importlib_with_pickle(tmp_path: Path) -> None:
"""Ensure that importlib mode works with pickle (#7859)."""
fn = tmp_path.joinpath("src/tests/test_pickle.py")
fn.parent.mkdir(parents=True)
fn.write_text(
dedent(
"""
import pickle

def _action():
return 42
def _action():
return 42

def round_trip():
s = pickle.dumps(_action)
return pickle.loads(s)
"""
def round_trip():
s = pickle.dumps(_action)
return pickle.loads(s)
"""
)
)
)

module = import_path(fn, mode="importlib", root=tmp_path)
round_trip = getattr(module, "round_trip")
action = round_trip()
assert action() == 42

module = import_path(fn, mode="importlib", root=tmp_path)
round_trip = getattr(module, "round_trip")
action = round_trip()
assert action() == 42

def test_importmode_importlib_with_pickle_separate_modules(tmp_path: Path) -> None:
"""
Ensure that importlib mode works can load pickles that look similar but are
defined in separate modules.
"""
fn1 = tmp_path.joinpath("src/m1/tests/test.py")
fn1.parent.mkdir(parents=True)
fn1.write_text(
dedent(
"""
import attr
import pickle
def test_importmode_importlib_with_pickle_separate_modules(
self, tmp_path: Path
) -> None:
"""
Ensure that importlib mode works can load pickles that look similar but are
defined in separate modules.
"""
fn1 = tmp_path.joinpath("_src/m1/tests/test.py")
fn1.parent.mkdir(parents=True)
fn1.write_text(
dedent(
"""
import attr
import pickle

@attr.s(auto_attribs=True)
class Data:
x: int = 42
"""
@attr.s(auto_attribs=True)
class Data:
x: int = 42
"""
)
)
)

fn2 = tmp_path.joinpath("src/m2/tests/test.py")
fn2.parent.mkdir(parents=True)
fn2.write_text(
dedent(
"""
import attr
import pickle
fn2 = tmp_path.joinpath("_src/m2/tests/test.py")
fn2.parent.mkdir(parents=True)
fn2.write_text(
dedent(
"""
import attr
import pickle

@attr.s(auto_attribs=True)
class Data:
x: str = ""
"""
@attr.s(auto_attribs=True)
class Data:
x: str = ""
"""
)
)
)

import pickle
import pickle

def round_trip(obj):
s = pickle.dumps(obj)
return pickle.loads(s)

module = import_path(fn1, mode="importlib", root=tmp_path)
Data1 = getattr(module, "Data")

def round_trip(obj):
s = pickle.dumps(obj)
return pickle.loads(s)
module = import_path(fn2, mode="importlib", root=tmp_path)
Data2 = getattr(module, "Data")

module = import_path(fn1, mode="importlib", root=tmp_path)
Data1 = getattr(module, "Data")
assert round_trip(Data1(20)) == Data1(20)
assert round_trip(Data2("hello")) == Data2("hello")
assert Data1.__module__ == "_src.m1.tests.test"
assert Data2.__module__ == "_src.m2.tests.test"

module = import_path(fn2, mode="importlib", root=tmp_path)
Data2 = getattr(module, "Data")
def test_module_name_from_path(self, tmp_path: Path) -> None:
result = module_name_from_path(tmp_path / "src/tests/test_foo.py", tmp_path)
assert result == "src.tests.test_foo"

assert round_trip(Data1(20)) == Data1(20)
assert round_trip(Data2("hello")) == Data2("hello")
assert Data1.__module__ == "src.m1.tests.test"
assert Data2.__module__ == "src.m2.tests.test"
# Path is not relative to root dir: use the full path to obtain the module name.
result = module_name_from_path(Path("/home/foo/test_foo.py"), Path("/bar"))
assert result == "home.foo.test_foo"

def test_insert_missing_modules(self) -> None:
modules = {"src.tests.foo": ModuleType("src.tests.foo")}
insert_missing_modules(modules, "src.tests.foo")
assert sorted(modules) == ["src", "src.tests", "src.tests.foo"]

def test_module_name_from_path(tmp_path: Path) -> None:
result = module_name_from_path(tmp_path / "src/tests/test_foo.py", tmp_path)
assert result == "src.tests.test_foo"
mod = ModuleType("mod", doc="My Module")
modules = {"src": mod}
insert_missing_modules(modules, "src")
assert modules == {"src": mod}

# Path is not relative to root dir: use the full path to obtain the module name.
result = module_name_from_path(Path("/home/foo/test_foo.py"), Path("/bar"))
assert result == "home.foo.test_foo"
modules = {}
insert_missing_modules(modules, "")
assert modules == {}