Skip to content
This repository was archived by the owner on Sep 2, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import re
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from mashumaro.helper import pass_through

from functools import lru_cache
import agate
from requests.exceptions import ConnectionError
Expand Down Expand Up @@ -35,7 +37,7 @@
from dbt.events.types import SQLQuery
from dbt.version import __version__ as dbt_version

from dbt.dataclass_schema import StrEnum
from dbt.dataclass_schema import ExtensibleDbtClassMixin, StrEnum

logger = AdapterLogger("BigQuery")

Expand Down Expand Up @@ -92,6 +94,12 @@ class BigQueryAdapterResponse(AdapterResponse):
slot_ms: Optional[int] = None


@dataclass
class DataprocBatchConfig(ExtensibleDbtClassMixin):
def __init__(self, batch_config):
self.batch_config = batch_config


@dataclass
class BigQueryCredentials(Credentials):
method: BigQueryConnectionMethod
Expand Down Expand Up @@ -124,6 +132,13 @@ class BigQueryCredentials(Credentials):
dataproc_cluster_name: Optional[str] = None
gcs_bucket: Optional[str] = None

dataproc_batch: Optional[DataprocBatchConfig] = field(
metadata={
"serialization_strategy": pass_through,
},
default=None,
)

scopes: Optional[Tuple[str, ...]] = (
"https://www.googleapis.com/auth/bigquery",
"https://www.googleapis.com/auth/cloud-platform",
Expand Down
53 changes: 37 additions & 16 deletions dbt/adapters/bigquery/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from google.api_core import retry
from google.api_core.client_options import ClientOptions
from google.cloud import storage, dataproc_v1 # type: ignore
from google.protobuf.json_format import ParseDict

OPERATION_RETRY_TIME = 10

Expand Down Expand Up @@ -120,23 +121,9 @@ def _get_job_client(self) -> dataproc_v1.BatchControllerClient:
)

def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job:
# create the Dataproc Serverless job config
batch = dataproc_v1.Batch()
batch.pyspark_batch.main_python_file_uri = self.gcs_location
# how to keep this up to date?
# we should probably also open this up to be configurable
jar_file_uri = self.parsed_model["config"].get(
"jar_file_uri",
"gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.12-0.21.1.jar",
)
batch.pyspark_batch.jar_file_uris = [jar_file_uri]
# should we make all of these spark/dataproc properties configurable?
# https://cloud.google.com/dataproc-serverless/docs/concepts/properties
# https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#runtimeconfig
batch.runtime_config.properties = {
"spark.executor.instances": "2",
}
batch = self._configure_batch()
parent = f"projects/{self.credential.execution_project}/locations/{self.credential.dataproc_region}"

request = dataproc_v1.CreateBatchRequest(
parent=parent,
batch=batch,
Expand All @@ -156,3 +143,37 @@ def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job:
# .blob(f"{matches.group(2)}.000000000")
# .download_as_string()
# )

def _configure_batch(self):
# create the Dataproc Serverless job config
batch = dataproc_v1.Batch()

# Apply defaults
batch.pyspark_batch.main_python_file_uri = self.gcs_location
jar_file_uri = self.parsed_model["config"].get(
"jar_file_uri",
"gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.12-0.21.1.jar",
)
batch.pyspark_batch.jar_file_uris = [jar_file_uri]

# https://cloud.google.com/dataproc-serverless/docs/concepts/properties
# https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#runtimeconfig
batch.runtime_config.properties = {
"spark.executor.instances": "2",
}

# Apply configuration from dataproc_batch key, possibly overriding defaults.
if self.credential.dataproc_batch:
try:
self._configure_batch_from_config(self.credential.dataproc_batch, batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives me

Failed to parse runtime_config field: Failed to parse properties field: expected string or bytes-like object..

when I try to add

        labels:
          role: "dev"
        runtime_config:
          properties:
            spark.executor.instances: 3
            spark.driver.memory: "1g"

to my bigquery profile.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably changed behaviour somewhat and broke my example, when i replaced my custom parsing code with ParseDict from the protobuff library. I suspect it requires all values to be strings. I'll run some test later tonight.

except Exception as e:
docurl = "https://cloud.google.com/dataproc-serverless/docs/reference/rpc/google.cloud.dataproc.v1#google.cloud.dataproc.v1.Batch"
raise ValueError(
f"Unable to parse dataproc_batch as valid batch specification. See {docurl}. {str(e)}"
) from e

return batch

@classmethod
def _configure_batch_from_config(cls, config_dict, target):
ParseDict(config_dict, target._pb)
56 changes: 56 additions & 0 deletions tests/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,43 @@ def setUp(self):
'threads': 1,
'location': 'Solar Station',
},
'dataproc-serverless-configured' : {
'type': 'bigquery',
'method': 'oauth',
'schema': 'dummy_schema',
'threads': 1,
'gcs_bucket': 'dummy-bucket',
'dataproc_region': 'europe-west1',
'submission_method': 'serverless',
'dataproc_batch': {
'environment_config' : {
'execution_config' : {
'service_account': 'dbt@dummy-project.iam.gserviceaccount.com',
'subnetwork_uri': 'dataproc',
'network_tags': [ "foo", "bar" ]
}
},
'labels': {
'dbt': 'rocks',
'number': '1'
},
'runtime_config': {
'properties': {
'spark.executor.instances': '4',
'spark.driver.memory': '1g'
}
}
}
},
'dataproc-serverless-default' : {
'type': 'bigquery',
'method': 'oauth',
'schema': 'dummy_schema',
'threads': 1,
'gcs_bucket': 'dummy-bucket',
'dataproc_region': 'europe-west1',
'submission_method': 'serverless'
}
},
'target': 'oauth',
}
Expand Down Expand Up @@ -183,6 +220,25 @@ def test_acquire_connection_oauth_validations(self, mock_open_connection):
connection.handle
mock_open_connection.assert_called_once()

@patch('dbt.adapters.bigquery.connections.get_bigquery_defaults', return_value=('credentials', 'project_id'))
@patch('dbt.adapters.bigquery.BigQueryConnectionManager.open', return_value=_bq_conn())
def test_acquire_connection_dataproc_serverless(self, mock_open_connection, mock_get_bigquery_defaults):
adapter = self.get_adapter('dataproc-serverless-configured')
mock_get_bigquery_defaults.assert_called_once()
try:
connection = adapter.acquire_connection('dummy')
self.assertEqual(connection.type, 'bigquery')

except dbt.exceptions.ValidationException as e:
self.fail('got ValidationException: {}'.format(str(e)))

except BaseException as e:
raise

mock_open_connection.assert_not_called()
connection.handle
mock_open_connection.assert_called_once()

@patch('dbt.adapters.bigquery.BigQueryConnectionManager.open', return_value=_bq_conn())
def test_acquire_connection_service_account_validations(self, mock_open_connection):
adapter = self.get_adapter('service_account')
Expand Down
57 changes: 57 additions & 0 deletions tests/unit/test_configure_dataproc_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from unittest.mock import patch

from dbt.adapters.bigquery.python_submissions import ServerlessDataProcHelper
from google.cloud import dataproc_v1

from .test_bigquery_adapter import BaseTestBigQueryAdapter

# Test application of dataproc_batch configuration to a
# google.cloud.dataproc_v1.Batch object.
# This reuses the machinery from BaseTestBigQueryAdapter to get hold of the
# parsed credentials
class TestConfigureDataprocBatch(BaseTestBigQueryAdapter):

@patch('dbt.adapters.bigquery.connections.get_bigquery_defaults', return_value=('credentials', 'project_id'))
def test_configure_dataproc_serverless_batch(self, mock_get_bigquery_defaults):
adapter = self.get_adapter('dataproc-serverless-configured')
mock_get_bigquery_defaults.assert_called_once()

credentials = adapter.acquire_connection('dummy').credentials
self.assertIsNotNone(credentials)

batchConfig = credentials.dataproc_batch
self.assertIsNotNone(batchConfig)

raw_batch_config = self.raw_profile['outputs']['dataproc-serverless-configured']['dataproc_batch']
raw_environment_config = raw_batch_config['environment_config']
raw_execution_config = raw_environment_config['execution_config']
raw_labels: dict[str, any] = raw_batch_config['labels']
raw_rt_config = raw_batch_config['runtime_config']

raw_batch_config = self.raw_profile['outputs']['dataproc-serverless-configured']['dataproc_batch']

batch = dataproc_v1.Batch()

ServerlessDataProcHelper._configure_batch_from_config(batchConfig, batch)

# google's protobuf types expose maps as dict[str, str]
to_str_values = lambda d: dict([(k, str(v)) for (k, v) in d.items()])

self.assertEqual(batch.environment_config.execution_config.service_account, raw_execution_config['service_account'])
self.assertFalse(batch.environment_config.execution_config.network_uri)
self.assertEqual(batch.environment_config.execution_config.subnetwork_uri, raw_execution_config['subnetwork_uri'])
self.assertEqual(batch.environment_config.execution_config.network_tags, raw_execution_config['network_tags'])
self.assertEqual(batch.labels, to_str_values(raw_labels))
self.assertEqual(batch.runtime_config.properties, to_str_values(raw_rt_config['properties']))


@patch('dbt.adapters.bigquery.connections.get_bigquery_defaults', return_value=('credentials', 'project_id'))
def test_default_dataproc_serverless_batch(self, mock_get_bigquery_defaults):
adapter = self.get_adapter('dataproc-serverless-default')
mock_get_bigquery_defaults.assert_called_once()

credentials = adapter.acquire_connection('dummy').credentials
self.assertIsNotNone(credentials)

batchConfig = credentials.dataproc_batch
self.assertIsNone(batchConfig)