use crate::{ AppState, Database, DeviceCounts, ExecResult, PendingExec, Registry, RegistryEntry, api::auth::{BearerToken, MaybeDeviceId}, core::config::Config, core::db::lookup_auth_token, transport::protocol::{ClientMsg, ServerMsg}, }; use dashmap::{DashMap, mapref::entry::Entry}; use futures::{SinkExt, StreamExt}; use log::{info, warn}; use rocket::{State, get}; use rocket_ws::{Channel, Message, WebSocket, stream::DuplexStream}; use std::{ sync::{ Arc, atomic::{AtomicUsize, Ordering}, }, time::Duration, }; use tokio::sync::{mpsc, oneshot}; use tokio_util::sync::CancellationToken; type PendingReplies = Arc>>; #[get("/connect")] pub fn connect<'r>( ws: WebSocket, token: BearerToken, device_id: MaybeDeviceId, state: &'r State, ) -> Channel<'r> { let database = state.database.clone(); let registry = state.registry.clone(); let device_counts = state.device_counts.clone(); let config = state.config.clone(); ws.channel(move |socket| { Box::pin(handle( database, registry, device_counts, config, socket, token.0, device_id.0, )) }) } async fn handle( database: Database, registry: Arc, device_counts: Arc, config: Arc, mut socket: DuplexStream, token: String, device_id: Result, ) -> Result<(), rocket_ws::result::Error> { let device_id = match device_id { Ok(device_id) => device_id, Err(reason) => { warn!("Rejecting websocket connection: {reason}"); return deny_connection(&mut socket, reason).await; } }; let user_id = match lookup_auth_token(&database, &token).await { Ok(user_id) => user_id, Err(reason) => { warn!("Rejecting websocket connection for device {device_id}: {reason}"); return deny_connection(&mut socket, reason).await; } }; let device_counter = match claim_device_slot(&device_counts, &user_id, &config) { Ok(counter) => counter, Err(reason) => { warn!("Rejecting websocket connection for user {user_id} device {device_id}: {reason}"); return deny_connection(&mut socket, reason).await; } }; let (exec_tx, exec_rx) = mpsc::channel::(config.max_concurrent_executions); let key = (user_id.clone(), device_id.clone()); if let Err(reason) = register_connection(®istry, key.clone(), exec_tx) { release_device_slot(&device_counts, &user_id, device_counter); warn!("Rejecting duplicate websocket connection for {user_id}/{device_id}: {reason}"); return deny_connection(&mut socket, reason).await; } info!("Accepted websocket connection for {user_id}/{device_id}"); send(&mut socket, &ServerMsg::AuthOk).await.ok(); let pending = Arc::new(DashMap::new()); let cancel = CancellationToken::new(); run_loop(&mut socket, exec_rx, &pending, &config, &cancel).await; cancel.cancel(); registry.remove(&key); pending.clear(); release_device_slot(&device_counts, &user_id, device_counter); info!("Closed websocket connection for {user_id}/{device_id}"); Ok(()) } async fn deny_connection( socket: &mut DuplexStream, reason: String, ) -> Result<(), rocket_ws::result::Error> { send(socket, &ServerMsg::AuthError { reason }).await.ok(); Ok(()) } fn register_connection( registry: &Arc, key: (String, String), sender: mpsc::Sender, ) -> Result<(), String> { match registry.entry(key) { Entry::Occupied(_) => Err("Device already connected".into()), Entry::Vacant(slot) => { slot.insert(RegistryEntry { sender, in_flight: Arc::new(AtomicUsize::new(0)), }); Ok(()) } } } fn claim_device_slot( device_counts: &Arc, user_id: &str, config: &Config, ) -> Result, String> { let counter = { let entry = device_counts .entry(user_id.to_owned()) .or_insert_with(|| Arc::new(AtomicUsize::new(0))); Arc::clone(&*entry) }; let previous = counter.fetch_add(1, Ordering::SeqCst); if previous >= config.max_connected_devices { counter.fetch_sub(1, Ordering::SeqCst); return Err("Too many devices connected for this account".into()); } Ok(counter) } fn release_device_slot( device_counts: &Arc, user_id: &str, counter: Arc, ) { let previous = counter.fetch_sub(1, Ordering::SeqCst); if previous == 1 { device_counts.remove_if(user_id, |_, value| value.load(Ordering::SeqCst) == 0); } } async fn run_loop( socket: &mut DuplexStream, mut exec_rx: mpsc::Receiver, pending: &PendingReplies, config: &Config, cancel: &CancellationToken, ) { let mut ping_interval = tokio::time::interval_at( tokio::time::Instant::now() + config.ping_interval, config.ping_interval, ); let pong_timer = tokio::time::sleep(Duration::from_secs(0)); tokio::pin!(pong_timer); let mut awaiting_pong = false; loop { tokio::select! { _ = ping_interval.tick() => { if awaiting_pong { break; } if send_ping(socket).await.is_err() { break; } awaiting_pong = true; pong_timer.as_mut().reset( tokio::time::Instant::now() + config.ping_timeout, ); } _ = &mut pong_timer, if awaiting_pong => { break; } Some(exec) = exec_rx.recv() => { if !dispatch(socket, pending, exec, config.max_execution_time, cancel.clone()).await { break; } } message = socket.next() => { let action = handle_socket_message(socket, message, pending).await; if action.break_loop { break; } if action.clear_pong_deadline { awaiting_pong = false; } } } } } struct SocketAction { break_loop: bool, clear_pong_deadline: bool, } async fn handle_socket_message( socket: &mut DuplexStream, message: Option>, pending: &PendingReplies, ) -> SocketAction { match message { Some(Ok(Message::Pong(_))) => SocketAction { break_loop: false, clear_pong_deadline: true, }, Some(Ok(Message::Ping(data))) => SocketAction { break_loop: socket.send(Message::Pong(data.clone())).await.is_err(), clear_pong_deadline: false, }, Some(Ok(Message::Text(text))) => { resolve_pending(&text, pending); SocketAction { break_loop: false, clear_pong_deadline: false, } } Some(Ok(Message::Close(_))) | None => SocketAction { break_loop: true, clear_pong_deadline: false, }, _ => SocketAction { break_loop: false, clear_pong_deadline: false, }, } } async fn send_ping(socket: &mut DuplexStream) -> Result<(), rocket_ws::result::Error> { socket.send(Message::Ping(vec![])).await } async fn dispatch( socket: &mut DuplexStream, pending: &PendingReplies, exec: PendingExec, max_execution_time: Duration, cancel: CancellationToken, ) -> bool { let exec_id = exec.exec_id.clone(); pending.insert(exec_id.clone(), exec.reply); let pending_for_timeout = Arc::clone(pending); tokio::spawn(async move { tokio::select! { _ = cancel.cancelled() => {} _ = tokio::time::sleep(max_execution_time) => { pending_for_timeout.remove(&exec_id); } } }); send( socket, &ServerMsg::Exec { exec_id: exec.exec_id, command: exec.command, }, ) .await .is_ok() } fn resolve_pending(text: &str, pending: &PendingReplies) { let Ok(ClientMsg::ExecResult { exec_id, exit_code, stdout, stderr, }) = serde_json::from_str(text) else { return; }; if let Some((_, reply)) = pending.remove(&exec_id) { reply .send(ExecResult { exit_code, stdout, stderr, }) .ok(); } } async fn send( socket: &mut DuplexStream, message: &ServerMsg, ) -> Result<(), Box> { let text = serde_json::to_string(message)?; socket.send(Message::Text(text)).await?; Ok(()) }