first commit

This commit is contained in:
2026-05-21 23:47:36 +02:00
commit 546406a85a
14 changed files with 1163 additions and 0 deletions
+59
View File
@@ -0,0 +1,59 @@
use crate::core::validation::validate_device_id;
use rocket::{
Request,
http::Status,
request::{FromRequest, Outcome},
};
const AUTHORIZATION_HEADER: &str = "Authorization";
const DEVICE_ID_HEADER: &str = "X-Device-ID";
const BEARER_PREFIX: &str = "Bearer ";
pub struct BearerToken(pub String);
pub struct DeviceId(pub String);
pub struct MaybeDeviceId(pub Result<String, String>);
#[rocket::async_trait]
impl<'r> FromRequest<'r> for BearerToken {
type Error = &'static str;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let Some(header) = request.headers().get_one(AUTHORIZATION_HEADER) else {
return Outcome::Error((Status::Unauthorized, "Missing Authorization header"));
};
let Some(token) = header.strip_prefix(BEARER_PREFIX) else {
return Outcome::Error((Status::Unauthorized, "Invalid Authorization header format"));
};
Outcome::Success(BearerToken(token.to_owned()))
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for DeviceId {
type Error = String;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match extract_device_id(request) {
Ok(device_id) => Outcome::Success(DeviceId(device_id)),
Err(reason) => Outcome::Error((Status::BadRequest, reason)),
}
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for MaybeDeviceId {
type Error = std::convert::Infallible;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
Outcome::Success(MaybeDeviceId(extract_device_id(request)))
}
}
fn extract_device_id(request: &Request<'_>) -> Result<String, String> {
let Some(device_id) = request.headers().get_one(DEVICE_ID_HEADER) else {
return Err("Missing X-Device-ID header".into());
};
validate_device_id(device_id)?;
Ok(device_id.to_owned())
}
+6
View File
@@ -0,0 +1,6 @@
use rocket::{Request, catch, http::Status, response::content::RawJson};
#[catch(default)]
pub fn default_catcher(status: Status, _request: &Request) -> RawJson<String> {
RawJson(serde_json::json!({ "error": status.reason_lossy() }).to_string())
}
+354
View File
@@ -0,0 +1,354 @@
use crate::{
AppState,
api::auth::BearerToken,
core::db::lookup_api_token,
transport::execute::{ExecuteRequest, execute_for_token},
};
use rocket::{State, http::Status, post, response::content::RawJson, serde::json::Json};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
type ApiResponse = (Status, RawJson<String>);
const JSONRPC_VERSION: &str = "2.0";
const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
#[derive(Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
pub id: Option<Value>,
pub method: String,
#[serde(default)]
pub params: Option<Value>,
}
#[derive(Serialize)]
struct JsonRpcResponse<T> {
jsonrpc: &'static str,
id: Option<Value>,
result: T,
}
#[derive(Serialize)]
struct JsonRpcErrorResponse {
jsonrpc: &'static str,
id: Option<Value>,
error: JsonRpcError,
}
#[derive(Serialize)]
struct JsonRpcError {
code: i64,
message: String,
}
#[derive(Serialize)]
struct InitializeResult {
#[serde(rename = "protocolVersion")]
protocol_version: &'static str,
capabilities: ServerCapabilities,
#[serde(rename = "serverInfo")]
server_info: ServerInfo,
}
#[derive(Serialize)]
struct ServerCapabilities {
tools: ToolsCapability,
}
#[derive(Serialize)]
struct ToolsCapability {
list_changed: bool,
}
#[derive(Serialize)]
struct ServerInfo {
name: &'static str,
version: &'static str,
}
#[derive(Serialize)]
struct ToolsListResult {
tools: Vec<ToolDefinition>,
}
#[derive(Serialize)]
struct ToolDefinition {
name: &'static str,
description: &'static str,
#[serde(rename = "inputSchema")]
input_schema: Value,
}
#[derive(Deserialize)]
struct ToolCallParams {
name: String,
arguments: Value,
}
#[derive(Serialize)]
struct ToolCallResult {
content: Vec<TextContent>,
#[serde(rename = "isError")]
is_error: bool,
}
#[derive(Serialize)]
struct ConnectedDevicesResult {
device_ids: Vec<String>,
}
enum ToolCall {
Execute(ExecuteRequest),
ListDevices,
}
#[derive(Serialize)]
struct TextContent {
#[serde(rename = "type")]
kind: &'static str,
text: String,
}
#[post("/mcp", data = "<body>")]
pub async fn route(
body: Json<JsonRpcRequest>,
token: BearerToken,
state: &State<AppState>,
) -> ApiResponse {
let request = body.into_inner();
if request.jsonrpc != JSONRPC_VERSION {
return jsonrpc_error(
Status::BadRequest,
request.id,
-32600,
"Invalid JSON-RPC version",
);
}
match request.method.as_str() {
"initialize" => jsonrpc_ok(
request.id,
InitializeResult {
protocol_version: MCP_PROTOCOL_VERSION,
capabilities: ServerCapabilities {
tools: ToolsCapability {
list_changed: false,
},
},
server_info: ServerInfo {
name: "folyo-bridge",
version: env!("CARGO_PKG_VERSION"),
},
},
),
"notifications/initialized" => jsonrpc_notification_ok(),
"tools/list" => jsonrpc_ok(
request.id,
ToolsListResult {
tools: tool_definitions(),
},
),
"tools/call" => handle_tool_call(request.id, request.params, &token.0, state).await,
_ if request.id.is_none() => jsonrpc_notification_ok(),
_ => jsonrpc_error(Status::BadRequest, request.id, -32601, "Method not found"),
}
}
async fn handle_tool_call(
id: Option<Value>,
params: Option<Value>,
token: &str,
state: &State<AppState>,
) -> ApiResponse {
let Some(params) = params else {
return jsonrpc_error(Status::BadRequest, id, -32602, "Missing params");
};
let tool_call = match parse_tool_call(params) {
Ok(call) => call,
Err(_) => return jsonrpc_error(Status::BadRequest, id, -32602, "Invalid params"),
};
match tool_call {
ToolCall::Execute(request) => execute_tool_call(id, token, state, request).await,
ToolCall::ListDevices => list_devices_tool_call(id, token, state).await,
}
}
fn tool_definitions() -> Vec<ToolDefinition> {
vec![
ToolDefinition {
name: "execute",
description: "Execute a shell command on a connected device owned by the authenticated user.",
input_schema: json!({
"type": "object",
"required": ["device_id", "command"],
"properties": {
"device_id": {
"type": "string",
"description": "Identifier of the connected device."
},
"command": {
"type": "string",
"description": "Shell command to execute on the target device."
}
},
"additionalProperties": false
}),
},
ToolDefinition {
name: "list_devices",
description: "List currently connected device IDs owned by the authenticated user.",
input_schema: json!({
"type": "object",
"properties": {},
"additionalProperties": false
}),
},
]
}
fn parse_tool_call(params: Value) -> Result<ToolCall, &'static str> {
let call: ToolCallParams = serde_json::from_value(params).map_err(|_| "Invalid params")?;
match call.name.as_str() {
"execute" => serde_json::from_value(call.arguments)
.map(ToolCall::Execute)
.map_err(|_| "Invalid execute arguments"),
"list_devices" => parse_list_devices_arguments(call.arguments),
_ => Err("Unknown tool"),
}
}
fn parse_list_devices_arguments(arguments: Value) -> Result<ToolCall, &'static str> {
match arguments {
Value::Object(map) if map.is_empty() => Ok(ToolCall::ListDevices),
_ => Err("Invalid list_devices arguments"),
}
}
fn connected_device_ids_for_user(state: &State<AppState>, user_id: &str) -> Vec<String> {
state
.registry
.iter()
.filter_map(|entry| {
let ((entry_user_id, device_id), _) = entry.pair();
(entry_user_id == user_id).then(|| device_id.clone())
})
.collect()
}
async fn execute_tool_call(
id: Option<Value>,
token: &str,
state: &State<AppState>,
request: ExecuteRequest,
) -> ApiResponse {
match execute_for_token(state, token, request).await {
Ok(result) => json_tool_result(id, &result),
Err((status, message)) if status == Status::Unauthorized => {
jsonrpc_error(Status::Unauthorized, id, -32001, message)
}
Err((_, message)) => jsonrpc_ok(id, tool_error(message)),
}
}
async fn list_devices_tool_call(
id: Option<Value>,
token: &str,
state: &State<AppState>,
) -> ApiResponse {
let user_id = match lookup_api_token(&state.database, token).await {
Ok(user_id) => user_id,
Err(message) => return jsonrpc_error(Status::Unauthorized, id, -32001, message),
};
let mut device_ids = connected_device_ids_for_user(state, &user_id);
device_ids.sort();
json_tool_result(id, &ConnectedDevicesResult { device_ids })
}
fn json_tool_result<T: Serialize>(id: Option<Value>, payload: &T) -> ApiResponse {
match serde_json::to_string(payload) {
Ok(text) => jsonrpc_ok(id, ToolCallResult::success(text)),
Err(error) => jsonrpc_error(Status::InternalServerError, id, -32603, error.to_string()),
}
}
fn tool_error(message: String) -> ToolCallResult {
ToolCallResult::error(message)
}
impl ToolCallResult {
fn success(text: String) -> Self {
Self {
content: vec![TextContent::text(text)],
is_error: false,
}
}
fn error(text: String) -> Self {
Self {
content: vec![TextContent::text(text)],
is_error: true,
}
}
}
impl TextContent {
fn text(text: String) -> Self {
Self { kind: "text", text }
}
}
fn jsonrpc_ok<T: Serialize>(id: Option<Value>, result: T) -> ApiResponse {
match serde_json::to_string(&JsonRpcResponse {
jsonrpc: JSONRPC_VERSION,
id,
result,
}) {
Ok(body) => (Status::Ok, RawJson(body)),
Err(e) => jsonrpc_error(Status::InternalServerError, None, -32603, e.to_string()),
}
}
fn jsonrpc_notification_ok() -> ApiResponse {
(Status::Accepted, RawJson(String::new()))
}
fn jsonrpc_error(
status: Status,
id: Option<Value>,
code: i64,
message: impl Into<String>,
) -> ApiResponse {
let response = JsonRpcErrorResponse {
jsonrpc: JSONRPC_VERSION,
id,
error: JsonRpcError {
code,
message: message.into(),
},
};
match serde_json::to_string(&response) {
Ok(body) => (status, RawJson(body)),
Err(e) => (
Status::InternalServerError,
RawJson(
json!({
"jsonrpc": JSONRPC_VERSION,
"id": null,
"error": {
"code": -32603,
"message": e.to_string()
}
})
.to_string(),
),
),
}
}