cleanup
This commit is contained in:
@@ -22,9 +22,11 @@ impl<'r> FromRequest<'r> for BearerToken {
|
|||||||
let Some(header) = request.headers().get_one(AUTHORIZATION_HEADER) else {
|
let Some(header) = request.headers().get_one(AUTHORIZATION_HEADER) else {
|
||||||
return Outcome::Error((Status::Unauthorized, "Missing Authorization header"));
|
return Outcome::Error((Status::Unauthorized, "Missing Authorization header"));
|
||||||
};
|
};
|
||||||
|
|
||||||
let Some(token) = header.strip_prefix(BEARER_PREFIX) else {
|
let Some(token) = header.strip_prefix(BEARER_PREFIX) else {
|
||||||
return Outcome::Error((Status::Unauthorized, "Invalid Authorization header format"));
|
return Outcome::Error((Status::Unauthorized, "Invalid Authorization header format"));
|
||||||
};
|
};
|
||||||
|
|
||||||
Outcome::Success(BearerToken(token.to_owned()))
|
Outcome::Success(BearerToken(token.to_owned()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+9
-8
@@ -3,15 +3,13 @@ mod tools;
|
|||||||
|
|
||||||
use crate::{api::auth::MaybeBearerToken, app::state::AppState};
|
use crate::{api::auth::MaybeBearerToken, app::state::AppState};
|
||||||
use protocol::{
|
use protocol::{
|
||||||
ApiResponse, InitializeResult, JSONRPC_VERSION, jsonrpc_error, jsonrpc_notification_ok,
|
ApiResponse, InitializeResult, JSONRPC_VERSION, JsonRpcRequest, jsonrpc_error,
|
||||||
jsonrpc_ok,
|
jsonrpc_notification_ok, jsonrpc_ok,
|
||||||
};
|
};
|
||||||
use rocket::{State, http::Status, post, serde::json::Json};
|
use rocket::{State, http::Status, post, serde::json::Json};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use tools::{handle_tool_call, tool_definitions};
|
use tools::{handle_tool_call, tool_definitions};
|
||||||
|
|
||||||
pub use protocol::JsonRpcRequest;
|
|
||||||
|
|
||||||
#[post("/mcp", data = "<body>")]
|
#[post("/mcp", data = "<body>")]
|
||||||
pub async fn route(
|
pub async fn route(
|
||||||
body: Json<JsonRpcRequest>,
|
body: Json<JsonRpcRequest>,
|
||||||
@@ -32,10 +30,13 @@ pub async fn route(
|
|||||||
match request.method.as_str() {
|
match request.method.as_str() {
|
||||||
"initialize" => jsonrpc_ok(request.id, InitializeResult::new()),
|
"initialize" => jsonrpc_ok(request.id, InitializeResult::new()),
|
||||||
"notifications/initialized" => jsonrpc_notification_ok(),
|
"notifications/initialized" => jsonrpc_notification_ok(),
|
||||||
"tools/list" => match token.0 {
|
"tools/list" => {
|
||||||
Some(_) => jsonrpc_ok(request.id, tool_definitions()),
|
if token.0.is_some() {
|
||||||
None => missing_authorization(request.id),
|
jsonrpc_ok(request.id, tool_definitions())
|
||||||
},
|
} else {
|
||||||
|
missing_authorization(request.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
"tools/call" => {
|
"tools/call" => {
|
||||||
handle_tool_call(request.id, request.params, token.0.as_deref(), state).await
|
handle_tool_call(request.id, request.params, token.0.as_deref(), state).await
|
||||||
}
|
}
|
||||||
|
|||||||
+11
-5
@@ -65,11 +65,7 @@ impl InitializeResult {
|
|||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
protocol_version: MCP_PROTOCOL_VERSION,
|
protocol_version: MCP_PROTOCOL_VERSION,
|
||||||
capabilities: ServerCapabilities {
|
capabilities: ServerCapabilities::new(),
|
||||||
tools: ToolsCapability {
|
|
||||||
list_changed: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
server_info: ServerInfo {
|
server_info: ServerInfo {
|
||||||
name: env!("CARGO_PKG_NAME"),
|
name: env!("CARGO_PKG_NAME"),
|
||||||
version: env!("CARGO_PKG_VERSION"),
|
version: env!("CARGO_PKG_VERSION"),
|
||||||
@@ -78,6 +74,16 @@ impl InitializeResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ServerCapabilities {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
tools: ToolsCapability {
|
||||||
|
list_changed: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn jsonrpc_ok<T: Serialize>(id: Option<Value>, result: T) -> ApiResponse {
|
pub fn jsonrpc_ok<T: Serialize>(id: Option<Value>, result: T) -> ApiResponse {
|
||||||
match serde_json::to_string(&JsonRpcResponse {
|
match serde_json::to_string(&JsonRpcResponse {
|
||||||
jsonrpc: JSONRPC_VERSION,
|
jsonrpc: JSONRPC_VERSION,
|
||||||
|
|||||||
@@ -178,16 +178,17 @@ fn json_tool_result<T: Serialize>(id: Option<Value>, payload: &T) -> ApiResponse
|
|||||||
|
|
||||||
impl ToolCallResult {
|
impl ToolCallResult {
|
||||||
fn success(text: String) -> Self {
|
fn success(text: String) -> Self {
|
||||||
Self {
|
Self::new(text, false)
|
||||||
content: vec![TextContent::text(text)],
|
|
||||||
is_error: false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn error(text: String) -> Self {
|
fn error(text: String) -> Self {
|
||||||
|
Self::new(text, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(text: String, is_error: bool) -> Self {
|
||||||
Self {
|
Self {
|
||||||
content: vec![TextContent::text(text)],
|
content: vec![TextContent::text(text)],
|
||||||
is_error: true,
|
is_error,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
pub mod auth;
|
|
||||||
pub mod catchers;
|
|
||||||
pub mod mcp;
|
|
||||||
+9
-4
@@ -4,6 +4,7 @@ use crate::{api, core::config::Config, transport};
|
|||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use rocket::routes;
|
use rocket::routes;
|
||||||
use sqlx::{SqlitePool, sqlite::SqliteConnectOptions};
|
use sqlx::{SqlitePool, sqlite::SqliteConnectOptions};
|
||||||
|
use state::AppState;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
const DATABASE_PATH: &str = "server.db";
|
const DATABASE_PATH: &str = "server.db";
|
||||||
@@ -31,12 +32,16 @@ pub async fn apply_schema(database: &SqlitePool) -> Result<(), sqlx::Error> {
|
|||||||
|
|
||||||
pub fn build_rocket(database: SqlitePool) -> rocket::Rocket<rocket::Build> {
|
pub fn build_rocket(database: SqlitePool) -> rocket::Rocket<rocket::Build> {
|
||||||
rocket::build()
|
rocket::build()
|
||||||
.manage(state::AppState {
|
.manage(build_state(database))
|
||||||
|
.mount("/", routes![transport::socket::connect, api::mcp::route])
|
||||||
|
.register("/", rocket::catchers![api::catchers::default_catcher])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_state(database: SqlitePool) -> AppState {
|
||||||
|
AppState {
|
||||||
database,
|
database,
|
||||||
registry: Arc::new(DashMap::new()),
|
registry: Arc::new(DashMap::new()),
|
||||||
device_counts: Arc::new(DashMap::new()),
|
device_counts: Arc::new(DashMap::new()),
|
||||||
config: Arc::new(Config::from_env()),
|
config: Arc::new(Config::from_env()),
|
||||||
})
|
}
|
||||||
.mount("/", routes![transport::socket::connect, api::mcp::route])
|
|
||||||
.register("/", rocket::catchers![api::catchers::default_catcher])
|
|
||||||
}
|
}
|
||||||
|
|||||||
+7
-7
@@ -9,6 +9,13 @@ pub type Database = SqlitePool;
|
|||||||
pub type DeviceCounts = DashMap<String, Arc<AtomicUsize>>;
|
pub type DeviceCounts = DashMap<String, Arc<AtomicUsize>>;
|
||||||
pub type Registry = DashMap<(String, String), RegistryEntry>;
|
pub type Registry = DashMap<(String, String), RegistryEntry>;
|
||||||
|
|
||||||
|
pub struct AppState {
|
||||||
|
pub database: Database,
|
||||||
|
pub registry: Arc<Registry>,
|
||||||
|
pub device_counts: Arc<DeviceCounts>,
|
||||||
|
pub config: Arc<Config>,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct RegistryEntry {
|
pub struct RegistryEntry {
|
||||||
pub sender: mpsc::Sender<PendingExec>,
|
pub sender: mpsc::Sender<PendingExec>,
|
||||||
pub in_flight: Arc<AtomicUsize>,
|
pub in_flight: Arc<AtomicUsize>,
|
||||||
@@ -26,10 +33,3 @@ pub struct PendingExec {
|
|||||||
pub command: String,
|
pub command: String,
|
||||||
pub reply: oneshot::Sender<ExecResult>,
|
pub reply: oneshot::Sender<ExecResult>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct AppState {
|
|
||||||
pub database: Database,
|
|
||||||
pub registry: Arc<Registry>,
|
|
||||||
pub device_counts: Arc<DeviceCounts>,
|
|
||||||
pub config: Arc<Config>,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
pub mod config;
|
|
||||||
pub mod db;
|
|
||||||
pub mod validation;
|
|
||||||
+17
-5
@@ -1,7 +1,21 @@
|
|||||||
pub mod api;
|
pub mod api {
|
||||||
|
pub mod auth;
|
||||||
|
pub mod catchers;
|
||||||
|
pub mod mcp;
|
||||||
|
}
|
||||||
|
|
||||||
pub mod app;
|
pub mod app;
|
||||||
pub mod core;
|
pub mod core {
|
||||||
pub mod transport;
|
pub mod config;
|
||||||
|
pub mod db;
|
||||||
|
pub mod validation;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub mod transport {
|
||||||
|
pub mod execute;
|
||||||
|
pub mod protocol;
|
||||||
|
pub mod socket;
|
||||||
|
}
|
||||||
|
|
||||||
use app::{apply_schema, build_rocket, connect_database};
|
use app::{apply_schema, build_rocket, connect_database};
|
||||||
use log::info;
|
use log::info;
|
||||||
@@ -14,8 +28,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
apply_schema(&database).await?;
|
apply_schema(&database).await?;
|
||||||
|
|
||||||
info!("Launching server");
|
info!("Launching server");
|
||||||
|
|
||||||
build_rocket(database).launch().await?;
|
build_rocket(database).launch().await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ pub async fn execute_for_user(
|
|||||||
config: &Config,
|
config: &Config,
|
||||||
) -> Result<ExecResult, (Status, String)> {
|
) -> Result<ExecResult, (Status, String)> {
|
||||||
let ExecuteRequest { device_id, command } = request;
|
let ExecuteRequest { device_id, command } = request;
|
||||||
|
|
||||||
validate_execute_request(&device_id, &command, config)?;
|
validate_execute_request(&device_id, &command, config)?;
|
||||||
|
|
||||||
execute(registry, user_id, &device_id, command, config)
|
execute(registry, user_id, &device_id, command, config)
|
||||||
@@ -69,14 +68,12 @@ async fn execute(
|
|||||||
config: &Config,
|
config: &Config,
|
||||||
) -> Result<ExecResult, String> {
|
) -> Result<ExecResult, String> {
|
||||||
let key = (user_id.to_owned(), device_id.to_owned());
|
let key = (user_id.to_owned(), device_id.to_owned());
|
||||||
|
|
||||||
let (sender, in_flight_counter) = registry
|
let (sender, in_flight_counter) = registry
|
||||||
.get(&key)
|
.get(&key)
|
||||||
.map(|entry| (entry.sender.clone(), Arc::clone(&entry.in_flight)))
|
.map(|entry| (entry.sender.clone(), Arc::clone(&entry.in_flight)))
|
||||||
.ok_or_else(|| "Device not connected".to_string())?;
|
.ok_or_else(|| "Device not connected".to_string())?;
|
||||||
|
|
||||||
claim_execution_slot(&in_flight_counter, config.max_concurrent_executions)?;
|
claim_execution_slot(&in_flight_counter, config.max_concurrent_executions)?;
|
||||||
|
|
||||||
let result = send_and_await(sender, command, config).await;
|
let result = send_and_await(sender, command, config).await;
|
||||||
in_flight_counter.fetch_sub(1, Ordering::SeqCst);
|
in_flight_counter.fetch_sub(1, Ordering::SeqCst);
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
pub mod execute;
|
|
||||||
pub mod protocol;
|
|
||||||
pub mod socket;
|
|
||||||
@@ -42,13 +42,15 @@ pub async fn run_loop(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
message = socket.next() => match handle_socket_message(socket, message, &pending).await {
|
message = socket.next() => {
|
||||||
|
match handle_socket_message(socket, message, &pending).await {
|
||||||
SocketAction::Break => break,
|
SocketAction::Break => break,
|
||||||
SocketAction::ClearPongDeadline => awaiting_pong = false,
|
SocketAction::ClearPongDeadline => awaiting_pong = false,
|
||||||
SocketAction::Continue => {}
|
SocketAction::Continue => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pending.clear();
|
pending.clear();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,10 +15,11 @@ pub fn connect<'r>(
|
|||||||
device_id: MaybeDeviceId,
|
device_id: MaybeDeviceId,
|
||||||
state: &'r State<AppState>,
|
state: &'r State<AppState>,
|
||||||
) -> Channel<'r> {
|
) -> Channel<'r> {
|
||||||
let database = state.database.clone();
|
let app_state = state.inner();
|
||||||
let registry = state.registry.clone();
|
let database = app_state.database.clone();
|
||||||
let device_counts = state.device_counts.clone();
|
let registry = app_state.registry.clone();
|
||||||
let config = state.config.clone();
|
let device_counts = app_state.device_counts.clone();
|
||||||
|
let config = app_state.config.clone();
|
||||||
|
|
||||||
ws.channel(move |socket| {
|
ws.channel(move |socket| {
|
||||||
Box::pin(connection::handle(
|
Box::pin(connection::handle(
|
||||||
|
|||||||
Reference in New Issue
Block a user