This commit is contained in:
2026-05-22 08:24:17 +02:00
parent d893e4c609
commit 749638b75c
16 changed files with 770 additions and 782 deletions
+131
View File
@@ -0,0 +1,131 @@
use super::loop_state::run_loop;
use crate::{
app::state::{Database, DeviceCounts, PendingExec, Registry, RegistryEntry},
core::{config::Config, db::lookup_auth_token},
transport::protocol::ServerMsg,
};
use dashmap::mapref::entry::Entry;
use futures::SinkExt;
use log::{info, warn};
use rocket_ws::{Message, result::Error, stream::DuplexStream};
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
pub async fn handle(
database: Database,
registry: Arc<Registry>,
device_counts: Arc<DeviceCounts>,
config: Arc<Config>,
mut socket: DuplexStream,
token: String,
device_id: Result<String, String>,
) -> Result<(), Error> {
let device_id = match device_id {
Ok(device_id) => device_id,
Err(reason) => {
warn!("Rejecting websocket connection: {reason}");
return deny_connection(&mut socket, reason).await;
}
};
let user_id = match lookup_auth_token(&database, &token).await {
Ok(user_id) => user_id,
Err(reason) => {
warn!("Rejecting websocket connection for device {device_id}: {reason}");
return deny_connection(&mut socket, reason).await;
}
};
let device_counter = match claim_device_slot(&device_counts, &user_id, &config) {
Ok(counter) => counter,
Err(reason) => {
warn!("Rejecting websocket connection for user {user_id} device {device_id}: {reason}");
return deny_connection(&mut socket, reason).await;
}
};
let key = (user_id.clone(), device_id.clone());
let (exec_tx, exec_rx) = mpsc::channel::<PendingExec>(config.max_concurrent_executions);
if let Err(reason) = register_connection(&registry, key.clone(), exec_tx) {
release_device_slot(&device_counts, &user_id, device_counter);
warn!("Rejecting duplicate websocket connection for {user_id}/{device_id}: {reason}");
return deny_connection(&mut socket, reason).await;
}
info!("Accepted websocket connection for {user_id}/{device_id}");
send(&mut socket, &ServerMsg::AuthOk).await.ok();
let cancel = CancellationToken::new();
run_loop(&mut socket, exec_rx, &config, &cancel).await;
cancel.cancel();
registry.remove(&key);
release_device_slot(&device_counts, &user_id, device_counter);
info!("Closed websocket connection for {user_id}/{device_id}");
Ok(())
}
fn register_connection(
registry: &Arc<Registry>,
key: (String, String),
sender: mpsc::Sender<PendingExec>,
) -> Result<(), String> {
match registry.entry(key) {
Entry::Occupied(_) => Err("Device already connected".into()),
Entry::Vacant(slot) => {
slot.insert(RegistryEntry {
sender,
in_flight: Arc::new(AtomicUsize::new(0)),
});
Ok(())
}
}
}
fn claim_device_slot(
device_counts: &Arc<DeviceCounts>,
user_id: &str,
config: &Config,
) -> Result<Arc<AtomicUsize>, String> {
let counter = Arc::clone(
&*device_counts
.entry(user_id.to_owned())
.or_insert_with(|| Arc::new(AtomicUsize::new(0))),
);
if counter.fetch_add(1, Ordering::SeqCst) < config.max_connected_devices {
return Ok(counter);
}
counter.fetch_sub(1, Ordering::SeqCst);
Err("Too many devices connected for this account".into())
}
fn release_device_slot(
device_counts: &Arc<DeviceCounts>,
user_id: &str,
counter: Arc<AtomicUsize>,
) {
if counter.fetch_sub(1, Ordering::SeqCst) == 1 {
device_counts.remove_if(user_id, |_, value| value.load(Ordering::SeqCst) == 0);
}
}
async fn deny_connection(socket: &mut DuplexStream, reason: String) -> Result<(), Error> {
send(socket, &ServerMsg::AuthError { reason }).await.ok();
Ok(())
}
async fn send(
socket: &mut DuplexStream,
message: &ServerMsg,
) -> Result<(), Box<dyn std::error::Error>> {
socket
.send(Message::Text(serde_json::to_string(message)?))
.await?;
Ok(())
}
+131
View File
@@ -0,0 +1,131 @@
use crate::{
app::state::{ExecResult, PendingExec},
transport::protocol::{ClientMsg, ServerMsg},
};
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use rocket_ws::{Message, stream::DuplexStream};
use std::{sync::Arc, time::Duration};
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
type PendingReplies = Arc<DashMap<String, oneshot::Sender<ExecResult>>>;
pub async fn run_loop(
socket: &mut DuplexStream,
mut exec_rx: mpsc::Receiver<PendingExec>,
config: &crate::core::config::Config,
cancel: &CancellationToken,
) {
let pending = Arc::new(DashMap::new());
let mut ping_interval = tokio::time::interval_at(
tokio::time::Instant::now() + config.ping_interval,
config.ping_interval,
);
let pong_timer = tokio::time::sleep(Duration::from_secs(0));
tokio::pin!(pong_timer);
let mut awaiting_pong = false;
loop {
tokio::select! {
_ = ping_interval.tick() => {
if awaiting_pong || socket.send(Message::Ping(vec![])).await.is_err() {
break;
}
awaiting_pong = true;
pong_timer.as_mut().reset(tokio::time::Instant::now() + config.ping_timeout);
}
_ = &mut pong_timer, if awaiting_pong => break,
Some(exec) = exec_rx.recv() => {
if !dispatch(socket, &pending, exec, config.max_execution_time, cancel.clone()).await {
break;
}
}
message = socket.next() => match handle_socket_message(socket, message, &pending).await {
SocketAction::Break => break,
SocketAction::ClearPongDeadline => awaiting_pong = false,
SocketAction::Continue => {}
}
}
}
pending.clear();
}
enum SocketAction {
Break,
ClearPongDeadline,
Continue,
}
async fn handle_socket_message(
socket: &mut DuplexStream,
message: Option<Result<Message, rocket_ws::result::Error>>,
pending: &PendingReplies,
) -> SocketAction {
match message {
Some(Ok(Message::Pong(_))) => SocketAction::ClearPongDeadline,
Some(Ok(Message::Ping(data))) => {
if socket.send(Message::Pong(data.clone())).await.is_err() {
return SocketAction::Break;
}
SocketAction::Continue
}
Some(Ok(Message::Text(text))) => {
if let Ok(ClientMsg::ExecResult {
exec_id,
exit_code,
stdout,
stderr,
}) = serde_json::from_str(&text)
&& let Some((_, reply)) = pending.remove(&exec_id)
{
reply
.send(ExecResult {
exit_code,
stdout,
stderr,
})
.ok();
}
SocketAction::Continue
}
Some(Ok(Message::Close(_))) | None => SocketAction::Break,
_ => SocketAction::Continue,
}
}
async fn dispatch(
socket: &mut DuplexStream,
pending: &PendingReplies,
exec: PendingExec,
max_execution_time: Duration,
cancel: CancellationToken,
) -> bool {
let exec_id = exec.exec_id.clone();
pending.insert(exec_id.clone(), exec.reply);
let pending_for_timeout = Arc::clone(pending);
tokio::spawn(async move {
tokio::select! {
_ = cancel.cancelled() => {}
_ = tokio::time::sleep(max_execution_time) => {
pending_for_timeout.remove(&exec_id);
}
}
});
socket
.send(Message::Text(
serde_json::to_string(&ServerMsg::Exec {
exec_id: exec.exec_id,
command: exec.command,
})
.expect("serializing server exec message should not fail"),
))
.await
.is_ok()
}
+34
View File
@@ -0,0 +1,34 @@
mod connection;
mod loop_state;
use crate::{
api::auth::{BearerToken, MaybeDeviceId},
app::state::AppState,
};
use rocket::{State, get};
use rocket_ws::{Channel, WebSocket};
#[get("/connect")]
pub fn connect<'r>(
ws: WebSocket,
token: BearerToken,
device_id: MaybeDeviceId,
state: &'r State<AppState>,
) -> Channel<'r> {
let database = state.database.clone();
let registry = state.registry.clone();
let device_counts = state.device_counts.clone();
let config = state.config.clone();
ws.channel(move |socket| {
Box::pin(connection::handle(
database,
registry,
device_counts,
config,
socket,
token.0,
device_id.0,
))
})
}