diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index ddbf6a4e01..49e2c4fbdd 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -18,9 +18,10 @@ //! This module contains the iceberg REST catalog implementation. use std::any::Any; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::future::Future; use std::str::FromStr; +use std::sync::OnceLock; use async_trait::async_trait; use iceberg::io::{self, FileIO}; @@ -37,6 +38,7 @@ use reqwest::{Client, Method, StatusCode, Url}; use tokio::sync::OnceCell; use typed_builder::TypedBuilder; +use crate::Endpoint; use crate::client::{ HttpClient, deserialize_catalog_response, deserialize_unexpected_catalog_error, }; @@ -55,6 +57,33 @@ const ICEBERG_REST_SPEC_VERSION: &str = "0.14.1"; const CARGO_PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); const PATH_V1: &str = "v1"; +static DEFAULT_ENDPOINTS: OnceLock> = OnceLock::new(); + +fn default_endpoints() -> &'static HashSet { + DEFAULT_ENDPOINTS.get_or_init(|| { + [ + Endpoint::v1_config(), + Endpoint::v1_list_namespaces(), + Endpoint::v1_create_namespace(), + Endpoint::v1_load_namespace(), + Endpoint::v1_update_namespace(), + Endpoint::v1_delete_namespace(), + Endpoint::v1_list_tables(), + Endpoint::v1_create_table(), + Endpoint::v1_load_table(), + Endpoint::v1_update_table(), + Endpoint::v1_delete_table(), + Endpoint::v1_rename_table(), + Endpoint::v1_register_table(), + Endpoint::v1_report_metrics(), + Endpoint::v1_commit_transaction(), + ] + .into_iter() + .cloned() + .collect() + }) +} + /// Builder for [`RestCatalog`]. #[derive(Debug)] pub struct RestCatalogBuilder(RestCatalogConfig); @@ -67,6 +96,7 @@ impl Default for RestCatalogBuilder { warehouse: None, props: HashMap::new(), client: None, + endpoints: default_endpoints().clone(), }) } } @@ -142,6 +172,9 @@ pub(crate) struct RestCatalogConfig { #[builder(default)] client: Option, + + #[builder(default)] + endpoints: HashSet, } impl RestCatalogConfig { @@ -304,6 +337,13 @@ impl RestCatalogConfig { props.extend(config.overrides); self.props = props; + self.endpoints = if config.endpoints.is_empty() { + default_endpoints().clone() + } else { + eprintln!("Endpoints are {:?}", config.endpoints); + config.endpoints + }; + self } } @@ -353,7 +393,6 @@ impl RestCatalog { let catalog_config = RestCatalog::load_config(&client, &self.user_config).await?; let config = self.user_config.clone().merge_with_config(catalog_config); let client = client.update_with(&config)?; - Ok(RestContext { config, client }) }) .await @@ -442,6 +481,7 @@ impl Catalog for RestCatalog { parent: Option<&NamespaceIdent>, ) -> Result> { let context = self.context().await?; + Endpoint::check_supported(&context.config.endpoints, Endpoint::v1_list_namespaces())?; let endpoint = context.config.namespaces_endpoint(); let mut namespaces = Vec::new(); let mut next_token = None; @@ -492,6 +532,7 @@ impl Catalog for RestCatalog { properties: HashMap, ) -> Result { let context = self.context().await?; + Endpoint::check_supported(&context.config.endpoints, Endpoint::v1_create_namespace())?; let request = context .client @@ -520,6 +561,7 @@ impl Catalog for RestCatalog { async fn get_namespace(&self, namespace: &NamespaceIdent) -> Result { let context = self.context().await?; + Endpoint::check_supported(&context.config.endpoints, Endpoint::v1_load_namespace())?; let request = context .client @@ -544,6 +586,7 @@ impl Catalog for RestCatalog { async fn namespace_exists(&self, ns: &NamespaceIdent) -> Result { let context = self.context().await?; + Endpoint::check_supported(&context.config.endpoints, Endpoint::v1_namespace_exists())?; let request = context .client @@ -572,6 +615,7 @@ impl Catalog for RestCatalog { async fn drop_namespace(&self, namespace: &NamespaceIdent) -> Result<()> { let context = self.context().await?; + Endpoint::check_supported(&context.config.endpoints, Endpoint::v1_delete_namespace())?; let request = context .client @@ -592,6 +636,7 @@ impl Catalog for RestCatalog { async fn list_tables(&self, namespace: &NamespaceIdent) -> Result> { let context = self.context().await?; + Endpoint::check_supported(&context.config.endpoints, Endpoint::v1_list_tables())?; let endpoint = context.config.tables_endpoint(namespace); let mut identifiers = Vec::new(); let mut next_token = None; @@ -642,6 +687,7 @@ impl Catalog for RestCatalog { creation: TableCreation, ) -> Result { let context = self.context().await?; + Endpoint::check_supported(&context.config.endpoints, Endpoint::v1_create_table())?; let table_ident = TableIdent::new(namespace.clone(), creation.name.clone()); @@ -714,6 +760,15 @@ impl Catalog for RestCatalog { /// provided locally to the `RestCatalog` will take precedence. async fn load_table(&self, table_ident: &TableIdent) -> Result
{ let context = self.context().await?; + eprintln!( + "DEBUG: endpoints value in load_table: {:?}", + context.config.endpoints + ); + eprintln!( + "DEBUG: looking for endpoint: {:?}", + Endpoint::v1_load_table() + ); + Endpoint::check_supported(&context.config.endpoints, Endpoint::v1_load_table())?; let request = context .client @@ -760,6 +815,7 @@ impl Catalog for RestCatalog { /// Drop a table from the catalog. async fn drop_table(&self, table: &TableIdent) -> Result<()> { let context = self.context().await?; + Endpoint::check_supported(&context.config.endpoints, Endpoint::v1_delete_table())?; let request = context .client @@ -781,6 +837,7 @@ impl Catalog for RestCatalog { /// Check if a table exists in the catalog. async fn table_exists(&self, table: &TableIdent) -> Result { let context = self.context().await?; + Endpoint::check_supported(&context.config.endpoints, Endpoint::v1_table_exists())?; let request = context .client @@ -799,6 +856,7 @@ impl Catalog for RestCatalog { /// Rename a table in the catalog. async fn rename_table(&self, src: &TableIdent, dest: &TableIdent) -> Result<()> { let context = self.context().await?; + Endpoint::check_supported(&context.config.endpoints, Endpoint::v1_rename_table())?; let request = context .client @@ -831,6 +889,7 @@ impl Catalog for RestCatalog { metadata_location: String, ) -> Result
{ let context = self.context().await?; + Endpoint::check_supported(&context.config.endpoints, Endpoint::v1_register_table())?; let request = context .client @@ -885,6 +944,7 @@ impl Catalog for RestCatalog { async fn update_table(&self, mut commit: TableCommit) -> Result
{ let context = self.context().await?; + Endpoint::check_supported(&context.config.endpoints, Endpoint::v1_update_table())?; let request = context .client diff --git a/crates/catalog/rest/src/types.rs b/crates/catalog/rest/src/types.rs index ab44c40ee3..cf79772d98 100644 --- a/crates/catalog/rest/src/types.rs +++ b/crates/catalog/rest/src/types.rs @@ -15,20 +15,244 @@ // specific language governing permissions and limitations // under the License. -//! Request and response types for the Iceberg REST API. - -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; +use std::sync::OnceLock; +use http::Method; use iceberg::spec::{Schema, SortOrder, TableMetadata, UnboundPartitionSpec}; use iceberg::{ Error, ErrorKind, Namespace, NamespaceIdent, TableIdent, TableRequirement, TableUpdate, }; -use serde_derive::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; #[derive(Clone, Debug, Serialize, Deserialize)] pub(super) struct CatalogConfig { pub(super) overrides: HashMap, pub(super) defaults: HashMap, + #[serde(default, skip_serializing_if = "HashSet::is_empty")] + pub(super) endpoints: HashSet, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +/// Struct containing the method and path of an REST Catalog Endpoint +pub struct Endpoint { + method: Method, + path: String, +} + +impl Endpoint { + /// HTTP Method for endpoint + pub fn method(&self) -> &Method { + &self.method + } + + /// Endpoint path + pub fn path(&self) -> &str { + &self.path + } + + /// Check if the set of supported endpoints supports the provided endpoint + pub fn check_supported( + supported_endpoints: &HashSet, + endpoint: &Endpoint, + ) -> Result<(), Error> { + if supported_endpoints.contains(endpoint) { + Ok(()) + } else { + Err(Error::new( + ErrorKind::FeatureUnsupported, + format!( + "Endpoint '{}' is not supported by the server", + endpoint.as_str() + ), + )) + } + } + + /// Parse endpoint in the form into an endpoint structure + pub fn parse(endpoint: &str) -> Result { + let parts: Vec<&str> = endpoint.splitn(2, ' ').collect(); + if parts.len() != 2 { + return Err(Error::new( + ErrorKind::DataInvalid, + "Invalid endpoint format: '{}'. Expected ", + )); + } + + let method: Method = Method::from_bytes(parts[0].as_bytes()).map_err(|_| { + Error::new( + ErrorKind::DataInvalid, + format!("Invalid HTTP method: '{}'", parts[0]), + ) + })?; + + // Validate that the method is one of the standard HTTP methods since from_bytes allows for http extensions + match method { + Method::GET + | Method::POST + | Method::PUT + | Method::DELETE + | Method::HEAD + | Method::OPTIONS + | Method::PATCH => {} + _ => { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Invalid HTTP method: '{}'. Must be one of GET, POST, PUT, DELETE, HEAD, OPTIONS, PATCH", + parts[0] + ), + )); + } + } + + return Ok(Self { + method: method, + path: parts[1].to_string(), + }); + } + + /// Config endpoint + pub fn v1_config() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| Endpoint::parse("GET /v1/config").unwrap()) + } + + /// List namespaces endpoint + pub fn v1_list_namespaces() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| Endpoint::parse("GET /v1/{prefix}/namespaces").unwrap()) + } + + /// Create namespace endpoint + pub fn v1_create_namespace() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| Endpoint::parse("POST /v1/{prefix}/namespaces").unwrap()) + } + + /// Load namespace endpoint + pub fn v1_load_namespace() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| Endpoint::parse("GET /v1/{prefix}/namespaces/{namespace}").unwrap()) + } + + /// Update namespace endpoint + pub fn v1_update_namespace() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| { + Endpoint::parse("POST /v1/{prefix}/namespaces/{namespace}/properties").unwrap() + }) + } + + /// Delete namespace endpoint + pub fn v1_delete_namespace() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT + .get_or_init(|| Endpoint::parse("DELETE /v1/{prefix}/namespaces/{namespace}").unwrap()) + } + + /// List tables endpoint + pub fn v1_list_tables() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| { + Endpoint::parse("GET /v1/{prefix}/namespaces/{namespace}/tables").unwrap() + }) + } + + /// Create table endpoint + pub fn v1_create_table() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| { + Endpoint::parse("POST /v1/{prefix}/namespaces/{namespace}/tables").unwrap() + }) + } + + /// Load table endpoint + pub fn v1_load_table() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| { + Endpoint::parse("GET /v1/{prefix}/namespaces/{namespace}/tables/{table}").unwrap() + }) + } + + /// Update table endpoint + pub fn v1_update_table() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| { + Endpoint::parse("POST /v1/{prefix}/namespaces/{namespace}/tables/{table}").unwrap() + }) + } + + /// Delete table endpoint + pub fn v1_delete_table() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| { + Endpoint::parse("DELETE /v1/{prefix}/namespaces/{namespace}/tables/{table}").unwrap() + }) + } + + /// Rename table endpoint + pub fn v1_rename_table() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| Endpoint::parse("POST /v1/{prefix}/tables/rename").unwrap()) + } + + /// Register table endpoint + pub fn v1_register_table() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| { + Endpoint::parse("POST /v1/{prefix}/namespaces/{namespace}/register").unwrap() + }) + } + + /// Report metrics endpoint + pub fn v1_report_metrics() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| { + Endpoint::parse("POST /v1/{prefix}/namespaces/{namespace}/tables/{table}/metrics") + .unwrap() + }) + } + + /// Commit transaction endpoint + pub fn v1_commit_transaction() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| Endpoint::parse("POST /v1/{prefix}/transactions/commit").unwrap()) + } + + /// Check namespace exists endpoint + pub fn v1_namespace_exists() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT + .get_or_init(|| Endpoint::parse("HEAD /v1/{prefix}/namespaces/{namespace}").unwrap()) + } + + /// Check table exists endpoint + pub fn v1_table_exists() -> &'static Endpoint { + static ENDPOINT: OnceLock = OnceLock::new(); + ENDPOINT.get_or_init(|| { + Endpoint::parse("HEAD /v1/{prefix}/namespaces/{namespace}/tables/{table}").unwrap() + }) + } + + fn as_str(&self) -> String { + format!("{} {}", self.method, self.path) + } +} + +impl<'de> Deserialize<'de> for Endpoint { + fn deserialize(deserializer: D) -> Result + where D: Deserializer<'de> { + let s = String::deserialize(deserializer)?; + Endpoint::parse(&s).map_err(serde::de::Error::custom) + } +} + +impl Serialize for Endpoint { + fn serialize(&self, serializer: S) -> Result + where S: Serializer { + serializer.serialize_str(&format!("{} {}", self.method, self.path)) + } } #[derive(Debug, Serialize, Deserialize)] @@ -315,6 +539,193 @@ pub struct RegisterTableRequest { mod tests { use super::*; + #[test] + fn test_endpoint_parse_valid() { + let endpoint = Endpoint::parse("GET /v1/config").unwrap(); + assert_eq!(endpoint.method(), &Method::GET); + assert_eq!(endpoint.path(), "/v1/config"); + + let endpoint = Endpoint::parse("POST /v1/namespaces").unwrap(); + assert_eq!(endpoint.method(), &Method::POST); + assert_eq!(endpoint.path(), "/v1/namespaces"); + + let endpoint = Endpoint::parse("DELETE /v1/namespaces/{namespace}").unwrap(); + assert_eq!(endpoint.method(), &Method::DELETE); + assert_eq!(endpoint.path(), "/v1/namespaces/{namespace}"); + + let endpoint = Endpoint::parse("HEAD /v1/namespaces/{namespace}/tables/{table}").unwrap(); + assert_eq!(endpoint.method(), &Method::HEAD); + assert_eq!(endpoint.path(), "/v1/namespaces/{namespace}/tables/{table}"); + } + + #[test] + fn test_endpoint_parse_invalid_format() { + // Missing space separator + let result = Endpoint::parse("GET/v1/config"); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid endpoint format") + ); + + // Missing method + let result = Endpoint::parse("/v1/config"); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid endpoint format") + ); + + // Missing path + let result = Endpoint::parse("GET"); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid endpoint format") + ); + + // Empty string + let result = Endpoint::parse(""); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid endpoint format") + ); + } + + #[test] + fn test_endpoint_parse_invalid_method() { + let result = Endpoint::parse("INVALID /v1/config"); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid HTTP method") + ); + + let result = Endpoint::parse("get /v1/config"); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid HTTP method") + ); + } + + #[test] + fn test_endpoint_serialize() { + let endpoint = Endpoint::parse("GET /v1/config").unwrap(); + let serialized = serde_json::to_string(&endpoint).unwrap(); + assert_eq!(serialized, "\"GET /v1/config\""); + + let endpoint = Endpoint::parse("POST /v1/namespaces/{namespace}/tables").unwrap(); + let serialized = serde_json::to_string(&endpoint).unwrap(); + assert_eq!(serialized, "\"POST /v1/namespaces/{namespace}/tables\""); + } + + #[test] + fn test_endpoint_deserialize() { + let json = "\"GET /v1/config\""; + let endpoint: Endpoint = serde_json::from_str(json).unwrap(); + assert_eq!(endpoint.method(), &Method::GET); + assert_eq!(endpoint.path(), "/v1/config"); + + let json = "\"DELETE /v1/namespaces/{namespace}\""; + let endpoint: Endpoint = serde_json::from_str(json).unwrap(); + assert_eq!(endpoint.method(), &Method::DELETE); + assert_eq!(endpoint.path(), "/v1/namespaces/{namespace}"); + } + + #[test] + fn test_endpoint_deserialize_invalid() { + let json = "\"INVALID /v1/config\""; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + + let json = "\"GET\""; + let result: Result = serde_json::from_str(json); + assert!(result.is_err()); + } + + #[test] + fn test_endpoint_check_supported() { + let mut supported = HashSet::new(); + supported.insert(Endpoint::parse("GET /v1/config").unwrap()); + supported.insert(Endpoint::parse("GET /v1/namespaces").unwrap()); + supported.insert(Endpoint::parse("POST /v1/namespaces").unwrap()); + + // Supported endpoint + let endpoint = Endpoint::parse("GET /v1/config").unwrap(); + assert!(Endpoint::check_supported(&supported, &endpoint).is_ok()); + + // Unsupported endpoint + let endpoint = Endpoint::parse("DELETE /v1/namespaces/{namespace}").unwrap(); + let result = Endpoint::check_supported(&supported, &endpoint); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("not supported by the server") + ); + } + + #[test] + fn test_catalog_config_endpoints_serde() { + // Test with endpoints + let json = serde_json::json!({ + "overrides": {}, + "defaults": {}, + "endpoints": [ + "GET /v1/config", + "POST /v1/namespaces", + "GET /v1/namespaces/{namespace}" + ] + }); + + let config: CatalogConfig = serde_json::from_value(json.clone()).unwrap(); + assert_eq!(config.endpoints.len(), 3); + assert!( + config + .endpoints + .contains(&Endpoint::parse("GET /v1/config").unwrap()) + ); + assert!( + config + .endpoints + .contains(&Endpoint::parse("POST /v1/namespaces").unwrap()) + ); + assert!( + config + .endpoints + .contains(&Endpoint::parse("GET /v1/namespaces/{namespace}").unwrap()) + ); + + let serialized = serde_json::to_value(&config).unwrap(); + assert_eq!(serialized["endpoints"].as_array().unwrap().len(), 3); + + let json_no_endpoints = serde_json::json!({ + "overrides": {}, + "defaults": {} + }); + + let config: CatalogConfig = serde_json::from_value(json_no_endpoints.clone()).unwrap(); + assert!(config.endpoints.is_empty()); + + let serialized = serde_json::to_value(&config).unwrap(); + assert!(serialized.get("endpoints").is_none()); + } + #[test] fn test_namespace_response_serde() { let json = serde_json::json!({