sse endpoint
This commit is contained in:
+48
-4
@@ -14,23 +14,37 @@ use rocket::{
|
|||||||
serde::json::Json,
|
serde::json::Json,
|
||||||
};
|
};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
use tools::{handle_tool_call, tool_definitions};
|
use tools::{handle_tool_call, tool_definitions};
|
||||||
|
use uuid::Uuid;
|
||||||
const MCP_MESSAGE_ENDPOINT_EVENT: &str = "endpoint";
|
|
||||||
|
|
||||||
#[get("/mcp")]
|
#[get("/mcp")]
|
||||||
pub fn sse(shutdown: Shutdown) -> EventStream![] {
|
pub fn sse(state: &State<AppState>, shutdown: Shutdown) -> EventStream![] {
|
||||||
|
let connections = state.mcp_connections.clone();
|
||||||
|
let connection_id = Uuid::new_v4().to_string();
|
||||||
|
let endpoint = format!("/mcp/{}", connection_id);
|
||||||
|
let (sender, mut receiver) = mpsc::unbounded_channel();
|
||||||
|
connections.insert(connection_id.clone(), sender);
|
||||||
|
|
||||||
EventStream! {
|
EventStream! {
|
||||||
yield Event::json(&serde_json::json!({ "path": "/mcp" })).event(MCP_MESSAGE_ENDPOINT_EVENT);
|
yield Event::data(endpoint).event("endpoint");
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
_ = shutdown.clone() => break,
|
_ = shutdown.clone() => break,
|
||||||
|
message = receiver.recv() => {
|
||||||
|
match message {
|
||||||
|
Some(message) => yield Event::data(message).event("message"),
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
_ = tokio::time::sleep(std::time::Duration::from_secs(15)) => {
|
_ = tokio::time::sleep(std::time::Duration::from_secs(15)) => {
|
||||||
yield Event::comment("keepalive");
|
yield Event::comment("keepalive");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
connections.remove(&connection_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,6 +53,36 @@ pub async fn route(
|
|||||||
body: Json<JsonRpcRequest>,
|
body: Json<JsonRpcRequest>,
|
||||||
token: MaybeBearerToken,
|
token: MaybeBearerToken,
|
||||||
state: &State<AppState>,
|
state: &State<AppState>,
|
||||||
|
) -> ApiResponse {
|
||||||
|
handle_request(body, token, state).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/mcp/<connection_id>", data = "<body>")]
|
||||||
|
pub async fn route_sse(
|
||||||
|
connection_id: &str,
|
||||||
|
body: Json<JsonRpcRequest>,
|
||||||
|
token: MaybeBearerToken,
|
||||||
|
state: &State<AppState>,
|
||||||
|
) -> ApiResponse {
|
||||||
|
let response = handle_request(body, token, state).await;
|
||||||
|
let Some(sender) = state.mcp_connections.get(connection_id) else {
|
||||||
|
return jsonrpc_error(Status::NotFound, None, -32004, "MCP connection not found");
|
||||||
|
};
|
||||||
|
|
||||||
|
if response.0 != Status::Accepted {
|
||||||
|
let _ = sender.send(response.1.0.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
(
|
||||||
|
Status::Accepted,
|
||||||
|
rocket::response::content::RawJson(String::new()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_request(
|
||||||
|
body: Json<JsonRpcRequest>,
|
||||||
|
token: MaybeBearerToken,
|
||||||
|
state: &State<AppState>,
|
||||||
) -> ApiResponse {
|
) -> ApiResponse {
|
||||||
let request = body.into_inner();
|
let request = body.into_inner();
|
||||||
|
|
||||||
|
|||||||
+7
-1
@@ -35,7 +35,12 @@ pub fn build_rocket(database: SqlitePool) -> rocket::Rocket<rocket::Build> {
|
|||||||
.manage(build_state(database))
|
.manage(build_state(database))
|
||||||
.mount(
|
.mount(
|
||||||
"/",
|
"/",
|
||||||
routes![transport::socket::connect, api::mcp::route, api::mcp::sse],
|
routes![
|
||||||
|
transport::socket::connect,
|
||||||
|
api::mcp::route,
|
||||||
|
api::mcp::route_sse,
|
||||||
|
api::mcp::sse
|
||||||
|
],
|
||||||
)
|
)
|
||||||
.register("/", rocket::catchers![api::catchers::default_catcher])
|
.register("/", rocket::catchers![api::catchers::default_catcher])
|
||||||
}
|
}
|
||||||
@@ -45,6 +50,7 @@ fn build_state(database: SqlitePool) -> AppState {
|
|||||||
database,
|
database,
|
||||||
registry: Arc::new(DashMap::new()),
|
registry: Arc::new(DashMap::new()),
|
||||||
device_counts: Arc::new(DashMap::new()),
|
device_counts: Arc::new(DashMap::new()),
|
||||||
|
mcp_connections: Arc::new(DashMap::new()),
|
||||||
config: Arc::new(Config::from_env()),
|
config: Arc::new(Config::from_env()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,11 +8,13 @@ use tokio::sync::{mpsc, oneshot};
|
|||||||
pub type Database = SqlitePool;
|
pub type Database = SqlitePool;
|
||||||
pub type DeviceCounts = DashMap<String, Arc<AtomicUsize>>;
|
pub type DeviceCounts = DashMap<String, Arc<AtomicUsize>>;
|
||||||
pub type Registry = DashMap<(String, String), RegistryEntry>;
|
pub type Registry = DashMap<(String, String), RegistryEntry>;
|
||||||
|
pub type McpConnections = DashMap<String, mpsc::UnboundedSender<String>>;
|
||||||
|
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub database: Database,
|
pub database: Database,
|
||||||
pub registry: Arc<Registry>,
|
pub registry: Arc<Registry>,
|
||||||
pub device_counts: Arc<DeviceCounts>,
|
pub device_counts: Arc<DeviceCounts>,
|
||||||
|
pub mcp_connections: Arc<McpConnections>,
|
||||||
pub config: Arc<Config>,
|
pub config: Arc<Config>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user