diff --git a/src/api/auth.rs b/src/api/auth.rs
index 74586cb..3b53d84 100644
--- a/src/api/auth.rs
+++ b/src/api/auth.rs
@@ -22,9 +22,11 @@ impl<'r> FromRequest<'r> for BearerToken {
let Some(header) = request.headers().get_one(AUTHORIZATION_HEADER) else {
return Outcome::Error((Status::Unauthorized, "Missing Authorization header"));
};
+
let Some(token) = header.strip_prefix(BEARER_PREFIX) else {
return Outcome::Error((Status::Unauthorized, "Invalid Authorization header format"));
};
+
Outcome::Success(BearerToken(token.to_owned()))
}
}
diff --git a/src/api/mcp/mod.rs b/src/api/mcp/mod.rs
index e23f800..8561a71 100644
--- a/src/api/mcp/mod.rs
+++ b/src/api/mcp/mod.rs
@@ -3,15 +3,13 @@ mod tools;
use crate::{api::auth::MaybeBearerToken, app::state::AppState};
use protocol::{
- ApiResponse, InitializeResult, JSONRPC_VERSION, jsonrpc_error, jsonrpc_notification_ok,
- jsonrpc_ok,
+ ApiResponse, InitializeResult, JSONRPC_VERSION, JsonRpcRequest, jsonrpc_error,
+ jsonrpc_notification_ok, jsonrpc_ok,
};
use rocket::{State, http::Status, post, serde::json::Json};
use serde_json::Value;
use tools::{handle_tool_call, tool_definitions};
-pub use protocol::JsonRpcRequest;
-
#[post("/mcp", data = "
")]
pub async fn route(
body: Json,
@@ -32,10 +30,13 @@ pub async fn route(
match request.method.as_str() {
"initialize" => jsonrpc_ok(request.id, InitializeResult::new()),
"notifications/initialized" => jsonrpc_notification_ok(),
- "tools/list" => match token.0 {
- Some(_) => jsonrpc_ok(request.id, tool_definitions()),
- None => missing_authorization(request.id),
- },
+ "tools/list" => {
+ if token.0.is_some() {
+ jsonrpc_ok(request.id, tool_definitions())
+ } else {
+ missing_authorization(request.id)
+ }
+ }
"tools/call" => {
handle_tool_call(request.id, request.params, token.0.as_deref(), state).await
}
diff --git a/src/api/mcp/protocol.rs b/src/api/mcp/protocol.rs
index f7d3d9e..6abe52a 100644
--- a/src/api/mcp/protocol.rs
+++ b/src/api/mcp/protocol.rs
@@ -65,11 +65,7 @@ impl InitializeResult {
pub fn new() -> Self {
Self {
protocol_version: MCP_PROTOCOL_VERSION,
- capabilities: ServerCapabilities {
- tools: ToolsCapability {
- list_changed: false,
- },
- },
+ capabilities: ServerCapabilities::new(),
server_info: ServerInfo {
name: env!("CARGO_PKG_NAME"),
version: env!("CARGO_PKG_VERSION"),
@@ -78,6 +74,16 @@ impl InitializeResult {
}
}
+impl ServerCapabilities {
+ fn new() -> Self {
+ Self {
+ tools: ToolsCapability {
+ list_changed: false,
+ },
+ }
+ }
+}
+
pub fn jsonrpc_ok(id: Option, result: T) -> ApiResponse {
match serde_json::to_string(&JsonRpcResponse {
jsonrpc: JSONRPC_VERSION,
diff --git a/src/api/mcp/tools.rs b/src/api/mcp/tools.rs
index 7dec1a6..0d1a53d 100644
--- a/src/api/mcp/tools.rs
+++ b/src/api/mcp/tools.rs
@@ -178,16 +178,17 @@ fn json_tool_result(id: Option, payload: &T) -> ApiResponse
impl ToolCallResult {
fn success(text: String) -> Self {
- Self {
- content: vec![TextContent::text(text)],
- is_error: false,
- }
+ Self::new(text, false)
}
fn error(text: String) -> Self {
+ Self::new(text, true)
+ }
+
+ fn new(text: String, is_error: bool) -> Self {
Self {
content: vec![TextContent::text(text)],
- is_error: true,
+ is_error,
}
}
}
diff --git a/src/api/mod.rs b/src/api/mod.rs
deleted file mode 100644
index d32fd1e..0000000
--- a/src/api/mod.rs
+++ /dev/null
@@ -1,3 +0,0 @@
-pub mod auth;
-pub mod catchers;
-pub mod mcp;
diff --git a/src/app/mod.rs b/src/app/mod.rs
index ac3f05d..15ac86f 100644
--- a/src/app/mod.rs
+++ b/src/app/mod.rs
@@ -4,6 +4,7 @@ use crate::{api, core::config::Config, transport};
use dashmap::DashMap;
use rocket::routes;
use sqlx::{SqlitePool, sqlite::SqliteConnectOptions};
+use state::AppState;
use std::sync::Arc;
const DATABASE_PATH: &str = "server.db";
@@ -31,12 +32,16 @@ pub async fn apply_schema(database: &SqlitePool) -> Result<(), sqlx::Error> {
pub fn build_rocket(database: SqlitePool) -> rocket::Rocket {
rocket::build()
- .manage(state::AppState {
- database,
- registry: Arc::new(DashMap::new()),
- device_counts: Arc::new(DashMap::new()),
- config: Arc::new(Config::from_env()),
- })
+ .manage(build_state(database))
.mount("/", routes![transport::socket::connect, api::mcp::route])
.register("/", rocket::catchers![api::catchers::default_catcher])
}
+
+fn build_state(database: SqlitePool) -> AppState {
+ AppState {
+ database,
+ registry: Arc::new(DashMap::new()),
+ device_counts: Arc::new(DashMap::new()),
+ config: Arc::new(Config::from_env()),
+ }
+}
diff --git a/src/app/state.rs b/src/app/state.rs
index 6bbf0cf..afff8f2 100644
--- a/src/app/state.rs
+++ b/src/app/state.rs
@@ -9,6 +9,13 @@ pub type Database = SqlitePool;
pub type DeviceCounts = DashMap>;
pub type Registry = DashMap<(String, String), RegistryEntry>;
+pub struct AppState {
+ pub database: Database,
+ pub registry: Arc,
+ pub device_counts: Arc,
+ pub config: Arc,
+}
+
pub struct RegistryEntry {
pub sender: mpsc::Sender,
pub in_flight: Arc,
@@ -26,10 +33,3 @@ pub struct PendingExec {
pub command: String,
pub reply: oneshot::Sender,
}
-
-pub struct AppState {
- pub database: Database,
- pub registry: Arc,
- pub device_counts: Arc,
- pub config: Arc,
-}
diff --git a/src/core/mod.rs b/src/core/mod.rs
deleted file mode 100644
index b7f192c..0000000
--- a/src/core/mod.rs
+++ /dev/null
@@ -1,3 +0,0 @@
-pub mod config;
-pub mod db;
-pub mod validation;
diff --git a/src/main.rs b/src/main.rs
index 6d43eb5..cab156f 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,7 +1,21 @@
-pub mod api;
+pub mod api {
+ pub mod auth;
+ pub mod catchers;
+ pub mod mcp;
+}
+
pub mod app;
-pub mod core;
-pub mod transport;
+pub mod core {
+ pub mod config;
+ pub mod db;
+ pub mod validation;
+}
+
+pub mod transport {
+ pub mod execute;
+ pub mod protocol;
+ pub mod socket;
+}
use app::{apply_schema, build_rocket, connect_database};
use log::info;
@@ -14,8 +28,6 @@ async fn main() -> Result<(), Box> {
apply_schema(&database).await?;
info!("Launching server");
-
build_rocket(database).launch().await?;
-
Ok(())
}
diff --git a/src/transport/execute.rs b/src/transport/execute.rs
index f8474b2..140874f 100644
--- a/src/transport/execute.rs
+++ b/src/transport/execute.rs
@@ -33,7 +33,6 @@ pub async fn execute_for_user(
config: &Config,
) -> Result {
let ExecuteRequest { device_id, command } = request;
-
validate_execute_request(&device_id, &command, config)?;
execute(registry, user_id, &device_id, command, config)
@@ -69,14 +68,12 @@ async fn execute(
config: &Config,
) -> Result {
let key = (user_id.to_owned(), device_id.to_owned());
-
let (sender, in_flight_counter) = registry
.get(&key)
.map(|entry| (entry.sender.clone(), Arc::clone(&entry.in_flight)))
.ok_or_else(|| "Device not connected".to_string())?;
claim_execution_slot(&in_flight_counter, config.max_concurrent_executions)?;
-
let result = send_and_await(sender, command, config).await;
in_flight_counter.fetch_sub(1, Ordering::SeqCst);
diff --git a/src/transport/mod.rs b/src/transport/mod.rs
deleted file mode 100644
index ef00198..0000000
--- a/src/transport/mod.rs
+++ /dev/null
@@ -1,3 +0,0 @@
-pub mod execute;
-pub mod protocol;
-pub mod socket;
diff --git a/src/transport/socket/loop_state.rs b/src/transport/socket/loop_state.rs
index f459ccf..6f47192 100644
--- a/src/transport/socket/loop_state.rs
+++ b/src/transport/socket/loop_state.rs
@@ -42,10 +42,12 @@ pub async fn run_loop(
break;
}
}
- message = socket.next() => match handle_socket_message(socket, message, &pending).await {
- SocketAction::Break => break,
- SocketAction::ClearPongDeadline => awaiting_pong = false,
- SocketAction::Continue => {}
+ message = socket.next() => {
+ match handle_socket_message(socket, message, &pending).await {
+ SocketAction::Break => break,
+ SocketAction::ClearPongDeadline => awaiting_pong = false,
+ SocketAction::Continue => {}
+ }
}
}
}
diff --git a/src/transport/socket/mod.rs b/src/transport/socket/mod.rs
index 952076f..77a285c 100644
--- a/src/transport/socket/mod.rs
+++ b/src/transport/socket/mod.rs
@@ -15,10 +15,11 @@ pub fn connect<'r>(
device_id: MaybeDeviceId,
state: &'r State,
) -> Channel<'r> {
- let database = state.database.clone();
- let registry = state.registry.clone();
- let device_counts = state.device_counts.clone();
- let config = state.config.clone();
+ let app_state = state.inner();
+ let database = app_state.database.clone();
+ let registry = app_state.registry.clone();
+ let device_counts = app_state.device_counts.clone();
+ let config = app_state.config.clone();
ws.channel(move |socket| {
Box::pin(connection::handle(