Skip to content

Commit 0130af3

Browse files
djandersonalamb
andauthored
Expose bulk ingest in flight sql client and server (#6201)
* Expose CommandStatementIngest as pub in sql module * Add do_put_statement_ingest to FlightSqlService Dispatch this handler for the new CommandStatementIngest command. * Sort list * Implement stub do_put_statement_ingest in example * Refactor helper functions into tests/common/utils * Implement execute_ingest for flight sql client I referenced the C++ implementation here: apache/arrow@0d1ea5d * Add integration test for sql client execute_ingest * Fix lint clippy::new_without_default * Allow streaming ingest for FlightClient::execute_ingest * Properly return client errors --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 8d1f0f5 commit 0130af3

File tree

10 files changed

+319
-112
lines changed

10 files changed

+319
-112
lines changed

arrow-flight/examples/flight_sql_server.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ use arrow_flight::sql::{
4646
ActionEndTransactionRequest, Any, CommandGetCatalogs, CommandGetCrossReference,
4747
CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys,
4848
CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo,
49-
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
50-
CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable, ProstMessageExt, Searchable,
51-
SqlInfo, TicketStatementQuery, XdbcDataType,
49+
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementIngest,
50+
CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable,
51+
ProstMessageExt, Searchable, SqlInfo, TicketStatementQuery, XdbcDataType,
5252
};
5353
use arrow_flight::utils::batches_to_flight_data;
5454
use arrow_flight::{
@@ -615,6 +615,14 @@ impl FlightSqlService for FlightSqlServiceImpl {
615615
Ok(FAKE_UPDATE_RESULT)
616616
}
617617

618+
async fn do_put_statement_ingest(
619+
&self,
620+
_ticket: CommandStatementIngest,
621+
_request: Request<PeekableFlightDataStream>,
622+
) -> Result<i64, Status> {
623+
Ok(FAKE_UPDATE_RESULT)
624+
}
625+
618626
async fn do_put_substrait_plan(
619627
&self,
620628
_ticket: CommandStatementSubstraitPlan,

arrow-flight/src/client.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -679,15 +679,15 @@ impl FlightClient {
679679
/// it encounters an error it uses the oneshot sender to
680680
/// notify the error and stop any further streaming. See `do_put` or
681681
/// `do_exchange` for it's uses.
682-
struct FallibleRequestStream<T, E> {
682+
pub(crate) struct FallibleRequestStream<T, E> {
683683
/// sender to notify error
684684
sender: Option<Sender<E>>,
685685
/// fallible stream
686686
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
687687
}
688688

689689
impl<T, E> FallibleRequestStream<T, E> {
690-
fn new(
690+
pub(crate) fn new(
691691
sender: Sender<E>,
692692
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
693693
) -> Self {

arrow-flight/src/sql/client.rs

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use std::collections::HashMap;
2424
use std::str::FromStr;
2525
use tonic::metadata::AsciiMetadataKey;
2626

27+
use crate::client::FallibleRequestStream;
2728
use crate::decode::FlightRecordBatchStream;
2829
use crate::encode::FlightDataEncoderBuilder;
2930
use crate::error::FlightError;
@@ -39,8 +40,8 @@ use crate::sql::{
3940
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
4041
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
4142
CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
42-
CommandStatementQuery, CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult,
43-
ProstMessageExt, SqlInfo,
43+
CommandStatementIngest, CommandStatementQuery, CommandStatementUpdate,
44+
DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
4445
};
4546
use crate::trailers::extract_lazy_trailers;
4647
use crate::{
@@ -53,10 +54,10 @@ use arrow_ipc::convert::fb_to_schema;
5354
use arrow_ipc::reader::read_record_batch;
5455
use arrow_ipc::{root_as_message, MessageHeader};
5556
use arrow_schema::{ArrowError, Schema, SchemaRef};
56-
use futures::{stream, TryStreamExt};
57+
use futures::{stream, Stream, TryStreamExt};
5758
use prost::Message;
5859
use tonic::transport::Channel;
59-
use tonic::{IntoRequest, Streaming};
60+
use tonic::{IntoRequest, IntoStreamingRequest, Streaming};
6061

6162
/// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow data
6263
/// by FlightSQL protocol.
@@ -227,6 +228,52 @@ impl FlightSqlServiceClient<Channel> {
227228
Ok(result.record_count)
228229
}
229230

231+
/// Execute a bulk ingest on the server and return the number of records added
232+
pub async fn execute_ingest<S>(
233+
&mut self,
234+
command: CommandStatementIngest,
235+
stream: S,
236+
) -> Result<i64, ArrowError>
237+
where
238+
S: Stream<Item = crate::error::Result<RecordBatch>> + Send + 'static,
239+
{
240+
let (sender, receiver) = futures::channel::oneshot::channel();
241+
242+
let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec());
243+
let flight_data = FlightDataEncoderBuilder::new()
244+
.with_flight_descriptor(Some(descriptor))
245+
.build(stream);
246+
247+
// Intercept client errors and send them to the one shot channel above
248+
let flight_data = Box::pin(flight_data);
249+
let flight_data: FallibleRequestStream<FlightData, FlightError> =
250+
FallibleRequestStream::new(sender, flight_data);
251+
252+
let req = self.set_request_headers(flight_data.into_streaming_request())?;
253+
let mut result = self
254+
.flight_client
255+
.do_put(req)
256+
.await
257+
.map_err(status_to_arrow_error)?
258+
.into_inner();
259+
260+
// check if the there were any errors in the input stream provided note
261+
// if receiver.await fails, it means the sender was dropped and there is
262+
// no message to return.
263+
if let Ok(msg) = receiver.await {
264+
return Err(ArrowError::ExternalError(Box::new(msg)));
265+
}
266+
267+
let result = result
268+
.message()
269+
.await
270+
.map_err(status_to_arrow_error)?
271+
.unwrap();
272+
let any = Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
273+
let result: DoPutUpdateResult = any.unpack()?.unwrap();
274+
Ok(result.record_count)
275+
}
276+
230277
/// Request a list of catalogs as tabular FlightInfo results
231278
pub async fn get_catalogs(&mut self) -> Result<FlightInfo, ArrowError> {
232279
self.get_flight_info_for_command(CommandGetCatalogs {})

arrow-flight/src/sql/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ mod gen {
5050
}
5151

5252
pub use gen::action_end_transaction_request::EndTransaction;
53+
pub use gen::command_statement_ingest::table_definition_options::{
54+
TableExistsOption, TableNotExistOption,
55+
};
56+
pub use gen::command_statement_ingest::TableDefinitionOptions;
5357
pub use gen::ActionBeginSavepointRequest;
5458
pub use gen::ActionBeginSavepointResult;
5559
pub use gen::ActionBeginTransactionRequest;
@@ -74,6 +78,7 @@ pub use gen::CommandGetTables;
7478
pub use gen::CommandGetXdbcTypeInfo;
7579
pub use gen::CommandPreparedStatementQuery;
7680
pub use gen::CommandPreparedStatementUpdate;
81+
pub use gen::CommandStatementIngest;
7782
pub use gen::CommandStatementQuery;
7883
pub use gen::CommandStatementSubstraitPlan;
7984
pub use gen::CommandStatementUpdate;
@@ -250,11 +255,12 @@ prost_message_ext!(
250255
CommandGetXdbcTypeInfo,
251256
CommandPreparedStatementQuery,
252257
CommandPreparedStatementUpdate,
258+
CommandStatementIngest,
253259
CommandStatementQuery,
254260
CommandStatementSubstraitPlan,
255261
CommandStatementUpdate,
256-
DoPutUpdateResult,
257262
DoPutPreparedStatementResult,
263+
DoPutUpdateResult,
258264
TicketStatementQuery,
259265
);
260266

arrow-flight/src/sql/server.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ use super::{
3232
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
3333
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
3434
CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
35-
CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate,
36-
DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
37-
TicketStatementQuery,
35+
CommandStatementIngest, CommandStatementQuery, CommandStatementSubstraitPlan,
36+
CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt,
37+
SqlInfo, TicketStatementQuery,
3838
};
3939
use crate::{
4040
flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty,
@@ -397,6 +397,17 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
397397
))
398398
}
399399

400+
/// Execute a bulk ingestion.
401+
async fn do_put_statement_ingest(
402+
&self,
403+
_ticket: CommandStatementIngest,
404+
_request: Request<PeekableFlightDataStream>,
405+
) -> Result<i64, Status> {
406+
Err(Status::unimplemented(
407+
"do_put_statement_ingest has no default implementation",
408+
))
409+
}
410+
400411
/// Bind parameters to given prepared statement.
401412
///
402413
/// Returns an opaque handle that the client should pass
@@ -713,6 +724,14 @@ where
713724
})]);
714725
Ok(Response::new(Box::pin(output)))
715726
}
727+
Command::CommandStatementIngest(command) => {
728+
let record_count = self.do_put_statement_ingest(command, request).await?;
729+
let result = DoPutUpdateResult { record_count };
730+
let output = futures::stream::iter(vec![Ok(PutResult {
731+
app_metadata: result.as_any().encode_to_vec().into(),
732+
})]);
733+
Ok(Response::new(Box::pin(output)))
734+
}
716735
Command::CommandPreparedStatementQuery(command) => {
717736
let result = self
718737
.do_put_prepared_statement_query(command, request)

arrow-flight/tests/common/fixture.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ pub struct TestFixture {
4141

4242
impl TestFixture {
4343
/// create a new test fixture from the server
44+
#[allow(dead_code)]
4445
pub async fn new<T: FlightService>(test_server: FlightServiceServer<T>) -> Self {
4546
// let OS choose a free port
4647
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();

arrow-flight/tests/common/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818
pub mod fixture;
1919
pub mod server;
2020
pub mod trailers_layer;
21+
pub mod utils;

arrow-flight/tests/common/utils.rs

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Common utilities for testing flight clients and servers
19+
20+
use std::sync::Arc;
21+
22+
use arrow_array::{
23+
types::Int32Type, ArrayRef, BinaryViewArray, DictionaryArray, Float64Array, RecordBatch,
24+
StringViewArray, UInt8Array,
25+
};
26+
use arrow_schema::{DataType, Field, Schema};
27+
28+
/// Make a primitive batch for testing
29+
///
30+
/// Example:
31+
/// i: 0, 1, None, 3, 4
32+
/// f: 5.0, 4.0, None, 2.0, 1.0
33+
#[allow(dead_code)]
34+
pub fn make_primitive_batch(num_rows: usize) -> RecordBatch {
35+
let i: UInt8Array = (0..num_rows)
36+
.map(|i| {
37+
if i == num_rows / 2 {
38+
None
39+
} else {
40+
Some(i.try_into().unwrap())
41+
}
42+
})
43+
.collect();
44+
45+
let f: Float64Array = (0..num_rows)
46+
.map(|i| {
47+
if i == num_rows / 2 {
48+
None
49+
} else {
50+
Some((num_rows - i) as f64)
51+
}
52+
})
53+
.collect();
54+
55+
RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", Arc::new(f))]).unwrap()
56+
}
57+
58+
/// Make a dictionary batch for testing
59+
///
60+
/// Example:
61+
/// a: value0, value1, value2, None, value1, value2
62+
#[allow(dead_code)]
63+
pub fn make_dictionary_batch(num_rows: usize) -> RecordBatch {
64+
let values: Vec<_> = (0..num_rows)
65+
.map(|i| {
66+
if i == num_rows / 2 {
67+
None
68+
} else {
69+
// repeat some values for low cardinality
70+
let v = i / 3;
71+
Some(format!("value{v}"))
72+
}
73+
})
74+
.collect();
75+
76+
let a: DictionaryArray<Int32Type> = values
77+
.iter()
78+
.map(|s| s.as_ref().map(|s| s.as_str()))
79+
.collect();
80+
81+
RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap()
82+
}
83+
84+
#[allow(dead_code)]
85+
pub fn make_view_batches(num_rows: usize) -> RecordBatch {
86+
const LONG_TEST_STRING: &str =
87+
"This is a long string to make sure binary view array handles it";
88+
let schema = Schema::new(vec![
89+
Field::new("field1", DataType::BinaryView, true),
90+
Field::new("field2", DataType::Utf8View, true),
91+
]);
92+
93+
let string_view_values: Vec<Option<&str>> = (0..num_rows)
94+
.map(|i| match i % 3 {
95+
0 => None,
96+
1 => Some("foo"),
97+
2 => Some(LONG_TEST_STRING),
98+
_ => unreachable!(),
99+
})
100+
.collect();
101+
102+
let bin_view_values: Vec<Option<&[u8]>> = (0..num_rows)
103+
.map(|i| match i % 3 {
104+
0 => None,
105+
1 => Some("bar".as_bytes()),
106+
2 => Some(LONG_TEST_STRING.as_bytes()),
107+
_ => unreachable!(),
108+
})
109+
.collect();
110+
111+
let binary_array = BinaryViewArray::from_iter(bin_view_values);
112+
let utf8_array = StringViewArray::from_iter(string_view_values);
113+
RecordBatch::try_new(
114+
Arc::new(schema.clone()),
115+
vec![Arc::new(binary_array), Arc::new(utf8_array)],
116+
)
117+
.unwrap()
118+
}

0 commit comments

Comments
 (0)