first commit
This commit is contained in:
@@ -0,0 +1,317 @@
|
||||
use crate::{
|
||||
AppState, Database, DeviceCounts, ExecResult, PendingExec, Registry, RegistryEntry,
|
||||
api::auth::{BearerToken, MaybeDeviceId},
|
||||
core::config::Config,
|
||||
core::db::lookup_auth_token,
|
||||
transport::protocol::{ClientMsg, ServerMsg},
|
||||
};
|
||||
use dashmap::{DashMap, mapref::entry::Entry};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use log::{info, warn};
|
||||
use rocket::{State, get};
|
||||
use rocket_ws::{Channel, Message, WebSocket, stream::DuplexStream};
|
||||
use std::{
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
type PendingReplies = Arc<DashMap<String, oneshot::Sender<ExecResult>>>;
|
||||
|
||||
#[get("/connect")]
|
||||
pub fn connect<'r>(
|
||||
ws: WebSocket,
|
||||
token: BearerToken,
|
||||
device_id: MaybeDeviceId,
|
||||
state: &'r State<AppState>,
|
||||
) -> Channel<'r> {
|
||||
let database = state.database.clone();
|
||||
let registry = state.registry.clone();
|
||||
let device_counts = state.device_counts.clone();
|
||||
let config = state.config.clone();
|
||||
ws.channel(move |socket| {
|
||||
Box::pin(handle(
|
||||
database,
|
||||
registry,
|
||||
device_counts,
|
||||
config,
|
||||
socket,
|
||||
token.0,
|
||||
device_id.0,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
database: Database,
|
||||
registry: Arc<Registry>,
|
||||
device_counts: Arc<DeviceCounts>,
|
||||
config: Arc<Config>,
|
||||
mut socket: DuplexStream,
|
||||
token: String,
|
||||
device_id: Result<String, String>,
|
||||
) -> Result<(), rocket_ws::result::Error> {
|
||||
let device_id = match device_id {
|
||||
Ok(device_id) => device_id,
|
||||
Err(reason) => {
|
||||
warn!("Rejecting websocket connection: {reason}");
|
||||
return deny_connection(&mut socket, reason).await;
|
||||
}
|
||||
};
|
||||
|
||||
let user_id = match lookup_auth_token(&database, &token).await {
|
||||
Ok(user_id) => user_id,
|
||||
Err(reason) => {
|
||||
warn!("Rejecting websocket connection for device {device_id}: {reason}");
|
||||
return deny_connection(&mut socket, reason).await;
|
||||
}
|
||||
};
|
||||
|
||||
let device_counter = match claim_device_slot(&device_counts, &user_id, &config) {
|
||||
Ok(counter) => counter,
|
||||
Err(reason) => {
|
||||
warn!("Rejecting websocket connection for user {user_id} device {device_id}: {reason}");
|
||||
return deny_connection(&mut socket, reason).await;
|
||||
}
|
||||
};
|
||||
|
||||
let (exec_tx, exec_rx) = mpsc::channel::<PendingExec>(config.max_concurrent_executions);
|
||||
let key = (user_id.clone(), device_id.clone());
|
||||
|
||||
if let Err(reason) = register_connection(®istry, key.clone(), exec_tx) {
|
||||
release_device_slot(&device_counts, &user_id, device_counter);
|
||||
warn!("Rejecting duplicate websocket connection for {user_id}/{device_id}: {reason}");
|
||||
return deny_connection(&mut socket, reason).await;
|
||||
}
|
||||
|
||||
info!("Accepted websocket connection for {user_id}/{device_id}");
|
||||
|
||||
send(&mut socket, &ServerMsg::AuthOk).await.ok();
|
||||
|
||||
let pending = Arc::new(DashMap::new());
|
||||
let cancel = CancellationToken::new();
|
||||
run_loop(&mut socket, exec_rx, &pending, &config, &cancel).await;
|
||||
cancel.cancel();
|
||||
|
||||
registry.remove(&key);
|
||||
pending.clear();
|
||||
release_device_slot(&device_counts, &user_id, device_counter);
|
||||
info!("Closed websocket connection for {user_id}/{device_id}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn deny_connection(
|
||||
socket: &mut DuplexStream,
|
||||
reason: String,
|
||||
) -> Result<(), rocket_ws::result::Error> {
|
||||
send(socket, &ServerMsg::AuthError { reason }).await.ok();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn register_connection(
|
||||
registry: &Arc<Registry>,
|
||||
key: (String, String),
|
||||
sender: mpsc::Sender<PendingExec>,
|
||||
) -> Result<(), String> {
|
||||
match registry.entry(key) {
|
||||
Entry::Occupied(_) => Err("Device already connected".into()),
|
||||
Entry::Vacant(slot) => {
|
||||
slot.insert(RegistryEntry {
|
||||
sender,
|
||||
in_flight: Arc::new(AtomicUsize::new(0)),
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn claim_device_slot(
|
||||
device_counts: &Arc<DeviceCounts>,
|
||||
user_id: &str,
|
||||
config: &Config,
|
||||
) -> Result<Arc<AtomicUsize>, String> {
|
||||
let counter = {
|
||||
let entry = device_counts
|
||||
.entry(user_id.to_owned())
|
||||
.or_insert_with(|| Arc::new(AtomicUsize::new(0)));
|
||||
Arc::clone(&*entry)
|
||||
};
|
||||
|
||||
let previous = counter.fetch_add(1, Ordering::SeqCst);
|
||||
if previous >= config.max_connected_devices {
|
||||
counter.fetch_sub(1, Ordering::SeqCst);
|
||||
return Err("Too many devices connected for this account".into());
|
||||
}
|
||||
|
||||
Ok(counter)
|
||||
}
|
||||
|
||||
fn release_device_slot(
|
||||
device_counts: &Arc<DeviceCounts>,
|
||||
user_id: &str,
|
||||
counter: Arc<AtomicUsize>,
|
||||
) {
|
||||
let previous = counter.fetch_sub(1, Ordering::SeqCst);
|
||||
if previous == 1 {
|
||||
device_counts.remove_if(user_id, |_, value| value.load(Ordering::SeqCst) == 0);
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_loop(
|
||||
socket: &mut DuplexStream,
|
||||
mut exec_rx: mpsc::Receiver<PendingExec>,
|
||||
pending: &PendingReplies,
|
||||
config: &Config,
|
||||
cancel: &CancellationToken,
|
||||
) {
|
||||
let mut ping_interval = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + config.ping_interval,
|
||||
config.ping_interval,
|
||||
);
|
||||
|
||||
let pong_timer = tokio::time::sleep(Duration::from_secs(0));
|
||||
tokio::pin!(pong_timer);
|
||||
let mut awaiting_pong = false;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = ping_interval.tick() => {
|
||||
if awaiting_pong {
|
||||
break;
|
||||
}
|
||||
if send_ping(socket).await.is_err() {
|
||||
break;
|
||||
}
|
||||
awaiting_pong = true;
|
||||
pong_timer.as_mut().reset(
|
||||
tokio::time::Instant::now() + config.ping_timeout,
|
||||
);
|
||||
}
|
||||
_ = &mut pong_timer, if awaiting_pong => {
|
||||
break;
|
||||
}
|
||||
Some(exec) = exec_rx.recv() => {
|
||||
if !dispatch(socket, pending, exec, config.max_execution_time, cancel.clone()).await {
|
||||
break;
|
||||
}
|
||||
}
|
||||
message = socket.next() => {
|
||||
let action = handle_socket_message(socket, message, pending).await;
|
||||
|
||||
if action.break_loop {
|
||||
break;
|
||||
}
|
||||
|
||||
if action.clear_pong_deadline {
|
||||
awaiting_pong = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SocketAction {
|
||||
break_loop: bool,
|
||||
clear_pong_deadline: bool,
|
||||
}
|
||||
|
||||
async fn handle_socket_message(
|
||||
socket: &mut DuplexStream,
|
||||
message: Option<Result<Message, rocket_ws::result::Error>>,
|
||||
pending: &PendingReplies,
|
||||
) -> SocketAction {
|
||||
match message {
|
||||
Some(Ok(Message::Pong(_))) => SocketAction {
|
||||
break_loop: false,
|
||||
clear_pong_deadline: true,
|
||||
},
|
||||
Some(Ok(Message::Ping(data))) => SocketAction {
|
||||
break_loop: socket.send(Message::Pong(data.clone())).await.is_err(),
|
||||
clear_pong_deadline: false,
|
||||
},
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
resolve_pending(&text, pending);
|
||||
SocketAction {
|
||||
break_loop: false,
|
||||
clear_pong_deadline: false,
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => SocketAction {
|
||||
break_loop: true,
|
||||
clear_pong_deadline: false,
|
||||
},
|
||||
_ => SocketAction {
|
||||
break_loop: false,
|
||||
clear_pong_deadline: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_ping(socket: &mut DuplexStream) -> Result<(), rocket_ws::result::Error> {
|
||||
socket.send(Message::Ping(vec![])).await
|
||||
}
|
||||
|
||||
async fn dispatch(
|
||||
socket: &mut DuplexStream,
|
||||
pending: &PendingReplies,
|
||||
exec: PendingExec,
|
||||
max_execution_time: Duration,
|
||||
cancel: CancellationToken,
|
||||
) -> bool {
|
||||
let exec_id = exec.exec_id.clone();
|
||||
pending.insert(exec_id.clone(), exec.reply);
|
||||
|
||||
let pending_for_timeout = Arc::clone(pending);
|
||||
tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {}
|
||||
_ = tokio::time::sleep(max_execution_time) => {
|
||||
pending_for_timeout.remove(&exec_id);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
send(
|
||||
socket,
|
||||
&ServerMsg::Exec {
|
||||
exec_id: exec.exec_id,
|
||||
command: exec.command,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
fn resolve_pending(text: &str, pending: &PendingReplies) {
|
||||
let Ok(ClientMsg::ExecResult {
|
||||
exec_id,
|
||||
exit_code,
|
||||
stdout,
|
||||
stderr,
|
||||
}) = serde_json::from_str(text)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
if let Some((_, reply)) = pending.remove(&exec_id) {
|
||||
reply
|
||||
.send(ExecResult {
|
||||
exit_code,
|
||||
stdout,
|
||||
stderr,
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
async fn send(
|
||||
socket: &mut DuplexStream,
|
||||
message: &ServerMsg,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = serde_json::to_string(message)?;
|
||||
socket.send(Message::Text(text)).await?;
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user