Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
feat(ingest): support dynamic imports for transfomer methods
  • Loading branch information
hsheth2 committed Jul 9, 2021
commit cec31f105751af182d4cc86f575db461634ae54b
13 changes: 13 additions & 0 deletions metadata-ingestion/src/datahub/configuration/import_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pydantic

from datahub.ingestion.api.registry import import_key


def _pydantic_resolver(v):
if isinstance(v, str):
return import_key(v)
return v


def pydantic_resolve_key(field):
return pydantic.validator(field, pre=True, allow_reuse=True)(_pydantic_resolver)
12 changes: 9 additions & 3 deletions metadata-ingestion/src/datahub/ingestion/api/registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import importlib
import inspect
from typing import Dict, Generic, Type, TypeVar, Union
from typing import Any, Dict, Generic, Type, TypeVar, Union

import entrypoints
import typing_inspect
Expand All @@ -11,6 +11,13 @@
T = TypeVar("T")


def import_key(key: str) -> Any:
assert "." in key, "import key must contain a ."
module_name, item_name = key.rsplit(".", 1)
item = getattr(importlib.import_module(module_name), item_name)
return item


class Registry(Generic[T]):
def __init__(self):
self._mapping: Dict[str, Union[Type[T], Exception]] = {}
Expand Down Expand Up @@ -68,8 +75,7 @@ def get(self, key: str) -> Type[T]:
if key.find(".") >= 0:
# If the key contains a dot, we treat it as a import path and attempt
# to load it dynamically.
module_name, class_name = key.rsplit(".", 1)
MyClass = getattr(importlib.import_module(module_name), class_name)
MyClass = import_key(key)
self._check_cls(MyClass)
return MyClass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datahub.emitter.mce_builder as builder
from datahub.configuration.common import ConfigModel
from datahub.configuration.import_resolver import pydantic_resolve_key
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope
from datahub.ingestion.api.transform import Transformer
from datahub.metadata.schema_classes import (
Expand All @@ -22,6 +23,8 @@ class AddDatasetOwnershipConfig(ConfigModel):
]
default_actor: str = builder.make_user_urn("etl")

_resolve_owner_fn = pydantic_resolve_key("get_owners_to_add")


class AddDatasetOwnership(Transformer):
"""Transformer that adds owners to datasets according to a callback function."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datahub.emitter.mce_builder as builder
from datahub.configuration.common import ConfigModel
from datahub.configuration.import_resolver import pydantic_resolve_key
from datahub.ingestion.api.common import PipelineContext, RecordEnvelope
from datahub.ingestion.api.transform import Transformer
from datahub.metadata.schema_classes import (
Expand All @@ -20,6 +21,8 @@ class AddDatasetTagsConfig(ConfigModel):
Callable[[DatasetSnapshotClass], List[TagAssociationClass]],
]

_resolve_tag_fn = pydantic_resolve_key("get_tags_to_add")


class AddDatasetTags(Transformer):
"""Transformer that adds tags to datasets according to a callback function."""
Expand Down
24 changes: 23 additions & 1 deletion metadata-ingestion/tests/unit/test_transform_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from datahub.ingestion.transformer.add_dataset_ownership import (
SimpleAddDatasetOwnership,
)
from datahub.ingestion.transformer.add_dataset_tags import SimpleAddDatasetTags
from datahub.ingestion.transformer.add_dataset_tags import (
AddDatasetTags,
SimpleAddDatasetTags,
)


def make_generic_dataset():
Expand Down Expand Up @@ -120,3 +123,22 @@ def test_simple_dataset_tags_transformation(mock_time):
assert tags_aspect
assert len(tags_aspect.tags) == 2
assert tags_aspect.tags[0].tag == builder.make_tag_urn("NeedsDocumentation")


def dummy_tag_resolver_method(dataset_snapshot):
return []


def test_import_resolver():
transformer = AddDatasetTags.create(
{
"get_tags_to_add": "tests.unit.test_transform_dataset.dummy_tag_resolver_method"
},
PipelineContext(run_id="test-tags"),
)
output = list(
transformer.transform(
[RecordEnvelope(input, metadata={}) for input in [make_generic_dataset()]]
)
)
assert output