From 042fe0cae288d1bee172624ca3922a5298163f96 Mon Sep 17 00:00:00 2001 From: ZeroZipp Date: Fri, 22 May 2026 10:07:25 +0200 Subject: [PATCH] sse endpoint --- src/api/mcp/mod.rs | 52 ++++++++++++++++++++++++++++++++++++++++++---- src/app/mod.rs | 8 ++++++- src/app/state.rs | 2 ++ 3 files changed, 57 insertions(+), 5 deletions(-) diff --git a/src/api/mcp/mod.rs b/src/api/mcp/mod.rs index 1e1aa36..0e4348e 100644 --- a/src/api/mcp/mod.rs +++ b/src/api/mcp/mod.rs @@ -14,23 +14,37 @@ use rocket::{ serde::json::Json, }; use serde_json::Value; +use tokio::sync::mpsc; use tools::{handle_tool_call, tool_definitions}; - -const MCP_MESSAGE_ENDPOINT_EVENT: &str = "endpoint"; +use uuid::Uuid; #[get("/mcp")] -pub fn sse(shutdown: Shutdown) -> EventStream![] { +pub fn sse(state: &State, shutdown: Shutdown) -> EventStream![] { + let connections = state.mcp_connections.clone(); + let connection_id = Uuid::new_v4().to_string(); + let endpoint = format!("/mcp/{}", connection_id); + let (sender, mut receiver) = mpsc::unbounded_channel(); + connections.insert(connection_id.clone(), sender); + EventStream! { - yield Event::json(&serde_json::json!({ "path": "/mcp" })).event(MCP_MESSAGE_ENDPOINT_EVENT); + yield Event::data(endpoint).event("endpoint"); loop { tokio::select! { _ = shutdown.clone() => break, + message = receiver.recv() => { + match message { + Some(message) => yield Event::data(message).event("message"), + None => break, + } + } _ = tokio::time::sleep(std::time::Duration::from_secs(15)) => { yield Event::comment("keepalive"); } } } + + connections.remove(&connection_id); } } @@ -39,6 +53,36 @@ pub async fn route( body: Json, token: MaybeBearerToken, state: &State, +) -> ApiResponse { + handle_request(body, token, state).await +} + +#[post("/mcp/", data = "")] +pub async fn route_sse( + connection_id: &str, + body: Json, + token: MaybeBearerToken, + state: &State, +) -> ApiResponse { + let response = handle_request(body, token, state).await; + let Some(sender) = state.mcp_connections.get(connection_id) else { + return jsonrpc_error(Status::NotFound, None, -32004, "MCP connection not found"); + }; + + if response.0 != Status::Accepted { + let _ = sender.send(response.1.0.clone()); + } + + ( + Status::Accepted, + rocket::response::content::RawJson(String::new()), + ) +} + +async fn handle_request( + body: Json, + token: MaybeBearerToken, + state: &State, ) -> ApiResponse { let request = body.into_inner(); diff --git a/src/app/mod.rs b/src/app/mod.rs index 6f3d725..2368bde 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -35,7 +35,12 @@ pub fn build_rocket(database: SqlitePool) -> rocket::Rocket { .manage(build_state(database)) .mount( "/", - routes![transport::socket::connect, api::mcp::route, api::mcp::sse], + routes![ + transport::socket::connect, + api::mcp::route, + api::mcp::route_sse, + api::mcp::sse + ], ) .register("/", rocket::catchers![api::catchers::default_catcher]) } @@ -45,6 +50,7 @@ fn build_state(database: SqlitePool) -> AppState { database, registry: Arc::new(DashMap::new()), device_counts: Arc::new(DashMap::new()), + mcp_connections: Arc::new(DashMap::new()), config: Arc::new(Config::from_env()), } } diff --git a/src/app/state.rs b/src/app/state.rs index afff8f2..4e7ade8 100644 --- a/src/app/state.rs +++ b/src/app/state.rs @@ -8,11 +8,13 @@ use tokio::sync::{mpsc, oneshot}; pub type Database = SqlitePool; pub type DeviceCounts = DashMap>; pub type Registry = DashMap<(String, String), RegistryEntry>; +pub type McpConnections = DashMap>; pub struct AppState { pub database: Database, pub registry: Arc, pub device_counts: Arc, + pub mcp_connections: Arc, pub config: Arc, }