150 lines
4.3 KiB
Rust
150 lines
4.3 KiB
Rust
use crate::{
|
|
app::state::{ExecResult, PendingExec},
|
|
transport::protocol::{ClientMsg, ServerMsg},
|
|
};
|
|
use dashmap::DashMap;
|
|
use futures::{SinkExt, StreamExt};
|
|
use rocket_ws::{Message, stream::DuplexStream};
|
|
use std::{sync::Arc, time::Duration};
|
|
use tokio::sync::{mpsc, oneshot};
|
|
use tokio_util::sync::CancellationToken;
|
|
|
|
type PendingReplies = Arc<DashMap<String, oneshot::Sender<ExecResult>>>;
|
|
|
|
pub async fn run_loop(
|
|
socket: &mut DuplexStream,
|
|
mut exec_rx: mpsc::Receiver<PendingExec>,
|
|
config: &crate::core::config::Config,
|
|
cancel: &CancellationToken,
|
|
) {
|
|
let pending = Arc::new(DashMap::new());
|
|
let mut ping_interval = tokio::time::interval_at(
|
|
tokio::time::Instant::now() + config.ping_interval,
|
|
config.ping_interval,
|
|
);
|
|
let pong_timer = tokio::time::sleep(Duration::from_secs(0));
|
|
tokio::pin!(pong_timer);
|
|
let mut awaiting_pong = false;
|
|
|
|
loop {
|
|
tokio::select! {
|
|
_ = ping_interval.tick() => {
|
|
if awaiting_pong || socket.send(Message::Ping(vec![])).await.is_err() {
|
|
break;
|
|
}
|
|
|
|
awaiting_pong = true;
|
|
pong_timer.as_mut().reset(tokio::time::Instant::now() + config.ping_timeout);
|
|
}
|
|
_ = &mut pong_timer, if awaiting_pong => break,
|
|
Some(exec) = exec_rx.recv() => {
|
|
if !dispatch(socket, &pending, exec, config.max_execution_time, cancel.clone()).await {
|
|
break;
|
|
}
|
|
}
|
|
message = socket.next() => {
|
|
match handle_socket_message(socket, message, &pending, config.max_result_bytes).await {
|
|
SocketAction::Break => break,
|
|
SocketAction::ClearPongDeadline => awaiting_pong = false,
|
|
SocketAction::Continue => {}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pending.clear();
|
|
}
|
|
|
|
enum SocketAction {
|
|
Break,
|
|
ClearPongDeadline,
|
|
Continue,
|
|
}
|
|
|
|
async fn handle_socket_message(
|
|
socket: &mut DuplexStream,
|
|
message: Option<Result<Message, rocket_ws::result::Error>>,
|
|
pending: &PendingReplies,
|
|
max_result_bytes: usize,
|
|
) -> SocketAction {
|
|
match message {
|
|
Some(Ok(Message::Pong(_))) => SocketAction::ClearPongDeadline,
|
|
Some(Ok(Message::Ping(data))) => {
|
|
if socket.send(Message::Pong(data.clone())).await.is_err() {
|
|
return SocketAction::Break;
|
|
}
|
|
|
|
SocketAction::Continue
|
|
}
|
|
Some(Ok(Message::Text(text))) => {
|
|
if let Ok(ClientMsg::ExecResult {
|
|
exec_id,
|
|
exit_code,
|
|
mut stdout,
|
|
mut stderr,
|
|
}) = serde_json::from_str(&text)
|
|
&& let Some((_, reply)) = pending.remove(&exec_id)
|
|
{
|
|
truncate_utf8(&mut stdout, max_result_bytes);
|
|
truncate_utf8(&mut stderr, max_result_bytes);
|
|
reply
|
|
.send(ExecResult {
|
|
exit_code,
|
|
stdout,
|
|
stderr,
|
|
})
|
|
.ok();
|
|
}
|
|
|
|
SocketAction::Continue
|
|
}
|
|
Some(Ok(Message::Close(_))) | None => SocketAction::Break,
|
|
_ => SocketAction::Continue,
|
|
}
|
|
}
|
|
|
|
fn truncate_utf8(text: &mut String, max_bytes: usize) {
|
|
if text.len() <= max_bytes {
|
|
return;
|
|
}
|
|
|
|
let mut end = max_bytes;
|
|
while end > 0 && !text.is_char_boundary(end) {
|
|
end -= 1;
|
|
}
|
|
|
|
text.truncate(end);
|
|
}
|
|
|
|
async fn dispatch(
|
|
socket: &mut DuplexStream,
|
|
pending: &PendingReplies,
|
|
exec: PendingExec,
|
|
max_execution_time: Duration,
|
|
cancel: CancellationToken,
|
|
) -> bool {
|
|
let exec_id = exec.exec_id.clone();
|
|
pending.insert(exec_id.clone(), exec.reply);
|
|
|
|
let pending_for_timeout = Arc::clone(pending);
|
|
tokio::spawn(async move {
|
|
tokio::select! {
|
|
_ = cancel.cancelled() => {}
|
|
_ = tokio::time::sleep(max_execution_time) => {
|
|
pending_for_timeout.remove(&exec_id);
|
|
}
|
|
}
|
|
});
|
|
|
|
socket
|
|
.send(Message::Text(
|
|
serde_json::to_string(&ServerMsg::Exec {
|
|
exec_id: exec.exec_id,
|
|
command: exec.command,
|
|
})
|
|
.expect("serializing server exec message should not fail"),
|
|
))
|
|
.await
|
|
.is_ok()
|
|
}
|