From 546406a85a6cf648538a51039d3c5717f7c356d5 Mon Sep 17 00:00:00 2001 From: ZeroZipp Date: Thu, 21 May 2026 23:47:36 +0200 Subject: [PATCH] first commit --- .gitignore | 3 + Cargo.toml | 33 ++++ Dockerfile | 23 +++ schema.sql | 21 +++ src/api/auth.rs | 59 +++++++ src/api/catchers.rs | 6 + src/api/mcp.rs | 354 ++++++++++++++++++++++++++++++++++++++ src/core/config.rs | 52 ++++++ src/core/db.rs | 34 ++++ src/core/validation.rs | 15 ++ src/main.rs | 104 +++++++++++ src/transport/execute.rs | 122 +++++++++++++ src/transport/protocol.rs | 20 +++ src/transport/socket.rs | 317 ++++++++++++++++++++++++++++++++++ 14 files changed, 1163 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 Dockerfile create mode 100644 schema.sql create mode 100644 src/api/auth.rs create mode 100644 src/api/catchers.rs create mode 100644 src/api/mcp.rs create mode 100644 src/core/config.rs create mode 100644 src/core/db.rs create mode 100644 src/core/validation.rs create mode 100644 src/main.rs create mode 100644 src/transport/execute.rs create mode 100644 src/transport/protocol.rs create mode 100644 src/transport/socket.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..149901c --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/server.db +/Cargo.lock +/target diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..a9e63e9 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "server" +version = "0.1.0" +edition = "2024" + +[dependencies.tokio] +features = ["full"] +version = "1.52.3" + +[dependencies.serde] +features = ["derive"] +version = "1.0.228" + +[dependencies.rocket] +features = ["json"] +version = "0.5.1" + +[dependencies.uuid] +features = ["v4"] +version = "1.23.1" + +[dependencies.sqlx] +features = ["runtime-tokio", "sqlite"] +version = "0.9.0" + +[dependencies] +dashmap = "6.2.1" +env_logger = "0.11.10" +futures = "0.3.32" +log = "0.4.29" +rocket_ws = "0.1.1" +serde_json = "1.0.150" +tokio-util = "0.7.18" diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..222ac4b --- /dev/null +++ b/Dockerfile @@ -0,0 +1,23 @@ +# Build stage +FROM rust:alpine AS builder +RUN apk add --no-cache pkgconfig +RUN apk add --no-cache openssl-dev musl-dev +RUN apk add --no-cache openssl-libs-static +ARG build="cargo build --release" +COPY Cargo.toml /app/Cargo.toml +COPY schema.sql /app/schema.sql +COPY src /app/src +WORKDIR "/app" +RUN $build + +# Runtime stage +FROM alpine:latest +RUN apk add --no-cache libgcc +RUN apk add --no-cache openssl-libs-static +ARG binary="/app/target/release/server" +ARG target="/usr/local/bin/server" +COPY --from=builder $binary $target +ENV ROCKET_ADDRESS="0.0.0.0" +VOLUME ["/data"] +WORKDIR "/data" +CMD ["server"] diff --git a/schema.sql b/schema.sql new file mode 100644 index 0000000..7817d7c --- /dev/null +++ b/schema.sql @@ -0,0 +1,21 @@ +CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + expires_at TEXT +); + +CREATE TABLE IF NOT EXISTS auth_tokens ( + token TEXT PRIMARY KEY, + user_id TEXT NOT NULL REFERENCES users(id), + expires_at TEXT +); + +CREATE TABLE IF NOT EXISTS api_tokens ( + token TEXT PRIMARY KEY, + user_id TEXT NOT NULL REFERENCES users(id), + expires_at TEXT +); + +CREATE INDEX IF NOT EXISTS +idx_auth_tokens_user ON auth_tokens(user_id); +CREATE INDEX IF NOT EXISTS +idx_api_tokens_user ON api_tokens(user_id); diff --git a/src/api/auth.rs b/src/api/auth.rs new file mode 100644 index 0000000..8d3f5f7 --- /dev/null +++ b/src/api/auth.rs @@ -0,0 +1,59 @@ +use crate::core::validation::validate_device_id; +use rocket::{ + Request, + http::Status, + request::{FromRequest, Outcome}, +}; + +const AUTHORIZATION_HEADER: &str = "Authorization"; +const DEVICE_ID_HEADER: &str = "X-Device-ID"; +const BEARER_PREFIX: &str = "Bearer "; + +pub struct BearerToken(pub String); +pub struct DeviceId(pub String); +pub struct MaybeDeviceId(pub Result); + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for BearerToken { + type Error = &'static str; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let Some(header) = request.headers().get_one(AUTHORIZATION_HEADER) else { + return Outcome::Error((Status::Unauthorized, "Missing Authorization header")); + }; + let Some(token) = header.strip_prefix(BEARER_PREFIX) else { + return Outcome::Error((Status::Unauthorized, "Invalid Authorization header format")); + }; + Outcome::Success(BearerToken(token.to_owned())) + } +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for DeviceId { + type Error = String; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + match extract_device_id(request) { + Ok(device_id) => Outcome::Success(DeviceId(device_id)), + Err(reason) => Outcome::Error((Status::BadRequest, reason)), + } + } +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for MaybeDeviceId { + type Error = std::convert::Infallible; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + Outcome::Success(MaybeDeviceId(extract_device_id(request))) + } +} + +fn extract_device_id(request: &Request<'_>) -> Result { + let Some(device_id) = request.headers().get_one(DEVICE_ID_HEADER) else { + return Err("Missing X-Device-ID header".into()); + }; + + validate_device_id(device_id)?; + Ok(device_id.to_owned()) +} diff --git a/src/api/catchers.rs b/src/api/catchers.rs new file mode 100644 index 0000000..e1ef702 --- /dev/null +++ b/src/api/catchers.rs @@ -0,0 +1,6 @@ +use rocket::{Request, catch, http::Status, response::content::RawJson}; + +#[catch(default)] +pub fn default_catcher(status: Status, _request: &Request) -> RawJson { + RawJson(serde_json::json!({ "error": status.reason_lossy() }).to_string()) +} diff --git a/src/api/mcp.rs b/src/api/mcp.rs new file mode 100644 index 0000000..3c1faea --- /dev/null +++ b/src/api/mcp.rs @@ -0,0 +1,354 @@ +use crate::{ + AppState, + api::auth::BearerToken, + core::db::lookup_api_token, + transport::execute::{ExecuteRequest, execute_for_token}, +}; +use rocket::{State, http::Status, post, response::content::RawJson, serde::json::Json}; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; + +type ApiResponse = (Status, RawJson); + +const JSONRPC_VERSION: &str = "2.0"; +const MCP_PROTOCOL_VERSION: &str = "2024-11-05"; + +#[derive(Deserialize)] +pub struct JsonRpcRequest { + pub jsonrpc: String, + pub id: Option, + pub method: String, + #[serde(default)] + pub params: Option, +} + +#[derive(Serialize)] +struct JsonRpcResponse { + jsonrpc: &'static str, + id: Option, + result: T, +} + +#[derive(Serialize)] +struct JsonRpcErrorResponse { + jsonrpc: &'static str, + id: Option, + error: JsonRpcError, +} + +#[derive(Serialize)] +struct JsonRpcError { + code: i64, + message: String, +} + +#[derive(Serialize)] +struct InitializeResult { + #[serde(rename = "protocolVersion")] + protocol_version: &'static str, + capabilities: ServerCapabilities, + #[serde(rename = "serverInfo")] + server_info: ServerInfo, +} + +#[derive(Serialize)] +struct ServerCapabilities { + tools: ToolsCapability, +} + +#[derive(Serialize)] +struct ToolsCapability { + list_changed: bool, +} + +#[derive(Serialize)] +struct ServerInfo { + name: &'static str, + version: &'static str, +} + +#[derive(Serialize)] +struct ToolsListResult { + tools: Vec, +} + +#[derive(Serialize)] +struct ToolDefinition { + name: &'static str, + description: &'static str, + #[serde(rename = "inputSchema")] + input_schema: Value, +} + +#[derive(Deserialize)] +struct ToolCallParams { + name: String, + arguments: Value, +} + +#[derive(Serialize)] +struct ToolCallResult { + content: Vec, + #[serde(rename = "isError")] + is_error: bool, +} + +#[derive(Serialize)] +struct ConnectedDevicesResult { + device_ids: Vec, +} + +enum ToolCall { + Execute(ExecuteRequest), + ListDevices, +} + +#[derive(Serialize)] +struct TextContent { + #[serde(rename = "type")] + kind: &'static str, + text: String, +} + +#[post("/mcp", data = "")] +pub async fn route( + body: Json, + token: BearerToken, + state: &State, +) -> ApiResponse { + let request = body.into_inner(); + + if request.jsonrpc != JSONRPC_VERSION { + return jsonrpc_error( + Status::BadRequest, + request.id, + -32600, + "Invalid JSON-RPC version", + ); + } + + match request.method.as_str() { + "initialize" => jsonrpc_ok( + request.id, + InitializeResult { + protocol_version: MCP_PROTOCOL_VERSION, + capabilities: ServerCapabilities { + tools: ToolsCapability { + list_changed: false, + }, + }, + server_info: ServerInfo { + name: "folyo-bridge", + version: env!("CARGO_PKG_VERSION"), + }, + }, + ), + "notifications/initialized" => jsonrpc_notification_ok(), + "tools/list" => jsonrpc_ok( + request.id, + ToolsListResult { + tools: tool_definitions(), + }, + ), + "tools/call" => handle_tool_call(request.id, request.params, &token.0, state).await, + _ if request.id.is_none() => jsonrpc_notification_ok(), + _ => jsonrpc_error(Status::BadRequest, request.id, -32601, "Method not found"), + } +} + +async fn handle_tool_call( + id: Option, + params: Option, + token: &str, + state: &State, +) -> ApiResponse { + let Some(params) = params else { + return jsonrpc_error(Status::BadRequest, id, -32602, "Missing params"); + }; + + let tool_call = match parse_tool_call(params) { + Ok(call) => call, + Err(_) => return jsonrpc_error(Status::BadRequest, id, -32602, "Invalid params"), + }; + + match tool_call { + ToolCall::Execute(request) => execute_tool_call(id, token, state, request).await, + ToolCall::ListDevices => list_devices_tool_call(id, token, state).await, + } +} + +fn tool_definitions() -> Vec { + vec![ + ToolDefinition { + name: "execute", + description: "Execute a shell command on a connected device owned by the authenticated user.", + input_schema: json!({ + "type": "object", + "required": ["device_id", "command"], + "properties": { + "device_id": { + "type": "string", + "description": "Identifier of the connected device." + }, + "command": { + "type": "string", + "description": "Shell command to execute on the target device." + } + }, + "additionalProperties": false + }), + }, + ToolDefinition { + name: "list_devices", + description: "List currently connected device IDs owned by the authenticated user.", + input_schema: json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + }), + }, + ] +} + +fn parse_tool_call(params: Value) -> Result { + let call: ToolCallParams = serde_json::from_value(params).map_err(|_| "Invalid params")?; + + match call.name.as_str() { + "execute" => serde_json::from_value(call.arguments) + .map(ToolCall::Execute) + .map_err(|_| "Invalid execute arguments"), + "list_devices" => parse_list_devices_arguments(call.arguments), + _ => Err("Unknown tool"), + } +} + +fn parse_list_devices_arguments(arguments: Value) -> Result { + match arguments { + Value::Object(map) if map.is_empty() => Ok(ToolCall::ListDevices), + _ => Err("Invalid list_devices arguments"), + } +} + +fn connected_device_ids_for_user(state: &State, user_id: &str) -> Vec { + state + .registry + .iter() + .filter_map(|entry| { + let ((entry_user_id, device_id), _) = entry.pair(); + (entry_user_id == user_id).then(|| device_id.clone()) + }) + .collect() +} + +async fn execute_tool_call( + id: Option, + token: &str, + state: &State, + request: ExecuteRequest, +) -> ApiResponse { + match execute_for_token(state, token, request).await { + Ok(result) => json_tool_result(id, &result), + Err((status, message)) if status == Status::Unauthorized => { + jsonrpc_error(Status::Unauthorized, id, -32001, message) + } + Err((_, message)) => jsonrpc_ok(id, tool_error(message)), + } +} + +async fn list_devices_tool_call( + id: Option, + token: &str, + state: &State, +) -> ApiResponse { + let user_id = match lookup_api_token(&state.database, token).await { + Ok(user_id) => user_id, + Err(message) => return jsonrpc_error(Status::Unauthorized, id, -32001, message), + }; + + let mut device_ids = connected_device_ids_for_user(state, &user_id); + device_ids.sort(); + + json_tool_result(id, &ConnectedDevicesResult { device_ids }) +} + +fn json_tool_result(id: Option, payload: &T) -> ApiResponse { + match serde_json::to_string(payload) { + Ok(text) => jsonrpc_ok(id, ToolCallResult::success(text)), + Err(error) => jsonrpc_error(Status::InternalServerError, id, -32603, error.to_string()), + } +} + +fn tool_error(message: String) -> ToolCallResult { + ToolCallResult::error(message) +} + +impl ToolCallResult { + fn success(text: String) -> Self { + Self { + content: vec![TextContent::text(text)], + is_error: false, + } + } + + fn error(text: String) -> Self { + Self { + content: vec![TextContent::text(text)], + is_error: true, + } + } +} + +impl TextContent { + fn text(text: String) -> Self { + Self { kind: "text", text } + } +} + +fn jsonrpc_ok(id: Option, result: T) -> ApiResponse { + match serde_json::to_string(&JsonRpcResponse { + jsonrpc: JSONRPC_VERSION, + id, + result, + }) { + Ok(body) => (Status::Ok, RawJson(body)), + Err(e) => jsonrpc_error(Status::InternalServerError, None, -32603, e.to_string()), + } +} + +fn jsonrpc_notification_ok() -> ApiResponse { + (Status::Accepted, RawJson(String::new())) +} + +fn jsonrpc_error( + status: Status, + id: Option, + code: i64, + message: impl Into, +) -> ApiResponse { + let response = JsonRpcErrorResponse { + jsonrpc: JSONRPC_VERSION, + id, + error: JsonRpcError { + code, + message: message.into(), + }, + }; + + match serde_json::to_string(&response) { + Ok(body) => (status, RawJson(body)), + Err(e) => ( + Status::InternalServerError, + RawJson( + json!({ + "jsonrpc": JSONRPC_VERSION, + "id": null, + "error": { + "code": -32603, + "message": e.to_string() + } + }) + .to_string(), + ), + ), + } +} diff --git a/src/core/config.rs b/src/core/config.rs new file mode 100644 index 0000000..7447234 --- /dev/null +++ b/src/core/config.rs @@ -0,0 +1,52 @@ +use std::time::Duration; + +const DEFAULT_MAX_CONCURRENT_EXECUTIONS: usize = 16; +const DEFAULT_MAX_CONNECTED_DEVICES: usize = 10; +const DEFAULT_MAX_EXECUTION_SECS: u64 = 3600; +const DEFAULT_MAX_COMMAND_LENGTH: usize = 65_536; +const DEFAULT_PING_INTERVAL_SECS: u64 = 30; +const DEFAULT_PING_TIMEOUT_SECS: u64 = 10; + +pub struct Config { + pub max_concurrent_executions: usize, + pub max_connected_devices: usize, + pub max_execution_time: Duration, + pub max_command_length: usize, + pub ping_interval: Duration, + pub ping_timeout: Duration, +} + +impl Config { + pub fn from_env() -> Self { + Self { + max_concurrent_executions: parse_env( + "MAX_CONCURRENT_EXECUTIONS", + DEFAULT_MAX_CONCURRENT_EXECUTIONS, + ), + max_connected_devices: parse_env( + "MAX_CONNECTED_DEVICES", + DEFAULT_MAX_CONNECTED_DEVICES, + ), + max_execution_time: Duration::from_secs(parse_env( + "MAX_EXECUTION_SECS", + DEFAULT_MAX_EXECUTION_SECS, + )), + max_command_length: parse_env("MAX_COMMAND_LENGTH", DEFAULT_MAX_COMMAND_LENGTH), + ping_interval: Duration::from_secs(parse_env( + "PING_INTERVAL_SECS", + DEFAULT_PING_INTERVAL_SECS, + )), + ping_timeout: Duration::from_secs(parse_env( + "PING_TIMEOUT_SECS", + DEFAULT_PING_TIMEOUT_SECS, + )), + } + } +} + +fn parse_env(key: &str, default: T) -> T { + std::env::var(key) + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(default) +} diff --git a/src/core/db.rs b/src/core/db.rs new file mode 100644 index 0000000..4e78885 --- /dev/null +++ b/src/core/db.rs @@ -0,0 +1,34 @@ +use crate::Database; + +const AUTH_TOKENS_TABLE: &str = "auth_tokens"; +const API_TOKENS_TABLE: &str = "api_tokens"; + +pub async fn lookup_auth_token(database: &Database, token: &str) -> Result { + lookup_token(database, AUTH_TOKENS_TABLE, token).await +} + +pub async fn lookup_api_token(database: &Database, token: &str) -> Result { + lookup_token(database, API_TOKENS_TABLE, token).await +} + +async fn lookup_token( + database: &Database, + table: &'static str, + token: &str, +) -> Result { + let sql = format!( + "SELECT t.user_id \ + FROM {table} t \ + JOIN users u ON u.id = t.user_id \ + WHERE t.token = ? \ + AND (t.expires_at IS NULL OR t.expires_at > DATETIME('now')) \ + AND (u.expires_at IS NULL OR u.expires_at > DATETIME('now'))" + ); + + sqlx::query_scalar::<_, String>(sqlx::AssertSqlSafe(sql)) + .bind(token) + .fetch_optional(database) + .await + .map_err(|e| e.to_string())? + .ok_or_else(|| "Invalid or expired token".to_string()) +} diff --git a/src/core/validation.rs b/src/core/validation.rs new file mode 100644 index 0000000..195e5a4 --- /dev/null +++ b/src/core/validation.rs @@ -0,0 +1,15 @@ +pub fn validate_device_id(id: &str) -> Result<(), String> { + if id.is_empty() { + return Err("device_id cannot be empty".into()); + } + if id.len() > 64 { + return Err("device_id too long (max 64 characters)".into()); + } + if !id + .chars() + .all(|c| c.is_alphanumeric() || c == '-' || c == '_') + { + return Err("device_id contains invalid characters (allowed: a-z A-Z 0-9 - _)".into()); + } + Ok(()) +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..fb506b2 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,104 @@ +pub mod api { + pub mod auth; + pub mod catchers; + pub mod mcp; +} + +pub mod core { + pub mod config; + pub mod db; + pub mod validation; +} + +pub mod transport { + pub mod execute; + pub mod protocol; + pub mod socket; +} + +use core::config::Config; +use dashmap::DashMap; +use log::info; +use rocket::routes; +use serde::Serialize; +use sqlx::SqlitePool; +use sqlx::sqlite::SqliteConnectOptions; +use std::sync::{Arc, atomic::AtomicUsize}; +use tokio::sync::{mpsc, oneshot}; + +const DATABASE_PATH: &str = "server.db"; + +pub type Database = SqlitePool; +pub type DeviceCounts = DashMap>; +pub type Registry = DashMap<(String, String), RegistryEntry>; + +pub struct RegistryEntry { + pub sender: mpsc::Sender, + pub in_flight: Arc, +} + +#[derive(Debug, Serialize)] +pub struct ExecResult { + pub exit_code: i32, + pub stdout: String, + pub stderr: String, +} + +pub struct PendingExec { + pub exec_id: String, + pub command: String, + pub reply: oneshot::Sender, +} + +pub struct AppState { + pub database: Database, + pub registry: Arc, + pub device_counts: Arc, + pub config: Arc, +} + +#[rocket::main] +async fn main() -> Result<(), Box> { + env_logger::init(); + + let database = connect_database().await?; + apply_schema(&database).await?; + + info!("Launching server"); + + build_rocket(database).launch().await?; + + Ok(()) +} + +async fn connect_database() -> Result { + let options = SqliteConnectOptions::new() + .filename(DATABASE_PATH) + .create_if_missing(true); + + SqlitePool::connect_with(options).await +} + +async fn apply_schema(database: &Database) -> Result<(), sqlx::Error> { + for statement in include_str!("../schema.sql") + .split(';') + .map(str::trim) + .filter(|statement| !statement.is_empty()) + { + sqlx::query(statement).execute(database).await?; + } + + Ok(()) +} + +fn build_rocket(database: Database) -> rocket::Rocket { + rocket::build() + .manage(AppState { + database, + registry: Arc::new(DashMap::new()), + device_counts: Arc::new(DashMap::new()), + config: Arc::new(Config::from_env()), + }) + .mount("/", routes![transport::socket::connect, api::mcp::route]) + .register("/", rocket::catchers![api::catchers::default_catcher]) +} diff --git a/src/transport/execute.rs b/src/transport/execute.rs new file mode 100644 index 0000000..4cd62d8 --- /dev/null +++ b/src/transport/execute.rs @@ -0,0 +1,122 @@ +use crate::{ + AppState, ExecResult, PendingExec, Registry, + core::{config::Config, db::lookup_api_token, validation::validate_device_id}, +}; +use rocket::{State, http::Status}; +use serde::Deserialize; +use std::sync::{Arc, atomic::Ordering}; +use tokio::sync::{mpsc, oneshot}; +use uuid::Uuid; + +#[derive(Deserialize)] +pub struct ExecuteRequest { + pub device_id: String, + pub command: String, +} + +pub async fn execute_for_token( + state: &State, + token: &str, + request: ExecuteRequest, +) -> Result { + let user_id = lookup_api_token(&state.database, token) + .await + .map_err(|e| (Status::Unauthorized, e))?; + + execute_for_user(&state.registry, &user_id, request, &state.config).await +} + +pub async fn execute_for_user( + registry: &Arc, + user_id: &str, + request: ExecuteRequest, + config: &Config, +) -> Result { + let ExecuteRequest { device_id, command } = request; + + validate_execute_request(&device_id, &command, config)?; + + execute(registry, user_id, &device_id, command, config) + .await + .map_err(|e| (Status::BadGateway, e)) +} + +fn validate_execute_request( + device_id: &str, + command: &str, + config: &Config, +) -> Result<(), (Status, String)> { + validate_device_id(device_id).map_err(|error| (Status::BadRequest, error))?; + + if command.len() <= config.max_command_length { + return Ok(()); + } + + Err(( + Status::PayloadTooLarge, + format!( + "Command exceeds maximum length of {} bytes", + config.max_command_length + ), + )) +} + +async fn execute( + registry: &Arc, + user_id: &str, + device_id: &str, + command: String, + config: &Config, +) -> Result { + let key = (user_id.to_owned(), device_id.to_owned()); + + let (sender, in_flight_counter) = registry + .get(&key) + .map(|entry| (entry.sender.clone(), Arc::clone(&entry.in_flight))) + .ok_or_else(|| "Device not connected".to_string())?; + + claim_execution_slot(&in_flight_counter, config.max_concurrent_executions)?; + + let result = send_and_await(sender, command, config).await; + in_flight_counter.fetch_sub(1, Ordering::SeqCst); + + result +} + +async fn send_and_await( + sender: mpsc::Sender, + command: String, + config: &Config, +) -> Result { + let (reply_tx, reply_rx) = oneshot::channel(); + + sender + .send(PendingExec { + exec_id: Uuid::new_v4().to_string(), + command, + reply: reply_tx, + }) + .await + .map_err(|_| "Device disconnected".to_string())?; + + tokio::time::timeout(config.max_execution_time, reply_rx) + .await + .map_err(|_| "Command timed out".to_string())? + .map_err(|_| "Device disconnected".to_string()) +} + +fn claim_execution_slot( + counter: &Arc, + limit: usize, +) -> Result<(), String> { + let previous = counter.fetch_add(1, Ordering::SeqCst); + if previous < limit { + return Ok(()); + } + + counter.fetch_sub(1, Ordering::SeqCst); + Err(format!( + "Device has reached the limit of {} concurrent executions", + limit, + )) +} diff --git a/src/transport/protocol.rs b/src/transport/protocol.rs new file mode 100644 index 0000000..9f55540 --- /dev/null +++ b/src/transport/protocol.rs @@ -0,0 +1,20 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ClientMsg { + ExecResult { + exec_id: String, + exit_code: i32, + stdout: String, + stderr: String, + }, +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ServerMsg { + AuthOk, + AuthError { reason: String }, + Exec { exec_id: String, command: String }, +} diff --git a/src/transport/socket.rs b/src/transport/socket.rs new file mode 100644 index 0000000..e92bcb8 --- /dev/null +++ b/src/transport/socket.rs @@ -0,0 +1,317 @@ +use crate::{ + AppState, Database, DeviceCounts, ExecResult, PendingExec, Registry, RegistryEntry, + api::auth::{BearerToken, MaybeDeviceId}, + core::config::Config, + core::db::lookup_auth_token, + transport::protocol::{ClientMsg, ServerMsg}, +}; +use dashmap::{DashMap, mapref::entry::Entry}; +use futures::{SinkExt, StreamExt}; +use log::{info, warn}; +use rocket::{State, get}; +use rocket_ws::{Channel, Message, WebSocket, stream::DuplexStream}; +use std::{ + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + time::Duration, +}; +use tokio::sync::{mpsc, oneshot}; +use tokio_util::sync::CancellationToken; + +type PendingReplies = Arc>>; + +#[get("/connect")] +pub fn connect<'r>( + ws: WebSocket, + token: BearerToken, + device_id: MaybeDeviceId, + state: &'r State, +) -> Channel<'r> { + let database = state.database.clone(); + let registry = state.registry.clone(); + let device_counts = state.device_counts.clone(); + let config = state.config.clone(); + ws.channel(move |socket| { + Box::pin(handle( + database, + registry, + device_counts, + config, + socket, + token.0, + device_id.0, + )) + }) +} + +async fn handle( + database: Database, + registry: Arc, + device_counts: Arc, + config: Arc, + mut socket: DuplexStream, + token: String, + device_id: Result, +) -> Result<(), rocket_ws::result::Error> { + let device_id = match device_id { + Ok(device_id) => device_id, + Err(reason) => { + warn!("Rejecting websocket connection: {reason}"); + return deny_connection(&mut socket, reason).await; + } + }; + + let user_id = match lookup_auth_token(&database, &token).await { + Ok(user_id) => user_id, + Err(reason) => { + warn!("Rejecting websocket connection for device {device_id}: {reason}"); + return deny_connection(&mut socket, reason).await; + } + }; + + let device_counter = match claim_device_slot(&device_counts, &user_id, &config) { + Ok(counter) => counter, + Err(reason) => { + warn!("Rejecting websocket connection for user {user_id} device {device_id}: {reason}"); + return deny_connection(&mut socket, reason).await; + } + }; + + let (exec_tx, exec_rx) = mpsc::channel::(config.max_concurrent_executions); + let key = (user_id.clone(), device_id.clone()); + + if let Err(reason) = register_connection(®istry, key.clone(), exec_tx) { + release_device_slot(&device_counts, &user_id, device_counter); + warn!("Rejecting duplicate websocket connection for {user_id}/{device_id}: {reason}"); + return deny_connection(&mut socket, reason).await; + } + + info!("Accepted websocket connection for {user_id}/{device_id}"); + + send(&mut socket, &ServerMsg::AuthOk).await.ok(); + + let pending = Arc::new(DashMap::new()); + let cancel = CancellationToken::new(); + run_loop(&mut socket, exec_rx, &pending, &config, &cancel).await; + cancel.cancel(); + + registry.remove(&key); + pending.clear(); + release_device_slot(&device_counts, &user_id, device_counter); + info!("Closed websocket connection for {user_id}/{device_id}"); + Ok(()) +} + +async fn deny_connection( + socket: &mut DuplexStream, + reason: String, +) -> Result<(), rocket_ws::result::Error> { + send(socket, &ServerMsg::AuthError { reason }).await.ok(); + Ok(()) +} + +fn register_connection( + registry: &Arc, + key: (String, String), + sender: mpsc::Sender, +) -> Result<(), String> { + match registry.entry(key) { + Entry::Occupied(_) => Err("Device already connected".into()), + Entry::Vacant(slot) => { + slot.insert(RegistryEntry { + sender, + in_flight: Arc::new(AtomicUsize::new(0)), + }); + Ok(()) + } + } +} + +fn claim_device_slot( + device_counts: &Arc, + user_id: &str, + config: &Config, +) -> Result, String> { + let counter = { + let entry = device_counts + .entry(user_id.to_owned()) + .or_insert_with(|| Arc::new(AtomicUsize::new(0))); + Arc::clone(&*entry) + }; + + let previous = counter.fetch_add(1, Ordering::SeqCst); + if previous >= config.max_connected_devices { + counter.fetch_sub(1, Ordering::SeqCst); + return Err("Too many devices connected for this account".into()); + } + + Ok(counter) +} + +fn release_device_slot( + device_counts: &Arc, + user_id: &str, + counter: Arc, +) { + let previous = counter.fetch_sub(1, Ordering::SeqCst); + if previous == 1 { + device_counts.remove_if(user_id, |_, value| value.load(Ordering::SeqCst) == 0); + } +} + +async fn run_loop( + socket: &mut DuplexStream, + mut exec_rx: mpsc::Receiver, + pending: &PendingReplies, + config: &Config, + cancel: &CancellationToken, +) { + let mut ping_interval = tokio::time::interval_at( + tokio::time::Instant::now() + config.ping_interval, + config.ping_interval, + ); + + let pong_timer = tokio::time::sleep(Duration::from_secs(0)); + tokio::pin!(pong_timer); + let mut awaiting_pong = false; + + loop { + tokio::select! { + _ = ping_interval.tick() => { + if awaiting_pong { + break; + } + if send_ping(socket).await.is_err() { + break; + } + awaiting_pong = true; + pong_timer.as_mut().reset( + tokio::time::Instant::now() + config.ping_timeout, + ); + } + _ = &mut pong_timer, if awaiting_pong => { + break; + } + Some(exec) = exec_rx.recv() => { + if !dispatch(socket, pending, exec, config.max_execution_time, cancel.clone()).await { + break; + } + } + message = socket.next() => { + let action = handle_socket_message(socket, message, pending).await; + + if action.break_loop { + break; + } + + if action.clear_pong_deadline { + awaiting_pong = false; + } + } + } + } +} + +struct SocketAction { + break_loop: bool, + clear_pong_deadline: bool, +} + +async fn handle_socket_message( + socket: &mut DuplexStream, + message: Option>, + pending: &PendingReplies, +) -> SocketAction { + match message { + Some(Ok(Message::Pong(_))) => SocketAction { + break_loop: false, + clear_pong_deadline: true, + }, + Some(Ok(Message::Ping(data))) => SocketAction { + break_loop: socket.send(Message::Pong(data.clone())).await.is_err(), + clear_pong_deadline: false, + }, + Some(Ok(Message::Text(text))) => { + resolve_pending(&text, pending); + SocketAction { + break_loop: false, + clear_pong_deadline: false, + } + } + Some(Ok(Message::Close(_))) | None => SocketAction { + break_loop: true, + clear_pong_deadline: false, + }, + _ => SocketAction { + break_loop: false, + clear_pong_deadline: false, + }, + } +} + +async fn send_ping(socket: &mut DuplexStream) -> Result<(), rocket_ws::result::Error> { + socket.send(Message::Ping(vec![])).await +} + +async fn dispatch( + socket: &mut DuplexStream, + pending: &PendingReplies, + exec: PendingExec, + max_execution_time: Duration, + cancel: CancellationToken, +) -> bool { + let exec_id = exec.exec_id.clone(); + pending.insert(exec_id.clone(), exec.reply); + + let pending_for_timeout = Arc::clone(pending); + tokio::spawn(async move { + tokio::select! { + _ = cancel.cancelled() => {} + _ = tokio::time::sleep(max_execution_time) => { + pending_for_timeout.remove(&exec_id); + } + } + }); + + send( + socket, + &ServerMsg::Exec { + exec_id: exec.exec_id, + command: exec.command, + }, + ) + .await + .is_ok() +} + +fn resolve_pending(text: &str, pending: &PendingReplies) { + let Ok(ClientMsg::ExecResult { + exec_id, + exit_code, + stdout, + stderr, + }) = serde_json::from_str(text) + else { + return; + }; + if let Some((_, reply)) = pending.remove(&exec_id) { + reply + .send(ExecResult { + exit_code, + stdout, + stderr, + }) + .ok(); + } +} + +async fn send( + socket: &mut DuplexStream, + message: &ServerMsg, +) -> Result<(), Box> { + let text = serde_json::to_string(message)?; + socket.send(Message::Text(text)).await?; + Ok(()) +}