Skip to content

Commit 7f96c3c

Browse files
committed
shim: move conn module to a sep file.
1 parent 40153d6 commit 7f96c3c

File tree

2 files changed

+218
-220
lines changed

2 files changed

+218
-220
lines changed
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
//! Connection management module.
2+
3+
use super::FinalizedHeadWatcher;
4+
use std::{
5+
mem,
6+
sync::{atomic::AtomicU64, Arc},
7+
};
8+
use subxt::backend::rpc::RpcClient;
9+
use sugondat_subxt::sugondat::is_codegen_valid_for;
10+
use tokio::sync::{oneshot, Mutex};
11+
12+
// Contains the RPC client structures that are assumed to be connected.
13+
pub struct Conn {
14+
/// Connection id. For diagnostics purposes only.
15+
pub conn_id: u64,
16+
pub raw: RpcClient,
17+
pub subxt: sugondat_subxt::Client,
18+
pub finalized: FinalizedHeadWatcher,
19+
}
20+
21+
impl Conn {
22+
async fn connect(conn_id: u64, rpc_url: &str) -> anyhow::Result<Arc<Self>> {
23+
let raw = RpcClient::from_url(rpc_url).await?;
24+
let subxt = sugondat_subxt::Client::from_rpc_client(raw.clone()).await?;
25+
check_if_compatible(&subxt)?;
26+
if !is_codegen_valid_for(&subxt.metadata()) {
27+
const WARN_WRONG_VERSION: &str = "connected to a sugondat node with a newer runtime than the one this shim was compiled against. Update the shim lest you run into problems. https://github.com/thrumdev/sugondat";
28+
tracing::warn!(WARN_WRONG_VERSION);
29+
}
30+
let finalized = FinalizedHeadWatcher::spawn(subxt.clone()).await;
31+
Ok(Arc::new(Self {
32+
conn_id,
33+
raw,
34+
subxt,
35+
finalized,
36+
}))
37+
}
38+
}
39+
40+
/// Tries to find the `Blob` pallet in the runtime metadata. If it's not there, then we are not
41+
/// connected to a Sugondat node.
42+
fn check_if_compatible(client: &sugondat_subxt::Client) -> anyhow::Result<()> {
43+
assert!(sugondat_subxt::sugondat::PALLETS.contains(&"Blob"));
44+
if let Some(pallet) = client.metadata().pallet_by_name("Blob") {
45+
if pallet.call_variant_by_name("submit_blob").is_some() {
46+
return Ok(());
47+
}
48+
}
49+
Err(anyhow::anyhow!(
50+
"connected to a Substrate node that is not Sugondat"
51+
))
52+
}
53+
54+
enum State {
55+
/// The client is known to be connected.
56+
///
57+
/// When the client experiences an error, there could be a brief state where the client is
58+
/// disconnected, but the connection has not been reset yet.
59+
Connected(Arc<Conn>),
60+
/// The client is currently connecting. The waiters are notified when the connection is
61+
/// established.
62+
Connecting {
63+
conn_id: u64,
64+
waiting: Vec<oneshot::Sender<Arc<Conn>>>,
65+
},
66+
/// Mostly used for as a dummy state during initialization, because the client should always
67+
/// be connected or connecting.
68+
Disconnected,
69+
}
70+
71+
/// A struct that abstracts the connection concerns.
72+
///
73+
/// Allows to wait for a connection to be established and to reset the connection if we detect
74+
/// that it's broken.
75+
pub struct Connector {
76+
state: Arc<Mutex<State>>,
77+
next_conn_id: AtomicU64,
78+
rpc_url: Arc<String>,
79+
}
80+
81+
impl Connector {
82+
pub fn new(rpc_url: Arc<String>) -> Self {
83+
Self {
84+
state: Arc::new(Mutex::new(State::Disconnected)),
85+
next_conn_id: AtomicU64::new(0),
86+
rpc_url,
87+
}
88+
}
89+
90+
/// Makes sure that the client is connected. Returns the connection handle.
91+
pub async fn ensure_connected(&self) -> Arc<Conn> {
92+
let mut state = self.state.lock().await;
93+
match &mut *state {
94+
State::Connected(conn) => {
95+
let conn_id = conn.conn_id;
96+
tracing::debug!(?conn_id, "reusing existing connection");
97+
conn.clone()
98+
}
99+
State::Connecting {
100+
conn_id,
101+
ref mut waiting,
102+
} => {
103+
// Somebody else is already connecting, let them cook.
104+
tracing::debug!(?conn_id, "waiting for existing connection");
105+
let (tx, rx) = oneshot::channel();
106+
waiting.push(tx);
107+
drop(state);
108+
109+
rx.await.expect("cannot be dropped")
110+
}
111+
State::Disconnected => {
112+
// We are the first to connect.
113+
//
114+
// Important part: if the task performing the connection is cancelled,
115+
// the `waiters` won't be notified and will wait forever unless we implement
116+
// mitigation measures.
117+
//
118+
// Instead, we just spawn a new task here and the current task will wait
119+
// similarly to the other waiters.
120+
121+
// Step 1: set the state to `Connecting` registering ourselves as a waiter.
122+
let conn_id = self.gen_conn_id();
123+
let (tx, rx) = oneshot::channel();
124+
*state = State::Connecting {
125+
conn_id,
126+
waiting: vec![tx],
127+
};
128+
129+
// Step 2: spawn the connection task.
130+
self.spawn_connection_task(conn_id);
131+
drop(state);
132+
133+
// Step 3: wait for the connection to be established.
134+
rx.await.expect("cannot be dropped")
135+
}
136+
}
137+
}
138+
139+
/// Drop the current connection and start a new connection task.
140+
pub async fn reset(&self) {
141+
let mut state = self.state.lock().await;
142+
match *state {
143+
State::Connecting { conn_id, .. } => {
144+
// Guard against initiating a new connection when one is already in progress.
145+
tracing::debug!(?conn_id, "reset: reconnection already in progress");
146+
return;
147+
}
148+
State::Connected(ref conn) => {
149+
let conn_id = conn.conn_id;
150+
tracing::debug!(?conn_id, "reset: dropping connection");
151+
}
152+
State::Disconnected => (),
153+
}
154+
let conn_id = self.gen_conn_id();
155+
tracing::debug!(?conn_id, "reset: initiating new connection");
156+
*state = State::Connecting {
157+
conn_id,
158+
waiting: vec![],
159+
};
160+
self.spawn_connection_task(conn_id);
161+
drop(state);
162+
}
163+
164+
fn gen_conn_id(&self) -> u64 {
165+
use std::sync::atomic::Ordering;
166+
let conn_id = self.next_conn_id.fetch_add(1, Ordering::Relaxed);
167+
conn_id
168+
}
169+
170+
/// Spawns a task that will connect to the sugondat node and notify all waiters.
171+
fn spawn_connection_task(&self, conn_id: u64) {
172+
let state = self.state.clone();
173+
let rpc_url = self.rpc_url.clone();
174+
let _ = tokio::spawn(async move {
175+
tracing::debug!(?conn_id, ?rpc_url, "connecting to sugondat node");
176+
let conn = loop {
177+
match Conn::connect(conn_id, &rpc_url).await {
178+
Ok(conn) => break conn,
179+
Err(e) => {
180+
tracing::error!(?conn_id, "failed to connect to sugondat node: {}\n", e);
181+
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
182+
}
183+
}
184+
};
185+
186+
let mut state = state.lock().await;
187+
let waiters = match &mut *state {
188+
State::Connected(_) => {
189+
// only one task is allowed to connect, and in this case it's us.
190+
unreachable!()
191+
}
192+
State::Connecting {
193+
conn_id: actual_conn_id,
194+
ref mut waiting,
195+
} => {
196+
debug_assert_eq!(conn_id, *actual_conn_id);
197+
mem::take(waiting)
198+
}
199+
State::Disconnected => {
200+
debug_assert!(false, "unexpected state");
201+
vec![]
202+
}
203+
};
204+
205+
// Finally, set the state to `Connected`, notify all waiters and explicitly
206+
// release the mutex.
207+
for tx in waiters {
208+
let _ = tx.send(conn.clone());
209+
}
210+
*state = State::Connected(conn);
211+
drop(state);
212+
213+
tracing::info!(?conn_id, "connected to sugondat node");
214+
});
215+
}
216+
}

0 commit comments

Comments
 (0)