-
Notifications
You must be signed in to change notification settings - Fork 894
Expand file tree
/
Copy pathaddressed_router.rs
More file actions
230 lines (200 loc) · 8.61 KB
/
addressed_router.rs
File metadata and controls
230 lines (200 loc) · 8.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use async_nats::client::Client;
use async_nats::{HeaderMap, HeaderValue};
use tracing as log;
use super::*;
use crate::logging::DistributedTraceContext;
use crate::logging::get_distributed_tracing_context;
use crate::{Result, protocols::maybe_error::MaybeError};
use tokio_stream::{StreamExt, StreamNotifyClose, wrappers::ReceiverStream};
use tracing::Instrument;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum RequestType {
SingleIn,
ManyIn,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ResponseType {
SingleOut,
ManyOut,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct RequestControlMessage {
id: String,
request_type: RequestType,
response_type: ResponseType,
connection_info: ConnectionInfo,
}
pub struct AddressedRequest<T> {
request: T,
address: String,
}
impl<T> AddressedRequest<T> {
pub fn new(request: T, address: String) -> Self {
Self { request, address }
}
fn into_parts(self) -> (T, String) {
(self.request, self.address)
}
}
pub struct AddressedPushRouter {
// todo: generalize with a generic
req_transport: Client,
// todo: generalize with a generic
resp_transport: Arc<tcp::server::TcpStreamServer>,
}
impl AddressedPushRouter {
pub fn new(
req_transport: Client,
resp_transport: Arc<tcp::server::TcpStreamServer>,
) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
req_transport,
resp_transport,
}))
}
}
#[async_trait]
impl<T, U> AsyncEngine<SingleIn<AddressedRequest<T>>, ManyOut<U>, Error> for AddressedPushRouter
where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de> + MaybeError,
{
async fn generate(&self, request: SingleIn<AddressedRequest<T>>) -> Result<ManyOut<U>, Error> {
let request_id = request.context().id().to_string();
let (addressed_request, context) = request.transfer(());
let (request, address) = addressed_request.into_parts();
let engine_ctx = context.context();
let engine_ctx_ = engine_ctx.clone();
// registration options for the data plane in a singe in / many out configuration
let options = StreamOptions::builder()
.context(engine_ctx.clone())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
// register our needs with the data plane
// todo - generalize this with a generic data plane object which hides the specific transports
let pending_connections: PendingConnections = self.resp_transport.register(options).await;
// validate and unwrap the RegisteredStream object
let pending_response_stream = match pending_connections.into_parts() {
(None, Some(recv_stream)) => recv_stream,
_ => {
panic!("Invalid data plane registration for a SingleIn/ManyOut transport");
}
};
// separate out the connection info and the stream provider from the registered stream
let (connection_info, response_stream_provider) = pending_response_stream.into_parts();
// package up the connection info as part of the "header" component of the two part message
// used to issue the request on the
// todo -- this object should be automatically created by the register call, and achieved by to the two into_parts()
// calls. all the information here is provided by the [`StreamOptions`] object and/or the dataplane object
let control_message = RequestControlMessage {
id: engine_ctx.id().to_string(),
request_type: RequestType::SingleIn,
response_type: ResponseType::ManyOut,
connection_info,
};
// next build the two part message where we package the connection info and the request into
// a single Vec<u8> that can be sent over the wire.
// --- package this up in the WorkQueuePublisher ---
let ctrl = serde_json::to_vec(&control_message)?;
let data = serde_json::to_vec(&request)?;
log::trace!(
request_id,
"packaging two-part message; ctrl: {} bytes, data: {} bytes",
ctrl.len(),
data.len()
);
let msg = TwoPartMessage::from_parts(ctrl.into(), data.into());
// the request plane / work queue should provide a two part message codec that can be used
// or it should take a two part message directly
// todo - update this
let codec = TwoPartCodec::default();
let buffer = codec.encode_message(msg)?;
// TRANSPORT ABSTRACT REQUIRED - END HERE
log::trace!(request_id, "enqueueing two-part message to nats");
// Insert Trace Context into Headers
// Enables span to be created in push_endpoint before
// payload is parsed
let mut headers = HeaderMap::new();
if let Some(trace_context) = get_distributed_tracing_context() {
headers.insert("traceparent", trace_context.create_traceparent());
if let Some(tracestate) = trace_context.tracestate {
headers.insert("tracestate", tracestate);
}
if let Some(x_request_id) = trace_context.x_request_id {
headers.insert("x-request-id", x_request_id);
}
if let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id {
headers.insert("x-dynamo-request-id", x_dynamo_request_id);
}
}
// we might need to add a timeout on this if there is no subscriber to the subject; however, I think nats
// will handle this for us
let _response = self
.req_transport
.request_with_headers(address.to_string(), headers, buffer)
.await?;
log::trace!(request_id, "awaiting transport handshake");
let response_stream = response_stream_provider
.await
.map_err(|_| PipelineError::DetachedStreamReceiver)?
.map_err(PipelineError::ConnectionFailed)?;
// TODO: Detect end-of-stream using Server-Sent Events (SSE)
let mut is_complete_final = false;
let stream = tokio_stream::StreamNotifyClose::new(
tokio_stream::wrappers::ReceiverStream::new(response_stream.rx),
)
.filter_map(move |res| {
if let Some(res_bytes) = res {
if is_complete_final {
return Some(U::from_err(
Error::msg(
"Response received after generation ended - this should never happen",
)
.into(),
));
}
match serde_json::from_slice::<NetworkStreamWrapper<U>>(&res_bytes) {
Ok(item) => {
is_complete_final = item.complete_final;
if let Some(data) = item.data {
Some(data)
} else if is_complete_final {
None
} else {
Some(U::from_err(
Error::msg("Empty response received - this should never happen")
.into(),
))
}
}
Err(err) => {
// legacy log print
let json_str = String::from_utf8_lossy(&res_bytes);
log::warn!(%err, %json_str, "Failed deserializing JSON to response");
Some(U::from_err(Error::new(err).into()))
}
}
} else if is_complete_final {
// end of stream
None
} else if engine_ctx_.is_stopped() {
// Gracefully end the stream if 'stop_generating()' was called. Do NOT check for
// 'is_killed()' here because it implies the stream ended abnormally which should be
// handled by the error branch below.
log::debug!("Request cancelled and then trying to read a response");
None
} else {
// stream ended unexpectedly
log::debug!("{STREAM_ERR_MSG}");
Some(U::from_err(Error::msg(STREAM_ERR_MSG).into()))
}
});
Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
}
}