Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Use new context invocation class.
  • Loading branch information
peterallenwebb committed Jan 30, 2024
commit b5862efc634e9bafdd3c5fd631583d7848acaab5
4 changes: 4 additions & 0 deletions core/dbt/cli/requires.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import dbt.tracking
from dbt_common.context import set_invocation_context
from dbt_common.invocation import reset_invocation_id

from dbt.version import installed as installed_version
from dbt.adapters.factory import adapter_management
from dbt.flags import set_flags, get_flag_dict
Expand Down Expand Up @@ -45,6 +47,8 @@ def wrapper(*args, **kwargs):
assert isinstance(ctx, Context)
ctx.obj = ctx.obj or {}

set_invocation_context()

# Flags
flags = Flags(ctx)
ctx.obj["flags"] = flags
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/config/renderer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, Any, Tuple, Optional, Union, Callable
import re
import os
from datetime import date

from dbt.clients.jinja import get_rendered
Expand All @@ -11,6 +10,7 @@
from dbt.context.base import BaseContext
from dbt.adapters.contracts.connection import HasCredentials
from dbt.exceptions import DbtProjectError
from dbt_common.context import get_invocation_context
from dbt_common.exceptions import CompilationError, RecursionError
from dbt_common.utils import deep_map_render

Expand Down Expand Up @@ -212,7 +212,7 @@ def render_value(self, value: Any, keypath: Optional[Keypath] = None) -> Any:
)
if m:
found = m.group(1)
value = os.environ[found]
value = get_invocation_context().env[found]
replace_this = SECRET_PLACEHOLDER.format(found)
return rendered.replace(replace_this, value)
else:
Expand Down
8 changes: 5 additions & 3 deletions core/dbt/context/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SetStrictWrongTypeError,
ZipStrictWrongTypeError,
)
from dbt_common.context import get_invocation_context
from dbt_common.exceptions.macros import MacroReturn
from dbt_common.events.functions import fire_event, get_invocation_id
from dbt.events.types import JinjaLogInfo, JinjaLogDebug
Expand Down Expand Up @@ -303,8 +304,9 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
return_value = None
if var.startswith(SECRET_ENV_PREFIX):
raise SecretEnvVarLocationError(var)
if var in os.environ:
return_value = os.environ[var]
env = get_invocation_context().env
if var in env:
return_value = env[var]
elif default is not None:
return_value = default

Expand All @@ -313,7 +315,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
# that so we can skip partial parsing. Otherwise the file will be scheduled for
# reparsing. If the default changes, the file will have been updated and therefore
# will be scheduled for reparsing anyways.
self.env_vars[var] = return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER
self.env_vars[var] = return_value if var in env else DEFAULT_ENV_PLACEHOLDER

return return_value
else:
Expand Down
10 changes: 6 additions & 4 deletions core/dbt/context/configured.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Any, Dict, Optional

from dbt_common.context import get_invocation_context

from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
from dbt.adapters.contracts.connection import AdapterRequiredConfig
from dbt.node_types import NodeType
Expand Down Expand Up @@ -89,8 +90,9 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
return_value = None
if var.startswith(SECRET_ENV_PREFIX):
raise SecretEnvVarLocationError(var)
if var in os.environ:
return_value = os.environ[var]
env = get_invocation_context().env
if var in env:
return_value = env[var]
elif default is not None:
return_value = default

Expand All @@ -101,7 +103,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
# reparsing. If the default changes, the file will have been updated and therefore
# will be scheduled for reparsing anyways.
self.schema_yaml_vars.env_vars[var] = (
return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER
return_value if var in env else DEFAULT_ENV_PLACEHOLDER
)

return return_value
Expand Down
19 changes: 13 additions & 6 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
Iterable,
Mapping,
)

from typing_extensions import Protocol

from dbt.adapters.base.column import Column
from dbt.artifacts.resources import NodeVersion, RefArgs
from dbt_common.clients.jinja import MacroProtocol
from dbt_common.context import get_invocation_context
from dbt.adapters.factory import get_adapter, get_adapter_package_names, get_adapter_type_names
from dbt_common.clients import agate_helper
from dbt.clients.jinja import get_rendered, MacroGenerator, MacroStack, UnitTestMacroGenerator
Expand Down Expand Up @@ -1353,8 +1355,11 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
return_value = None
if var.startswith(SECRET_ENV_PREFIX):
raise SecretEnvVarLocationError(var)
if var in os.environ:
return_value = os.environ[var]

env = get_invocation_context().env

if var in env:
return_value = env[var]
elif default is not None:
return_value = default

Expand All @@ -1373,7 +1378,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
# reparsing. If the default changes, the file will have been updated and therefore
# will be scheduled for reparsing anyways.
self.manifest.env_vars[var] = (
return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER
return_value if var in env else DEFAULT_ENV_PLACEHOLDER
)

# hooks come from dbt_project.yml which doesn't have a real file_id
Expand Down Expand Up @@ -1793,8 +1798,10 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
return_value = None
if var.startswith(SECRET_ENV_PREFIX):
raise SecretEnvVarLocationError(var)
if var in os.environ:
return_value = os.environ[var]

env = get_invocation_context().env
if var in env:
return_value = env[var]
elif default is not None:
return_value = default

Expand All @@ -1806,7 +1813,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
# reparsing. If the default changes, the file will have been updated and therefore
# will be scheduled for reparsing anyways.
self.manifest.env_vars[var] = (
return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER
return_value if var in env else DEFAULT_ENV_PLACEHOLDER
)
# the "model" should only be test nodes, but just in case, check
# TODO CT-211
Expand Down
14 changes: 8 additions & 6 deletions core/dbt/context/secret.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Any, Dict, Optional

from dbt_common.context import get_invocation_context

from .base import BaseContext, contextmember

from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
Expand Down Expand Up @@ -30,24 +31,25 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:
# if this is a 'secret' env var, just return the name of the env var
# instead of rendering the actual value here, to avoid any risk of
# Jinja manipulation. it will be subbed out later, in SecretRenderer.render_value
if var in os.environ and var.startswith(SECRET_ENV_PREFIX):
env = get_invocation_context().env
if var in env and var.startswith(SECRET_ENV_PREFIX):
return SECRET_PLACEHOLDER.format(var)

elif var in os.environ:
return_value = os.environ[var]
if var in env:
return_value = env[var]
elif default is not None:
return_value = default

if return_value is not None:
# store env vars in the internal manifest to power partial parsing
# if it's a 'secret' env var, we shouldn't even get here
# but just to be safe — don't save secrets
# but just to be safe, don't save secrets
if not var.startswith(SECRET_ENV_PREFIX):
# If the environment variable is set from a default, store a string indicating
# that so we can skip partial parsing. Otherwise the file will be scheduled for
# reparsing. If the default changes, the file will have been updated and therefore
# will be scheduled for reparsing anyways.
self.env_vars[var] = return_value if var in os.environ else DEFAULT_ENV_PLACEHOLDER
self.env_vars[var] = return_value if var in env else DEFAULT_ENV_PLACEHOLDER
return return_value
else:
raise EnvVarMissingError(var)
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import json
import logging
import os
import sys
import time
import warnings
Expand All @@ -12,15 +11,16 @@
from typing import Optional, List, ContextManager, Callable, Dict, Any, Set

import logbook
from dbt.constants import SECRET_ENV_PREFIX

from dbt_common.context import get_invocation_context
from dbt_common.dataclass_schema import dbtClassMixin

STDOUT_LOG_FORMAT = "{record.message}"
DEBUG_LOG_FORMAT = "{record.time:%Y-%m-%d %H:%M:%S.%f%z} ({record.thread_name}): {record.message}"


def get_secret_env() -> List[str]:
return [v for k, v in os.environ.items() if k.startswith(SECRET_ENV_PREFIX)]
return get_invocation_context().env_secrets


ExceptionInformation = str
Expand Down
5 changes: 4 additions & 1 deletion core/dbt/parser/partial.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
from copy import deepcopy
from typing import MutableMapping, Dict, List, Callable

from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.files import (
AnySourceFile,
ParseFileType,
parse_file_type_to_parser,
SchemaSourceFile,
)
from dbt_common.context import get_invocation_context
from dbt_common.events.functions import fire_event
from dbt_common.events.base_types import EventLevel
from dbt.events.types import (
Expand Down Expand Up @@ -159,7 +161,8 @@ def build_file_diff(self):
deleted = len(deleted) + len(deleted_schema_files)
changed = len(changed) + len(changed_schema_files)
event = PartialParsingEnabled(deleted=deleted, added=len(added), changed=changed)
if os.environ.get("DBT_PP_TEST"):

if get_invocation_context().env.get("DBT_PP_TEST"):
fire_event(event, level=EventLevel.INFO)
else:
fire_event(event)
Expand Down