diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt.py b/metadata-ingestion/src/datahub/ingestion/source/dbt.py index 368833e95323e3..97199417e45ce9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, List from datahub.configuration import ConfigModel +from datahub.configuration.common import AllowDenyPattern from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.source import Source, SourceReport from datahub.ingestion.source.metadata_common import MetadataWorkUnit @@ -38,6 +39,7 @@ class DBTConfig(ConfigModel): env: str = "PROD" target_platform: str load_schemas: bool + node_type_pattern: AllowDenyPattern = AllowDenyPattern.allow_all() class DBTColumn: @@ -91,24 +93,25 @@ def extract_dbt_entities( load_catalog: bool, target_platform: str, environment: str, + node_type_pattern: AllowDenyPattern, ) -> List[DBTNode]: dbt_entities = [] for key in nodes: node = nodes[key] dbtNode = DBTNode() - if key not in catalog and load_catalog is False: + # check if node pattern allowed based on config file + if not node_type_pattern.allowed(node["resource_type"]): continue - - if "identifier" in node and load_catalog is False: - dbtNode.name = node["identifier"] - else: - dbtNode.name = node["name"] dbtNode.dbt_name = key dbtNode.database = node["database"] dbtNode.schema = node["schema"] dbtNode.dbt_file_path = node["original_file_path"] dbtNode.node_type = node["resource_type"] + if "identifier" in node and load_catalog is False: + dbtNode.name = node["identifier"] + else: + dbtNode.name = node["name"] if "materialized" in node["config"].keys(): # It's a model @@ -154,6 +157,7 @@ def loadManifestAndCatalog( load_catalog: bool, target_platform: str, environment: str, + node_type_pattern: AllowDenyPattern, ) -> List[DBTNode]: with open(manifest_path, "r") as manifest: with open(catalog_path, "r") as catalog: @@ -176,6 +180,7 @@ def loadManifestAndCatalog( load_catalog, target_platform, environment, + node_type_pattern, ) return nodes @@ -339,6 +344,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]: self.config.load_schemas, self.config.target_platform, self.config.env, + self.config.node_type_pattern, ) for node in nodes: