80 lines
2.2 KiB
Rust
80 lines
2.2 KiB
Rust
mod protocol;
|
|
mod tools;
|
|
|
|
use crate::{api::auth::MaybeBearerToken, app::state::AppState};
|
|
use protocol::{
|
|
ApiResponse, InitializeResult, JSONRPC_VERSION, JsonRpcRequest, jsonrpc_error,
|
|
jsonrpc_notification_ok, jsonrpc_ok,
|
|
};
|
|
use rocket::{
|
|
Shutdown, State, get,
|
|
http::Status,
|
|
post,
|
|
response::stream::{Event, EventStream},
|
|
serde::json::Json,
|
|
};
|
|
use serde_json::Value;
|
|
use tools::{handle_tool_call, tool_definitions};
|
|
|
|
const MCP_MESSAGE_ENDPOINT_EVENT: &str = "endpoint";
|
|
|
|
#[get("/mcp")]
|
|
pub fn sse(shutdown: Shutdown) -> EventStream![] {
|
|
EventStream! {
|
|
yield Event::json(&serde_json::json!({ "path": "/mcp" })).event(MCP_MESSAGE_ENDPOINT_EVENT);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
_ = shutdown.clone() => break,
|
|
_ = tokio::time::sleep(std::time::Duration::from_secs(15)) => {
|
|
yield Event::comment("keepalive");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[post("/mcp", data = "<body>")]
|
|
pub async fn route(
|
|
body: Json<JsonRpcRequest>,
|
|
token: MaybeBearerToken,
|
|
state: &State<AppState>,
|
|
) -> ApiResponse {
|
|
let request = body.into_inner();
|
|
|
|
if request.jsonrpc != JSONRPC_VERSION {
|
|
return jsonrpc_error(
|
|
Status::BadRequest,
|
|
request.id,
|
|
-32600,
|
|
"Invalid JSON-RPC version",
|
|
);
|
|
}
|
|
|
|
match request.method.as_str() {
|
|
"initialize" => jsonrpc_ok(request.id, InitializeResult::new()),
|
|
"notifications/initialized" => jsonrpc_notification_ok(),
|
|
"tools/list" => {
|
|
if token.0.is_some() {
|
|
jsonrpc_ok(request.id, tool_definitions())
|
|
} else {
|
|
missing_authorization(request.id)
|
|
}
|
|
}
|
|
"tools/call" => {
|
|
handle_tool_call(request.id, request.params, token.0.as_deref(), state).await
|
|
}
|
|
_ if request.id.is_none() => jsonrpc_notification_ok(),
|
|
_ => jsonrpc_error(Status::BadRequest, request.id, -32601, "Method not found"),
|
|
}
|
|
}
|
|
|
|
fn missing_authorization(id: Option<Value>) -> ApiResponse {
|
|
jsonrpc_error(
|
|
Status::Unauthorized,
|
|
id,
|
|
-32001,
|
|
"Missing Authorization header",
|
|
)
|
|
}
|