diff --git a/src/api/auth.rs b/src/api/auth.rs index 74586cb..3b53d84 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -22,9 +22,11 @@ impl<'r> FromRequest<'r> for BearerToken { let Some(header) = request.headers().get_one(AUTHORIZATION_HEADER) else { return Outcome::Error((Status::Unauthorized, "Missing Authorization header")); }; + let Some(token) = header.strip_prefix(BEARER_PREFIX) else { return Outcome::Error((Status::Unauthorized, "Invalid Authorization header format")); }; + Outcome::Success(BearerToken(token.to_owned())) } } diff --git a/src/api/mcp/mod.rs b/src/api/mcp/mod.rs index e23f800..8561a71 100644 --- a/src/api/mcp/mod.rs +++ b/src/api/mcp/mod.rs @@ -3,15 +3,13 @@ mod tools; use crate::{api::auth::MaybeBearerToken, app::state::AppState}; use protocol::{ - ApiResponse, InitializeResult, JSONRPC_VERSION, jsonrpc_error, jsonrpc_notification_ok, - jsonrpc_ok, + ApiResponse, InitializeResult, JSONRPC_VERSION, JsonRpcRequest, jsonrpc_error, + jsonrpc_notification_ok, jsonrpc_ok, }; use rocket::{State, http::Status, post, serde::json::Json}; use serde_json::Value; use tools::{handle_tool_call, tool_definitions}; -pub use protocol::JsonRpcRequest; - #[post("/mcp", data = "")] pub async fn route( body: Json, @@ -32,10 +30,13 @@ pub async fn route( match request.method.as_str() { "initialize" => jsonrpc_ok(request.id, InitializeResult::new()), "notifications/initialized" => jsonrpc_notification_ok(), - "tools/list" => match token.0 { - Some(_) => jsonrpc_ok(request.id, tool_definitions()), - None => missing_authorization(request.id), - }, + "tools/list" => { + if token.0.is_some() { + jsonrpc_ok(request.id, tool_definitions()) + } else { + missing_authorization(request.id) + } + } "tools/call" => { handle_tool_call(request.id, request.params, token.0.as_deref(), state).await } diff --git a/src/api/mcp/protocol.rs b/src/api/mcp/protocol.rs index f7d3d9e..6abe52a 100644 --- a/src/api/mcp/protocol.rs +++ b/src/api/mcp/protocol.rs @@ -65,11 +65,7 @@ impl InitializeResult { pub fn new() -> Self { Self { protocol_version: MCP_PROTOCOL_VERSION, - capabilities: ServerCapabilities { - tools: ToolsCapability { - list_changed: false, - }, - }, + capabilities: ServerCapabilities::new(), server_info: ServerInfo { name: env!("CARGO_PKG_NAME"), version: env!("CARGO_PKG_VERSION"), @@ -78,6 +74,16 @@ impl InitializeResult { } } +impl ServerCapabilities { + fn new() -> Self { + Self { + tools: ToolsCapability { + list_changed: false, + }, + } + } +} + pub fn jsonrpc_ok(id: Option, result: T) -> ApiResponse { match serde_json::to_string(&JsonRpcResponse { jsonrpc: JSONRPC_VERSION, diff --git a/src/api/mcp/tools.rs b/src/api/mcp/tools.rs index 7dec1a6..0d1a53d 100644 --- a/src/api/mcp/tools.rs +++ b/src/api/mcp/tools.rs @@ -178,16 +178,17 @@ fn json_tool_result(id: Option, payload: &T) -> ApiResponse impl ToolCallResult { fn success(text: String) -> Self { - Self { - content: vec![TextContent::text(text)], - is_error: false, - } + Self::new(text, false) } fn error(text: String) -> Self { + Self::new(text, true) + } + + fn new(text: String, is_error: bool) -> Self { Self { content: vec![TextContent::text(text)], - is_error: true, + is_error, } } } diff --git a/src/api/mod.rs b/src/api/mod.rs deleted file mode 100644 index d32fd1e..0000000 --- a/src/api/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod auth; -pub mod catchers; -pub mod mcp; diff --git a/src/app/mod.rs b/src/app/mod.rs index ac3f05d..15ac86f 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -4,6 +4,7 @@ use crate::{api, core::config::Config, transport}; use dashmap::DashMap; use rocket::routes; use sqlx::{SqlitePool, sqlite::SqliteConnectOptions}; +use state::AppState; use std::sync::Arc; const DATABASE_PATH: &str = "server.db"; @@ -31,12 +32,16 @@ pub async fn apply_schema(database: &SqlitePool) -> Result<(), sqlx::Error> { pub fn build_rocket(database: SqlitePool) -> rocket::Rocket { rocket::build() - .manage(state::AppState { - database, - registry: Arc::new(DashMap::new()), - device_counts: Arc::new(DashMap::new()), - config: Arc::new(Config::from_env()), - }) + .manage(build_state(database)) .mount("/", routes![transport::socket::connect, api::mcp::route]) .register("/", rocket::catchers![api::catchers::default_catcher]) } + +fn build_state(database: SqlitePool) -> AppState { + AppState { + database, + registry: Arc::new(DashMap::new()), + device_counts: Arc::new(DashMap::new()), + config: Arc::new(Config::from_env()), + } +} diff --git a/src/app/state.rs b/src/app/state.rs index 6bbf0cf..afff8f2 100644 --- a/src/app/state.rs +++ b/src/app/state.rs @@ -9,6 +9,13 @@ pub type Database = SqlitePool; pub type DeviceCounts = DashMap>; pub type Registry = DashMap<(String, String), RegistryEntry>; +pub struct AppState { + pub database: Database, + pub registry: Arc, + pub device_counts: Arc, + pub config: Arc, +} + pub struct RegistryEntry { pub sender: mpsc::Sender, pub in_flight: Arc, @@ -26,10 +33,3 @@ pub struct PendingExec { pub command: String, pub reply: oneshot::Sender, } - -pub struct AppState { - pub database: Database, - pub registry: Arc, - pub device_counts: Arc, - pub config: Arc, -} diff --git a/src/core/mod.rs b/src/core/mod.rs deleted file mode 100644 index b7f192c..0000000 --- a/src/core/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod config; -pub mod db; -pub mod validation; diff --git a/src/main.rs b/src/main.rs index 6d43eb5..cab156f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,21 @@ -pub mod api; +pub mod api { + pub mod auth; + pub mod catchers; + pub mod mcp; +} + pub mod app; -pub mod core; -pub mod transport; +pub mod core { + pub mod config; + pub mod db; + pub mod validation; +} + +pub mod transport { + pub mod execute; + pub mod protocol; + pub mod socket; +} use app::{apply_schema, build_rocket, connect_database}; use log::info; @@ -14,8 +28,6 @@ async fn main() -> Result<(), Box> { apply_schema(&database).await?; info!("Launching server"); - build_rocket(database).launch().await?; - Ok(()) } diff --git a/src/transport/execute.rs b/src/transport/execute.rs index f8474b2..140874f 100644 --- a/src/transport/execute.rs +++ b/src/transport/execute.rs @@ -33,7 +33,6 @@ pub async fn execute_for_user( config: &Config, ) -> Result { let ExecuteRequest { device_id, command } = request; - validate_execute_request(&device_id, &command, config)?; execute(registry, user_id, &device_id, command, config) @@ -69,14 +68,12 @@ async fn execute( config: &Config, ) -> Result { let key = (user_id.to_owned(), device_id.to_owned()); - let (sender, in_flight_counter) = registry .get(&key) .map(|entry| (entry.sender.clone(), Arc::clone(&entry.in_flight))) .ok_or_else(|| "Device not connected".to_string())?; claim_execution_slot(&in_flight_counter, config.max_concurrent_executions)?; - let result = send_and_await(sender, command, config).await; in_flight_counter.fetch_sub(1, Ordering::SeqCst); diff --git a/src/transport/mod.rs b/src/transport/mod.rs deleted file mode 100644 index ef00198..0000000 --- a/src/transport/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod execute; -pub mod protocol; -pub mod socket; diff --git a/src/transport/socket/loop_state.rs b/src/transport/socket/loop_state.rs index f459ccf..6f47192 100644 --- a/src/transport/socket/loop_state.rs +++ b/src/transport/socket/loop_state.rs @@ -42,10 +42,12 @@ pub async fn run_loop( break; } } - message = socket.next() => match handle_socket_message(socket, message, &pending).await { - SocketAction::Break => break, - SocketAction::ClearPongDeadline => awaiting_pong = false, - SocketAction::Continue => {} + message = socket.next() => { + match handle_socket_message(socket, message, &pending).await { + SocketAction::Break => break, + SocketAction::ClearPongDeadline => awaiting_pong = false, + SocketAction::Continue => {} + } } } } diff --git a/src/transport/socket/mod.rs b/src/transport/socket/mod.rs index 952076f..77a285c 100644 --- a/src/transport/socket/mod.rs +++ b/src/transport/socket/mod.rs @@ -15,10 +15,11 @@ pub fn connect<'r>( device_id: MaybeDeviceId, state: &'r State, ) -> Channel<'r> { - let database = state.database.clone(); - let registry = state.registry.clone(); - let device_counts = state.device_counts.clone(); - let config = state.config.clone(); + let app_state = state.inner(); + let database = app_state.database.clone(); + let registry = app_state.registry.clone(); + let device_counts = app_state.device_counts.clone(); + let config = app_state.config.clone(); ws.channel(move |socket| { Box::pin(connection::handle(