refactor
This commit is contained in:
-375
@@ -1,375 +0,0 @@
|
|||||||
use crate::{
|
|
||||||
AppState,
|
|
||||||
api::auth::MaybeBearerToken,
|
|
||||||
core::db::lookup_api_token,
|
|
||||||
transport::execute::{ExecuteRequest, execute_for_token},
|
|
||||||
};
|
|
||||||
use rocket::{State, http::Status, post, response::content::RawJson, serde::json::Json};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::{Value, json};
|
|
||||||
|
|
||||||
type ApiResponse = (Status, RawJson<String>);
|
|
||||||
|
|
||||||
const JSONRPC_VERSION: &str = "2.0";
|
|
||||||
const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct JsonRpcRequest {
|
|
||||||
pub jsonrpc: String,
|
|
||||||
pub id: Option<Value>,
|
|
||||||
pub method: String,
|
|
||||||
#[serde(default)]
|
|
||||||
pub params: Option<Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct JsonRpcResponse<T> {
|
|
||||||
jsonrpc: &'static str,
|
|
||||||
id: Option<Value>,
|
|
||||||
result: T,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct JsonRpcErrorResponse {
|
|
||||||
jsonrpc: &'static str,
|
|
||||||
id: Option<Value>,
|
|
||||||
error: JsonRpcError,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct JsonRpcError {
|
|
||||||
code: i64,
|
|
||||||
message: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct InitializeResult {
|
|
||||||
#[serde(rename = "protocolVersion")]
|
|
||||||
protocol_version: &'static str,
|
|
||||||
capabilities: ServerCapabilities,
|
|
||||||
#[serde(rename = "serverInfo")]
|
|
||||||
server_info: ServerInfo,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct ServerCapabilities {
|
|
||||||
tools: ToolsCapability,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct ToolsCapability {
|
|
||||||
list_changed: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct ServerInfo {
|
|
||||||
name: &'static str,
|
|
||||||
version: &'static str,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct ToolsListResult {
|
|
||||||
tools: Vec<ToolDefinition>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct ToolDefinition {
|
|
||||||
name: &'static str,
|
|
||||||
description: &'static str,
|
|
||||||
#[serde(rename = "inputSchema")]
|
|
||||||
input_schema: Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct ToolCallParams {
|
|
||||||
name: String,
|
|
||||||
arguments: Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct ToolCallResult {
|
|
||||||
content: Vec<TextContent>,
|
|
||||||
#[serde(rename = "isError")]
|
|
||||||
is_error: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct ConnectedDevicesResult {
|
|
||||||
device_ids: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
enum ToolCall {
|
|
||||||
Execute(ExecuteRequest),
|
|
||||||
ListDevices,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct TextContent {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
kind: &'static str,
|
|
||||||
text: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/mcp", data = "<body>")]
|
|
||||||
pub async fn route(
|
|
||||||
body: Json<JsonRpcRequest>,
|
|
||||||
token: MaybeBearerToken,
|
|
||||||
state: &State<AppState>,
|
|
||||||
) -> ApiResponse {
|
|
||||||
let request = body.into_inner();
|
|
||||||
|
|
||||||
if request.jsonrpc != JSONRPC_VERSION {
|
|
||||||
return jsonrpc_error(
|
|
||||||
Status::BadRequest,
|
|
||||||
request.id,
|
|
||||||
-32600,
|
|
||||||
"Invalid JSON-RPC version",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
match request.method.as_str() {
|
|
||||||
"initialize" => jsonrpc_ok(
|
|
||||||
request.id,
|
|
||||||
InitializeResult {
|
|
||||||
protocol_version: MCP_PROTOCOL_VERSION,
|
|
||||||
capabilities: ServerCapabilities {
|
|
||||||
tools: ToolsCapability {
|
|
||||||
list_changed: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
server_info: ServerInfo {
|
|
||||||
name: env!("CARGO_PKG_NAME"),
|
|
||||||
version: env!("CARGO_PKG_VERSION"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"notifications/initialized" => jsonrpc_notification_ok(),
|
|
||||||
"tools/list" => match require_bearer_token(token.0.as_deref(), request.id.clone()) {
|
|
||||||
Ok(_) => jsonrpc_ok(
|
|
||||||
request.id,
|
|
||||||
ToolsListResult {
|
|
||||||
tools: tool_definitions(),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
Err(response) => response,
|
|
||||||
},
|
|
||||||
"tools/call" => {
|
|
||||||
handle_tool_call(request.id, request.params, token.0.as_deref(), state).await
|
|
||||||
}
|
|
||||||
_ if request.id.is_none() => jsonrpc_notification_ok(),
|
|
||||||
_ => jsonrpc_error(Status::BadRequest, request.id, -32601, "Method not found"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_tool_call(
|
|
||||||
id: Option<Value>,
|
|
||||||
params: Option<Value>,
|
|
||||||
token: Option<&str>,
|
|
||||||
state: &State<AppState>,
|
|
||||||
) -> ApiResponse {
|
|
||||||
let token = match require_bearer_token(token, id.clone()) {
|
|
||||||
Ok(token) => token,
|
|
||||||
Err(response) => return response,
|
|
||||||
};
|
|
||||||
|
|
||||||
let Some(params) = params else {
|
|
||||||
return jsonrpc_error(Status::BadRequest, id, -32602, "Missing params");
|
|
||||||
};
|
|
||||||
|
|
||||||
let tool_call = match parse_tool_call(params) {
|
|
||||||
Ok(call) => call,
|
|
||||||
Err(_) => return jsonrpc_error(Status::BadRequest, id, -32602, "Invalid params"),
|
|
||||||
};
|
|
||||||
|
|
||||||
match tool_call {
|
|
||||||
ToolCall::Execute(request) => execute_tool_call(id, token, state, request).await,
|
|
||||||
ToolCall::ListDevices => list_devices_tool_call(id, token, state).await,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn require_bearer_token(token: Option<&str>, id: Option<Value>) -> Result<&str, ApiResponse> {
|
|
||||||
token.ok_or_else(|| {
|
|
||||||
jsonrpc_error(
|
|
||||||
Status::Unauthorized,
|
|
||||||
id,
|
|
||||||
-32001,
|
|
||||||
"Missing Authorization header",
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn tool_definitions() -> Vec<ToolDefinition> {
|
|
||||||
vec![
|
|
||||||
ToolDefinition {
|
|
||||||
name: "execute",
|
|
||||||
description: "Execute a shell command on a connected device.",
|
|
||||||
input_schema: json!({
|
|
||||||
"type": "object",
|
|
||||||
"required": ["device_id", "command"],
|
|
||||||
"properties": {
|
|
||||||
"device_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Identifier of the connected device."
|
|
||||||
},
|
|
||||||
"command": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Shell command to execute on the target device."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
ToolDefinition {
|
|
||||||
name: "list_devices",
|
|
||||||
description: "List currently connected device identifiers.",
|
|
||||||
input_schema: json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {},
|
|
||||||
"additionalProperties": false
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_tool_call(params: Value) -> Result<ToolCall, &'static str> {
|
|
||||||
let call: ToolCallParams = serde_json::from_value(params).map_err(|_| "Invalid params")?;
|
|
||||||
|
|
||||||
match call.name.as_str() {
|
|
||||||
"execute" => serde_json::from_value(call.arguments)
|
|
||||||
.map(ToolCall::Execute)
|
|
||||||
.map_err(|_| "Invalid execute arguments"),
|
|
||||||
"list_devices" => parse_list_devices_arguments(call.arguments),
|
|
||||||
_ => Err("Unknown tool"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parse_list_devices_arguments(arguments: Value) -> Result<ToolCall, &'static str> {
|
|
||||||
match arguments {
|
|
||||||
Value::Object(map) if map.is_empty() => Ok(ToolCall::ListDevices),
|
|
||||||
_ => Err("Invalid list_devices arguments"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn connected_device_ids_for_user(state: &State<AppState>, user_id: &str) -> Vec<String> {
|
|
||||||
state
|
|
||||||
.registry
|
|
||||||
.iter()
|
|
||||||
.filter_map(|entry| {
|
|
||||||
let ((entry_user_id, device_id), _) = entry.pair();
|
|
||||||
(entry_user_id == user_id).then(|| device_id.clone())
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute_tool_call(
|
|
||||||
id: Option<Value>,
|
|
||||||
token: &str,
|
|
||||||
state: &State<AppState>,
|
|
||||||
request: ExecuteRequest,
|
|
||||||
) -> ApiResponse {
|
|
||||||
match execute_for_token(state, token, request).await {
|
|
||||||
Ok(result) => json_tool_result(id, &result),
|
|
||||||
Err((status, message)) if status == Status::Unauthorized => {
|
|
||||||
jsonrpc_error(Status::Unauthorized, id, -32001, message)
|
|
||||||
}
|
|
||||||
Err((_, message)) => jsonrpc_ok(id, tool_error(message)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn list_devices_tool_call(
|
|
||||||
id: Option<Value>,
|
|
||||||
token: &str,
|
|
||||||
state: &State<AppState>,
|
|
||||||
) -> ApiResponse {
|
|
||||||
let user_id = match lookup_api_token(&state.database, token).await {
|
|
||||||
Ok(user_id) => user_id,
|
|
||||||
Err(message) => return jsonrpc_error(Status::Unauthorized, id, -32001, message),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut device_ids = connected_device_ids_for_user(state, &user_id);
|
|
||||||
device_ids.sort();
|
|
||||||
|
|
||||||
json_tool_result(id, &ConnectedDevicesResult { device_ids })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn json_tool_result<T: Serialize>(id: Option<Value>, payload: &T) -> ApiResponse {
|
|
||||||
match serde_json::to_string(payload) {
|
|
||||||
Ok(text) => jsonrpc_ok(id, ToolCallResult::success(text)),
|
|
||||||
Err(error) => jsonrpc_error(Status::InternalServerError, id, -32603, error.to_string()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn tool_error(message: String) -> ToolCallResult {
|
|
||||||
ToolCallResult::error(message)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ToolCallResult {
|
|
||||||
fn success(text: String) -> Self {
|
|
||||||
Self {
|
|
||||||
content: vec![TextContent::text(text)],
|
|
||||||
is_error: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn error(text: String) -> Self {
|
|
||||||
Self {
|
|
||||||
content: vec![TextContent::text(text)],
|
|
||||||
is_error: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextContent {
|
|
||||||
fn text(text: String) -> Self {
|
|
||||||
Self { kind: "text", text }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn jsonrpc_ok<T: Serialize>(id: Option<Value>, result: T) -> ApiResponse {
|
|
||||||
match serde_json::to_string(&JsonRpcResponse {
|
|
||||||
jsonrpc: JSONRPC_VERSION,
|
|
||||||
id,
|
|
||||||
result,
|
|
||||||
}) {
|
|
||||||
Ok(body) => (Status::Ok, RawJson(body)),
|
|
||||||
Err(e) => jsonrpc_error(Status::InternalServerError, None, -32603, e.to_string()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn jsonrpc_notification_ok() -> ApiResponse {
|
|
||||||
(Status::Accepted, RawJson(String::new()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn jsonrpc_error(
|
|
||||||
status: Status,
|
|
||||||
id: Option<Value>,
|
|
||||||
code: i64,
|
|
||||||
message: impl Into<String>,
|
|
||||||
) -> ApiResponse {
|
|
||||||
let response = JsonRpcErrorResponse {
|
|
||||||
jsonrpc: JSONRPC_VERSION,
|
|
||||||
id,
|
|
||||||
error: JsonRpcError {
|
|
||||||
code,
|
|
||||||
message: message.into(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
match serde_json::to_string(&response) {
|
|
||||||
Ok(body) => (status, RawJson(body)),
|
|
||||||
Err(e) => (
|
|
||||||
Status::InternalServerError,
|
|
||||||
RawJson(
|
|
||||||
json!({
|
|
||||||
"jsonrpc": JSONRPC_VERSION,
|
|
||||||
"id": null,
|
|
||||||
"error": {
|
|
||||||
"code": -32603,
|
|
||||||
"message": e.to_string()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.to_string(),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
mod protocol;
|
||||||
|
mod tools;
|
||||||
|
|
||||||
|
use crate::{api::auth::MaybeBearerToken, app::state::AppState};
|
||||||
|
use protocol::{
|
||||||
|
ApiResponse, InitializeResult, JSONRPC_VERSION, jsonrpc_error, jsonrpc_notification_ok,
|
||||||
|
jsonrpc_ok,
|
||||||
|
};
|
||||||
|
use rocket::{State, http::Status, post, serde::json::Json};
|
||||||
|
use serde_json::Value;
|
||||||
|
use tools::{handle_tool_call, tool_definitions};
|
||||||
|
|
||||||
|
pub use protocol::JsonRpcRequest;
|
||||||
|
|
||||||
|
#[post("/mcp", data = "<body>")]
|
||||||
|
pub async fn route(
|
||||||
|
body: Json<JsonRpcRequest>,
|
||||||
|
token: MaybeBearerToken,
|
||||||
|
state: &State<AppState>,
|
||||||
|
) -> ApiResponse {
|
||||||
|
let request = body.into_inner();
|
||||||
|
|
||||||
|
if request.jsonrpc != JSONRPC_VERSION {
|
||||||
|
return jsonrpc_error(
|
||||||
|
Status::BadRequest,
|
||||||
|
request.id,
|
||||||
|
-32600,
|
||||||
|
"Invalid JSON-RPC version",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
match request.method.as_str() {
|
||||||
|
"initialize" => jsonrpc_ok(request.id, InitializeResult::new()),
|
||||||
|
"notifications/initialized" => jsonrpc_notification_ok(),
|
||||||
|
"tools/list" => match token.0 {
|
||||||
|
Some(_) => jsonrpc_ok(request.id, tool_definitions()),
|
||||||
|
None => missing_authorization(request.id),
|
||||||
|
},
|
||||||
|
"tools/call" => {
|
||||||
|
handle_tool_call(request.id, request.params, token.0.as_deref(), state).await
|
||||||
|
}
|
||||||
|
_ if request.id.is_none() => jsonrpc_notification_ok(),
|
||||||
|
_ => jsonrpc_error(Status::BadRequest, request.id, -32601, "Method not found"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn missing_authorization(id: Option<Value>) -> ApiResponse {
|
||||||
|
jsonrpc_error(
|
||||||
|
Status::Unauthorized,
|
||||||
|
id,
|
||||||
|
-32001,
|
||||||
|
"Missing Authorization header",
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -0,0 +1,128 @@
|
|||||||
|
use rocket::{http::Status, response::content::RawJson};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::{Value, json};
|
||||||
|
|
||||||
|
pub type ApiResponse = (Status, RawJson<String>);
|
||||||
|
|
||||||
|
pub const JSONRPC_VERSION: &str = "2.0";
|
||||||
|
const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct JsonRpcRequest {
|
||||||
|
pub jsonrpc: String,
|
||||||
|
pub id: Option<Value>,
|
||||||
|
pub method: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub params: Option<Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct JsonRpcResponse<T> {
|
||||||
|
jsonrpc: &'static str,
|
||||||
|
id: Option<Value>,
|
||||||
|
result: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct JsonRpcErrorResponse {
|
||||||
|
jsonrpc: &'static str,
|
||||||
|
id: Option<Value>,
|
||||||
|
error: JsonRpcError,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct JsonRpcError {
|
||||||
|
code: i64,
|
||||||
|
message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct InitializeResult {
|
||||||
|
#[serde(rename = "protocolVersion")]
|
||||||
|
protocol_version: &'static str,
|
||||||
|
capabilities: ServerCapabilities,
|
||||||
|
#[serde(rename = "serverInfo")]
|
||||||
|
server_info: ServerInfo,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ServerCapabilities {
|
||||||
|
tools: ToolsCapability,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ToolsCapability {
|
||||||
|
list_changed: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ServerInfo {
|
||||||
|
name: &'static str,
|
||||||
|
version: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InitializeResult {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
protocol_version: MCP_PROTOCOL_VERSION,
|
||||||
|
capabilities: ServerCapabilities {
|
||||||
|
tools: ToolsCapability {
|
||||||
|
list_changed: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
server_info: ServerInfo {
|
||||||
|
name: env!("CARGO_PKG_NAME"),
|
||||||
|
version: env!("CARGO_PKG_VERSION"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn jsonrpc_ok<T: Serialize>(id: Option<Value>, result: T) -> ApiResponse {
|
||||||
|
match serde_json::to_string(&JsonRpcResponse {
|
||||||
|
jsonrpc: JSONRPC_VERSION,
|
||||||
|
id,
|
||||||
|
result,
|
||||||
|
}) {
|
||||||
|
Ok(body) => (Status::Ok, RawJson(body)),
|
||||||
|
Err(error) => jsonrpc_error(Status::InternalServerError, None, -32603, error.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn jsonrpc_notification_ok() -> ApiResponse {
|
||||||
|
(Status::Accepted, RawJson(String::new()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn jsonrpc_error(
|
||||||
|
status: Status,
|
||||||
|
id: Option<Value>,
|
||||||
|
code: i64,
|
||||||
|
message: impl Into<String>,
|
||||||
|
) -> ApiResponse {
|
||||||
|
let response = JsonRpcErrorResponse {
|
||||||
|
jsonrpc: JSONRPC_VERSION,
|
||||||
|
id,
|
||||||
|
error: JsonRpcError {
|
||||||
|
code,
|
||||||
|
message: message.into(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
match serde_json::to_string(&response) {
|
||||||
|
Ok(body) => (status, RawJson(body)),
|
||||||
|
Err(error) => (
|
||||||
|
Status::InternalServerError,
|
||||||
|
RawJson(
|
||||||
|
json!({
|
||||||
|
"jsonrpc": JSONRPC_VERSION,
|
||||||
|
"id": null,
|
||||||
|
"error": {
|
||||||
|
"code": -32603,
|
||||||
|
"message": error.to_string()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.to_string(),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
use crate::{
|
||||||
|
api::mcp::protocol::{ApiResponse, jsonrpc_error, jsonrpc_ok},
|
||||||
|
app::state::AppState,
|
||||||
|
core::db::lookup_api_token,
|
||||||
|
transport::execute::{ExecuteRequest, execute_for_token},
|
||||||
|
};
|
||||||
|
use rocket::{State, http::Status};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::{Value, json};
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct ToolsListResult {
|
||||||
|
tools: Vec<ToolDefinition>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ToolDefinition {
|
||||||
|
name: &'static str,
|
||||||
|
description: &'static str,
|
||||||
|
#[serde(rename = "inputSchema")]
|
||||||
|
input_schema: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct ToolCallParams {
|
||||||
|
name: String,
|
||||||
|
arguments: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ToolCallResult {
|
||||||
|
content: Vec<TextContent>,
|
||||||
|
#[serde(rename = "isError")]
|
||||||
|
is_error: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ConnectedDevicesResult {
|
||||||
|
device_ids: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ToolCall {
|
||||||
|
Execute(ExecuteRequest),
|
||||||
|
ListDevices,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct TextContent {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
kind: &'static str,
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tool_definitions() -> ToolsListResult {
|
||||||
|
ToolsListResult {
|
||||||
|
tools: vec![
|
||||||
|
ToolDefinition {
|
||||||
|
name: "execute",
|
||||||
|
description: "Execute a shell command on a connected device.",
|
||||||
|
input_schema: json!({
|
||||||
|
"type": "object",
|
||||||
|
"required": ["device_id", "command"],
|
||||||
|
"properties": {
|
||||||
|
"device_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Identifier of the connected device."
|
||||||
|
},
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Shell command to execute on the target device."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
ToolDefinition {
|
||||||
|
name: "list_devices",
|
||||||
|
description: "List currently connected device identifiers.",
|
||||||
|
input_schema: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"additionalProperties": false
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle_tool_call(
|
||||||
|
id: Option<Value>,
|
||||||
|
params: Option<Value>,
|
||||||
|
token: Option<&str>,
|
||||||
|
state: &State<AppState>,
|
||||||
|
) -> ApiResponse {
|
||||||
|
let Some(token) = token else {
|
||||||
|
return jsonrpc_error(
|
||||||
|
Status::Unauthorized,
|
||||||
|
id,
|
||||||
|
-32001,
|
||||||
|
"Missing Authorization header",
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(params) = params else {
|
||||||
|
return jsonrpc_error(Status::BadRequest, id, -32602, "Missing params");
|
||||||
|
};
|
||||||
|
|
||||||
|
let tool_call = match parse_tool_call(params) {
|
||||||
|
Ok(call) => call,
|
||||||
|
Err(_) => return jsonrpc_error(Status::BadRequest, id, -32602, "Invalid params"),
|
||||||
|
};
|
||||||
|
|
||||||
|
match tool_call {
|
||||||
|
ToolCall::Execute(request) => execute_tool_call(id, token, state, request).await,
|
||||||
|
ToolCall::ListDevices => list_devices_tool_call(id, token, state).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_tool_call(params: Value) -> Result<ToolCall, ()> {
|
||||||
|
let call: ToolCallParams = serde_json::from_value(params).map_err(|_| ())?;
|
||||||
|
|
||||||
|
match call.name.as_str() {
|
||||||
|
"execute" => serde_json::from_value(call.arguments)
|
||||||
|
.map(ToolCall::Execute)
|
||||||
|
.map_err(|_| ()),
|
||||||
|
"list_devices" => match call.arguments {
|
||||||
|
Value::Object(map) if map.is_empty() => Ok(ToolCall::ListDevices),
|
||||||
|
_ => Err(()),
|
||||||
|
},
|
||||||
|
_ => Err(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute_tool_call(
|
||||||
|
id: Option<Value>,
|
||||||
|
token: &str,
|
||||||
|
state: &State<AppState>,
|
||||||
|
request: ExecuteRequest,
|
||||||
|
) -> ApiResponse {
|
||||||
|
match execute_for_token(state, token, request).await {
|
||||||
|
Ok(result) => json_tool_result(id, &result),
|
||||||
|
Err((status, message)) if status == Status::Unauthorized => {
|
||||||
|
jsonrpc_error(Status::Unauthorized, id, -32001, message)
|
||||||
|
}
|
||||||
|
Err((_, message)) => jsonrpc_ok(id, ToolCallResult::error(message)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_devices_tool_call(
|
||||||
|
id: Option<Value>,
|
||||||
|
token: &str,
|
||||||
|
state: &State<AppState>,
|
||||||
|
) -> ApiResponse {
|
||||||
|
let user_id = match lookup_api_token(&state.database, token).await {
|
||||||
|
Ok(user_id) => user_id,
|
||||||
|
Err(message) => return jsonrpc_error(Status::Unauthorized, id, -32001, message),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut device_ids = state
|
||||||
|
.registry
|
||||||
|
.iter()
|
||||||
|
.filter_map(|entry| {
|
||||||
|
let ((entry_user_id, device_id), _) = entry.pair();
|
||||||
|
(entry_user_id == &user_id).then(|| device_id.clone())
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
device_ids.sort();
|
||||||
|
|
||||||
|
json_tool_result(id, &ConnectedDevicesResult { device_ids })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn json_tool_result<T: Serialize>(id: Option<Value>, payload: &T) -> ApiResponse {
|
||||||
|
match serde_json::to_string(payload) {
|
||||||
|
Ok(text) => jsonrpc_ok(id, ToolCallResult::success(text)),
|
||||||
|
Err(error) => jsonrpc_error(Status::InternalServerError, id, -32603, error.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolCallResult {
|
||||||
|
fn success(text: String) -> Self {
|
||||||
|
Self {
|
||||||
|
content: vec![TextContent::text(text)],
|
||||||
|
is_error: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn error(text: String) -> Self {
|
||||||
|
Self {
|
||||||
|
content: vec![TextContent::text(text)],
|
||||||
|
is_error: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextContent {
|
||||||
|
fn text(text: String) -> Self {
|
||||||
|
Self { kind: "text", text }
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
pub mod auth;
|
||||||
|
pub mod catchers;
|
||||||
|
pub mod mcp;
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
pub mod state;
|
||||||
|
|
||||||
|
use crate::{api, core::config::Config, transport};
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use rocket::routes;
|
||||||
|
use sqlx::{SqlitePool, sqlite::SqliteConnectOptions};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
const DATABASE_PATH: &str = "server.db";
|
||||||
|
|
||||||
|
pub async fn connect_database() -> Result<SqlitePool, sqlx::Error> {
|
||||||
|
SqlitePool::connect_with(
|
||||||
|
SqliteConnectOptions::new()
|
||||||
|
.filename(DATABASE_PATH)
|
||||||
|
.create_if_missing(true),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn apply_schema(database: &SqlitePool) -> Result<(), sqlx::Error> {
|
||||||
|
for statement in include_str!("../../schema.sql")
|
||||||
|
.split(';')
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|statement| !statement.is_empty())
|
||||||
|
{
|
||||||
|
sqlx::query(statement).execute(database).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_rocket(database: SqlitePool) -> rocket::Rocket<rocket::Build> {
|
||||||
|
rocket::build()
|
||||||
|
.manage(state::AppState {
|
||||||
|
database,
|
||||||
|
registry: Arc::new(DashMap::new()),
|
||||||
|
device_counts: Arc::new(DashMap::new()),
|
||||||
|
config: Arc::new(Config::from_env()),
|
||||||
|
})
|
||||||
|
.mount("/", routes![transport::socket::connect, api::mcp::route])
|
||||||
|
.register("/", rocket::catchers![api::catchers::default_catcher])
|
||||||
|
}
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
use crate::core::config::Config;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use serde::Serialize;
|
||||||
|
use sqlx::SqlitePool;
|
||||||
|
use std::sync::{Arc, atomic::AtomicUsize};
|
||||||
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
|
pub type Database = SqlitePool;
|
||||||
|
pub type DeviceCounts = DashMap<String, Arc<AtomicUsize>>;
|
||||||
|
pub type Registry = DashMap<(String, String), RegistryEntry>;
|
||||||
|
|
||||||
|
pub struct RegistryEntry {
|
||||||
|
pub sender: mpsc::Sender<PendingExec>,
|
||||||
|
pub in_flight: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct ExecResult {
|
||||||
|
pub exit_code: i32,
|
||||||
|
pub stdout: String,
|
||||||
|
pub stderr: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PendingExec {
|
||||||
|
pub exec_id: String,
|
||||||
|
pub command: String,
|
||||||
|
pub reply: oneshot::Sender<ExecResult>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AppState {
|
||||||
|
pub database: Database,
|
||||||
|
pub registry: Arc<Registry>,
|
||||||
|
pub device_counts: Arc<DeviceCounts>,
|
||||||
|
pub config: Arc<Config>,
|
||||||
|
}
|
||||||
+1
-1
@@ -1,4 +1,4 @@
|
|||||||
use crate::Database;
|
use crate::app::state::Database;
|
||||||
|
|
||||||
const AUTH_TOKENS_TABLE: &str = "auth_tokens";
|
const AUTH_TOKENS_TABLE: &str = "auth_tokens";
|
||||||
const API_TOKENS_TABLE: &str = "api_tokens";
|
const API_TOKENS_TABLE: &str = "api_tokens";
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
pub mod config;
|
||||||
|
pub mod db;
|
||||||
|
pub mod validation;
|
||||||
+5
-88
@@ -1,61 +1,10 @@
|
|||||||
pub mod api {
|
pub mod api;
|
||||||
pub mod auth;
|
pub mod app;
|
||||||
pub mod catchers;
|
pub mod core;
|
||||||
pub mod mcp;
|
pub mod transport;
|
||||||
}
|
|
||||||
|
|
||||||
pub mod core {
|
use app::{apply_schema, build_rocket, connect_database};
|
||||||
pub mod config;
|
|
||||||
pub mod db;
|
|
||||||
pub mod validation;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod transport {
|
|
||||||
pub mod execute;
|
|
||||||
pub mod protocol;
|
|
||||||
pub mod socket;
|
|
||||||
}
|
|
||||||
|
|
||||||
use core::config::Config;
|
|
||||||
use dashmap::DashMap;
|
|
||||||
use log::info;
|
use log::info;
|
||||||
use rocket::routes;
|
|
||||||
use serde::Serialize;
|
|
||||||
use sqlx::SqlitePool;
|
|
||||||
use sqlx::sqlite::SqliteConnectOptions;
|
|
||||||
use std::sync::{Arc, atomic::AtomicUsize};
|
|
||||||
use tokio::sync::{mpsc, oneshot};
|
|
||||||
|
|
||||||
const DATABASE_PATH: &str = "server.db";
|
|
||||||
|
|
||||||
pub type Database = SqlitePool;
|
|
||||||
pub type DeviceCounts = DashMap<String, Arc<AtomicUsize>>;
|
|
||||||
pub type Registry = DashMap<(String, String), RegistryEntry>;
|
|
||||||
|
|
||||||
pub struct RegistryEntry {
|
|
||||||
pub sender: mpsc::Sender<PendingExec>,
|
|
||||||
pub in_flight: Arc<AtomicUsize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
|
||||||
pub struct ExecResult {
|
|
||||||
pub exit_code: i32,
|
|
||||||
pub stdout: String,
|
|
||||||
pub stderr: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct PendingExec {
|
|
||||||
pub exec_id: String,
|
|
||||||
pub command: String,
|
|
||||||
pub reply: oneshot::Sender<ExecResult>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct AppState {
|
|
||||||
pub database: Database,
|
|
||||||
pub registry: Arc<Registry>,
|
|
||||||
pub device_counts: Arc<DeviceCounts>,
|
|
||||||
pub config: Arc<Config>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[rocket::main]
|
#[rocket::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
@@ -70,35 +19,3 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn connect_database() -> Result<Database, sqlx::Error> {
|
|
||||||
let options = SqliteConnectOptions::new()
|
|
||||||
.filename(DATABASE_PATH)
|
|
||||||
.create_if_missing(true);
|
|
||||||
|
|
||||||
SqlitePool::connect_with(options).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn apply_schema(database: &Database) -> Result<(), sqlx::Error> {
|
|
||||||
for statement in include_str!("../schema.sql")
|
|
||||||
.split(';')
|
|
||||||
.map(str::trim)
|
|
||||||
.filter(|statement| !statement.is_empty())
|
|
||||||
{
|
|
||||||
sqlx::query(statement).execute(database).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_rocket(database: Database) -> rocket::Rocket<rocket::Build> {
|
|
||||||
rocket::build()
|
|
||||||
.manage(AppState {
|
|
||||||
database,
|
|
||||||
registry: Arc::new(DashMap::new()),
|
|
||||||
device_counts: Arc::new(DashMap::new()),
|
|
||||||
config: Arc::new(Config::from_env()),
|
|
||||||
})
|
|
||||||
.mount("/", routes![transport::socket::connect, api::mcp::route])
|
|
||||||
.register("/", rocket::catchers![api::catchers::default_catcher])
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
AppState, ExecResult, PendingExec, Registry,
|
app::state::{AppState, ExecResult, PendingExec, Registry},
|
||||||
core::{config::Config, db::lookup_api_token, validation::validate_device_id},
|
core::{config::Config, db::lookup_api_token, validation::validate_device_id},
|
||||||
};
|
};
|
||||||
use rocket::{State, http::Status};
|
use rocket::{State, http::Status};
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
pub mod execute;
|
||||||
|
pub mod protocol;
|
||||||
|
pub mod socket;
|
||||||
@@ -1,317 +0,0 @@
|
|||||||
use crate::{
|
|
||||||
AppState, Database, DeviceCounts, ExecResult, PendingExec, Registry, RegistryEntry,
|
|
||||||
api::auth::{BearerToken, MaybeDeviceId},
|
|
||||||
core::config::Config,
|
|
||||||
core::db::lookup_auth_token,
|
|
||||||
transport::protocol::{ClientMsg, ServerMsg},
|
|
||||||
};
|
|
||||||
use dashmap::{DashMap, mapref::entry::Entry};
|
|
||||||
use futures::{SinkExt, StreamExt};
|
|
||||||
use log::{info, warn};
|
|
||||||
use rocket::{State, get};
|
|
||||||
use rocket_ws::{Channel, Message, WebSocket, stream::DuplexStream};
|
|
||||||
use std::{
|
|
||||||
sync::{
|
|
||||||
Arc,
|
|
||||||
atomic::{AtomicUsize, Ordering},
|
|
||||||
},
|
|
||||||
time::Duration,
|
|
||||||
};
|
|
||||||
use tokio::sync::{mpsc, oneshot};
|
|
||||||
use tokio_util::sync::CancellationToken;
|
|
||||||
|
|
||||||
type PendingReplies = Arc<DashMap<String, oneshot::Sender<ExecResult>>>;
|
|
||||||
|
|
||||||
#[get("/connect")]
|
|
||||||
pub fn connect<'r>(
|
|
||||||
ws: WebSocket,
|
|
||||||
token: BearerToken,
|
|
||||||
device_id: MaybeDeviceId,
|
|
||||||
state: &'r State<AppState>,
|
|
||||||
) -> Channel<'r> {
|
|
||||||
let database = state.database.clone();
|
|
||||||
let registry = state.registry.clone();
|
|
||||||
let device_counts = state.device_counts.clone();
|
|
||||||
let config = state.config.clone();
|
|
||||||
ws.channel(move |socket| {
|
|
||||||
Box::pin(handle(
|
|
||||||
database,
|
|
||||||
registry,
|
|
||||||
device_counts,
|
|
||||||
config,
|
|
||||||
socket,
|
|
||||||
token.0,
|
|
||||||
device_id.0,
|
|
||||||
))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle(
|
|
||||||
database: Database,
|
|
||||||
registry: Arc<Registry>,
|
|
||||||
device_counts: Arc<DeviceCounts>,
|
|
||||||
config: Arc<Config>,
|
|
||||||
mut socket: DuplexStream,
|
|
||||||
token: String,
|
|
||||||
device_id: Result<String, String>,
|
|
||||||
) -> Result<(), rocket_ws::result::Error> {
|
|
||||||
let device_id = match device_id {
|
|
||||||
Ok(device_id) => device_id,
|
|
||||||
Err(reason) => {
|
|
||||||
warn!("Rejecting websocket connection: {reason}");
|
|
||||||
return deny_connection(&mut socket, reason).await;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let user_id = match lookup_auth_token(&database, &token).await {
|
|
||||||
Ok(user_id) => user_id,
|
|
||||||
Err(reason) => {
|
|
||||||
warn!("Rejecting websocket connection for device {device_id}: {reason}");
|
|
||||||
return deny_connection(&mut socket, reason).await;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let device_counter = match claim_device_slot(&device_counts, &user_id, &config) {
|
|
||||||
Ok(counter) => counter,
|
|
||||||
Err(reason) => {
|
|
||||||
warn!("Rejecting websocket connection for user {user_id} device {device_id}: {reason}");
|
|
||||||
return deny_connection(&mut socket, reason).await;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let (exec_tx, exec_rx) = mpsc::channel::<PendingExec>(config.max_concurrent_executions);
|
|
||||||
let key = (user_id.clone(), device_id.clone());
|
|
||||||
|
|
||||||
if let Err(reason) = register_connection(®istry, key.clone(), exec_tx) {
|
|
||||||
release_device_slot(&device_counts, &user_id, device_counter);
|
|
||||||
warn!("Rejecting duplicate websocket connection for {user_id}/{device_id}: {reason}");
|
|
||||||
return deny_connection(&mut socket, reason).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
info!("Accepted websocket connection for {user_id}/{device_id}");
|
|
||||||
|
|
||||||
send(&mut socket, &ServerMsg::AuthOk).await.ok();
|
|
||||||
|
|
||||||
let pending = Arc::new(DashMap::new());
|
|
||||||
let cancel = CancellationToken::new();
|
|
||||||
run_loop(&mut socket, exec_rx, &pending, &config, &cancel).await;
|
|
||||||
cancel.cancel();
|
|
||||||
|
|
||||||
registry.remove(&key);
|
|
||||||
pending.clear();
|
|
||||||
release_device_slot(&device_counts, &user_id, device_counter);
|
|
||||||
info!("Closed websocket connection for {user_id}/{device_id}");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn deny_connection(
|
|
||||||
socket: &mut DuplexStream,
|
|
||||||
reason: String,
|
|
||||||
) -> Result<(), rocket_ws::result::Error> {
|
|
||||||
send(socket, &ServerMsg::AuthError { reason }).await.ok();
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn register_connection(
|
|
||||||
registry: &Arc<Registry>,
|
|
||||||
key: (String, String),
|
|
||||||
sender: mpsc::Sender<PendingExec>,
|
|
||||||
) -> Result<(), String> {
|
|
||||||
match registry.entry(key) {
|
|
||||||
Entry::Occupied(_) => Err("Device already connected".into()),
|
|
||||||
Entry::Vacant(slot) => {
|
|
||||||
slot.insert(RegistryEntry {
|
|
||||||
sender,
|
|
||||||
in_flight: Arc::new(AtomicUsize::new(0)),
|
|
||||||
});
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn claim_device_slot(
|
|
||||||
device_counts: &Arc<DeviceCounts>,
|
|
||||||
user_id: &str,
|
|
||||||
config: &Config,
|
|
||||||
) -> Result<Arc<AtomicUsize>, String> {
|
|
||||||
let counter = {
|
|
||||||
let entry = device_counts
|
|
||||||
.entry(user_id.to_owned())
|
|
||||||
.or_insert_with(|| Arc::new(AtomicUsize::new(0)));
|
|
||||||
Arc::clone(&*entry)
|
|
||||||
};
|
|
||||||
|
|
||||||
let previous = counter.fetch_add(1, Ordering::SeqCst);
|
|
||||||
if previous >= config.max_connected_devices {
|
|
||||||
counter.fetch_sub(1, Ordering::SeqCst);
|
|
||||||
return Err("Too many devices connected for this account".into());
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(counter)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn release_device_slot(
|
|
||||||
device_counts: &Arc<DeviceCounts>,
|
|
||||||
user_id: &str,
|
|
||||||
counter: Arc<AtomicUsize>,
|
|
||||||
) {
|
|
||||||
let previous = counter.fetch_sub(1, Ordering::SeqCst);
|
|
||||||
if previous == 1 {
|
|
||||||
device_counts.remove_if(user_id, |_, value| value.load(Ordering::SeqCst) == 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run_loop(
|
|
||||||
socket: &mut DuplexStream,
|
|
||||||
mut exec_rx: mpsc::Receiver<PendingExec>,
|
|
||||||
pending: &PendingReplies,
|
|
||||||
config: &Config,
|
|
||||||
cancel: &CancellationToken,
|
|
||||||
) {
|
|
||||||
let mut ping_interval = tokio::time::interval_at(
|
|
||||||
tokio::time::Instant::now() + config.ping_interval,
|
|
||||||
config.ping_interval,
|
|
||||||
);
|
|
||||||
|
|
||||||
let pong_timer = tokio::time::sleep(Duration::from_secs(0));
|
|
||||||
tokio::pin!(pong_timer);
|
|
||||||
let mut awaiting_pong = false;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
_ = ping_interval.tick() => {
|
|
||||||
if awaiting_pong {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if send_ping(socket).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
awaiting_pong = true;
|
|
||||||
pong_timer.as_mut().reset(
|
|
||||||
tokio::time::Instant::now() + config.ping_timeout,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
_ = &mut pong_timer, if awaiting_pong => {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Some(exec) = exec_rx.recv() => {
|
|
||||||
if !dispatch(socket, pending, exec, config.max_execution_time, cancel.clone()).await {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
message = socket.next() => {
|
|
||||||
let action = handle_socket_message(socket, message, pending).await;
|
|
||||||
|
|
||||||
if action.break_loop {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if action.clear_pong_deadline {
|
|
||||||
awaiting_pong = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct SocketAction {
|
|
||||||
break_loop: bool,
|
|
||||||
clear_pong_deadline: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_socket_message(
|
|
||||||
socket: &mut DuplexStream,
|
|
||||||
message: Option<Result<Message, rocket_ws::result::Error>>,
|
|
||||||
pending: &PendingReplies,
|
|
||||||
) -> SocketAction {
|
|
||||||
match message {
|
|
||||||
Some(Ok(Message::Pong(_))) => SocketAction {
|
|
||||||
break_loop: false,
|
|
||||||
clear_pong_deadline: true,
|
|
||||||
},
|
|
||||||
Some(Ok(Message::Ping(data))) => SocketAction {
|
|
||||||
break_loop: socket.send(Message::Pong(data.clone())).await.is_err(),
|
|
||||||
clear_pong_deadline: false,
|
|
||||||
},
|
|
||||||
Some(Ok(Message::Text(text))) => {
|
|
||||||
resolve_pending(&text, pending);
|
|
||||||
SocketAction {
|
|
||||||
break_loop: false,
|
|
||||||
clear_pong_deadline: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(Ok(Message::Close(_))) | None => SocketAction {
|
|
||||||
break_loop: true,
|
|
||||||
clear_pong_deadline: false,
|
|
||||||
},
|
|
||||||
_ => SocketAction {
|
|
||||||
break_loop: false,
|
|
||||||
clear_pong_deadline: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_ping(socket: &mut DuplexStream) -> Result<(), rocket_ws::result::Error> {
|
|
||||||
socket.send(Message::Ping(vec![])).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn dispatch(
|
|
||||||
socket: &mut DuplexStream,
|
|
||||||
pending: &PendingReplies,
|
|
||||||
exec: PendingExec,
|
|
||||||
max_execution_time: Duration,
|
|
||||||
cancel: CancellationToken,
|
|
||||||
) -> bool {
|
|
||||||
let exec_id = exec.exec_id.clone();
|
|
||||||
pending.insert(exec_id.clone(), exec.reply);
|
|
||||||
|
|
||||||
let pending_for_timeout = Arc::clone(pending);
|
|
||||||
tokio::spawn(async move {
|
|
||||||
tokio::select! {
|
|
||||||
_ = cancel.cancelled() => {}
|
|
||||||
_ = tokio::time::sleep(max_execution_time) => {
|
|
||||||
pending_for_timeout.remove(&exec_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
send(
|
|
||||||
socket,
|
|
||||||
&ServerMsg::Exec {
|
|
||||||
exec_id: exec.exec_id,
|
|
||||||
command: exec.command,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.is_ok()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resolve_pending(text: &str, pending: &PendingReplies) {
|
|
||||||
let Ok(ClientMsg::ExecResult {
|
|
||||||
exec_id,
|
|
||||||
exit_code,
|
|
||||||
stdout,
|
|
||||||
stderr,
|
|
||||||
}) = serde_json::from_str(text)
|
|
||||||
else {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
if let Some((_, reply)) = pending.remove(&exec_id) {
|
|
||||||
reply
|
|
||||||
.send(ExecResult {
|
|
||||||
exit_code,
|
|
||||||
stdout,
|
|
||||||
stderr,
|
|
||||||
})
|
|
||||||
.ok();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send(
|
|
||||||
socket: &mut DuplexStream,
|
|
||||||
message: &ServerMsg,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
let text = serde_json::to_string(message)?;
|
|
||||||
socket.send(Message::Text(text)).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,131 @@
|
|||||||
|
use super::loop_state::run_loop;
|
||||||
|
use crate::{
|
||||||
|
app::state::{Database, DeviceCounts, PendingExec, Registry, RegistryEntry},
|
||||||
|
core::{config::Config, db::lookup_auth_token},
|
||||||
|
transport::protocol::ServerMsg,
|
||||||
|
};
|
||||||
|
use dashmap::mapref::entry::Entry;
|
||||||
|
use futures::SinkExt;
|
||||||
|
use log::{info, warn};
|
||||||
|
use rocket_ws::{Message, result::Error, stream::DuplexStream};
|
||||||
|
use std::sync::{
|
||||||
|
Arc,
|
||||||
|
atomic::{AtomicUsize, Ordering},
|
||||||
|
};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
pub async fn handle(
|
||||||
|
database: Database,
|
||||||
|
registry: Arc<Registry>,
|
||||||
|
device_counts: Arc<DeviceCounts>,
|
||||||
|
config: Arc<Config>,
|
||||||
|
mut socket: DuplexStream,
|
||||||
|
token: String,
|
||||||
|
device_id: Result<String, String>,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let device_id = match device_id {
|
||||||
|
Ok(device_id) => device_id,
|
||||||
|
Err(reason) => {
|
||||||
|
warn!("Rejecting websocket connection: {reason}");
|
||||||
|
return deny_connection(&mut socket, reason).await;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let user_id = match lookup_auth_token(&database, &token).await {
|
||||||
|
Ok(user_id) => user_id,
|
||||||
|
Err(reason) => {
|
||||||
|
warn!("Rejecting websocket connection for device {device_id}: {reason}");
|
||||||
|
return deny_connection(&mut socket, reason).await;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let device_counter = match claim_device_slot(&device_counts, &user_id, &config) {
|
||||||
|
Ok(counter) => counter,
|
||||||
|
Err(reason) => {
|
||||||
|
warn!("Rejecting websocket connection for user {user_id} device {device_id}: {reason}");
|
||||||
|
return deny_connection(&mut socket, reason).await;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let key = (user_id.clone(), device_id.clone());
|
||||||
|
let (exec_tx, exec_rx) = mpsc::channel::<PendingExec>(config.max_concurrent_executions);
|
||||||
|
if let Err(reason) = register_connection(®istry, key.clone(), exec_tx) {
|
||||||
|
release_device_slot(&device_counts, &user_id, device_counter);
|
||||||
|
warn!("Rejecting duplicate websocket connection for {user_id}/{device_id}: {reason}");
|
||||||
|
return deny_connection(&mut socket, reason).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Accepted websocket connection for {user_id}/{device_id}");
|
||||||
|
send(&mut socket, &ServerMsg::AuthOk).await.ok();
|
||||||
|
|
||||||
|
let cancel = CancellationToken::new();
|
||||||
|
run_loop(&mut socket, exec_rx, &config, &cancel).await;
|
||||||
|
cancel.cancel();
|
||||||
|
|
||||||
|
registry.remove(&key);
|
||||||
|
release_device_slot(&device_counts, &user_id, device_counter);
|
||||||
|
info!("Closed websocket connection for {user_id}/{device_id}");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn register_connection(
|
||||||
|
registry: &Arc<Registry>,
|
||||||
|
key: (String, String),
|
||||||
|
sender: mpsc::Sender<PendingExec>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
match registry.entry(key) {
|
||||||
|
Entry::Occupied(_) => Err("Device already connected".into()),
|
||||||
|
Entry::Vacant(slot) => {
|
||||||
|
slot.insert(RegistryEntry {
|
||||||
|
sender,
|
||||||
|
in_flight: Arc::new(AtomicUsize::new(0)),
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn claim_device_slot(
|
||||||
|
device_counts: &Arc<DeviceCounts>,
|
||||||
|
user_id: &str,
|
||||||
|
config: &Config,
|
||||||
|
) -> Result<Arc<AtomicUsize>, String> {
|
||||||
|
let counter = Arc::clone(
|
||||||
|
&*device_counts
|
||||||
|
.entry(user_id.to_owned())
|
||||||
|
.or_insert_with(|| Arc::new(AtomicUsize::new(0))),
|
||||||
|
);
|
||||||
|
|
||||||
|
if counter.fetch_add(1, Ordering::SeqCst) < config.max_connected_devices {
|
||||||
|
return Ok(counter);
|
||||||
|
}
|
||||||
|
|
||||||
|
counter.fetch_sub(1, Ordering::SeqCst);
|
||||||
|
Err("Too many devices connected for this account".into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn release_device_slot(
|
||||||
|
device_counts: &Arc<DeviceCounts>,
|
||||||
|
user_id: &str,
|
||||||
|
counter: Arc<AtomicUsize>,
|
||||||
|
) {
|
||||||
|
if counter.fetch_sub(1, Ordering::SeqCst) == 1 {
|
||||||
|
device_counts.remove_if(user_id, |_, value| value.load(Ordering::SeqCst) == 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn deny_connection(socket: &mut DuplexStream, reason: String) -> Result<(), Error> {
|
||||||
|
send(socket, &ServerMsg::AuthError { reason }).await.ok();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send(
|
||||||
|
socket: &mut DuplexStream,
|
||||||
|
message: &ServerMsg,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
socket
|
||||||
|
.send(Message::Text(serde_json::to_string(message)?))
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -0,0 +1,131 @@
|
|||||||
|
use crate::{
|
||||||
|
app::state::{ExecResult, PendingExec},
|
||||||
|
transport::protocol::{ClientMsg, ServerMsg},
|
||||||
|
};
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use futures::{SinkExt, StreamExt};
|
||||||
|
use rocket_ws::{Message, stream::DuplexStream};
|
||||||
|
use std::{sync::Arc, time::Duration};
|
||||||
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
type PendingReplies = Arc<DashMap<String, oneshot::Sender<ExecResult>>>;
|
||||||
|
|
||||||
|
pub async fn run_loop(
|
||||||
|
socket: &mut DuplexStream,
|
||||||
|
mut exec_rx: mpsc::Receiver<PendingExec>,
|
||||||
|
config: &crate::core::config::Config,
|
||||||
|
cancel: &CancellationToken,
|
||||||
|
) {
|
||||||
|
let pending = Arc::new(DashMap::new());
|
||||||
|
let mut ping_interval = tokio::time::interval_at(
|
||||||
|
tokio::time::Instant::now() + config.ping_interval,
|
||||||
|
config.ping_interval,
|
||||||
|
);
|
||||||
|
let pong_timer = tokio::time::sleep(Duration::from_secs(0));
|
||||||
|
tokio::pin!(pong_timer);
|
||||||
|
let mut awaiting_pong = false;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
_ = ping_interval.tick() => {
|
||||||
|
if awaiting_pong || socket.send(Message::Ping(vec![])).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
awaiting_pong = true;
|
||||||
|
pong_timer.as_mut().reset(tokio::time::Instant::now() + config.ping_timeout);
|
||||||
|
}
|
||||||
|
_ = &mut pong_timer, if awaiting_pong => break,
|
||||||
|
Some(exec) = exec_rx.recv() => {
|
||||||
|
if !dispatch(socket, &pending, exec, config.max_execution_time, cancel.clone()).await {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
message = socket.next() => match handle_socket_message(socket, message, &pending).await {
|
||||||
|
SocketAction::Break => break,
|
||||||
|
SocketAction::ClearPongDeadline => awaiting_pong = false,
|
||||||
|
SocketAction::Continue => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pending.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
enum SocketAction {
|
||||||
|
Break,
|
||||||
|
ClearPongDeadline,
|
||||||
|
Continue,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_socket_message(
|
||||||
|
socket: &mut DuplexStream,
|
||||||
|
message: Option<Result<Message, rocket_ws::result::Error>>,
|
||||||
|
pending: &PendingReplies,
|
||||||
|
) -> SocketAction {
|
||||||
|
match message {
|
||||||
|
Some(Ok(Message::Pong(_))) => SocketAction::ClearPongDeadline,
|
||||||
|
Some(Ok(Message::Ping(data))) => {
|
||||||
|
if socket.send(Message::Pong(data.clone())).await.is_err() {
|
||||||
|
return SocketAction::Break;
|
||||||
|
}
|
||||||
|
|
||||||
|
SocketAction::Continue
|
||||||
|
}
|
||||||
|
Some(Ok(Message::Text(text))) => {
|
||||||
|
if let Ok(ClientMsg::ExecResult {
|
||||||
|
exec_id,
|
||||||
|
exit_code,
|
||||||
|
stdout,
|
||||||
|
stderr,
|
||||||
|
}) = serde_json::from_str(&text)
|
||||||
|
&& let Some((_, reply)) = pending.remove(&exec_id)
|
||||||
|
{
|
||||||
|
reply
|
||||||
|
.send(ExecResult {
|
||||||
|
exit_code,
|
||||||
|
stdout,
|
||||||
|
stderr,
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
SocketAction::Continue
|
||||||
|
}
|
||||||
|
Some(Ok(Message::Close(_))) | None => SocketAction::Break,
|
||||||
|
_ => SocketAction::Continue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn dispatch(
|
||||||
|
socket: &mut DuplexStream,
|
||||||
|
pending: &PendingReplies,
|
||||||
|
exec: PendingExec,
|
||||||
|
max_execution_time: Duration,
|
||||||
|
cancel: CancellationToken,
|
||||||
|
) -> bool {
|
||||||
|
let exec_id = exec.exec_id.clone();
|
||||||
|
pending.insert(exec_id.clone(), exec.reply);
|
||||||
|
|
||||||
|
let pending_for_timeout = Arc::clone(pending);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
tokio::select! {
|
||||||
|
_ = cancel.cancelled() => {}
|
||||||
|
_ = tokio::time::sleep(max_execution_time) => {
|
||||||
|
pending_for_timeout.remove(&exec_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
socket
|
||||||
|
.send(Message::Text(
|
||||||
|
serde_json::to_string(&ServerMsg::Exec {
|
||||||
|
exec_id: exec.exec_id,
|
||||||
|
command: exec.command,
|
||||||
|
})
|
||||||
|
.expect("serializing server exec message should not fail"),
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.is_ok()
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
mod connection;
|
||||||
|
mod loop_state;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
api::auth::{BearerToken, MaybeDeviceId},
|
||||||
|
app::state::AppState,
|
||||||
|
};
|
||||||
|
use rocket::{State, get};
|
||||||
|
use rocket_ws::{Channel, WebSocket};
|
||||||
|
|
||||||
|
#[get("/connect")]
|
||||||
|
pub fn connect<'r>(
|
||||||
|
ws: WebSocket,
|
||||||
|
token: BearerToken,
|
||||||
|
device_id: MaybeDeviceId,
|
||||||
|
state: &'r State<AppState>,
|
||||||
|
) -> Channel<'r> {
|
||||||
|
let database = state.database.clone();
|
||||||
|
let registry = state.registry.clone();
|
||||||
|
let device_counts = state.device_counts.clone();
|
||||||
|
let config = state.config.clone();
|
||||||
|
|
||||||
|
ws.channel(move |socket| {
|
||||||
|
Box::pin(connection::handle(
|
||||||
|
database,
|
||||||
|
registry,
|
||||||
|
device_counts,
|
||||||
|
config,
|
||||||
|
socket,
|
||||||
|
token.0,
|
||||||
|
device_id.0,
|
||||||
|
))
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user