sse endpoint
This commit is contained in:
+48
-4
@@ -14,23 +14,37 @@ use rocket::{
|
||||
serde::json::Json,
|
||||
};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
use tools::{handle_tool_call, tool_definitions};
|
||||
|
||||
const MCP_MESSAGE_ENDPOINT_EVENT: &str = "endpoint";
|
||||
use uuid::Uuid;
|
||||
|
||||
#[get("/mcp")]
|
||||
pub fn sse(shutdown: Shutdown) -> EventStream![] {
|
||||
pub fn sse(state: &State<AppState>, shutdown: Shutdown) -> EventStream![] {
|
||||
let connections = state.mcp_connections.clone();
|
||||
let connection_id = Uuid::new_v4().to_string();
|
||||
let endpoint = format!("/mcp/{}", connection_id);
|
||||
let (sender, mut receiver) = mpsc::unbounded_channel();
|
||||
connections.insert(connection_id.clone(), sender);
|
||||
|
||||
EventStream! {
|
||||
yield Event::json(&serde_json::json!({ "path": "/mcp" })).event(MCP_MESSAGE_ENDPOINT_EVENT);
|
||||
yield Event::data(endpoint).event("endpoint");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown.clone() => break,
|
||||
message = receiver.recv() => {
|
||||
match message {
|
||||
Some(message) => yield Event::data(message).event("message"),
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
_ = tokio::time::sleep(std::time::Duration::from_secs(15)) => {
|
||||
yield Event::comment("keepalive");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
connections.remove(&connection_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,6 +53,36 @@ pub async fn route(
|
||||
body: Json<JsonRpcRequest>,
|
||||
token: MaybeBearerToken,
|
||||
state: &State<AppState>,
|
||||
) -> ApiResponse {
|
||||
handle_request(body, token, state).await
|
||||
}
|
||||
|
||||
#[post("/mcp/<connection_id>", data = "<body>")]
|
||||
pub async fn route_sse(
|
||||
connection_id: &str,
|
||||
body: Json<JsonRpcRequest>,
|
||||
token: MaybeBearerToken,
|
||||
state: &State<AppState>,
|
||||
) -> ApiResponse {
|
||||
let response = handle_request(body, token, state).await;
|
||||
let Some(sender) = state.mcp_connections.get(connection_id) else {
|
||||
return jsonrpc_error(Status::NotFound, None, -32004, "MCP connection not found");
|
||||
};
|
||||
|
||||
if response.0 != Status::Accepted {
|
||||
let _ = sender.send(response.1.0.clone());
|
||||
}
|
||||
|
||||
(
|
||||
Status::Accepted,
|
||||
rocket::response::content::RawJson(String::new()),
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_request(
|
||||
body: Json<JsonRpcRequest>,
|
||||
token: MaybeBearerToken,
|
||||
state: &State<AppState>,
|
||||
) -> ApiResponse {
|
||||
let request = body.into_inner();
|
||||
|
||||
|
||||
+7
-1
@@ -35,7 +35,12 @@ pub fn build_rocket(database: SqlitePool) -> rocket::Rocket<rocket::Build> {
|
||||
.manage(build_state(database))
|
||||
.mount(
|
||||
"/",
|
||||
routes![transport::socket::connect, api::mcp::route, api::mcp::sse],
|
||||
routes![
|
||||
transport::socket::connect,
|
||||
api::mcp::route,
|
||||
api::mcp::route_sse,
|
||||
api::mcp::sse
|
||||
],
|
||||
)
|
||||
.register("/", rocket::catchers![api::catchers::default_catcher])
|
||||
}
|
||||
@@ -45,6 +50,7 @@ fn build_state(database: SqlitePool) -> AppState {
|
||||
database,
|
||||
registry: Arc::new(DashMap::new()),
|
||||
device_counts: Arc::new(DashMap::new()),
|
||||
mcp_connections: Arc::new(DashMap::new()),
|
||||
config: Arc::new(Config::from_env()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,11 +8,13 @@ use tokio::sync::{mpsc, oneshot};
|
||||
pub type Database = SqlitePool;
|
||||
pub type DeviceCounts = DashMap<String, Arc<AtomicUsize>>;
|
||||
pub type Registry = DashMap<(String, String), RegistryEntry>;
|
||||
pub type McpConnections = DashMap<String, mpsc::UnboundedSender<String>>;
|
||||
|
||||
pub struct AppState {
|
||||
pub database: Database,
|
||||
pub registry: Arc<Registry>,
|
||||
pub device_counts: Arc<DeviceCounts>,
|
||||
pub mcp_connections: Arc<McpConnections>,
|
||||
pub config: Arc<Config>,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user