sse endpoint

This commit is contained in:
2026-05-22 10:07:25 +02:00
parent ca1315bfaf
commit 042fe0cae2
3 changed files with 57 additions and 5 deletions
+48 -4
View File
@@ -14,23 +14,37 @@ use rocket::{
serde::json::Json,
};
use serde_json::Value;
use tokio::sync::mpsc;
use tools::{handle_tool_call, tool_definitions};
const MCP_MESSAGE_ENDPOINT_EVENT: &str = "endpoint";
use uuid::Uuid;
#[get("/mcp")]
pub fn sse(shutdown: Shutdown) -> EventStream![] {
pub fn sse(state: &State<AppState>, shutdown: Shutdown) -> EventStream![] {
let connections = state.mcp_connections.clone();
let connection_id = Uuid::new_v4().to_string();
let endpoint = format!("/mcp/{}", connection_id);
let (sender, mut receiver) = mpsc::unbounded_channel();
connections.insert(connection_id.clone(), sender);
EventStream! {
yield Event::json(&serde_json::json!({ "path": "/mcp" })).event(MCP_MESSAGE_ENDPOINT_EVENT);
yield Event::data(endpoint).event("endpoint");
loop {
tokio::select! {
_ = shutdown.clone() => break,
message = receiver.recv() => {
match message {
Some(message) => yield Event::data(message).event("message"),
None => break,
}
}
_ = tokio::time::sleep(std::time::Duration::from_secs(15)) => {
yield Event::comment("keepalive");
}
}
}
connections.remove(&connection_id);
}
}
@@ -39,6 +53,36 @@ pub async fn route(
body: Json<JsonRpcRequest>,
token: MaybeBearerToken,
state: &State<AppState>,
) -> ApiResponse {
handle_request(body, token, state).await
}
#[post("/mcp/<connection_id>", data = "<body>")]
pub async fn route_sse(
connection_id: &str,
body: Json<JsonRpcRequest>,
token: MaybeBearerToken,
state: &State<AppState>,
) -> ApiResponse {
let response = handle_request(body, token, state).await;
let Some(sender) = state.mcp_connections.get(connection_id) else {
return jsonrpc_error(Status::NotFound, None, -32004, "MCP connection not found");
};
if response.0 != Status::Accepted {
let _ = sender.send(response.1.0.clone());
}
(
Status::Accepted,
rocket::response::content::RawJson(String::new()),
)
}
async fn handle_request(
body: Json<JsonRpcRequest>,
token: MaybeBearerToken,
state: &State<AppState>,
) -> ApiResponse {
let request = body.into_inner();
+7 -1
View File
@@ -35,7 +35,12 @@ pub fn build_rocket(database: SqlitePool) -> rocket::Rocket<rocket::Build> {
.manage(build_state(database))
.mount(
"/",
routes![transport::socket::connect, api::mcp::route, api::mcp::sse],
routes![
transport::socket::connect,
api::mcp::route,
api::mcp::route_sse,
api::mcp::sse
],
)
.register("/", rocket::catchers![api::catchers::default_catcher])
}
@@ -45,6 +50,7 @@ fn build_state(database: SqlitePool) -> AppState {
database,
registry: Arc::new(DashMap::new()),
device_counts: Arc::new(DashMap::new()),
mcp_connections: Arc::new(DashMap::new()),
config: Arc::new(Config::from_env()),
}
}
+2
View File
@@ -8,11 +8,13 @@ use tokio::sync::{mpsc, oneshot};
pub type Database = SqlitePool;
pub type DeviceCounts = DashMap<String, Arc<AtomicUsize>>;
pub type Registry = DashMap<(String, String), RegistryEntry>;
pub type McpConnections = DashMap<String, mpsc::UnboundedSender<String>>;
pub struct AppState {
pub database: Database,
pub registry: Arc<Registry>,
pub device_counts: Arc<DeviceCounts>,
pub mcp_connections: Arc<McpConnections>,
pub config: Arc<Config>,
}