diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index f97311d6f9e3..9650031d8b5f 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -17,7 +17,7 @@ use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll}; -use crate::{error::Result, FlightData, SchemaAsIpc}; +use crate::{error::Result, FlightData, FlightDescriptor, SchemaAsIpc}; use arrow_array::{ArrayRef, RecordBatch, RecordBatchOptions}; use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; @@ -72,6 +72,8 @@ pub struct FlightDataEncoderBuilder { app_metadata: Bytes, /// Optional schema, if known before data. schema: Option, + /// Optional flight descriptor, if known before data. + descriptor: Option, } /// Default target size for encoded [`FlightData`]. @@ -87,6 +89,7 @@ impl Default for FlightDataEncoderBuilder { options: IpcWriteOptions::default(), app_metadata: Bytes::new(), schema: None, + descriptor: None, } } } @@ -134,6 +137,15 @@ impl FlightDataEncoderBuilder { self } + /// Specify a flight descriptor in the first FlightData message. + pub fn with_flight_descriptor( + mut self, + descriptor: Option, + ) -> Self { + self.descriptor = descriptor; + self + } + /// Return a [`Stream`](futures::Stream) of [`FlightData`], /// consuming self. More details on [`FlightDataEncoder`] pub fn build(self, input: S) -> FlightDataEncoder @@ -145,6 +157,7 @@ impl FlightDataEncoderBuilder { options, app_metadata, schema, + descriptor, } = self; FlightDataEncoder::new( @@ -153,6 +166,7 @@ impl FlightDataEncoderBuilder { max_flight_data_size, options, app_metadata, + descriptor, ) } } @@ -176,6 +190,8 @@ pub struct FlightDataEncoder { queue: VecDeque, /// Is this stream done (inner is empty or errored) done: bool, + /// cleared after the first FlightData message is sent + descriptor: Option, } impl FlightDataEncoder { @@ -185,6 +201,7 @@ impl FlightDataEncoder { max_flight_data_size: usize, options: IpcWriteOptions, app_metadata: Bytes, + descriptor: Option, ) -> Self { let mut encoder = Self { inner, @@ -194,17 +211,22 @@ impl FlightDataEncoder { app_metadata: Some(app_metadata), queue: VecDeque::new(), done: false, + descriptor, }; // If schema is known up front, enqueue it immediately if let Some(schema) = schema { encoder.encode_schema(&schema); } + encoder } /// Place the `FlightData` in the queue to send - fn queue_message(&mut self, data: FlightData) { + fn queue_message(&mut self, mut data: FlightData) { + if let Some(descriptor) = self.descriptor.take() { + data.flight_descriptor = Some(descriptor); + } self.queue.push_back(data); } diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs index 90fa2b7a6832..4f1a8e667ffc 100644 --- a/arrow-flight/tests/encode_decode.rs +++ b/arrow-flight/tests/encode_decode.rs @@ -22,6 +22,8 @@ use std::{collections::HashMap, sync::Arc}; use arrow_array::types::Int32Type; use arrow_array::{ArrayRef, DictionaryArray, Float64Array, RecordBatch, UInt8Array}; use arrow_cast::pretty::pretty_format_batches; +use arrow_flight::flight_descriptor::DescriptorType; +use arrow_flight::FlightDescriptor; use arrow_flight::{ decode::{DecodedPayload, FlightDataDecoder, FlightRecordBatchStream}, encode::FlightDataEncoderBuilder, @@ -136,6 +138,29 @@ async fn test_zero_batches_schema_specified() { assert_eq!(decoder.schema(), Some(&schema)); } +#[tokio::test] +async fn test_with_flight_descriptor() { + let stream = futures::stream::iter(vec![Ok(make_dictionary_batch(5))]); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + + let descriptor = Some(FlightDescriptor { + r#type: DescriptorType::Path.into(), + path: vec!["table_name".to_string()], + cmd: Bytes::default(), + }); + + let encoder = FlightDataEncoderBuilder::default() + .with_schema(schema.clone()) + .with_flight_descriptor(descriptor.clone()); + + let mut encoder = encoder.build(stream); + + // First batch should be the schema + let first_batch = encoder.next().await.unwrap().unwrap(); + + assert_eq!(first_batch.flight_descriptor, descriptor); +} + #[tokio::test] async fn test_zero_batches_dictionary_schema_specified() { let schema = Arc::new(Schema::new(vec![