Skip to content

Commit e01d4c0

Browse files
authored
Add restrict-access to dbt_project.yml (#7962)
1 parent 7a6beda commit e01d4c0

File tree

9 files changed

+243
-38
lines changed

9 files changed

+243
-38
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
kind: Features
2+
body: Add restrict-access to dbt_project.yml
3+
time: 2023-06-27T13:27:49.114257-04:00
4+
custom:
5+
Author: michelleark
6+
Issue: "7713"

core/dbt/config/project.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ def create_project(self, rendered: RenderComponents) -> "Project":
497497
config_version=cfg.config_version,
498498
unrendered=unrendered,
499499
project_env_vars=project_env_vars,
500+
restrict_access=cfg.restrict_access,
500501
)
501502
# sanity check - this means an internal issue
502503
project.validate()
@@ -607,6 +608,7 @@ class Project:
607608
config_version: int
608609
unrendered: RenderComponents
609610
project_env_vars: Dict[str, Any]
611+
restrict_access: bool
610612

611613
@property
612614
def all_source_paths(self) -> List[str]:
@@ -675,6 +677,7 @@ def to_project_config(self, with_packages=False):
675677
"vars": self.vars.to_dict(),
676678
"require-dbt-version": [v.to_version_string() for v in self.dbt_version],
677679
"config-version": self.config_version,
680+
"restrict-access": self.restrict_access,
678681
}
679682
)
680683
if self.query_comment:

core/dbt/config/runtime.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def from_parts(
172172
config_version=project.config_version,
173173
unrendered=project.unrendered,
174174
project_env_vars=project.project_env_vars,
175+
restrict_access=project.restrict_access,
175176
profile_env_vars=profile.profile_env_vars,
176177
profile_name=profile.profile_name,
177178
target_name=profile.target_name,

core/dbt/context/providers.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -506,19 +506,24 @@ def resolve(
506506
target_version=target_version,
507507
disabled=isinstance(target_model, Disabled),
508508
)
509-
elif (
510-
target_model.resource_type == NodeType.Model
511-
and target_model.access == AccessType.Private
512-
# don't raise this reference error for ad hoc 'preview' queries
513-
and self.model.resource_type != NodeType.SqlOperation
514-
and self.model.resource_type != NodeType.RPCCall # TODO: rm
509+
elif self.manifest.is_invalid_private_ref(
510+
self.model, target_model, self.config.dependencies
515511
):
516-
if not self.model.group or self.model.group != target_model.group:
517-
raise DbtReferenceError(
518-
unique_id=self.model.unique_id,
519-
ref_unique_id=target_model.unique_id,
520-
group=cast_to_str(target_model.group),
521-
)
512+
raise DbtReferenceError(
513+
unique_id=self.model.unique_id,
514+
ref_unique_id=target_model.unique_id,
515+
access=AccessType.Private,
516+
scope=cast_to_str(target_model.group),
517+
)
518+
elif self.manifest.is_invalid_protected_ref(
519+
self.model, target_model, self.config.dependencies
520+
):
521+
raise DbtReferenceError(
522+
unique_id=self.model.unique_id,
523+
ref_unique_id=target_model.unique_id,
524+
access=AccessType.Protected,
525+
scope=target_model.package_name,
526+
)
522527

523528
self.validate(target_model, target_name, target_package, target_version)
524529
return self.create_relation(target_model)

core/dbt/contracts/graph/manifest.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from dbt.events.functions import fire_event
5656
from dbt.events.types import MergedFromState, UnpinnedRefNewVersionAvailable
5757
from dbt.events.contextvars import get_node_info
58-
from dbt.node_types import NodeType
58+
from dbt.node_types import NodeType, AccessType
5959
from dbt.flags import get_flags, MP_CONTEXT
6060
from dbt import tracking
6161
import dbt.utils
@@ -1123,6 +1123,50 @@ def resolve_doc(
11231123
return result
11241124
return None
11251125

1126+
def is_invalid_private_ref(
1127+
self, node: GraphMemberNode, target_model: MaybeNonSource, dependencies: Optional[Mapping]
1128+
) -> bool:
1129+
dependencies = dependencies or {}
1130+
if not isinstance(target_model, ModelNode):
1131+
return False
1132+
1133+
is_private_ref = (
1134+
target_model.access == AccessType.Private
1135+
# don't raise this reference error for ad hoc 'preview' queries
1136+
and node.resource_type != NodeType.SqlOperation
1137+
and node.resource_type != NodeType.RPCCall # TODO: rm
1138+
)
1139+
target_dependency = dependencies.get(target_model.package_name)
1140+
restrict_package_access = target_dependency.restrict_access if target_dependency else False
1141+
1142+
# TODO: SemanticModel and SourceDefinition do not have group, and so should not be able to make _any_ private ref.
1143+
return is_private_ref and (
1144+
not hasattr(node, "group")
1145+
or not node.group
1146+
or node.group != target_model.group
1147+
or restrict_package_access
1148+
)
1149+
1150+
def is_invalid_protected_ref(
1151+
self, node: GraphMemberNode, target_model: MaybeNonSource, dependencies: Optional[Mapping]
1152+
) -> bool:
1153+
dependencies = dependencies or {}
1154+
if not isinstance(target_model, ModelNode):
1155+
return False
1156+
1157+
is_protected_ref = (
1158+
target_model.access == AccessType.Protected
1159+
# don't raise this reference error for ad hoc 'preview' queries
1160+
and node.resource_type != NodeType.SqlOperation
1161+
and node.resource_type != NodeType.RPCCall # TODO: rm
1162+
)
1163+
target_dependency = dependencies.get(target_model.package_name)
1164+
restrict_package_access = target_dependency.restrict_access if target_dependency else False
1165+
1166+
return is_protected_ref and (
1167+
node.package_name != target_model.package_name and restrict_package_access
1168+
)
1169+
11261170
# Called by RunTask.defer_to_manifest
11271171
def merge_from_artifact(
11281172
self,

core/dbt/contracts/project.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ class Project(HyphenatedDbtClassMixin, Replaceable):
223223
)
224224
packages: List[PackageSpec] = field(default_factory=list)
225225
query_comment: Optional[Union[QueryComment, NoValue, str]] = field(default_factory=NoValue)
226+
restrict_access: bool = False
226227

227228
@classmethod
228229
def validate(cls, data):

core/dbt/exceptions.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from dbt.dataclass_schema import ValidationError
99
from dbt.events.helpers import env_secrets, scrub_secrets
10-
from dbt.node_types import NodeType
10+
from dbt.node_types import NodeType, AccessType
1111
from dbt.ui import line_wrap_message
1212

1313
import dbt.dataclass_schema
@@ -1219,16 +1219,18 @@ def __init__(self, exc: ValidationError, node):
12191219

12201220

12211221
class DbtReferenceError(ParsingError):
1222-
def __init__(self, unique_id: str, ref_unique_id: str, group: str):
1222+
def __init__(self, unique_id: str, ref_unique_id: str, access: AccessType, scope: str):
12231223
self.unique_id = unique_id
12241224
self.ref_unique_id = ref_unique_id
1225-
self.group = group
1225+
self.access = access
1226+
self.scope = scope
1227+
self.scope_type = "group" if self.access == AccessType.Private else "package"
12261228
super().__init__(msg=self.get_message())
12271229

12281230
def get_message(self) -> str:
12291231
return (
12301232
f"Node {self.unique_id} attempted to reference node {self.ref_unique_id}, "
1231-
f"which is not allowed because the referenced node is private to the {self.group} group."
1233+
f"which is not allowed because the referenced node is {self.access} to the '{self.scope}' {self.scope_type}."
12321234
)
12331235

12341236

core/dbt/parser/manifest.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def load(self):
508508
# determine whether they need processing.
509509
start_process = time.perf_counter()
510510
self.process_sources(self.root_project.project_name)
511-
self.process_refs(self.root_project.project_name)
511+
self.process_refs(self.root_project.project_name, self.root_project.dependencies)
512512
self.process_docs(self.root_project)
513513
self.process_metrics(self.root_project)
514514
self.check_valid_group_config()
@@ -533,7 +533,10 @@ def load(self):
533533
external_nodes_modified = self.inject_external_nodes()
534534
if external_nodes_modified:
535535
self.manifest.rebuild_ref_lookup()
536-
self.process_refs(self.root_project.project_name)
536+
self.process_refs(
537+
self.root_project.project_name,
538+
self.root_project.dependencies,
539+
)
537540
# parent and child maps will be rebuilt by write_manifest
538541

539542
if not skip_parsing:
@@ -1038,23 +1041,23 @@ def track_project_load(self):
10381041

10391042
# Takes references in 'refs' array of nodes and exposures, finds the target
10401043
# node, and updates 'depends_on.nodes' with the unique id
1041-
def process_refs(self, current_project: str):
1044+
def process_refs(self, current_project: str, dependencies: Optional[Dict[str, Project]]):
10421045
for node in self.manifest.nodes.values():
10431046
if node.created_at < self.started_at:
10441047
continue
1045-
_process_refs(self.manifest, current_project, node)
1048+
_process_refs(self.manifest, current_project, node, dependencies)
10461049
for exposure in self.manifest.exposures.values():
10471050
if exposure.created_at < self.started_at:
10481051
continue
1049-
_process_refs(self.manifest, current_project, exposure)
1052+
_process_refs(self.manifest, current_project, exposure, dependencies)
10501053
for metric in self.manifest.metrics.values():
10511054
if metric.created_at < self.started_at:
10521055
continue
1053-
_process_refs(self.manifest, current_project, metric)
1056+
_process_refs(self.manifest, current_project, metric, dependencies)
10541057
for semantic_model in self.manifest.semantic_models.values():
10551058
if semantic_model.created_at < self.started_at:
10561059
continue
1057-
_process_refs(self.manifest, current_project, semantic_model)
1060+
_process_refs(self.manifest, current_project, semantic_model, dependencies)
10581061
self.update_semantic_model(semantic_model)
10591062

10601063
# Takes references in 'metrics' array of nodes and exposures, finds the target
@@ -1372,9 +1375,13 @@ def _process_docs_for_metrics(context: Dict[str, Any], metric: Metric) -> None:
13721375
metric.description = get_rendered(metric.description, context)
13731376

13741377

1375-
def _process_refs(manifest: Manifest, current_project: str, node) -> None:
1378+
def _process_refs(
1379+
manifest: Manifest, current_project: str, node, dependencies: Optional[Mapping[str, Project]]
1380+
) -> None:
13761381
"""Given a manifest and node in that manifest, process its refs"""
13771382

1383+
dependencies = dependencies or {}
1384+
13781385
if isinstance(node, SeedNode):
13791386
return
13801387

@@ -1413,18 +1420,20 @@ def _process_refs(manifest: Manifest, current_project: str, node) -> None:
14131420
)
14141421

14151422
continue
1416-
elif (
1417-
isinstance(target_model, ModelNode)
1418-
and target_model.access == AccessType.Private
1419-
and node.resource_type != NodeType.SqlOperation
1420-
and node.resource_type != NodeType.RPCCall # TODO: rm
1421-
):
1422-
if not node.group or node.group != target_model.group:
1423-
raise dbt.exceptions.DbtReferenceError(
1424-
unique_id=node.unique_id,
1425-
ref_unique_id=target_model.unique_id,
1426-
group=dbt.utils.cast_to_str(target_model.group),
1427-
)
1423+
elif manifest.is_invalid_private_ref(node, target_model, dependencies):
1424+
raise dbt.exceptions.DbtReferenceError(
1425+
unique_id=node.unique_id,
1426+
ref_unique_id=target_model.unique_id,
1427+
access=AccessType.Private,
1428+
scope=dbt.utils.cast_to_str(target_model.group),
1429+
)
1430+
elif manifest.is_invalid_protected_ref(node, target_model, dependencies):
1431+
raise dbt.exceptions.DbtReferenceError(
1432+
unique_id=node.unique_id,
1433+
ref_unique_id=target_model.unique_id,
1434+
access=AccessType.Protected,
1435+
scope=target_model.package_name,
1436+
)
14281437

14291438
target_model_id = target_model.unique_id
14301439
node.depends_on.add_node(target_model_id)
@@ -1577,7 +1586,7 @@ def process_macro(config: RuntimeConfig, manifest: Manifest, macro: Macro) -> No
15771586
def process_node(config: RuntimeConfig, manifest: Manifest, node: ManifestNode):
15781587

15791588
_process_sources_for_node(manifest, config.project_name, node)
1580-
_process_refs(manifest, config.project_name, node)
1589+
_process_refs(manifest, config.project_name, node, config.dependencies)
15811590
ctx = generate_runtime_docs_context(config, node, manifest, config.project_name)
15821591
_process_docs_for_node(ctx, node)
15831592

0 commit comments

Comments
 (0)