diff --git a/src/api/auth.rs b/src/api/auth.rs index 8d3f5f7..74586cb 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -10,6 +10,7 @@ const DEVICE_ID_HEADER: &str = "X-Device-ID"; const BEARER_PREFIX: &str = "Bearer "; pub struct BearerToken(pub String); +pub struct MaybeBearerToken(pub Option); pub struct DeviceId(pub String); pub struct MaybeDeviceId(pub Result); @@ -28,6 +29,21 @@ impl<'r> FromRequest<'r> for BearerToken { } } +#[rocket::async_trait] +impl<'r> FromRequest<'r> for MaybeBearerToken { + type Error = std::convert::Infallible; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let token = request + .headers() + .get_one(AUTHORIZATION_HEADER) + .and_then(|header| header.strip_prefix(BEARER_PREFIX)) + .map(str::to_owned); + + Outcome::Success(MaybeBearerToken(token)) + } +} + #[rocket::async_trait] impl<'r> FromRequest<'r> for DeviceId { type Error = String; diff --git a/src/api/mcp.rs b/src/api/mcp.rs index 3c1faea..6cde5a8 100644 --- a/src/api/mcp.rs +++ b/src/api/mcp.rs @@ -1,6 +1,6 @@ use crate::{ AppState, - api::auth::BearerToken, + api::auth::MaybeBearerToken, core::db::lookup_api_token, transport::execute::{ExecuteRequest, execute_for_token}, }; @@ -113,7 +113,7 @@ struct TextContent { #[post("/mcp", data = "")] pub async fn route( body: Json, - token: BearerToken, + token: MaybeBearerToken, state: &State, ) -> ApiResponse { let request = body.into_inner(); @@ -144,13 +144,18 @@ pub async fn route( }, ), "notifications/initialized" => jsonrpc_notification_ok(), - "tools/list" => jsonrpc_ok( - request.id, - ToolsListResult { - tools: tool_definitions(), - }, - ), - "tools/call" => handle_tool_call(request.id, request.params, &token.0, state).await, + "tools/list" => match require_bearer_token(token.0.as_deref(), request.id.clone()) { + Ok(_) => jsonrpc_ok( + request.id, + ToolsListResult { + tools: tool_definitions(), + }, + ), + Err(response) => response, + }, + "tools/call" => { + handle_tool_call(request.id, request.params, token.0.as_deref(), state).await + } _ if request.id.is_none() => jsonrpc_notification_ok(), _ => jsonrpc_error(Status::BadRequest, request.id, -32601, "Method not found"), } @@ -159,9 +164,14 @@ pub async fn route( async fn handle_tool_call( id: Option, params: Option, - token: &str, + token: Option<&str>, state: &State, ) -> ApiResponse { + let token = match require_bearer_token(token, id.clone()) { + Ok(token) => token, + Err(response) => return response, + }; + let Some(params) = params else { return jsonrpc_error(Status::BadRequest, id, -32602, "Missing params"); }; @@ -177,6 +187,17 @@ async fn handle_tool_call( } } +fn require_bearer_token(token: Option<&str>, id: Option) -> Result<&str, ApiResponse> { + token.ok_or_else(|| { + jsonrpc_error( + Status::Unauthorized, + id, + -32001, + "Missing Authorization header", + ) + }) +} + fn tool_definitions() -> Vec { vec![ ToolDefinition {