first commit
This commit is contained in:
@@ -0,0 +1,59 @@
|
||||
use crate::core::validation::validate_device_id;
|
||||
use rocket::{
|
||||
Request,
|
||||
http::Status,
|
||||
request::{FromRequest, Outcome},
|
||||
};
|
||||
|
||||
const AUTHORIZATION_HEADER: &str = "Authorization";
|
||||
const DEVICE_ID_HEADER: &str = "X-Device-ID";
|
||||
const BEARER_PREFIX: &str = "Bearer ";
|
||||
|
||||
pub struct BearerToken(pub String);
|
||||
pub struct DeviceId(pub String);
|
||||
pub struct MaybeDeviceId(pub Result<String, String>);
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for BearerToken {
|
||||
type Error = &'static str;
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||
let Some(header) = request.headers().get_one(AUTHORIZATION_HEADER) else {
|
||||
return Outcome::Error((Status::Unauthorized, "Missing Authorization header"));
|
||||
};
|
||||
let Some(token) = header.strip_prefix(BEARER_PREFIX) else {
|
||||
return Outcome::Error((Status::Unauthorized, "Invalid Authorization header format"));
|
||||
};
|
||||
Outcome::Success(BearerToken(token.to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for DeviceId {
|
||||
type Error = String;
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||
match extract_device_id(request) {
|
||||
Ok(device_id) => Outcome::Success(DeviceId(device_id)),
|
||||
Err(reason) => Outcome::Error((Status::BadRequest, reason)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[rocket::async_trait]
|
||||
impl<'r> FromRequest<'r> for MaybeDeviceId {
|
||||
type Error = std::convert::Infallible;
|
||||
|
||||
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||
Outcome::Success(MaybeDeviceId(extract_device_id(request)))
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_device_id(request: &Request<'_>) -> Result<String, String> {
|
||||
let Some(device_id) = request.headers().get_one(DEVICE_ID_HEADER) else {
|
||||
return Err("Missing X-Device-ID header".into());
|
||||
};
|
||||
|
||||
validate_device_id(device_id)?;
|
||||
Ok(device_id.to_owned())
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
use rocket::{Request, catch, http::Status, response::content::RawJson};
|
||||
|
||||
#[catch(default)]
|
||||
pub fn default_catcher(status: Status, _request: &Request) -> RawJson<String> {
|
||||
RawJson(serde_json::json!({ "error": status.reason_lossy() }).to_string())
|
||||
}
|
||||
+354
@@ -0,0 +1,354 @@
|
||||
use crate::{
|
||||
AppState,
|
||||
api::auth::BearerToken,
|
||||
core::db::lookup_api_token,
|
||||
transport::execute::{ExecuteRequest, execute_for_token},
|
||||
};
|
||||
use rocket::{State, http::Status, post, response::content::RawJson, serde::json::Json};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
type ApiResponse = (Status, RawJson<String>);
|
||||
|
||||
const JSONRPC_VERSION: &str = "2.0";
|
||||
const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct JsonRpcRequest {
|
||||
pub jsonrpc: String,
|
||||
pub id: Option<Value>,
|
||||
pub method: String,
|
||||
#[serde(default)]
|
||||
pub params: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct JsonRpcResponse<T> {
|
||||
jsonrpc: &'static str,
|
||||
id: Option<Value>,
|
||||
result: T,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct JsonRpcErrorResponse {
|
||||
jsonrpc: &'static str,
|
||||
id: Option<Value>,
|
||||
error: JsonRpcError,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct JsonRpcError {
|
||||
code: i64,
|
||||
message: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct InitializeResult {
|
||||
#[serde(rename = "protocolVersion")]
|
||||
protocol_version: &'static str,
|
||||
capabilities: ServerCapabilities,
|
||||
#[serde(rename = "serverInfo")]
|
||||
server_info: ServerInfo,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ServerCapabilities {
|
||||
tools: ToolsCapability,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ToolsCapability {
|
||||
list_changed: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ServerInfo {
|
||||
name: &'static str,
|
||||
version: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ToolsListResult {
|
||||
tools: Vec<ToolDefinition>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ToolDefinition {
|
||||
name: &'static str,
|
||||
description: &'static str,
|
||||
#[serde(rename = "inputSchema")]
|
||||
input_schema: Value,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ToolCallParams {
|
||||
name: String,
|
||||
arguments: Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ToolCallResult {
|
||||
content: Vec<TextContent>,
|
||||
#[serde(rename = "isError")]
|
||||
is_error: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ConnectedDevicesResult {
|
||||
device_ids: Vec<String>,
|
||||
}
|
||||
|
||||
enum ToolCall {
|
||||
Execute(ExecuteRequest),
|
||||
ListDevices,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct TextContent {
|
||||
#[serde(rename = "type")]
|
||||
kind: &'static str,
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[post("/mcp", data = "<body>")]
|
||||
pub async fn route(
|
||||
body: Json<JsonRpcRequest>,
|
||||
token: BearerToken,
|
||||
state: &State<AppState>,
|
||||
) -> ApiResponse {
|
||||
let request = body.into_inner();
|
||||
|
||||
if request.jsonrpc != JSONRPC_VERSION {
|
||||
return jsonrpc_error(
|
||||
Status::BadRequest,
|
||||
request.id,
|
||||
-32600,
|
||||
"Invalid JSON-RPC version",
|
||||
);
|
||||
}
|
||||
|
||||
match request.method.as_str() {
|
||||
"initialize" => jsonrpc_ok(
|
||||
request.id,
|
||||
InitializeResult {
|
||||
protocol_version: MCP_PROTOCOL_VERSION,
|
||||
capabilities: ServerCapabilities {
|
||||
tools: ToolsCapability {
|
||||
list_changed: false,
|
||||
},
|
||||
},
|
||||
server_info: ServerInfo {
|
||||
name: "folyo-bridge",
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
},
|
||||
},
|
||||
),
|
||||
"notifications/initialized" => jsonrpc_notification_ok(),
|
||||
"tools/list" => jsonrpc_ok(
|
||||
request.id,
|
||||
ToolsListResult {
|
||||
tools: tool_definitions(),
|
||||
},
|
||||
),
|
||||
"tools/call" => handle_tool_call(request.id, request.params, &token.0, state).await,
|
||||
_ if request.id.is_none() => jsonrpc_notification_ok(),
|
||||
_ => jsonrpc_error(Status::BadRequest, request.id, -32601, "Method not found"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_tool_call(
|
||||
id: Option<Value>,
|
||||
params: Option<Value>,
|
||||
token: &str,
|
||||
state: &State<AppState>,
|
||||
) -> ApiResponse {
|
||||
let Some(params) = params else {
|
||||
return jsonrpc_error(Status::BadRequest, id, -32602, "Missing params");
|
||||
};
|
||||
|
||||
let tool_call = match parse_tool_call(params) {
|
||||
Ok(call) => call,
|
||||
Err(_) => return jsonrpc_error(Status::BadRequest, id, -32602, "Invalid params"),
|
||||
};
|
||||
|
||||
match tool_call {
|
||||
ToolCall::Execute(request) => execute_tool_call(id, token, state, request).await,
|
||||
ToolCall::ListDevices => list_devices_tool_call(id, token, state).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn tool_definitions() -> Vec<ToolDefinition> {
|
||||
vec![
|
||||
ToolDefinition {
|
||||
name: "execute",
|
||||
description: "Execute a shell command on a connected device owned by the authenticated user.",
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"required": ["device_id", "command"],
|
||||
"properties": {
|
||||
"device_id": {
|
||||
"type": "string",
|
||||
"description": "Identifier of the connected device."
|
||||
},
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Shell command to execute on the target device."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}),
|
||||
},
|
||||
ToolDefinition {
|
||||
name: "list_devices",
|
||||
description: "List currently connected device IDs owned by the authenticated user.",
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": false
|
||||
}),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn parse_tool_call(params: Value) -> Result<ToolCall, &'static str> {
|
||||
let call: ToolCallParams = serde_json::from_value(params).map_err(|_| "Invalid params")?;
|
||||
|
||||
match call.name.as_str() {
|
||||
"execute" => serde_json::from_value(call.arguments)
|
||||
.map(ToolCall::Execute)
|
||||
.map_err(|_| "Invalid execute arguments"),
|
||||
"list_devices" => parse_list_devices_arguments(call.arguments),
|
||||
_ => Err("Unknown tool"),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_list_devices_arguments(arguments: Value) -> Result<ToolCall, &'static str> {
|
||||
match arguments {
|
||||
Value::Object(map) if map.is_empty() => Ok(ToolCall::ListDevices),
|
||||
_ => Err("Invalid list_devices arguments"),
|
||||
}
|
||||
}
|
||||
|
||||
fn connected_device_ids_for_user(state: &State<AppState>, user_id: &str) -> Vec<String> {
|
||||
state
|
||||
.registry
|
||||
.iter()
|
||||
.filter_map(|entry| {
|
||||
let ((entry_user_id, device_id), _) = entry.pair();
|
||||
(entry_user_id == user_id).then(|| device_id.clone())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn execute_tool_call(
|
||||
id: Option<Value>,
|
||||
token: &str,
|
||||
state: &State<AppState>,
|
||||
request: ExecuteRequest,
|
||||
) -> ApiResponse {
|
||||
match execute_for_token(state, token, request).await {
|
||||
Ok(result) => json_tool_result(id, &result),
|
||||
Err((status, message)) if status == Status::Unauthorized => {
|
||||
jsonrpc_error(Status::Unauthorized, id, -32001, message)
|
||||
}
|
||||
Err((_, message)) => jsonrpc_ok(id, tool_error(message)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_devices_tool_call(
|
||||
id: Option<Value>,
|
||||
token: &str,
|
||||
state: &State<AppState>,
|
||||
) -> ApiResponse {
|
||||
let user_id = match lookup_api_token(&state.database, token).await {
|
||||
Ok(user_id) => user_id,
|
||||
Err(message) => return jsonrpc_error(Status::Unauthorized, id, -32001, message),
|
||||
};
|
||||
|
||||
let mut device_ids = connected_device_ids_for_user(state, &user_id);
|
||||
device_ids.sort();
|
||||
|
||||
json_tool_result(id, &ConnectedDevicesResult { device_ids })
|
||||
}
|
||||
|
||||
fn json_tool_result<T: Serialize>(id: Option<Value>, payload: &T) -> ApiResponse {
|
||||
match serde_json::to_string(payload) {
|
||||
Ok(text) => jsonrpc_ok(id, ToolCallResult::success(text)),
|
||||
Err(error) => jsonrpc_error(Status::InternalServerError, id, -32603, error.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn tool_error(message: String) -> ToolCallResult {
|
||||
ToolCallResult::error(message)
|
||||
}
|
||||
|
||||
impl ToolCallResult {
|
||||
fn success(text: String) -> Self {
|
||||
Self {
|
||||
content: vec![TextContent::text(text)],
|
||||
is_error: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn error(text: String) -> Self {
|
||||
Self {
|
||||
content: vec![TextContent::text(text)],
|
||||
is_error: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TextContent {
|
||||
fn text(text: String) -> Self {
|
||||
Self { kind: "text", text }
|
||||
}
|
||||
}
|
||||
|
||||
fn jsonrpc_ok<T: Serialize>(id: Option<Value>, result: T) -> ApiResponse {
|
||||
match serde_json::to_string(&JsonRpcResponse {
|
||||
jsonrpc: JSONRPC_VERSION,
|
||||
id,
|
||||
result,
|
||||
}) {
|
||||
Ok(body) => (Status::Ok, RawJson(body)),
|
||||
Err(e) => jsonrpc_error(Status::InternalServerError, None, -32603, e.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn jsonrpc_notification_ok() -> ApiResponse {
|
||||
(Status::Accepted, RawJson(String::new()))
|
||||
}
|
||||
|
||||
fn jsonrpc_error(
|
||||
status: Status,
|
||||
id: Option<Value>,
|
||||
code: i64,
|
||||
message: impl Into<String>,
|
||||
) -> ApiResponse {
|
||||
let response = JsonRpcErrorResponse {
|
||||
jsonrpc: JSONRPC_VERSION,
|
||||
id,
|
||||
error: JsonRpcError {
|
||||
code,
|
||||
message: message.into(),
|
||||
},
|
||||
};
|
||||
|
||||
match serde_json::to_string(&response) {
|
||||
Ok(body) => (status, RawJson(body)),
|
||||
Err(e) => (
|
||||
Status::InternalServerError,
|
||||
RawJson(
|
||||
json!({
|
||||
"jsonrpc": JSONRPC_VERSION,
|
||||
"id": null,
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": e.to_string()
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
),
|
||||
),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
use std::time::Duration;
|
||||
|
||||
const DEFAULT_MAX_CONCURRENT_EXECUTIONS: usize = 16;
|
||||
const DEFAULT_MAX_CONNECTED_DEVICES: usize = 10;
|
||||
const DEFAULT_MAX_EXECUTION_SECS: u64 = 3600;
|
||||
const DEFAULT_MAX_COMMAND_LENGTH: usize = 65_536;
|
||||
const DEFAULT_PING_INTERVAL_SECS: u64 = 30;
|
||||
const DEFAULT_PING_TIMEOUT_SECS: u64 = 10;
|
||||
|
||||
pub struct Config {
|
||||
pub max_concurrent_executions: usize,
|
||||
pub max_connected_devices: usize,
|
||||
pub max_execution_time: Duration,
|
||||
pub max_command_length: usize,
|
||||
pub ping_interval: Duration,
|
||||
pub ping_timeout: Duration,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn from_env() -> Self {
|
||||
Self {
|
||||
max_concurrent_executions: parse_env(
|
||||
"MAX_CONCURRENT_EXECUTIONS",
|
||||
DEFAULT_MAX_CONCURRENT_EXECUTIONS,
|
||||
),
|
||||
max_connected_devices: parse_env(
|
||||
"MAX_CONNECTED_DEVICES",
|
||||
DEFAULT_MAX_CONNECTED_DEVICES,
|
||||
),
|
||||
max_execution_time: Duration::from_secs(parse_env(
|
||||
"MAX_EXECUTION_SECS",
|
||||
DEFAULT_MAX_EXECUTION_SECS,
|
||||
)),
|
||||
max_command_length: parse_env("MAX_COMMAND_LENGTH", DEFAULT_MAX_COMMAND_LENGTH),
|
||||
ping_interval: Duration::from_secs(parse_env(
|
||||
"PING_INTERVAL_SECS",
|
||||
DEFAULT_PING_INTERVAL_SECS,
|
||||
)),
|
||||
ping_timeout: Duration::from_secs(parse_env(
|
||||
"PING_TIMEOUT_SECS",
|
||||
DEFAULT_PING_TIMEOUT_SECS,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_env<T: std::str::FromStr>(key: &str, default: T) -> T {
|
||||
std::env::var(key)
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
use crate::Database;
|
||||
|
||||
const AUTH_TOKENS_TABLE: &str = "auth_tokens";
|
||||
const API_TOKENS_TABLE: &str = "api_tokens";
|
||||
|
||||
pub async fn lookup_auth_token(database: &Database, token: &str) -> Result<String, String> {
|
||||
lookup_token(database, AUTH_TOKENS_TABLE, token).await
|
||||
}
|
||||
|
||||
pub async fn lookup_api_token(database: &Database, token: &str) -> Result<String, String> {
|
||||
lookup_token(database, API_TOKENS_TABLE, token).await
|
||||
}
|
||||
|
||||
async fn lookup_token(
|
||||
database: &Database,
|
||||
table: &'static str,
|
||||
token: &str,
|
||||
) -> Result<String, String> {
|
||||
let sql = format!(
|
||||
"SELECT t.user_id \
|
||||
FROM {table} t \
|
||||
JOIN users u ON u.id = t.user_id \
|
||||
WHERE t.token = ? \
|
||||
AND (t.expires_at IS NULL OR t.expires_at > DATETIME('now')) \
|
||||
AND (u.expires_at IS NULL OR u.expires_at > DATETIME('now'))"
|
||||
);
|
||||
|
||||
sqlx::query_scalar::<_, String>(sqlx::AssertSqlSafe(sql))
|
||||
.bind(token)
|
||||
.fetch_optional(database)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?
|
||||
.ok_or_else(|| "Invalid or expired token".to_string())
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
pub fn validate_device_id(id: &str) -> Result<(), String> {
|
||||
if id.is_empty() {
|
||||
return Err("device_id cannot be empty".into());
|
||||
}
|
||||
if id.len() > 64 {
|
||||
return Err("device_id too long (max 64 characters)".into());
|
||||
}
|
||||
if !id
|
||||
.chars()
|
||||
.all(|c| c.is_alphanumeric() || c == '-' || c == '_')
|
||||
{
|
||||
return Err("device_id contains invalid characters (allowed: a-z A-Z 0-9 - _)".into());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
+104
@@ -0,0 +1,104 @@
|
||||
pub mod api {
|
||||
pub mod auth;
|
||||
pub mod catchers;
|
||||
pub mod mcp;
|
||||
}
|
||||
|
||||
pub mod core {
|
||||
pub mod config;
|
||||
pub mod db;
|
||||
pub mod validation;
|
||||
}
|
||||
|
||||
pub mod transport {
|
||||
pub mod execute;
|
||||
pub mod protocol;
|
||||
pub mod socket;
|
||||
}
|
||||
|
||||
use core::config::Config;
|
||||
use dashmap::DashMap;
|
||||
use log::info;
|
||||
use rocket::routes;
|
||||
use serde::Serialize;
|
||||
use sqlx::SqlitePool;
|
||||
use sqlx::sqlite::SqliteConnectOptions;
|
||||
use std::sync::{Arc, atomic::AtomicUsize};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
const DATABASE_PATH: &str = "server.db";
|
||||
|
||||
pub type Database = SqlitePool;
|
||||
pub type DeviceCounts = DashMap<String, Arc<AtomicUsize>>;
|
||||
pub type Registry = DashMap<(String, String), RegistryEntry>;
|
||||
|
||||
pub struct RegistryEntry {
|
||||
pub sender: mpsc::Sender<PendingExec>,
|
||||
pub in_flight: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ExecResult {
|
||||
pub exit_code: i32,
|
||||
pub stdout: String,
|
||||
pub stderr: String,
|
||||
}
|
||||
|
||||
pub struct PendingExec {
|
||||
pub exec_id: String,
|
||||
pub command: String,
|
||||
pub reply: oneshot::Sender<ExecResult>,
|
||||
}
|
||||
|
||||
pub struct AppState {
|
||||
pub database: Database,
|
||||
pub registry: Arc<Registry>,
|
||||
pub device_counts: Arc<DeviceCounts>,
|
||||
pub config: Arc<Config>,
|
||||
}
|
||||
|
||||
#[rocket::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
env_logger::init();
|
||||
|
||||
let database = connect_database().await?;
|
||||
apply_schema(&database).await?;
|
||||
|
||||
info!("Launching server");
|
||||
|
||||
build_rocket(database).launch().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn connect_database() -> Result<Database, sqlx::Error> {
|
||||
let options = SqliteConnectOptions::new()
|
||||
.filename(DATABASE_PATH)
|
||||
.create_if_missing(true);
|
||||
|
||||
SqlitePool::connect_with(options).await
|
||||
}
|
||||
|
||||
async fn apply_schema(database: &Database) -> Result<(), sqlx::Error> {
|
||||
for statement in include_str!("../schema.sql")
|
||||
.split(';')
|
||||
.map(str::trim)
|
||||
.filter(|statement| !statement.is_empty())
|
||||
{
|
||||
sqlx::query(statement).execute(database).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_rocket(database: Database) -> rocket::Rocket<rocket::Build> {
|
||||
rocket::build()
|
||||
.manage(AppState {
|
||||
database,
|
||||
registry: Arc::new(DashMap::new()),
|
||||
device_counts: Arc::new(DashMap::new()),
|
||||
config: Arc::new(Config::from_env()),
|
||||
})
|
||||
.mount("/", routes![transport::socket::connect, api::mcp::route])
|
||||
.register("/", rocket::catchers![api::catchers::default_catcher])
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
use crate::{
|
||||
AppState, ExecResult, PendingExec, Registry,
|
||||
core::{config::Config, db::lookup_api_token, validation::validate_device_id},
|
||||
};
|
||||
use rocket::{State, http::Status};
|
||||
use serde::Deserialize;
|
||||
use std::sync::{Arc, atomic::Ordering};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ExecuteRequest {
|
||||
pub device_id: String,
|
||||
pub command: String,
|
||||
}
|
||||
|
||||
pub async fn execute_for_token(
|
||||
state: &State<AppState>,
|
||||
token: &str,
|
||||
request: ExecuteRequest,
|
||||
) -> Result<ExecResult, (Status, String)> {
|
||||
let user_id = lookup_api_token(&state.database, token)
|
||||
.await
|
||||
.map_err(|e| (Status::Unauthorized, e))?;
|
||||
|
||||
execute_for_user(&state.registry, &user_id, request, &state.config).await
|
||||
}
|
||||
|
||||
pub async fn execute_for_user(
|
||||
registry: &Arc<Registry>,
|
||||
user_id: &str,
|
||||
request: ExecuteRequest,
|
||||
config: &Config,
|
||||
) -> Result<ExecResult, (Status, String)> {
|
||||
let ExecuteRequest { device_id, command } = request;
|
||||
|
||||
validate_execute_request(&device_id, &command, config)?;
|
||||
|
||||
execute(registry, user_id, &device_id, command, config)
|
||||
.await
|
||||
.map_err(|e| (Status::BadGateway, e))
|
||||
}
|
||||
|
||||
fn validate_execute_request(
|
||||
device_id: &str,
|
||||
command: &str,
|
||||
config: &Config,
|
||||
) -> Result<(), (Status, String)> {
|
||||
validate_device_id(device_id).map_err(|error| (Status::BadRequest, error))?;
|
||||
|
||||
if command.len() <= config.max_command_length {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err((
|
||||
Status::PayloadTooLarge,
|
||||
format!(
|
||||
"Command exceeds maximum length of {} bytes",
|
||||
config.max_command_length
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
registry: &Arc<Registry>,
|
||||
user_id: &str,
|
||||
device_id: &str,
|
||||
command: String,
|
||||
config: &Config,
|
||||
) -> Result<ExecResult, String> {
|
||||
let key = (user_id.to_owned(), device_id.to_owned());
|
||||
|
||||
let (sender, in_flight_counter) = registry
|
||||
.get(&key)
|
||||
.map(|entry| (entry.sender.clone(), Arc::clone(&entry.in_flight)))
|
||||
.ok_or_else(|| "Device not connected".to_string())?;
|
||||
|
||||
claim_execution_slot(&in_flight_counter, config.max_concurrent_executions)?;
|
||||
|
||||
let result = send_and_await(sender, command, config).await;
|
||||
in_flight_counter.fetch_sub(1, Ordering::SeqCst);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
async fn send_and_await(
|
||||
sender: mpsc::Sender<PendingExec>,
|
||||
command: String,
|
||||
config: &Config,
|
||||
) -> Result<ExecResult, String> {
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
|
||||
sender
|
||||
.send(PendingExec {
|
||||
exec_id: Uuid::new_v4().to_string(),
|
||||
command,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| "Device disconnected".to_string())?;
|
||||
|
||||
tokio::time::timeout(config.max_execution_time, reply_rx)
|
||||
.await
|
||||
.map_err(|_| "Command timed out".to_string())?
|
||||
.map_err(|_| "Device disconnected".to_string())
|
||||
}
|
||||
|
||||
fn claim_execution_slot(
|
||||
counter: &Arc<std::sync::atomic::AtomicUsize>,
|
||||
limit: usize,
|
||||
) -> Result<(), String> {
|
||||
let previous = counter.fetch_add(1, Ordering::SeqCst);
|
||||
if previous < limit {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
counter.fetch_sub(1, Ordering::SeqCst);
|
||||
Err(format!(
|
||||
"Device has reached the limit of {} concurrent executions",
|
||||
limit,
|
||||
))
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ClientMsg {
|
||||
ExecResult {
|
||||
exec_id: String,
|
||||
exit_code: i32,
|
||||
stdout: String,
|
||||
stderr: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ServerMsg {
|
||||
AuthOk,
|
||||
AuthError { reason: String },
|
||||
Exec { exec_id: String, command: String },
|
||||
}
|
||||
@@ -0,0 +1,317 @@
|
||||
use crate::{
|
||||
AppState, Database, DeviceCounts, ExecResult, PendingExec, Registry, RegistryEntry,
|
||||
api::auth::{BearerToken, MaybeDeviceId},
|
||||
core::config::Config,
|
||||
core::db::lookup_auth_token,
|
||||
transport::protocol::{ClientMsg, ServerMsg},
|
||||
};
|
||||
use dashmap::{DashMap, mapref::entry::Entry};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use log::{info, warn};
|
||||
use rocket::{State, get};
|
||||
use rocket_ws::{Channel, Message, WebSocket, stream::DuplexStream};
|
||||
use std::{
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
type PendingReplies = Arc<DashMap<String, oneshot::Sender<ExecResult>>>;
|
||||
|
||||
#[get("/connect")]
|
||||
pub fn connect<'r>(
|
||||
ws: WebSocket,
|
||||
token: BearerToken,
|
||||
device_id: MaybeDeviceId,
|
||||
state: &'r State<AppState>,
|
||||
) -> Channel<'r> {
|
||||
let database = state.database.clone();
|
||||
let registry = state.registry.clone();
|
||||
let device_counts = state.device_counts.clone();
|
||||
let config = state.config.clone();
|
||||
ws.channel(move |socket| {
|
||||
Box::pin(handle(
|
||||
database,
|
||||
registry,
|
||||
device_counts,
|
||||
config,
|
||||
socket,
|
||||
token.0,
|
||||
device_id.0,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
database: Database,
|
||||
registry: Arc<Registry>,
|
||||
device_counts: Arc<DeviceCounts>,
|
||||
config: Arc<Config>,
|
||||
mut socket: DuplexStream,
|
||||
token: String,
|
||||
device_id: Result<String, String>,
|
||||
) -> Result<(), rocket_ws::result::Error> {
|
||||
let device_id = match device_id {
|
||||
Ok(device_id) => device_id,
|
||||
Err(reason) => {
|
||||
warn!("Rejecting websocket connection: {reason}");
|
||||
return deny_connection(&mut socket, reason).await;
|
||||
}
|
||||
};
|
||||
|
||||
let user_id = match lookup_auth_token(&database, &token).await {
|
||||
Ok(user_id) => user_id,
|
||||
Err(reason) => {
|
||||
warn!("Rejecting websocket connection for device {device_id}: {reason}");
|
||||
return deny_connection(&mut socket, reason).await;
|
||||
}
|
||||
};
|
||||
|
||||
let device_counter = match claim_device_slot(&device_counts, &user_id, &config) {
|
||||
Ok(counter) => counter,
|
||||
Err(reason) => {
|
||||
warn!("Rejecting websocket connection for user {user_id} device {device_id}: {reason}");
|
||||
return deny_connection(&mut socket, reason).await;
|
||||
}
|
||||
};
|
||||
|
||||
let (exec_tx, exec_rx) = mpsc::channel::<PendingExec>(config.max_concurrent_executions);
|
||||
let key = (user_id.clone(), device_id.clone());
|
||||
|
||||
if let Err(reason) = register_connection(®istry, key.clone(), exec_tx) {
|
||||
release_device_slot(&device_counts, &user_id, device_counter);
|
||||
warn!("Rejecting duplicate websocket connection for {user_id}/{device_id}: {reason}");
|
||||
return deny_connection(&mut socket, reason).await;
|
||||
}
|
||||
|
||||
info!("Accepted websocket connection for {user_id}/{device_id}");
|
||||
|
||||
send(&mut socket, &ServerMsg::AuthOk).await.ok();
|
||||
|
||||
let pending = Arc::new(DashMap::new());
|
||||
let cancel = CancellationToken::new();
|
||||
run_loop(&mut socket, exec_rx, &pending, &config, &cancel).await;
|
||||
cancel.cancel();
|
||||
|
||||
registry.remove(&key);
|
||||
pending.clear();
|
||||
release_device_slot(&device_counts, &user_id, device_counter);
|
||||
info!("Closed websocket connection for {user_id}/{device_id}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn deny_connection(
|
||||
socket: &mut DuplexStream,
|
||||
reason: String,
|
||||
) -> Result<(), rocket_ws::result::Error> {
|
||||
send(socket, &ServerMsg::AuthError { reason }).await.ok();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn register_connection(
|
||||
registry: &Arc<Registry>,
|
||||
key: (String, String),
|
||||
sender: mpsc::Sender<PendingExec>,
|
||||
) -> Result<(), String> {
|
||||
match registry.entry(key) {
|
||||
Entry::Occupied(_) => Err("Device already connected".into()),
|
||||
Entry::Vacant(slot) => {
|
||||
slot.insert(RegistryEntry {
|
||||
sender,
|
||||
in_flight: Arc::new(AtomicUsize::new(0)),
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn claim_device_slot(
|
||||
device_counts: &Arc<DeviceCounts>,
|
||||
user_id: &str,
|
||||
config: &Config,
|
||||
) -> Result<Arc<AtomicUsize>, String> {
|
||||
let counter = {
|
||||
let entry = device_counts
|
||||
.entry(user_id.to_owned())
|
||||
.or_insert_with(|| Arc::new(AtomicUsize::new(0)));
|
||||
Arc::clone(&*entry)
|
||||
};
|
||||
|
||||
let previous = counter.fetch_add(1, Ordering::SeqCst);
|
||||
if previous >= config.max_connected_devices {
|
||||
counter.fetch_sub(1, Ordering::SeqCst);
|
||||
return Err("Too many devices connected for this account".into());
|
||||
}
|
||||
|
||||
Ok(counter)
|
||||
}
|
||||
|
||||
fn release_device_slot(
|
||||
device_counts: &Arc<DeviceCounts>,
|
||||
user_id: &str,
|
||||
counter: Arc<AtomicUsize>,
|
||||
) {
|
||||
let previous = counter.fetch_sub(1, Ordering::SeqCst);
|
||||
if previous == 1 {
|
||||
device_counts.remove_if(user_id, |_, value| value.load(Ordering::SeqCst) == 0);
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_loop(
|
||||
socket: &mut DuplexStream,
|
||||
mut exec_rx: mpsc::Receiver<PendingExec>,
|
||||
pending: &PendingReplies,
|
||||
config: &Config,
|
||||
cancel: &CancellationToken,
|
||||
) {
|
||||
let mut ping_interval = tokio::time::interval_at(
|
||||
tokio::time::Instant::now() + config.ping_interval,
|
||||
config.ping_interval,
|
||||
);
|
||||
|
||||
let pong_timer = tokio::time::sleep(Duration::from_secs(0));
|
||||
tokio::pin!(pong_timer);
|
||||
let mut awaiting_pong = false;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = ping_interval.tick() => {
|
||||
if awaiting_pong {
|
||||
break;
|
||||
}
|
||||
if send_ping(socket).await.is_err() {
|
||||
break;
|
||||
}
|
||||
awaiting_pong = true;
|
||||
pong_timer.as_mut().reset(
|
||||
tokio::time::Instant::now() + config.ping_timeout,
|
||||
);
|
||||
}
|
||||
_ = &mut pong_timer, if awaiting_pong => {
|
||||
break;
|
||||
}
|
||||
Some(exec) = exec_rx.recv() => {
|
||||
if !dispatch(socket, pending, exec, config.max_execution_time, cancel.clone()).await {
|
||||
break;
|
||||
}
|
||||
}
|
||||
message = socket.next() => {
|
||||
let action = handle_socket_message(socket, message, pending).await;
|
||||
|
||||
if action.break_loop {
|
||||
break;
|
||||
}
|
||||
|
||||
if action.clear_pong_deadline {
|
||||
awaiting_pong = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SocketAction {
|
||||
break_loop: bool,
|
||||
clear_pong_deadline: bool,
|
||||
}
|
||||
|
||||
async fn handle_socket_message(
|
||||
socket: &mut DuplexStream,
|
||||
message: Option<Result<Message, rocket_ws::result::Error>>,
|
||||
pending: &PendingReplies,
|
||||
) -> SocketAction {
|
||||
match message {
|
||||
Some(Ok(Message::Pong(_))) => SocketAction {
|
||||
break_loop: false,
|
||||
clear_pong_deadline: true,
|
||||
},
|
||||
Some(Ok(Message::Ping(data))) => SocketAction {
|
||||
break_loop: socket.send(Message::Pong(data.clone())).await.is_err(),
|
||||
clear_pong_deadline: false,
|
||||
},
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
resolve_pending(&text, pending);
|
||||
SocketAction {
|
||||
break_loop: false,
|
||||
clear_pong_deadline: false,
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => SocketAction {
|
||||
break_loop: true,
|
||||
clear_pong_deadline: false,
|
||||
},
|
||||
_ => SocketAction {
|
||||
break_loop: false,
|
||||
clear_pong_deadline: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_ping(socket: &mut DuplexStream) -> Result<(), rocket_ws::result::Error> {
|
||||
socket.send(Message::Ping(vec![])).await
|
||||
}
|
||||
|
||||
async fn dispatch(
|
||||
socket: &mut DuplexStream,
|
||||
pending: &PendingReplies,
|
||||
exec: PendingExec,
|
||||
max_execution_time: Duration,
|
||||
cancel: CancellationToken,
|
||||
) -> bool {
|
||||
let exec_id = exec.exec_id.clone();
|
||||
pending.insert(exec_id.clone(), exec.reply);
|
||||
|
||||
let pending_for_timeout = Arc::clone(pending);
|
||||
tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {}
|
||||
_ = tokio::time::sleep(max_execution_time) => {
|
||||
pending_for_timeout.remove(&exec_id);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
send(
|
||||
socket,
|
||||
&ServerMsg::Exec {
|
||||
exec_id: exec.exec_id,
|
||||
command: exec.command,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
fn resolve_pending(text: &str, pending: &PendingReplies) {
|
||||
let Ok(ClientMsg::ExecResult {
|
||||
exec_id,
|
||||
exit_code,
|
||||
stdout,
|
||||
stderr,
|
||||
}) = serde_json::from_str(text)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
if let Some((_, reply)) = pending.remove(&exec_id) {
|
||||
reply
|
||||
.send(ExecResult {
|
||||
exit_code,
|
||||
stdout,
|
||||
stderr,
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
async fn send(
|
||||
socket: &mut DuplexStream,
|
||||
message: &ServerMsg,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = serde_json::to_string(message)?;
|
||||
socket.send(Message::Text(text)).await?;
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user