This commit is contained in:
2026-05-22 08:40:04 +02:00
parent 749638b75c
commit 777fcf168e
13 changed files with 74 additions and 56 deletions
+2
View File
@@ -22,9 +22,11 @@ impl<'r> FromRequest<'r> for BearerToken {
let Some(header) = request.headers().get_one(AUTHORIZATION_HEADER) else { let Some(header) = request.headers().get_one(AUTHORIZATION_HEADER) else {
return Outcome::Error((Status::Unauthorized, "Missing Authorization header")); return Outcome::Error((Status::Unauthorized, "Missing Authorization header"));
}; };
let Some(token) = header.strip_prefix(BEARER_PREFIX) else { let Some(token) = header.strip_prefix(BEARER_PREFIX) else {
return Outcome::Error((Status::Unauthorized, "Invalid Authorization header format")); return Outcome::Error((Status::Unauthorized, "Invalid Authorization header format"));
}; };
Outcome::Success(BearerToken(token.to_owned())) Outcome::Success(BearerToken(token.to_owned()))
} }
} }
+9 -8
View File
@@ -3,15 +3,13 @@ mod tools;
use crate::{api::auth::MaybeBearerToken, app::state::AppState}; use crate::{api::auth::MaybeBearerToken, app::state::AppState};
use protocol::{ use protocol::{
ApiResponse, InitializeResult, JSONRPC_VERSION, jsonrpc_error, jsonrpc_notification_ok, ApiResponse, InitializeResult, JSONRPC_VERSION, JsonRpcRequest, jsonrpc_error,
jsonrpc_ok, jsonrpc_notification_ok, jsonrpc_ok,
}; };
use rocket::{State, http::Status, post, serde::json::Json}; use rocket::{State, http::Status, post, serde::json::Json};
use serde_json::Value; use serde_json::Value;
use tools::{handle_tool_call, tool_definitions}; use tools::{handle_tool_call, tool_definitions};
pub use protocol::JsonRpcRequest;
#[post("/mcp", data = "<body>")] #[post("/mcp", data = "<body>")]
pub async fn route( pub async fn route(
body: Json<JsonRpcRequest>, body: Json<JsonRpcRequest>,
@@ -32,10 +30,13 @@ pub async fn route(
match request.method.as_str() { match request.method.as_str() {
"initialize" => jsonrpc_ok(request.id, InitializeResult::new()), "initialize" => jsonrpc_ok(request.id, InitializeResult::new()),
"notifications/initialized" => jsonrpc_notification_ok(), "notifications/initialized" => jsonrpc_notification_ok(),
"tools/list" => match token.0 { "tools/list" => {
Some(_) => jsonrpc_ok(request.id, tool_definitions()), if token.0.is_some() {
None => missing_authorization(request.id), jsonrpc_ok(request.id, tool_definitions())
}, } else {
missing_authorization(request.id)
}
}
"tools/call" => { "tools/call" => {
handle_tool_call(request.id, request.params, token.0.as_deref(), state).await handle_tool_call(request.id, request.params, token.0.as_deref(), state).await
} }
+11 -5
View File
@@ -65,11 +65,7 @@ impl InitializeResult {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
protocol_version: MCP_PROTOCOL_VERSION, protocol_version: MCP_PROTOCOL_VERSION,
capabilities: ServerCapabilities { capabilities: ServerCapabilities::new(),
tools: ToolsCapability {
list_changed: false,
},
},
server_info: ServerInfo { server_info: ServerInfo {
name: env!("CARGO_PKG_NAME"), name: env!("CARGO_PKG_NAME"),
version: env!("CARGO_PKG_VERSION"), version: env!("CARGO_PKG_VERSION"),
@@ -78,6 +74,16 @@ impl InitializeResult {
} }
} }
impl ServerCapabilities {
fn new() -> Self {
Self {
tools: ToolsCapability {
list_changed: false,
},
}
}
}
pub fn jsonrpc_ok<T: Serialize>(id: Option<Value>, result: T) -> ApiResponse { pub fn jsonrpc_ok<T: Serialize>(id: Option<Value>, result: T) -> ApiResponse {
match serde_json::to_string(&JsonRpcResponse { match serde_json::to_string(&JsonRpcResponse {
jsonrpc: JSONRPC_VERSION, jsonrpc: JSONRPC_VERSION,
+6 -5
View File
@@ -178,16 +178,17 @@ fn json_tool_result<T: Serialize>(id: Option<Value>, payload: &T) -> ApiResponse
impl ToolCallResult { impl ToolCallResult {
fn success(text: String) -> Self { fn success(text: String) -> Self {
Self { Self::new(text, false)
content: vec![TextContent::text(text)],
is_error: false,
}
} }
fn error(text: String) -> Self { fn error(text: String) -> Self {
Self::new(text, true)
}
fn new(text: String, is_error: bool) -> Self {
Self { Self {
content: vec![TextContent::text(text)], content: vec![TextContent::text(text)],
is_error: true, is_error,
} }
} }
} }
-3
View File
@@ -1,3 +0,0 @@
pub mod auth;
pub mod catchers;
pub mod mcp;
+9 -4
View File
@@ -4,6 +4,7 @@ use crate::{api, core::config::Config, transport};
use dashmap::DashMap; use dashmap::DashMap;
use rocket::routes; use rocket::routes;
use sqlx::{SqlitePool, sqlite::SqliteConnectOptions}; use sqlx::{SqlitePool, sqlite::SqliteConnectOptions};
use state::AppState;
use std::sync::Arc; use std::sync::Arc;
const DATABASE_PATH: &str = "server.db"; const DATABASE_PATH: &str = "server.db";
@@ -31,12 +32,16 @@ pub async fn apply_schema(database: &SqlitePool) -> Result<(), sqlx::Error> {
pub fn build_rocket(database: SqlitePool) -> rocket::Rocket<rocket::Build> { pub fn build_rocket(database: SqlitePool) -> rocket::Rocket<rocket::Build> {
rocket::build() rocket::build()
.manage(state::AppState { .manage(build_state(database))
.mount("/", routes![transport::socket::connect, api::mcp::route])
.register("/", rocket::catchers![api::catchers::default_catcher])
}
fn build_state(database: SqlitePool) -> AppState {
AppState {
database, database,
registry: Arc::new(DashMap::new()), registry: Arc::new(DashMap::new()),
device_counts: Arc::new(DashMap::new()), device_counts: Arc::new(DashMap::new()),
config: Arc::new(Config::from_env()), config: Arc::new(Config::from_env()),
}) }
.mount("/", routes![transport::socket::connect, api::mcp::route])
.register("/", rocket::catchers![api::catchers::default_catcher])
} }
+7 -7
View File
@@ -9,6 +9,13 @@ pub type Database = SqlitePool;
pub type DeviceCounts = DashMap<String, Arc<AtomicUsize>>; pub type DeviceCounts = DashMap<String, Arc<AtomicUsize>>;
pub type Registry = DashMap<(String, String), RegistryEntry>; pub type Registry = DashMap<(String, String), RegistryEntry>;
pub struct AppState {
pub database: Database,
pub registry: Arc<Registry>,
pub device_counts: Arc<DeviceCounts>,
pub config: Arc<Config>,
}
pub struct RegistryEntry { pub struct RegistryEntry {
pub sender: mpsc::Sender<PendingExec>, pub sender: mpsc::Sender<PendingExec>,
pub in_flight: Arc<AtomicUsize>, pub in_flight: Arc<AtomicUsize>,
@@ -26,10 +33,3 @@ pub struct PendingExec {
pub command: String, pub command: String,
pub reply: oneshot::Sender<ExecResult>, pub reply: oneshot::Sender<ExecResult>,
} }
pub struct AppState {
pub database: Database,
pub registry: Arc<Registry>,
pub device_counts: Arc<DeviceCounts>,
pub config: Arc<Config>,
}
-3
View File
@@ -1,3 +0,0 @@
pub mod config;
pub mod db;
pub mod validation;
+17 -5
View File
@@ -1,7 +1,21 @@
pub mod api; pub mod api {
pub mod auth;
pub mod catchers;
pub mod mcp;
}
pub mod app; pub mod app;
pub mod core; pub mod core {
pub mod transport; pub mod config;
pub mod db;
pub mod validation;
}
pub mod transport {
pub mod execute;
pub mod protocol;
pub mod socket;
}
use app::{apply_schema, build_rocket, connect_database}; use app::{apply_schema, build_rocket, connect_database};
use log::info; use log::info;
@@ -14,8 +28,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
apply_schema(&database).await?; apply_schema(&database).await?;
info!("Launching server"); info!("Launching server");
build_rocket(database).launch().await?; build_rocket(database).launch().await?;
Ok(()) Ok(())
} }
-3
View File
@@ -33,7 +33,6 @@ pub async fn execute_for_user(
config: &Config, config: &Config,
) -> Result<ExecResult, (Status, String)> { ) -> Result<ExecResult, (Status, String)> {
let ExecuteRequest { device_id, command } = request; let ExecuteRequest { device_id, command } = request;
validate_execute_request(&device_id, &command, config)?; validate_execute_request(&device_id, &command, config)?;
execute(registry, user_id, &device_id, command, config) execute(registry, user_id, &device_id, command, config)
@@ -69,14 +68,12 @@ async fn execute(
config: &Config, config: &Config,
) -> Result<ExecResult, String> { ) -> Result<ExecResult, String> {
let key = (user_id.to_owned(), device_id.to_owned()); let key = (user_id.to_owned(), device_id.to_owned());
let (sender, in_flight_counter) = registry let (sender, in_flight_counter) = registry
.get(&key) .get(&key)
.map(|entry| (entry.sender.clone(), Arc::clone(&entry.in_flight))) .map(|entry| (entry.sender.clone(), Arc::clone(&entry.in_flight)))
.ok_or_else(|| "Device not connected".to_string())?; .ok_or_else(|| "Device not connected".to_string())?;
claim_execution_slot(&in_flight_counter, config.max_concurrent_executions)?; claim_execution_slot(&in_flight_counter, config.max_concurrent_executions)?;
let result = send_and_await(sender, command, config).await; let result = send_and_await(sender, command, config).await;
in_flight_counter.fetch_sub(1, Ordering::SeqCst); in_flight_counter.fetch_sub(1, Ordering::SeqCst);
-3
View File
@@ -1,3 +0,0 @@
pub mod execute;
pub mod protocol;
pub mod socket;
+3 -1
View File
@@ -42,13 +42,15 @@ pub async fn run_loop(
break; break;
} }
} }
message = socket.next() => match handle_socket_message(socket, message, &pending).await { message = socket.next() => {
match handle_socket_message(socket, message, &pending).await {
SocketAction::Break => break, SocketAction::Break => break,
SocketAction::ClearPongDeadline => awaiting_pong = false, SocketAction::ClearPongDeadline => awaiting_pong = false,
SocketAction::Continue => {} SocketAction::Continue => {}
} }
} }
} }
}
pending.clear(); pending.clear();
} }
+5 -4
View File
@@ -15,10 +15,11 @@ pub fn connect<'r>(
device_id: MaybeDeviceId, device_id: MaybeDeviceId,
state: &'r State<AppState>, state: &'r State<AppState>,
) -> Channel<'r> { ) -> Channel<'r> {
let database = state.database.clone(); let app_state = state.inner();
let registry = state.registry.clone(); let database = app_state.database.clone();
let device_counts = state.device_counts.clone(); let registry = app_state.registry.clone();
let config = state.config.clone(); let device_counts = app_state.device_counts.clone();
let config = app_state.config.clone();
ws.channel(move |socket| { ws.channel(move |socket| {
Box::pin(connection::handle( Box::pin(connection::handle(