auth bug fix
This commit is contained in:
@@ -10,6 +10,7 @@ const DEVICE_ID_HEADER: &str = "X-Device-ID";
|
||||
const BEARER_PREFIX: &str = "Bearer ";
|
||||
|
||||
pub struct BearerToken(pub String);
|
||||
pub struct MaybeBearerToken(pub Option<String>);
|
||||
pub struct DeviceId(pub String);
|
||||
pub struct MaybeDeviceId(pub Result<String, String>);
|
||||
|
||||
@@ -28,6 +29,21 @@ impl<'r> FromRequest<'r> for BearerToken {
|
||||
}
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for MaybeBearerToken {
|
||||
type Error = std::convert::Infallible;
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||
let token = request
|
||||
.headers()
|
||||
.get_one(AUTHORIZATION_HEADER)
|
||||
.and_then(|header| header.strip_prefix(BEARER_PREFIX))
|
||||
.map(str::to_owned);
|
||||
|
||||
Outcome::Success(MaybeBearerToken(token))
|
||||
}
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for DeviceId {
|
||||
type Error = String;
|
||||
|
||||
+26
-5
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
AppState,
|
||||
api::auth::BearerToken,
|
||||
api::auth::MaybeBearerToken,
|
||||
core::db::lookup_api_token,
|
||||
transport::execute::{ExecuteRequest, execute_for_token},
|
||||
};
|
||||
@@ -113,7 +113,7 @@ struct TextContent {
|
||||
#[post("/mcp", data = "<body>")]
|
||||
pub async fn route(
|
||||
body: Json<JsonRpcRequest>,
|
||||
token: BearerToken,
|
||||
token: MaybeBearerToken,
|
||||
state: &State<AppState>,
|
||||
) -> ApiResponse {
|
||||
let request = body.into_inner();
|
||||
@@ -144,13 +144,18 @@ pub async fn route(
|
||||
},
|
||||
),
|
||||
"notifications/initialized" => jsonrpc_notification_ok(),
|
||||
"tools/list" => jsonrpc_ok(
|
||||
"tools/list" => match require_bearer_token(token.0.as_deref(), request.id.clone()) {
|
||||
Ok(_) => jsonrpc_ok(
|
||||
request.id,
|
||||
ToolsListResult {
|
||||
tools: tool_definitions(),
|
||||
},
|
||||
),
|
||||
"tools/call" => handle_tool_call(request.id, request.params, &token.0, state).await,
|
||||
Err(response) => response,
|
||||
},
|
||||
"tools/call" => {
|
||||
handle_tool_call(request.id, request.params, token.0.as_deref(), state).await
|
||||
}
|
||||
_ if request.id.is_none() => jsonrpc_notification_ok(),
|
||||
_ => jsonrpc_error(Status::BadRequest, request.id, -32601, "Method not found"),
|
||||
}
|
||||
@@ -159,9 +164,14 @@ pub async fn route(
|
||||
async fn handle_tool_call(
|
||||
id: Option<Value>,
|
||||
params: Option<Value>,
|
||||
token: &str,
|
||||
token: Option<&str>,
|
||||
state: &State<AppState>,
|
||||
) -> ApiResponse {
|
||||
let token = match require_bearer_token(token, id.clone()) {
|
||||
Ok(token) => token,
|
||||
Err(response) => return response,
|
||||
};
|
||||
|
||||
let Some(params) = params else {
|
||||
return jsonrpc_error(Status::BadRequest, id, -32602, "Missing params");
|
||||
};
|
||||
@@ -177,6 +187,17 @@ async fn handle_tool_call(
|
||||
}
|
||||
}
|
||||
|
||||
fn require_bearer_token(token: Option<&str>, id: Option<Value>) -> Result<&str, ApiResponse> {
|
||||
token.ok_or_else(|| {
|
||||
jsonrpc_error(
|
||||
Status::Unauthorized,
|
||||
id,
|
||||
-32001,
|
||||
"Missing Authorization header",
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn tool_definitions() -> Vec<ToolDefinition> {
|
||||
vec![
|
||||
ToolDefinition {
|
||||
|
||||
Reference in New Issue
Block a user