From 749638b75cdf60070918726033ce74c3023690b2 Mon Sep 17 00:00:00 2001 From: ZeroZipp Date: Fri, 22 May 2026 08:24:17 +0200 Subject: [PATCH] refactor --- src/api/mcp.rs | 375 ----------------------------- src/api/mcp/mod.rs | 54 +++++ src/api/mcp/protocol.rs | 128 ++++++++++ src/api/mcp/tools.rs | 199 +++++++++++++++ src/api/mod.rs | 3 + src/app/mod.rs | 42 ++++ src/app/state.rs | 35 +++ src/core/db.rs | 2 +- src/core/mod.rs | 3 + src/main.rs | 93 +------ src/transport/execute.rs | 2 +- src/transport/mod.rs | 3 + src/transport/socket.rs | 317 ------------------------ src/transport/socket/connection.rs | 131 ++++++++++ src/transport/socket/loop_state.rs | 131 ++++++++++ src/transport/socket/mod.rs | 34 +++ 16 files changed, 770 insertions(+), 782 deletions(-) delete mode 100644 src/api/mcp.rs create mode 100644 src/api/mcp/mod.rs create mode 100644 src/api/mcp/protocol.rs create mode 100644 src/api/mcp/tools.rs create mode 100644 src/api/mod.rs create mode 100644 src/app/mod.rs create mode 100644 src/app/state.rs create mode 100644 src/core/mod.rs create mode 100644 src/transport/mod.rs delete mode 100644 src/transport/socket.rs create mode 100644 src/transport/socket/connection.rs create mode 100644 src/transport/socket/loop_state.rs create mode 100644 src/transport/socket/mod.rs diff --git a/src/api/mcp.rs b/src/api/mcp.rs deleted file mode 100644 index 1d4d98f..0000000 --- a/src/api/mcp.rs +++ /dev/null @@ -1,375 +0,0 @@ -use crate::{ - AppState, - api::auth::MaybeBearerToken, - core::db::lookup_api_token, - transport::execute::{ExecuteRequest, execute_for_token}, -}; -use rocket::{State, http::Status, post, response::content::RawJson, serde::json::Json}; -use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; - -type ApiResponse = (Status, RawJson); - -const JSONRPC_VERSION: &str = "2.0"; -const MCP_PROTOCOL_VERSION: &str = "2025-11-25"; - -#[derive(Deserialize)] -pub struct JsonRpcRequest { - pub jsonrpc: String, - pub id: Option, - pub method: String, - #[serde(default)] - pub params: Option, -} - -#[derive(Serialize)] -struct JsonRpcResponse { - jsonrpc: &'static str, - id: Option, - result: T, -} - -#[derive(Serialize)] -struct JsonRpcErrorResponse { - jsonrpc: &'static str, - id: Option, - error: JsonRpcError, -} - -#[derive(Serialize)] -struct JsonRpcError { - code: i64, - message: String, -} - -#[derive(Serialize)] -struct InitializeResult { - #[serde(rename = "protocolVersion")] - protocol_version: &'static str, - capabilities: ServerCapabilities, - #[serde(rename = "serverInfo")] - server_info: ServerInfo, -} - -#[derive(Serialize)] -struct ServerCapabilities { - tools: ToolsCapability, -} - -#[derive(Serialize)] -struct ToolsCapability { - list_changed: bool, -} - -#[derive(Serialize)] -struct ServerInfo { - name: &'static str, - version: &'static str, -} - -#[derive(Serialize)] -struct ToolsListResult { - tools: Vec, -} - -#[derive(Serialize)] -struct ToolDefinition { - name: &'static str, - description: &'static str, - #[serde(rename = "inputSchema")] - input_schema: Value, -} - -#[derive(Deserialize)] -struct ToolCallParams { - name: String, - arguments: Value, -} - -#[derive(Serialize)] -struct ToolCallResult { - content: Vec, - #[serde(rename = "isError")] - is_error: bool, -} - -#[derive(Serialize)] -struct ConnectedDevicesResult { - device_ids: Vec, -} - -enum ToolCall { - Execute(ExecuteRequest), - ListDevices, -} - -#[derive(Serialize)] -struct TextContent { - #[serde(rename = "type")] - kind: &'static str, - text: String, -} - -#[post("/mcp", data = "")] -pub async fn route( - body: Json, - token: MaybeBearerToken, - state: &State, -) -> ApiResponse { - let request = body.into_inner(); - - if request.jsonrpc != JSONRPC_VERSION { - return jsonrpc_error( - Status::BadRequest, - request.id, - -32600, - "Invalid JSON-RPC version", - ); - } - - match request.method.as_str() { - "initialize" => jsonrpc_ok( - request.id, - InitializeResult { - protocol_version: MCP_PROTOCOL_VERSION, - capabilities: ServerCapabilities { - tools: ToolsCapability { - list_changed: false, - }, - }, - server_info: ServerInfo { - name: env!("CARGO_PKG_NAME"), - version: env!("CARGO_PKG_VERSION"), - }, - }, - ), - "notifications/initialized" => jsonrpc_notification_ok(), - "tools/list" => match require_bearer_token(token.0.as_deref(), request.id.clone()) { - Ok(_) => jsonrpc_ok( - request.id, - ToolsListResult { - tools: tool_definitions(), - }, - ), - Err(response) => response, - }, - "tools/call" => { - handle_tool_call(request.id, request.params, token.0.as_deref(), state).await - } - _ if request.id.is_none() => jsonrpc_notification_ok(), - _ => jsonrpc_error(Status::BadRequest, request.id, -32601, "Method not found"), - } -} - -async fn handle_tool_call( - id: Option, - params: Option, - token: Option<&str>, - state: &State, -) -> ApiResponse { - let token = match require_bearer_token(token, id.clone()) { - Ok(token) => token, - Err(response) => return response, - }; - - let Some(params) = params else { - return jsonrpc_error(Status::BadRequest, id, -32602, "Missing params"); - }; - - let tool_call = match parse_tool_call(params) { - Ok(call) => call, - Err(_) => return jsonrpc_error(Status::BadRequest, id, -32602, "Invalid params"), - }; - - match tool_call { - ToolCall::Execute(request) => execute_tool_call(id, token, state, request).await, - ToolCall::ListDevices => list_devices_tool_call(id, token, state).await, - } -} - -fn require_bearer_token(token: Option<&str>, id: Option) -> Result<&str, ApiResponse> { - token.ok_or_else(|| { - jsonrpc_error( - Status::Unauthorized, - id, - -32001, - "Missing Authorization header", - ) - }) -} - -fn tool_definitions() -> Vec { - vec![ - ToolDefinition { - name: "execute", - description: "Execute a shell command on a connected device.", - input_schema: json!({ - "type": "object", - "required": ["device_id", "command"], - "properties": { - "device_id": { - "type": "string", - "description": "Identifier of the connected device." - }, - "command": { - "type": "string", - "description": "Shell command to execute on the target device." - } - }, - "additionalProperties": false - }), - }, - ToolDefinition { - name: "list_devices", - description: "List currently connected device identifiers.", - input_schema: json!({ - "type": "object", - "properties": {}, - "additionalProperties": false - }), - }, - ] -} - -fn parse_tool_call(params: Value) -> Result { - let call: ToolCallParams = serde_json::from_value(params).map_err(|_| "Invalid params")?; - - match call.name.as_str() { - "execute" => serde_json::from_value(call.arguments) - .map(ToolCall::Execute) - .map_err(|_| "Invalid execute arguments"), - "list_devices" => parse_list_devices_arguments(call.arguments), - _ => Err("Unknown tool"), - } -} - -fn parse_list_devices_arguments(arguments: Value) -> Result { - match arguments { - Value::Object(map) if map.is_empty() => Ok(ToolCall::ListDevices), - _ => Err("Invalid list_devices arguments"), - } -} - -fn connected_device_ids_for_user(state: &State, user_id: &str) -> Vec { - state - .registry - .iter() - .filter_map(|entry| { - let ((entry_user_id, device_id), _) = entry.pair(); - (entry_user_id == user_id).then(|| device_id.clone()) - }) - .collect() -} - -async fn execute_tool_call( - id: Option, - token: &str, - state: &State, - request: ExecuteRequest, -) -> ApiResponse { - match execute_for_token(state, token, request).await { - Ok(result) => json_tool_result(id, &result), - Err((status, message)) if status == Status::Unauthorized => { - jsonrpc_error(Status::Unauthorized, id, -32001, message) - } - Err((_, message)) => jsonrpc_ok(id, tool_error(message)), - } -} - -async fn list_devices_tool_call( - id: Option, - token: &str, - state: &State, -) -> ApiResponse { - let user_id = match lookup_api_token(&state.database, token).await { - Ok(user_id) => user_id, - Err(message) => return jsonrpc_error(Status::Unauthorized, id, -32001, message), - }; - - let mut device_ids = connected_device_ids_for_user(state, &user_id); - device_ids.sort(); - - json_tool_result(id, &ConnectedDevicesResult { device_ids }) -} - -fn json_tool_result(id: Option, payload: &T) -> ApiResponse { - match serde_json::to_string(payload) { - Ok(text) => jsonrpc_ok(id, ToolCallResult::success(text)), - Err(error) => jsonrpc_error(Status::InternalServerError, id, -32603, error.to_string()), - } -} - -fn tool_error(message: String) -> ToolCallResult { - ToolCallResult::error(message) -} - -impl ToolCallResult { - fn success(text: String) -> Self { - Self { - content: vec![TextContent::text(text)], - is_error: false, - } - } - - fn error(text: String) -> Self { - Self { - content: vec![TextContent::text(text)], - is_error: true, - } - } -} - -impl TextContent { - fn text(text: String) -> Self { - Self { kind: "text", text } - } -} - -fn jsonrpc_ok(id: Option, result: T) -> ApiResponse { - match serde_json::to_string(&JsonRpcResponse { - jsonrpc: JSONRPC_VERSION, - id, - result, - }) { - Ok(body) => (Status::Ok, RawJson(body)), - Err(e) => jsonrpc_error(Status::InternalServerError, None, -32603, e.to_string()), - } -} - -fn jsonrpc_notification_ok() -> ApiResponse { - (Status::Accepted, RawJson(String::new())) -} - -fn jsonrpc_error( - status: Status, - id: Option, - code: i64, - message: impl Into, -) -> ApiResponse { - let response = JsonRpcErrorResponse { - jsonrpc: JSONRPC_VERSION, - id, - error: JsonRpcError { - code, - message: message.into(), - }, - }; - - match serde_json::to_string(&response) { - Ok(body) => (status, RawJson(body)), - Err(e) => ( - Status::InternalServerError, - RawJson( - json!({ - "jsonrpc": JSONRPC_VERSION, - "id": null, - "error": { - "code": -32603, - "message": e.to_string() - } - }) - .to_string(), - ), - ), - } -} diff --git a/src/api/mcp/mod.rs b/src/api/mcp/mod.rs new file mode 100644 index 0000000..e23f800 --- /dev/null +++ b/src/api/mcp/mod.rs @@ -0,0 +1,54 @@ +mod protocol; +mod tools; + +use crate::{api::auth::MaybeBearerToken, app::state::AppState}; +use protocol::{ + ApiResponse, InitializeResult, JSONRPC_VERSION, jsonrpc_error, jsonrpc_notification_ok, + jsonrpc_ok, +}; +use rocket::{State, http::Status, post, serde::json::Json}; +use serde_json::Value; +use tools::{handle_tool_call, tool_definitions}; + +pub use protocol::JsonRpcRequest; + +#[post("/mcp", data = "")] +pub async fn route( + body: Json, + token: MaybeBearerToken, + state: &State, +) -> ApiResponse { + let request = body.into_inner(); + + if request.jsonrpc != JSONRPC_VERSION { + return jsonrpc_error( + Status::BadRequest, + request.id, + -32600, + "Invalid JSON-RPC version", + ); + } + + match request.method.as_str() { + "initialize" => jsonrpc_ok(request.id, InitializeResult::new()), + "notifications/initialized" => jsonrpc_notification_ok(), + "tools/list" => match token.0 { + Some(_) => jsonrpc_ok(request.id, tool_definitions()), + None => missing_authorization(request.id), + }, + "tools/call" => { + handle_tool_call(request.id, request.params, token.0.as_deref(), state).await + } + _ if request.id.is_none() => jsonrpc_notification_ok(), + _ => jsonrpc_error(Status::BadRequest, request.id, -32601, "Method not found"), + } +} + +fn missing_authorization(id: Option) -> ApiResponse { + jsonrpc_error( + Status::Unauthorized, + id, + -32001, + "Missing Authorization header", + ) +} diff --git a/src/api/mcp/protocol.rs b/src/api/mcp/protocol.rs new file mode 100644 index 0000000..f7d3d9e --- /dev/null +++ b/src/api/mcp/protocol.rs @@ -0,0 +1,128 @@ +use rocket::{http::Status, response::content::RawJson}; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; + +pub type ApiResponse = (Status, RawJson); + +pub const JSONRPC_VERSION: &str = "2.0"; +const MCP_PROTOCOL_VERSION: &str = "2025-11-25"; + +#[derive(Deserialize)] +pub struct JsonRpcRequest { + pub jsonrpc: String, + pub id: Option, + pub method: String, + #[serde(default)] + pub params: Option, +} + +#[derive(Serialize)] +struct JsonRpcResponse { + jsonrpc: &'static str, + id: Option, + result: T, +} + +#[derive(Serialize)] +struct JsonRpcErrorResponse { + jsonrpc: &'static str, + id: Option, + error: JsonRpcError, +} + +#[derive(Serialize)] +struct JsonRpcError { + code: i64, + message: String, +} + +#[derive(Serialize)] +pub struct InitializeResult { + #[serde(rename = "protocolVersion")] + protocol_version: &'static str, + capabilities: ServerCapabilities, + #[serde(rename = "serverInfo")] + server_info: ServerInfo, +} + +#[derive(Serialize)] +struct ServerCapabilities { + tools: ToolsCapability, +} + +#[derive(Serialize)] +struct ToolsCapability { + list_changed: bool, +} + +#[derive(Serialize)] +struct ServerInfo { + name: &'static str, + version: &'static str, +} + +impl InitializeResult { + pub fn new() -> Self { + Self { + protocol_version: MCP_PROTOCOL_VERSION, + capabilities: ServerCapabilities { + tools: ToolsCapability { + list_changed: false, + }, + }, + server_info: ServerInfo { + name: env!("CARGO_PKG_NAME"), + version: env!("CARGO_PKG_VERSION"), + }, + } + } +} + +pub fn jsonrpc_ok(id: Option, result: T) -> ApiResponse { + match serde_json::to_string(&JsonRpcResponse { + jsonrpc: JSONRPC_VERSION, + id, + result, + }) { + Ok(body) => (Status::Ok, RawJson(body)), + Err(error) => jsonrpc_error(Status::InternalServerError, None, -32603, error.to_string()), + } +} + +pub fn jsonrpc_notification_ok() -> ApiResponse { + (Status::Accepted, RawJson(String::new())) +} + +pub fn jsonrpc_error( + status: Status, + id: Option, + code: i64, + message: impl Into, +) -> ApiResponse { + let response = JsonRpcErrorResponse { + jsonrpc: JSONRPC_VERSION, + id, + error: JsonRpcError { + code, + message: message.into(), + }, + }; + + match serde_json::to_string(&response) { + Ok(body) => (status, RawJson(body)), + Err(error) => ( + Status::InternalServerError, + RawJson( + json!({ + "jsonrpc": JSONRPC_VERSION, + "id": null, + "error": { + "code": -32603, + "message": error.to_string() + } + }) + .to_string(), + ), + ), + } +} diff --git a/src/api/mcp/tools.rs b/src/api/mcp/tools.rs new file mode 100644 index 0000000..7dec1a6 --- /dev/null +++ b/src/api/mcp/tools.rs @@ -0,0 +1,199 @@ +use crate::{ + api::mcp::protocol::{ApiResponse, jsonrpc_error, jsonrpc_ok}, + app::state::AppState, + core::db::lookup_api_token, + transport::execute::{ExecuteRequest, execute_for_token}, +}; +use rocket::{State, http::Status}; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; + +#[derive(Serialize)] +pub struct ToolsListResult { + tools: Vec, +} + +#[derive(Serialize)] +struct ToolDefinition { + name: &'static str, + description: &'static str, + #[serde(rename = "inputSchema")] + input_schema: Value, +} + +#[derive(Deserialize)] +struct ToolCallParams { + name: String, + arguments: Value, +} + +#[derive(Serialize)] +struct ToolCallResult { + content: Vec, + #[serde(rename = "isError")] + is_error: bool, +} + +#[derive(Serialize)] +struct ConnectedDevicesResult { + device_ids: Vec, +} + +enum ToolCall { + Execute(ExecuteRequest), + ListDevices, +} + +#[derive(Serialize)] +struct TextContent { + #[serde(rename = "type")] + kind: &'static str, + text: String, +} + +pub fn tool_definitions() -> ToolsListResult { + ToolsListResult { + tools: vec![ + ToolDefinition { + name: "execute", + description: "Execute a shell command on a connected device.", + input_schema: json!({ + "type": "object", + "required": ["device_id", "command"], + "properties": { + "device_id": { + "type": "string", + "description": "Identifier of the connected device." + }, + "command": { + "type": "string", + "description": "Shell command to execute on the target device." + } + }, + "additionalProperties": false + }), + }, + ToolDefinition { + name: "list_devices", + description: "List currently connected device identifiers.", + input_schema: json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + }), + }, + ], + } +} + +pub async fn handle_tool_call( + id: Option, + params: Option, + token: Option<&str>, + state: &State, +) -> ApiResponse { + let Some(token) = token else { + return jsonrpc_error( + Status::Unauthorized, + id, + -32001, + "Missing Authorization header", + ); + }; + + let Some(params) = params else { + return jsonrpc_error(Status::BadRequest, id, -32602, "Missing params"); + }; + + let tool_call = match parse_tool_call(params) { + Ok(call) => call, + Err(_) => return jsonrpc_error(Status::BadRequest, id, -32602, "Invalid params"), + }; + + match tool_call { + ToolCall::Execute(request) => execute_tool_call(id, token, state, request).await, + ToolCall::ListDevices => list_devices_tool_call(id, token, state).await, + } +} + +fn parse_tool_call(params: Value) -> Result { + let call: ToolCallParams = serde_json::from_value(params).map_err(|_| ())?; + + match call.name.as_str() { + "execute" => serde_json::from_value(call.arguments) + .map(ToolCall::Execute) + .map_err(|_| ()), + "list_devices" => match call.arguments { + Value::Object(map) if map.is_empty() => Ok(ToolCall::ListDevices), + _ => Err(()), + }, + _ => Err(()), + } +} + +async fn execute_tool_call( + id: Option, + token: &str, + state: &State, + request: ExecuteRequest, +) -> ApiResponse { + match execute_for_token(state, token, request).await { + Ok(result) => json_tool_result(id, &result), + Err((status, message)) if status == Status::Unauthorized => { + jsonrpc_error(Status::Unauthorized, id, -32001, message) + } + Err((_, message)) => jsonrpc_ok(id, ToolCallResult::error(message)), + } +} + +async fn list_devices_tool_call( + id: Option, + token: &str, + state: &State, +) -> ApiResponse { + let user_id = match lookup_api_token(&state.database, token).await { + Ok(user_id) => user_id, + Err(message) => return jsonrpc_error(Status::Unauthorized, id, -32001, message), + }; + + let mut device_ids = state + .registry + .iter() + .filter_map(|entry| { + let ((entry_user_id, device_id), _) = entry.pair(); + (entry_user_id == &user_id).then(|| device_id.clone()) + }) + .collect::>(); + device_ids.sort(); + + json_tool_result(id, &ConnectedDevicesResult { device_ids }) +} + +fn json_tool_result(id: Option, payload: &T) -> ApiResponse { + match serde_json::to_string(payload) { + Ok(text) => jsonrpc_ok(id, ToolCallResult::success(text)), + Err(error) => jsonrpc_error(Status::InternalServerError, id, -32603, error.to_string()), + } +} + +impl ToolCallResult { + fn success(text: String) -> Self { + Self { + content: vec![TextContent::text(text)], + is_error: false, + } + } + + fn error(text: String) -> Self { + Self { + content: vec![TextContent::text(text)], + is_error: true, + } + } +} + +impl TextContent { + fn text(text: String) -> Self { + Self { kind: "text", text } + } +} diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 0000000..d32fd1e --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1,3 @@ +pub mod auth; +pub mod catchers; +pub mod mcp; diff --git a/src/app/mod.rs b/src/app/mod.rs new file mode 100644 index 0000000..ac3f05d --- /dev/null +++ b/src/app/mod.rs @@ -0,0 +1,42 @@ +pub mod state; + +use crate::{api, core::config::Config, transport}; +use dashmap::DashMap; +use rocket::routes; +use sqlx::{SqlitePool, sqlite::SqliteConnectOptions}; +use std::sync::Arc; + +const DATABASE_PATH: &str = "server.db"; + +pub async fn connect_database() -> Result { + SqlitePool::connect_with( + SqliteConnectOptions::new() + .filename(DATABASE_PATH) + .create_if_missing(true), + ) + .await +} + +pub async fn apply_schema(database: &SqlitePool) -> Result<(), sqlx::Error> { + for statement in include_str!("../../schema.sql") + .split(';') + .map(str::trim) + .filter(|statement| !statement.is_empty()) + { + sqlx::query(statement).execute(database).await?; + } + + Ok(()) +} + +pub fn build_rocket(database: SqlitePool) -> rocket::Rocket { + rocket::build() + .manage(state::AppState { + database, + registry: Arc::new(DashMap::new()), + device_counts: Arc::new(DashMap::new()), + config: Arc::new(Config::from_env()), + }) + .mount("/", routes![transport::socket::connect, api::mcp::route]) + .register("/", rocket::catchers![api::catchers::default_catcher]) +} diff --git a/src/app/state.rs b/src/app/state.rs new file mode 100644 index 0000000..6bbf0cf --- /dev/null +++ b/src/app/state.rs @@ -0,0 +1,35 @@ +use crate::core::config::Config; +use dashmap::DashMap; +use serde::Serialize; +use sqlx::SqlitePool; +use std::sync::{Arc, atomic::AtomicUsize}; +use tokio::sync::{mpsc, oneshot}; + +pub type Database = SqlitePool; +pub type DeviceCounts = DashMap>; +pub type Registry = DashMap<(String, String), RegistryEntry>; + +pub struct RegistryEntry { + pub sender: mpsc::Sender, + pub in_flight: Arc, +} + +#[derive(Debug, Serialize)] +pub struct ExecResult { + pub exit_code: i32, + pub stdout: String, + pub stderr: String, +} + +pub struct PendingExec { + pub exec_id: String, + pub command: String, + pub reply: oneshot::Sender, +} + +pub struct AppState { + pub database: Database, + pub registry: Arc, + pub device_counts: Arc, + pub config: Arc, +} diff --git a/src/core/db.rs b/src/core/db.rs index 4e78885..41fc5c2 100644 --- a/src/core/db.rs +++ b/src/core/db.rs @@ -1,4 +1,4 @@ -use crate::Database; +use crate::app::state::Database; const AUTH_TOKENS_TABLE: &str = "auth_tokens"; const API_TOKENS_TABLE: &str = "api_tokens"; diff --git a/src/core/mod.rs b/src/core/mod.rs new file mode 100644 index 0000000..b7f192c --- /dev/null +++ b/src/core/mod.rs @@ -0,0 +1,3 @@ +pub mod config; +pub mod db; +pub mod validation; diff --git a/src/main.rs b/src/main.rs index fb506b2..6d43eb5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,61 +1,10 @@ -pub mod api { - pub mod auth; - pub mod catchers; - pub mod mcp; -} +pub mod api; +pub mod app; +pub mod core; +pub mod transport; -pub mod core { - pub mod config; - pub mod db; - pub mod validation; -} - -pub mod transport { - pub mod execute; - pub mod protocol; - pub mod socket; -} - -use core::config::Config; -use dashmap::DashMap; +use app::{apply_schema, build_rocket, connect_database}; use log::info; -use rocket::routes; -use serde::Serialize; -use sqlx::SqlitePool; -use sqlx::sqlite::SqliteConnectOptions; -use std::sync::{Arc, atomic::AtomicUsize}; -use tokio::sync::{mpsc, oneshot}; - -const DATABASE_PATH: &str = "server.db"; - -pub type Database = SqlitePool; -pub type DeviceCounts = DashMap>; -pub type Registry = DashMap<(String, String), RegistryEntry>; - -pub struct RegistryEntry { - pub sender: mpsc::Sender, - pub in_flight: Arc, -} - -#[derive(Debug, Serialize)] -pub struct ExecResult { - pub exit_code: i32, - pub stdout: String, - pub stderr: String, -} - -pub struct PendingExec { - pub exec_id: String, - pub command: String, - pub reply: oneshot::Sender, -} - -pub struct AppState { - pub database: Database, - pub registry: Arc, - pub device_counts: Arc, - pub config: Arc, -} #[rocket::main] async fn main() -> Result<(), Box> { @@ -70,35 +19,3 @@ async fn main() -> Result<(), Box> { Ok(()) } - -async fn connect_database() -> Result { - let options = SqliteConnectOptions::new() - .filename(DATABASE_PATH) - .create_if_missing(true); - - SqlitePool::connect_with(options).await -} - -async fn apply_schema(database: &Database) -> Result<(), sqlx::Error> { - for statement in include_str!("../schema.sql") - .split(';') - .map(str::trim) - .filter(|statement| !statement.is_empty()) - { - sqlx::query(statement).execute(database).await?; - } - - Ok(()) -} - -fn build_rocket(database: Database) -> rocket::Rocket { - rocket::build() - .manage(AppState { - database, - registry: Arc::new(DashMap::new()), - device_counts: Arc::new(DashMap::new()), - config: Arc::new(Config::from_env()), - }) - .mount("/", routes![transport::socket::connect, api::mcp::route]) - .register("/", rocket::catchers![api::catchers::default_catcher]) -} diff --git a/src/transport/execute.rs b/src/transport/execute.rs index 4cd62d8..f8474b2 100644 --- a/src/transport/execute.rs +++ b/src/transport/execute.rs @@ -1,5 +1,5 @@ use crate::{ - AppState, ExecResult, PendingExec, Registry, + app::state::{AppState, ExecResult, PendingExec, Registry}, core::{config::Config, db::lookup_api_token, validation::validate_device_id}, }; use rocket::{State, http::Status}; diff --git a/src/transport/mod.rs b/src/transport/mod.rs new file mode 100644 index 0000000..ef00198 --- /dev/null +++ b/src/transport/mod.rs @@ -0,0 +1,3 @@ +pub mod execute; +pub mod protocol; +pub mod socket; diff --git a/src/transport/socket.rs b/src/transport/socket.rs deleted file mode 100644 index e92bcb8..0000000 --- a/src/transport/socket.rs +++ /dev/null @@ -1,317 +0,0 @@ -use crate::{ - AppState, Database, DeviceCounts, ExecResult, PendingExec, Registry, RegistryEntry, - api::auth::{BearerToken, MaybeDeviceId}, - core::config::Config, - core::db::lookup_auth_token, - transport::protocol::{ClientMsg, ServerMsg}, -}; -use dashmap::{DashMap, mapref::entry::Entry}; -use futures::{SinkExt, StreamExt}; -use log::{info, warn}; -use rocket::{State, get}; -use rocket_ws::{Channel, Message, WebSocket, stream::DuplexStream}; -use std::{ - sync::{ - Arc, - atomic::{AtomicUsize, Ordering}, - }, - time::Duration, -}; -use tokio::sync::{mpsc, oneshot}; -use tokio_util::sync::CancellationToken; - -type PendingReplies = Arc>>; - -#[get("/connect")] -pub fn connect<'r>( - ws: WebSocket, - token: BearerToken, - device_id: MaybeDeviceId, - state: &'r State, -) -> Channel<'r> { - let database = state.database.clone(); - let registry = state.registry.clone(); - let device_counts = state.device_counts.clone(); - let config = state.config.clone(); - ws.channel(move |socket| { - Box::pin(handle( - database, - registry, - device_counts, - config, - socket, - token.0, - device_id.0, - )) - }) -} - -async fn handle( - database: Database, - registry: Arc, - device_counts: Arc, - config: Arc, - mut socket: DuplexStream, - token: String, - device_id: Result, -) -> Result<(), rocket_ws::result::Error> { - let device_id = match device_id { - Ok(device_id) => device_id, - Err(reason) => { - warn!("Rejecting websocket connection: {reason}"); - return deny_connection(&mut socket, reason).await; - } - }; - - let user_id = match lookup_auth_token(&database, &token).await { - Ok(user_id) => user_id, - Err(reason) => { - warn!("Rejecting websocket connection for device {device_id}: {reason}"); - return deny_connection(&mut socket, reason).await; - } - }; - - let device_counter = match claim_device_slot(&device_counts, &user_id, &config) { - Ok(counter) => counter, - Err(reason) => { - warn!("Rejecting websocket connection for user {user_id} device {device_id}: {reason}"); - return deny_connection(&mut socket, reason).await; - } - }; - - let (exec_tx, exec_rx) = mpsc::channel::(config.max_concurrent_executions); - let key = (user_id.clone(), device_id.clone()); - - if let Err(reason) = register_connection(®istry, key.clone(), exec_tx) { - release_device_slot(&device_counts, &user_id, device_counter); - warn!("Rejecting duplicate websocket connection for {user_id}/{device_id}: {reason}"); - return deny_connection(&mut socket, reason).await; - } - - info!("Accepted websocket connection for {user_id}/{device_id}"); - - send(&mut socket, &ServerMsg::AuthOk).await.ok(); - - let pending = Arc::new(DashMap::new()); - let cancel = CancellationToken::new(); - run_loop(&mut socket, exec_rx, &pending, &config, &cancel).await; - cancel.cancel(); - - registry.remove(&key); - pending.clear(); - release_device_slot(&device_counts, &user_id, device_counter); - info!("Closed websocket connection for {user_id}/{device_id}"); - Ok(()) -} - -async fn deny_connection( - socket: &mut DuplexStream, - reason: String, -) -> Result<(), rocket_ws::result::Error> { - send(socket, &ServerMsg::AuthError { reason }).await.ok(); - Ok(()) -} - -fn register_connection( - registry: &Arc, - key: (String, String), - sender: mpsc::Sender, -) -> Result<(), String> { - match registry.entry(key) { - Entry::Occupied(_) => Err("Device already connected".into()), - Entry::Vacant(slot) => { - slot.insert(RegistryEntry { - sender, - in_flight: Arc::new(AtomicUsize::new(0)), - }); - Ok(()) - } - } -} - -fn claim_device_slot( - device_counts: &Arc, - user_id: &str, - config: &Config, -) -> Result, String> { - let counter = { - let entry = device_counts - .entry(user_id.to_owned()) - .or_insert_with(|| Arc::new(AtomicUsize::new(0))); - Arc::clone(&*entry) - }; - - let previous = counter.fetch_add(1, Ordering::SeqCst); - if previous >= config.max_connected_devices { - counter.fetch_sub(1, Ordering::SeqCst); - return Err("Too many devices connected for this account".into()); - } - - Ok(counter) -} - -fn release_device_slot( - device_counts: &Arc, - user_id: &str, - counter: Arc, -) { - let previous = counter.fetch_sub(1, Ordering::SeqCst); - if previous == 1 { - device_counts.remove_if(user_id, |_, value| value.load(Ordering::SeqCst) == 0); - } -} - -async fn run_loop( - socket: &mut DuplexStream, - mut exec_rx: mpsc::Receiver, - pending: &PendingReplies, - config: &Config, - cancel: &CancellationToken, -) { - let mut ping_interval = tokio::time::interval_at( - tokio::time::Instant::now() + config.ping_interval, - config.ping_interval, - ); - - let pong_timer = tokio::time::sleep(Duration::from_secs(0)); - tokio::pin!(pong_timer); - let mut awaiting_pong = false; - - loop { - tokio::select! { - _ = ping_interval.tick() => { - if awaiting_pong { - break; - } - if send_ping(socket).await.is_err() { - break; - } - awaiting_pong = true; - pong_timer.as_mut().reset( - tokio::time::Instant::now() + config.ping_timeout, - ); - } - _ = &mut pong_timer, if awaiting_pong => { - break; - } - Some(exec) = exec_rx.recv() => { - if !dispatch(socket, pending, exec, config.max_execution_time, cancel.clone()).await { - break; - } - } - message = socket.next() => { - let action = handle_socket_message(socket, message, pending).await; - - if action.break_loop { - break; - } - - if action.clear_pong_deadline { - awaiting_pong = false; - } - } - } - } -} - -struct SocketAction { - break_loop: bool, - clear_pong_deadline: bool, -} - -async fn handle_socket_message( - socket: &mut DuplexStream, - message: Option>, - pending: &PendingReplies, -) -> SocketAction { - match message { - Some(Ok(Message::Pong(_))) => SocketAction { - break_loop: false, - clear_pong_deadline: true, - }, - Some(Ok(Message::Ping(data))) => SocketAction { - break_loop: socket.send(Message::Pong(data.clone())).await.is_err(), - clear_pong_deadline: false, - }, - Some(Ok(Message::Text(text))) => { - resolve_pending(&text, pending); - SocketAction { - break_loop: false, - clear_pong_deadline: false, - } - } - Some(Ok(Message::Close(_))) | None => SocketAction { - break_loop: true, - clear_pong_deadline: false, - }, - _ => SocketAction { - break_loop: false, - clear_pong_deadline: false, - }, - } -} - -async fn send_ping(socket: &mut DuplexStream) -> Result<(), rocket_ws::result::Error> { - socket.send(Message::Ping(vec![])).await -} - -async fn dispatch( - socket: &mut DuplexStream, - pending: &PendingReplies, - exec: PendingExec, - max_execution_time: Duration, - cancel: CancellationToken, -) -> bool { - let exec_id = exec.exec_id.clone(); - pending.insert(exec_id.clone(), exec.reply); - - let pending_for_timeout = Arc::clone(pending); - tokio::spawn(async move { - tokio::select! { - _ = cancel.cancelled() => {} - _ = tokio::time::sleep(max_execution_time) => { - pending_for_timeout.remove(&exec_id); - } - } - }); - - send( - socket, - &ServerMsg::Exec { - exec_id: exec.exec_id, - command: exec.command, - }, - ) - .await - .is_ok() -} - -fn resolve_pending(text: &str, pending: &PendingReplies) { - let Ok(ClientMsg::ExecResult { - exec_id, - exit_code, - stdout, - stderr, - }) = serde_json::from_str(text) - else { - return; - }; - if let Some((_, reply)) = pending.remove(&exec_id) { - reply - .send(ExecResult { - exit_code, - stdout, - stderr, - }) - .ok(); - } -} - -async fn send( - socket: &mut DuplexStream, - message: &ServerMsg, -) -> Result<(), Box> { - let text = serde_json::to_string(message)?; - socket.send(Message::Text(text)).await?; - Ok(()) -} diff --git a/src/transport/socket/connection.rs b/src/transport/socket/connection.rs new file mode 100644 index 0000000..d9c310e --- /dev/null +++ b/src/transport/socket/connection.rs @@ -0,0 +1,131 @@ +use super::loop_state::run_loop; +use crate::{ + app::state::{Database, DeviceCounts, PendingExec, Registry, RegistryEntry}, + core::{config::Config, db::lookup_auth_token}, + transport::protocol::ServerMsg, +}; +use dashmap::mapref::entry::Entry; +use futures::SinkExt; +use log::{info, warn}; +use rocket_ws::{Message, result::Error, stream::DuplexStream}; +use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, +}; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; + +pub async fn handle( + database: Database, + registry: Arc, + device_counts: Arc, + config: Arc, + mut socket: DuplexStream, + token: String, + device_id: Result, +) -> Result<(), Error> { + let device_id = match device_id { + Ok(device_id) => device_id, + Err(reason) => { + warn!("Rejecting websocket connection: {reason}"); + return deny_connection(&mut socket, reason).await; + } + }; + + let user_id = match lookup_auth_token(&database, &token).await { + Ok(user_id) => user_id, + Err(reason) => { + warn!("Rejecting websocket connection for device {device_id}: {reason}"); + return deny_connection(&mut socket, reason).await; + } + }; + + let device_counter = match claim_device_slot(&device_counts, &user_id, &config) { + Ok(counter) => counter, + Err(reason) => { + warn!("Rejecting websocket connection for user {user_id} device {device_id}: {reason}"); + return deny_connection(&mut socket, reason).await; + } + }; + + let key = (user_id.clone(), device_id.clone()); + let (exec_tx, exec_rx) = mpsc::channel::(config.max_concurrent_executions); + if let Err(reason) = register_connection(®istry, key.clone(), exec_tx) { + release_device_slot(&device_counts, &user_id, device_counter); + warn!("Rejecting duplicate websocket connection for {user_id}/{device_id}: {reason}"); + return deny_connection(&mut socket, reason).await; + } + + info!("Accepted websocket connection for {user_id}/{device_id}"); + send(&mut socket, &ServerMsg::AuthOk).await.ok(); + + let cancel = CancellationToken::new(); + run_loop(&mut socket, exec_rx, &config, &cancel).await; + cancel.cancel(); + + registry.remove(&key); + release_device_slot(&device_counts, &user_id, device_counter); + info!("Closed websocket connection for {user_id}/{device_id}"); + Ok(()) +} + +fn register_connection( + registry: &Arc, + key: (String, String), + sender: mpsc::Sender, +) -> Result<(), String> { + match registry.entry(key) { + Entry::Occupied(_) => Err("Device already connected".into()), + Entry::Vacant(slot) => { + slot.insert(RegistryEntry { + sender, + in_flight: Arc::new(AtomicUsize::new(0)), + }); + Ok(()) + } + } +} + +fn claim_device_slot( + device_counts: &Arc, + user_id: &str, + config: &Config, +) -> Result, String> { + let counter = Arc::clone( + &*device_counts + .entry(user_id.to_owned()) + .or_insert_with(|| Arc::new(AtomicUsize::new(0))), + ); + + if counter.fetch_add(1, Ordering::SeqCst) < config.max_connected_devices { + return Ok(counter); + } + + counter.fetch_sub(1, Ordering::SeqCst); + Err("Too many devices connected for this account".into()) +} + +fn release_device_slot( + device_counts: &Arc, + user_id: &str, + counter: Arc, +) { + if counter.fetch_sub(1, Ordering::SeqCst) == 1 { + device_counts.remove_if(user_id, |_, value| value.load(Ordering::SeqCst) == 0); + } +} + +async fn deny_connection(socket: &mut DuplexStream, reason: String) -> Result<(), Error> { + send(socket, &ServerMsg::AuthError { reason }).await.ok(); + Ok(()) +} + +async fn send( + socket: &mut DuplexStream, + message: &ServerMsg, +) -> Result<(), Box> { + socket + .send(Message::Text(serde_json::to_string(message)?)) + .await?; + Ok(()) +} diff --git a/src/transport/socket/loop_state.rs b/src/transport/socket/loop_state.rs new file mode 100644 index 0000000..f459ccf --- /dev/null +++ b/src/transport/socket/loop_state.rs @@ -0,0 +1,131 @@ +use crate::{ + app::state::{ExecResult, PendingExec}, + transport::protocol::{ClientMsg, ServerMsg}, +}; +use dashmap::DashMap; +use futures::{SinkExt, StreamExt}; +use rocket_ws::{Message, stream::DuplexStream}; +use std::{sync::Arc, time::Duration}; +use tokio::sync::{mpsc, oneshot}; +use tokio_util::sync::CancellationToken; + +type PendingReplies = Arc>>; + +pub async fn run_loop( + socket: &mut DuplexStream, + mut exec_rx: mpsc::Receiver, + config: &crate::core::config::Config, + cancel: &CancellationToken, +) { + let pending = Arc::new(DashMap::new()); + let mut ping_interval = tokio::time::interval_at( + tokio::time::Instant::now() + config.ping_interval, + config.ping_interval, + ); + let pong_timer = tokio::time::sleep(Duration::from_secs(0)); + tokio::pin!(pong_timer); + let mut awaiting_pong = false; + + loop { + tokio::select! { + _ = ping_interval.tick() => { + if awaiting_pong || socket.send(Message::Ping(vec![])).await.is_err() { + break; + } + + awaiting_pong = true; + pong_timer.as_mut().reset(tokio::time::Instant::now() + config.ping_timeout); + } + _ = &mut pong_timer, if awaiting_pong => break, + Some(exec) = exec_rx.recv() => { + if !dispatch(socket, &pending, exec, config.max_execution_time, cancel.clone()).await { + break; + } + } + message = socket.next() => match handle_socket_message(socket, message, &pending).await { + SocketAction::Break => break, + SocketAction::ClearPongDeadline => awaiting_pong = false, + SocketAction::Continue => {} + } + } + } + + pending.clear(); +} + +enum SocketAction { + Break, + ClearPongDeadline, + Continue, +} + +async fn handle_socket_message( + socket: &mut DuplexStream, + message: Option>, + pending: &PendingReplies, +) -> SocketAction { + match message { + Some(Ok(Message::Pong(_))) => SocketAction::ClearPongDeadline, + Some(Ok(Message::Ping(data))) => { + if socket.send(Message::Pong(data.clone())).await.is_err() { + return SocketAction::Break; + } + + SocketAction::Continue + } + Some(Ok(Message::Text(text))) => { + if let Ok(ClientMsg::ExecResult { + exec_id, + exit_code, + stdout, + stderr, + }) = serde_json::from_str(&text) + && let Some((_, reply)) = pending.remove(&exec_id) + { + reply + .send(ExecResult { + exit_code, + stdout, + stderr, + }) + .ok(); + } + + SocketAction::Continue + } + Some(Ok(Message::Close(_))) | None => SocketAction::Break, + _ => SocketAction::Continue, + } +} + +async fn dispatch( + socket: &mut DuplexStream, + pending: &PendingReplies, + exec: PendingExec, + max_execution_time: Duration, + cancel: CancellationToken, +) -> bool { + let exec_id = exec.exec_id.clone(); + pending.insert(exec_id.clone(), exec.reply); + + let pending_for_timeout = Arc::clone(pending); + tokio::spawn(async move { + tokio::select! { + _ = cancel.cancelled() => {} + _ = tokio::time::sleep(max_execution_time) => { + pending_for_timeout.remove(&exec_id); + } + } + }); + + socket + .send(Message::Text( + serde_json::to_string(&ServerMsg::Exec { + exec_id: exec.exec_id, + command: exec.command, + }) + .expect("serializing server exec message should not fail"), + )) + .await + .is_ok() +} diff --git a/src/transport/socket/mod.rs b/src/transport/socket/mod.rs new file mode 100644 index 0000000..952076f --- /dev/null +++ b/src/transport/socket/mod.rs @@ -0,0 +1,34 @@ +mod connection; +mod loop_state; + +use crate::{ + api::auth::{BearerToken, MaybeDeviceId}, + app::state::AppState, +}; +use rocket::{State, get}; +use rocket_ws::{Channel, WebSocket}; + +#[get("/connect")] +pub fn connect<'r>( + ws: WebSocket, + token: BearerToken, + device_id: MaybeDeviceId, + state: &'r State, +) -> Channel<'r> { + let database = state.database.clone(); + let registry = state.registry.clone(); + let device_counts = state.device_counts.clone(); + let config = state.config.clone(); + + ws.channel(move |socket| { + Box::pin(connection::handle( + database, + registry, + device_counts, + config, + socket, + token.0, + device_id.0, + )) + }) +}