Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use arrow_ipc::convert::fb_to_schema;
use arrow_ipc::reader::read_record_batch;
use arrow_ipc::{root_as_message, MessageHeader};
use arrow_schema::{ArrowError, Schema, SchemaRef};
use futures::{stream, StreamExt, TryStreamExt};
use futures::{stream, Stream, StreamExt, TryStreamExt};
use prost::Message;
use tonic::transport::Channel;
use tonic::{IntoRequest, IntoStreamingRequest, Streaming};
Expand Down Expand Up @@ -228,15 +228,18 @@ impl FlightSqlServiceClient<Channel> {
}

/// Execute a bulk ingest on the server and return the number of records added
pub async fn execute_ingest(
pub async fn execute_ingest<S>(
&mut self,
command: CommandStatementIngest,
batches: Vec<RecordBatch>,
) -> Result<i64, ArrowError> {
stream: S,
) -> Result<i64, ArrowError>
where
S: Stream<Item = crate::error::Result<RecordBatch>> + Send + 'static,
{
let descriptor = FlightDescriptor::new_cmd(command.as_any().encode_to_vec());
let flight_data_encoder = FlightDataEncoderBuilder::new()
.with_flight_descriptor(Some(descriptor))
.build(stream::iter(batches).map(Ok));
.build(stream);
// Safe unwrap, explicitly wrapped on line above.
let flight_data = flight_data_encoder.map(|fd| fd.unwrap());
let req = self.set_request_headers(flight_data.into_streaming_request())?;
Expand Down
1 change: 1 addition & 0 deletions arrow-flight/tests/common/fixture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub struct TestFixture {

impl TestFixture {
/// create a new test fixture from the server
#[allow(dead_code)]
pub async fn new<T: FlightService>(test_server: FlightServiceServer<T>) -> Self {
// let OS choose a free port
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
Expand Down
14 changes: 9 additions & 5 deletions arrow-flight/tests/flight_sql_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use arrow_flight::sql::{
TableNotExistOption,
};
use arrow_flight::Action;
use futures::TryStreamExt;
use futures::{StreamExt, TryStreamExt};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
Expand All @@ -40,9 +40,7 @@ use uuid::Uuid;

#[tokio::test]
pub async fn test_begin_end_transaction() {
let test_server = FlightSqlServiceImpl {
transactions: Arc::new(Mutex::new(HashMap::new())),
};
let test_server = FlightSqlServiceImpl::new();
let fixture = TestFixture::new(test_server.service()).await;
let channel = fixture.channel().await;
let mut flight_sql_client = FlightSqlServiceClient::new(channel);
Expand Down Expand Up @@ -94,21 +92,26 @@ pub async fn test_execute_ingest() {
make_primitive_batch(2),
];
let actual_rows = flight_sql_client
.execute_ingest(cmd, batches)
.execute_ingest(cmd, futures::stream::iter(batches.clone()).map(Ok))
.await
.expect("ingest should succeed");
assert_eq!(actual_rows, expected_rows);
// make sure the batches made it through to the server
let ingested_batches = test_server.ingested_batches.lock().await.clone();
assert_eq!(ingested_batches, batches);
}

#[derive(Clone)]
pub struct FlightSqlServiceImpl {
transactions: Arc<Mutex<HashMap<String, ()>>>,
ingested_batches: Arc<Mutex<Vec<RecordBatch>>>,
}

impl FlightSqlServiceImpl {
pub fn new() -> Self {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Self {
transactions: Arc::new(Mutex::new(HashMap::new())),
ingested_batches: Arc::new(Mutex::new(Vec::new())),
}
}

Expand Down Expand Up @@ -177,6 +180,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
.try_collect()
.await?;
let affected_rows = batches.iter().map(|batch| batch.num_rows() as i64).sum();
*self.ingested_batches.lock().await.as_mut() = batches;
Ok(affected_rows)
}
}