Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
define snowflake catalog
  • Loading branch information
prabodh1194 committed Jun 8, 2024
commit 87d8548b3b6394d2ccbc38b9ff1b2d791bd95909
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ install-poetry:
pip install poetry==1.8.2

install-dependencies:
poetry install -E pyarrow -E hive -E s3fs -E glue -E adlfs -E duckdb -E ray -E sql-postgres -E gcsfs -E sql-sqlite -E daft
poetry install -E pyarrow -E hive -E s3fs -E glue -E adlfs -E duckdb -E ray -E sql-postgres -E gcsfs -E sql-sqlite -E daft -E snowflake

install: | install-poetry install-dependencies

Expand Down
1,348 changes: 729 additions & 619 deletions poetry.lock

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions pyiceberg/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class CatalogType(Enum):
GLUE = "glue"
DYNAMODB = "dynamodb"
SQL = "sql"
SNOWFLAKE = "snowflake"


def load_rest(name: str, conf: Properties) -> Catalog:
Expand Down Expand Up @@ -152,12 +153,22 @@ def load_sql(name: str, conf: Properties) -> Catalog:
) from exc


def load_snowflake(name: str, conf: Properties) -> Catalog:
try:
from pyiceberg.catalog.snowflake_catalog import SnowflakeCatalog

return SnowflakeCatalog(name, **conf)
except ImportError as exc:
raise NotInstalledError("Snowflake support not installed: pip install 'pyiceberg[snowflake]'") from exc


AVAILABLE_CATALOGS: dict[CatalogType, Callable[[str, Properties], Catalog]] = {
CatalogType.REST: load_rest,
CatalogType.HIVE: load_hive,
CatalogType.GLUE: load_glue,
CatalogType.DYNAMODB: load_dynamodb,
CatalogType.SQL: load_sql,
CatalogType.SNOWFLAKE: load_snowflake,
}


Expand Down
221 changes: 221 additions & 0 deletions pyiceberg/catalog/snowflake_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
from __future__ import annotations

import json
import os
from dataclasses import dataclass
from typing import Iterator, List, Optional, Set, Union

import pyarrow as pa
from boto3.session import Session
from snowflake.connector import DictCursor, SnowflakeConnection

from pyiceberg.catalog import MetastoreCatalog, PropertiesUpdateSummary
from pyiceberg.exceptions import NoSuchTableError, TableAlreadyExistsError
from pyiceberg.io import S3_ACCESS_KEY_ID, S3_REGION, S3_SECRET_ACCESS_KEY, S3_SESSION_TOKEN
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import CommitTableRequest, CommitTableResponse, StaticTable, Table, sorting
from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties


class SnowflakeCatalog(MetastoreCatalog):
@dataclass(frozen=True, eq=True)
class _SnowflakeIdentifier:
database: str | None
schema: str | None
table: str | None

def __iter__(self) -> Iterator[str]:
"""
Iterate of the non-None parts of the identifier.

Returns:
Iterator[str]: Iterator of the non-None parts of the identifier.
"""
yield from filter(None, [self.database, self.schema, self.table])

@classmethod
def table_from_string(cls, identifier: str) -> SnowflakeCatalog._SnowflakeIdentifier:
parts = identifier.split(".")
if len(parts) == 1:
return cls(None, None, parts[0])
elif len(parts) == 2:
return cls(None, parts[0], parts[1])
elif len(parts) == 3:
return cls(parts[0], parts[1], parts[2])

raise ValueError(f"Invalid identifier: {identifier}")

@classmethod
def schema_from_string(cls, identifier: str) -> SnowflakeCatalog._SnowflakeIdentifier:
parts = identifier.split(".")
if len(parts) == 1:
return cls(None, parts[0], None)
elif len(parts) == 2:
return cls(parts[0], parts[1], None)

raise ValueError(f"Invalid identifier: {identifier}")

@property
def table_name(self) -> str:
return ".".join(self)

@property
def schema_name(self) -> str:
return ".".join(self)

def __init__(self, name: str, **properties: str):
super().__init__(name, **properties)

params = {
"user": properties["user"],
"account": properties["account"],
}

if "authenticator" in properties:
params["authenticator"] = properties["authenticator"]

if "password" in properties:
params["password"] = properties["password"]

if "private_key" in properties:
params["private_key"] = properties["private_key"]

self.connection = SnowflakeConnection(**params)

def load_table(self, identifier: Union[str, Identifier]) -> Table:
sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string(
identifier if isinstance(identifier, str) else ".".join(identifier)
)

metadata_query = "SELECT SYSTEM$GET_ICEBERG_TABLE_INFORMATION(%s) AS METADATA"

with self.connection.cursor(DictCursor) as cursor:
try:
cursor.execute(metadata_query, (sf_identifier.table_name,))
metadata = json.loads(cursor.fetchone()["METADATA"])["metadataLocation"]
except Exception as e:
raise NoSuchTableError(f"Table {sf_identifier.table_name} not found") from e

session = Session()
credentials = session.get_credentials()
current_credentials = credentials.get_frozen_credentials()

s3_props = {
S3_ACCESS_KEY_ID: current_credentials.access_key,
S3_SECRET_ACCESS_KEY: current_credentials.secret_key,
S3_SESSION_TOKEN: current_credentials.token,
S3_REGION: os.environ.get("AWS_REGION", "us-east-1"),
}

tbl = StaticTable.from_metadata(metadata, properties=s3_props)
tbl.identifier = tuple(identifier.split(".")) if isinstance(identifier, str) else identifier
tbl.catalog = self

return tbl

def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table:
query = "CREATE ICEBERG TABLE (%s) METADATA_FILE_PATH = (%s)"
sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string(
identifier if isinstance(identifier, str) else ".".join(identifier)
)

with self.connection.cursor(DictCursor) as cursor:
try:
cursor.execute(query, (sf_identifier.table_name, metadata_location))
except Exception as e:
raise TableAlreadyExistsError(f"Table {sf_identifier.table_name} already exists") from e

return self.load_table(identifier)

def drop_table(self, identifier: Union[str, Identifier]) -> None:
sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string(
identifier if isinstance(identifier, str) else ".".join(identifier)
)

query = "DROP TABLE IF EXISTS (%s)"

with self.connection.cursor(DictCursor) as cursor:
cursor.execute(query, (sf_identifier.table_name,))

def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table:
sf_from_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string(
from_identifier if isinstance(from_identifier, str) else ".".join(from_identifier)
)
sf_to_identifier = SnowflakeCatalog._SnowflakeIdentifier.table_from_string(
to_identifier if isinstance(to_identifier, str) else ".".join(to_identifier)
)

query = "ALTER TABLE (%s) RENAME TO (%s)"

with self.connection.cursor(DictCursor) as cursor:
cursor.execute(query, (sf_from_identifier.table_name, sf_to_identifier.table_name))

return self.load_table(to_identifier)

def _commit_table(self, table_request: CommitTableRequest) -> CommitTableResponse:
raise NotImplementedError

def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None:
sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.schema_from_string(
namespace if isinstance(namespace, str) else ".".join(namespace)
)

db_query = "CREATE DATABASE IF NOT EXISTS (%s)"
schema_query = "CREATE SCHEMA IF NOT EXISTS (%s)"

with self.connection.cursor(DictCursor) as cursor:
if sf_identifier.database:
cursor.execute(db_query, (sf_identifier.database,))
cursor.execute(schema_query, (sf_identifier.schema_name,))

def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.schema_from_string(
namespace if isinstance(namespace, str) else ".".join(namespace)
)

sf_query = "DROP SCHEMA IF EXISTS (%s)"
db_query = "DROP DATABASE IF EXISTS (%s)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to drop the database as well? I would expect only the schema.


with self.connection.cursor(DictCursor) as cursor:
if sf_identifier.database:
cursor.execute(db_query, (sf_identifier.database,))
cursor.execute(sf_query, (sf_identifier.schema_name,))

def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
sf_identifier = SnowflakeCatalog._SnowflakeIdentifier.schema_from_string(
namespace if isinstance(namespace, str) else ".".join(namespace)
)

schema_query = "SHOW ICEBERG TABLES IN SCHEMA (%s)"
db_query = "SHOW ICEBERG TABLES IN DATABASE (%s)"

with self.connection.cursor(DictCursor) as cursor:
if sf_identifier.database:
cursor.execute(db_query, (sf_identifier.database,))
else:
cursor.execute(schema_query, (sf_identifier.schema,))

return [(row["database_name"], row["schema_name"], row["table_name"]) for row in cursor.fetchall()]

def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]:
raise NotImplementedError

def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties:
raise NotImplementedError

def update_namespace_properties(
self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT
) -> PropertiesUpdateSummary:
raise NotImplementedError

def create_table(
self,
identifier: Union[str, Identifier],
schema: Union[Schema, pa.Schema],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: sorting.SortOrder = sorting.UNSORTED_SORT_ORDER,
properties: Properties = EMPTY_DICT,
) -> Table:
raise NotImplementedError
Loading