Files
2026-05-22 08:55:16 +02:00

150 lines
4.3 KiB
Rust

use crate::{
app::state::{ExecResult, PendingExec},
transport::protocol::{ClientMsg, ServerMsg},
};
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use rocket_ws::{Message, stream::DuplexStream};
use std::{sync::Arc, time::Duration};
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
type PendingReplies = Arc<DashMap<String, oneshot::Sender<ExecResult>>>;
pub async fn run_loop(
socket: &mut DuplexStream,
mut exec_rx: mpsc::Receiver<PendingExec>,
config: &crate::core::config::Config,
cancel: &CancellationToken,
) {
let pending = Arc::new(DashMap::new());
let mut ping_interval = tokio::time::interval_at(
tokio::time::Instant::now() + config.ping_interval,
config.ping_interval,
);
let pong_timer = tokio::time::sleep(Duration::from_secs(0));
tokio::pin!(pong_timer);
let mut awaiting_pong = false;
loop {
tokio::select! {
_ = ping_interval.tick() => {
if awaiting_pong || socket.send(Message::Ping(vec![])).await.is_err() {
break;
}
awaiting_pong = true;
pong_timer.as_mut().reset(tokio::time::Instant::now() + config.ping_timeout);
}
_ = &mut pong_timer, if awaiting_pong => break,
Some(exec) = exec_rx.recv() => {
if !dispatch(socket, &pending, exec, config.max_execution_time, cancel.clone()).await {
break;
}
}
message = socket.next() => {
match handle_socket_message(socket, message, &pending, config.max_result_bytes).await {
SocketAction::Break => break,
SocketAction::ClearPongDeadline => awaiting_pong = false,
SocketAction::Continue => {}
}
}
}
}
pending.clear();
}
enum SocketAction {
Break,
ClearPongDeadline,
Continue,
}
async fn handle_socket_message(
socket: &mut DuplexStream,
message: Option<Result<Message, rocket_ws::result::Error>>,
pending: &PendingReplies,
max_result_bytes: usize,
) -> SocketAction {
match message {
Some(Ok(Message::Pong(_))) => SocketAction::ClearPongDeadline,
Some(Ok(Message::Ping(data))) => {
if socket.send(Message::Pong(data.clone())).await.is_err() {
return SocketAction::Break;
}
SocketAction::Continue
}
Some(Ok(Message::Text(text))) => {
if let Ok(ClientMsg::ExecResult {
exec_id,
exit_code,
mut stdout,
mut stderr,
}) = serde_json::from_str(&text)
&& let Some((_, reply)) = pending.remove(&exec_id)
{
truncate_utf8(&mut stdout, max_result_bytes);
truncate_utf8(&mut stderr, max_result_bytes);
reply
.send(ExecResult {
exit_code,
stdout,
stderr,
})
.ok();
}
SocketAction::Continue
}
Some(Ok(Message::Close(_))) | None => SocketAction::Break,
_ => SocketAction::Continue,
}
}
fn truncate_utf8(text: &mut String, max_bytes: usize) {
if text.len() <= max_bytes {
return;
}
let mut end = max_bytes;
while end > 0 && !text.is_char_boundary(end) {
end -= 1;
}
text.truncate(end);
}
async fn dispatch(
socket: &mut DuplexStream,
pending: &PendingReplies,
exec: PendingExec,
max_execution_time: Duration,
cancel: CancellationToken,
) -> bool {
let exec_id = exec.exec_id.clone();
pending.insert(exec_id.clone(), exec.reply);
let pending_for_timeout = Arc::clone(pending);
tokio::spawn(async move {
tokio::select! {
_ = cancel.cancelled() => {}
_ = tokio::time::sleep(max_execution_time) => {
pending_for_timeout.remove(&exec_id);
}
}
});
socket
.send(Message::Text(
serde_json::to_string(&ServerMsg::Exec {
exec_id: exec.exec_id,
command: exec.command,
})
.expect("serializing server exec message should not fail"),
))
.await
.is_ok()
}