use crate::{ app::state::{ExecResult, PendingExec}, transport::protocol::{ClientMsg, ServerMsg}, }; use dashmap::DashMap; use futures::{SinkExt, StreamExt}; use rocket_ws::{Message, stream::DuplexStream}; use std::{sync::Arc, time::Duration}; use tokio::sync::{mpsc, oneshot}; use tokio_util::sync::CancellationToken; type PendingReplies = Arc>>; pub async fn run_loop( socket: &mut DuplexStream, mut exec_rx: mpsc::Receiver, config: &crate::core::config::Config, cancel: &CancellationToken, ) { let pending = Arc::new(DashMap::new()); let mut ping_interval = tokio::time::interval_at( tokio::time::Instant::now() + config.ping_interval, config.ping_interval, ); let pong_timer = tokio::time::sleep(Duration::from_secs(0)); tokio::pin!(pong_timer); let mut awaiting_pong = false; loop { tokio::select! { _ = ping_interval.tick() => { if awaiting_pong || socket.send(Message::Ping(vec![])).await.is_err() { break; } awaiting_pong = true; pong_timer.as_mut().reset(tokio::time::Instant::now() + config.ping_timeout); } _ = &mut pong_timer, if awaiting_pong => break, Some(exec) = exec_rx.recv() => { if !dispatch(socket, &pending, exec, config.max_execution_time, cancel.clone()).await { break; } } message = socket.next() => { match handle_socket_message(socket, message, &pending, config.max_result_bytes).await { SocketAction::Break => break, SocketAction::ClearPongDeadline => awaiting_pong = false, SocketAction::Continue => {} } } } } pending.clear(); } enum SocketAction { Break, ClearPongDeadline, Continue, } async fn handle_socket_message( socket: &mut DuplexStream, message: Option>, pending: &PendingReplies, max_result_bytes: usize, ) -> SocketAction { match message { Some(Ok(Message::Pong(_))) => SocketAction::ClearPongDeadline, Some(Ok(Message::Ping(data))) => { if socket.send(Message::Pong(data.clone())).await.is_err() { return SocketAction::Break; } SocketAction::Continue } Some(Ok(Message::Text(text))) => { if let Ok(ClientMsg::ExecResult { exec_id, exit_code, mut stdout, mut stderr, }) = serde_json::from_str(&text) && let Some((_, reply)) = pending.remove(&exec_id) { truncate_utf8(&mut stdout, max_result_bytes); truncate_utf8(&mut stderr, max_result_bytes); reply .send(ExecResult { exit_code, stdout, stderr, }) .ok(); } SocketAction::Continue } Some(Ok(Message::Close(_))) | None => SocketAction::Break, _ => SocketAction::Continue, } } fn truncate_utf8(text: &mut String, max_bytes: usize) { if text.len() <= max_bytes { return; } let mut end = max_bytes; while end > 0 && !text.is_char_boundary(end) { end -= 1; } text.truncate(end); } async fn dispatch( socket: &mut DuplexStream, pending: &PendingReplies, exec: PendingExec, max_execution_time: Duration, cancel: CancellationToken, ) -> bool { let exec_id = exec.exec_id.clone(); pending.insert(exec_id.clone(), exec.reply); let pending_for_timeout = Arc::clone(pending); tokio::spawn(async move { tokio::select! { _ = cancel.cancelled() => {} _ = tokio::time::sleep(max_execution_time) => { pending_for_timeout.remove(&exec_id); } } }); socket .send(Message::Text( serde_json::to_string(&ServerMsg::Exec { exec_id: exec.exec_id, command: exec.command, }) .expect("serializing server exec message should not fail"), )) .await .is_ok() }