This commit is contained in:
2026-05-22 08:24:20 +02:00
parent 01e72931d1
commit 59469df9f8
5 changed files with 335 additions and 317 deletions
+163
View File
@@ -0,0 +1,163 @@
use std::{sync::Arc, time::Duration};
use futures::StreamExt;
use log::{error, info};
use tokio::{
sync::{mpsc, oneshot},
time::Instant,
};
use tokio_tungstenite::tungstenite::Message;
use crate::app::{
command,
config::Config,
connection::{self, OutgoingMessage, SocketReader},
protocol::{ClientMsg, ServerMsg},
};
pub async fn run_forever() {
let config = match Config::from_env() {
Ok(config) => Arc::new(config),
Err(error) => {
error!("Configuration error: {error}");
std::process::exit(1);
}
};
loop {
match run_session(&config).await {
Ok(()) => info!("Disconnected. Reconnecting in 5s..."),
Err(error) => error!("Error: {error}. Reconnecting in 5s..."),
}
tokio::time::sleep(Duration::from_secs(5)).await;
}
}
async fn run_session(config: &Arc<Config>) -> Result<(), String> {
let mut socket = connection::connect(config).await?;
authenticate(&mut socket).await?;
let (sink, stream) = connection::split(socket);
let (outgoing_tx, writer_closed) =
connection::spawn_writer(sink, config.writer_channel_capacity);
process_messages(stream, outgoing_tx, writer_closed, config).await
}
async fn authenticate(socket: &mut connection::Socket) -> Result<(), String> {
match connection::read_server_message(socket).await? {
ServerMsg::AuthOk => {
info!("Authenticated.");
Ok(())
}
ServerMsg::AuthError { reason } => Err(format!("Auth failed: {reason}")),
ServerMsg::Exec { .. } => Err("Unexpected message during auth".into()),
}
}
async fn process_messages(
mut stream: SocketReader,
outgoing_tx: mpsc::Sender<OutgoingMessage>,
mut writer_closed: oneshot::Receiver<()>,
config: &Arc<Config>,
) -> Result<(), String> {
let mut ping_interval =
tokio::time::interval_at(Instant::now() + config.ping_interval, config.ping_interval);
let pong_timer = tokio::time::sleep(Duration::from_secs(0));
tokio::pin!(pong_timer);
let mut awaiting_pong = false;
loop {
tokio::select! {
_ = &mut writer_closed => return Err("Write half closed".into()),
_ = ping_interval.tick() => {
if awaiting_pong {
return Err("Server did not respond to ping".into());
}
outgoing_tx
.send(OutgoingMessage::Ping)
.await
.map_err(|error| error.to_string())?;
awaiting_pong = true;
pong_timer.as_mut().reset(Instant::now() + config.ping_timeout);
}
_ = &mut pong_timer, if awaiting_pong => {
return Err("Server did not respond to ping".into());
}
next_message = stream.next() => {
if handle_stream_message(next_message, &outgoing_tx, config.max_output_bytes).await? {
awaiting_pong = false;
}
}
}
}
}
async fn handle_stream_message(
message: Option<Result<Message, tokio_tungstenite::tungstenite::Error>>,
outgoing_tx: &mpsc::Sender<OutgoingMessage>,
max_output_bytes: u64,
) -> Result<bool, String> {
match message {
Some(Ok(Message::Text(text))) => {
handle_server_message(
serde_json::from_str(&text).map_err(|error| error.to_string())?,
outgoing_tx,
max_output_bytes,
)
.await?;
Ok(false)
}
Some(Ok(Message::Ping(payload))) => {
outgoing_tx
.send(OutgoingMessage::Pong(payload.to_vec()))
.await
.map_err(|error| error.to_string())?;
Ok(false)
}
Some(Ok(Message::Pong(_))) => Ok(true),
Some(Ok(Message::Close(_))) | None => Err("Connection closed".into()),
Some(Err(error)) => Err(error.to_string()),
Some(Ok(_)) => Ok(false),
}
}
async fn handle_server_message(
message: ServerMsg,
outgoing_tx: &mpsc::Sender<OutgoingMessage>,
max_output_bytes: u64,
) -> Result<(), String> {
match message {
ServerMsg::Exec { exec_id, command } => {
spawn_command_task(exec_id, command, outgoing_tx.clone(), max_output_bytes);
Ok(())
}
ServerMsg::AuthError { reason } => Err(reason),
ServerMsg::AuthOk => Ok(()),
}
}
fn spawn_command_task(
exec_id: String,
command: String,
outgoing_tx: mpsc::Sender<OutgoingMessage>,
max_output_bytes: u64,
) {
tokio::spawn(async move {
info!("Executing [{exec_id}]: {command}");
let (exit_code, stdout, stderr) = command::execute(&command, max_output_bytes).await;
let result = ClientMsg::ExecResult {
exec_id,
exit_code,
stdout,
stderr,
};
if let Err(error) = outgoing_tx.send(OutgoingMessage::Application(result)).await {
error!("Failed to send command result: {error}");
}
});
}
+54
View File
@@ -0,0 +1,54 @@
use std::process::Stdio;
use tokio::{
io::{AsyncRead, AsyncReadExt},
process::Command,
};
pub async fn execute(command: &str, max_output_bytes: u64) -> (i32, String, String) {
let mut child = match Command::new("sh")
.arg("-c")
.arg(command)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
{
Ok(child) => child,
Err(error) => return (-1, String::new(), error.to_string()),
};
let Some(stdout) = child.stdout.take() else {
return (-1, String::new(), "stdout pipe unavailable".to_string());
};
let Some(stderr) = child.stderr.take() else {
return (-1, String::new(), "stderr pipe unavailable".to_string());
};
let (stdout_data, stderr_data) = tokio::join!(
read_limited_output(stdout, max_output_bytes),
read_limited_output(stderr, max_output_bytes),
);
child.kill().await.ok();
let exit_code = child
.wait()
.await
.ok()
.and_then(|status| status.code())
.unwrap_or(-1);
(
exit_code,
String::from_utf8_lossy(&stdout_data).into_owned(),
String::from_utf8_lossy(&stderr_data).into_owned(),
)
}
async fn read_limited_output<R>(reader: R, max_output_bytes: u64) -> Vec<u8>
where
R: AsyncRead + Unpin,
{
let mut buffer = Vec::new();
let _ = reader.take(max_output_bytes).read_to_end(&mut buffer).await;
buffer
}
+111
View File
@@ -0,0 +1,111 @@
use futures::{
SinkExt, StreamExt,
stream::{SplitSink, SplitStream},
};
use log::info;
use tokio::sync::{mpsc, oneshot};
use tokio_tungstenite::{
connect_async_with_config,
tungstenite::{Message, Utf8Bytes, client::IntoClientRequest, http::HeaderValue},
};
use crate::app::{config::Config, protocol::ClientMsg};
pub type Socket =
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
pub type SocketReader = SplitStream<Socket>;
type SocketWriter = SplitSink<Socket, Message>;
pub enum OutgoingMessage {
Application(ClientMsg),
Ping,
Pong(Vec<u8>),
}
pub async fn connect(config: &Config) -> Result<Socket, String> {
let mut request = format!("{}/connect", config.base_url)
.into_client_request()
.map_err(|error| error.to_string())?;
request.headers_mut().insert(
"Authorization",
HeaderValue::from_str(&format!("Bearer {}", config.token))
.map_err(|error| error.to_string())?,
);
request.headers_mut().insert(
"X-Device-ID",
HeaderValue::from_str(&config.device_id).map_err(|error| error.to_string())?,
);
let (socket, _) = connect_async_with_config(request, None, false)
.await
.map_err(|error| error.to_string())?;
info!("Connected.");
Ok(socket)
}
pub fn split(socket: Socket) -> (SocketWriter, SocketReader) {
socket.split()
}
pub fn spawn_writer(
sink: SocketWriter,
capacity: usize,
) -> (mpsc::Sender<OutgoingMessage>, oneshot::Receiver<()>) {
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingMessage>(capacity);
let (closed_tx, closed_rx) = oneshot::channel::<()>();
tokio::spawn(async move {
let mut sink = sink;
let _closed = closed_tx;
while let Some(message) = outgoing_rx.recv().await {
if write_message(&mut sink, message).await.is_err() {
break;
}
}
});
(outgoing_tx, closed_rx)
}
pub async fn read_server_message(
socket: &mut Socket,
) -> Result<crate::app::protocol::ServerMsg, String> {
loop {
match socket.next().await {
Some(Ok(Message::Text(text))) => {
return serde_json::from_str(&text).map_err(|error| error.to_string());
}
Some(Ok(Message::Ping(payload))) => {
socket
.send(Message::Pong(payload))
.await
.map_err(|error| error.to_string())?;
}
Some(Ok(Message::Close(_))) | None => return Err("Connection closed".into()),
Some(Err(error)) => return Err(error.to_string()),
Some(Ok(_)) => {}
}
}
}
async fn write_message(sink: &mut SocketWriter, message: OutgoingMessage) -> Result<(), String> {
match message {
OutgoingMessage::Application(message) => sink
.send(Message::Text(Utf8Bytes::from(
serde_json::to_string(&message).map_err(|error| error.to_string())?,
)))
.await
.map_err(|error| error.to_string()),
OutgoingMessage::Ping => sink
.send(Message::Ping(Vec::new().into()))
.await
.map_err(|error| error.to_string()),
OutgoingMessage::Pong(payload) => sink
.send(Message::Pong(payload.into()))
.await
.map_err(|error| error.to_string()),
}
}
+5
View File
@@ -0,0 +1,5 @@
pub mod client;
pub mod command;
pub mod config;
pub mod connection;
pub mod protocol;
+2 -317
View File
@@ -1,322 +1,7 @@
mod app { mod app;
pub mod config;
pub mod protocol;
}
use std::{process::Stdio, sync::Arc, time::Duration};
use app::config::Config;
use app::protocol::{ClientMsg, ServerMsg};
use futures::{
SinkExt, StreamExt,
stream::{SplitSink, SplitStream},
};
use log::{error, info};
use tokio::{
io::AsyncRead,
io::AsyncReadExt,
process::Command,
sync::{mpsc, oneshot},
};
use tokio_tungstenite::{
connect_async_with_config,
tungstenite::{Message, Utf8Bytes, client::IntoClientRequest, http::HeaderValue},
};
type Sock =
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
type Writer = SplitSink<Sock, Message>;
enum OutgoingMessage {
Application(ClientMsg),
Ping,
Pong(Vec<u8>),
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
env_logger::init(); env_logger::init();
app::client::run_forever().await;
let config = match Config::from_env() {
Ok(config) => Arc::new(config),
Err(e) => {
error!("Configuration error: {e}");
std::process::exit(1);
}
};
loop {
match run_client(&config).await {
Ok(()) => info!("Disconnected. Reconnecting in 5s..."),
Err(e) => error!("Error: {e}. Reconnecting in 5s..."),
}
tokio::time::sleep(Duration::from_secs(5)).await;
}
}
async fn run_client(config: &Arc<Config>) -> Result<(), String> {
let mut socket = connect_to_server(config).await?;
authenticate(&mut socket).await?;
let (sink, stream) = socket.split();
let (outgoing_tx, writer_closed) = spawn_writer_task(sink, config.writer_channel_capacity);
process_messages(stream, outgoing_tx, writer_closed, config).await
}
async fn connect_to_server(config: &Config) -> Result<Sock, String> {
let url = format!("{}/connect", config.base_url);
let mut request = url.into_client_request().map_err(|e| e.to_string())?;
let headers = request.headers_mut();
headers.insert(
"Authorization",
HeaderValue::from_str(&format!("Bearer {}", config.token)).map_err(|e| e.to_string())?,
);
headers.insert(
"X-Device-ID",
HeaderValue::from_str(&config.device_id).map_err(|e| e.to_string())?,
);
let (socket, _) = connect_async_with_config(request, None, false)
.await
.map_err(|e| e.to_string())?;
info!("Connected.");
Ok(socket)
}
async fn authenticate(socket: &mut Sock) -> Result<(), String> {
match read_protocol_message(socket).await? {
ServerMsg::AuthOk => {
info!("Authenticated.");
Ok(())
}
ServerMsg::AuthError { reason } => Err(format!("Auth failed: {reason}")),
ServerMsg::Exec { .. } => Err("Unexpected message during auth".into()),
}
}
async fn read_protocol_message(socket: &mut Sock) -> Result<ServerMsg, String> {
loop {
match socket.next().await {
Some(Ok(Message::Text(text))) => return parse_server_message(&text),
Some(Ok(Message::Ping(data))) => {
socket
.send(Message::Pong(data))
.await
.map_err(|e| e.to_string())?;
}
Some(Ok(Message::Close(_))) | None => return Err("Connection closed".into()),
Some(Err(error)) => return Err(error.to_string()),
_ => {}
}
}
}
fn parse_server_message(text: &str) -> Result<ServerMsg, String> {
serde_json::from_str(text).map_err(|e| e.to_string())
}
async fn process_messages(
mut stream: SplitStream<Sock>,
outgoing_tx: mpsc::Sender<OutgoingMessage>,
mut writer_closed: oneshot::Receiver<()>,
config: &Arc<Config>,
) -> Result<(), String> {
let mut ping_interval = tokio::time::interval_at(
tokio::time::Instant::now() + config.ping_interval,
config.ping_interval,
);
let pong_timer = tokio::time::sleep(Duration::from_secs(0));
tokio::pin!(pong_timer);
let mut awaiting_pong = false;
loop {
tokio::select! {
_ = &mut writer_closed => return Err("Write half closed".into()),
_ = ping_interval.tick() => {
if awaiting_pong {
return Err("Server did not respond to ping".into());
}
send_outgoing(&outgoing_tx, OutgoingMessage::Ping).await?;
awaiting_pong = true;
pong_timer.as_mut().reset(
tokio::time::Instant::now() + config.ping_timeout,
);
}
_ = &mut pong_timer, if awaiting_pong => {
return Err("Server did not respond to ping".into());
}
next_message = stream.next() => {
let action = handle_stream_message(next_message, &outgoing_tx, config.max_output_bytes).await?;
if action.clear_pong_deadline {
awaiting_pong = false;
}
}
}
}
}
fn spawn_writer_task(
sink: Writer,
capacity: usize,
) -> (mpsc::Sender<OutgoingMessage>, oneshot::Receiver<()>) {
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<OutgoingMessage>(capacity);
let (closed_tx, closed_rx) = oneshot::channel::<()>();
tokio::spawn(async move {
let mut sink = sink;
let _closed = closed_tx;
while let Some(message) = outgoing_rx.recv().await {
if write_message(&mut sink, message).await.is_err() {
break;
}
}
});
(outgoing_tx, closed_rx)
}
async fn write_message(sink: &mut Writer, message: OutgoingMessage) -> Result<(), String> {
match message {
OutgoingMessage::Application(message) => {
let text = serde_json::to_string(&message).map_err(|e| e.to_string())?;
sink.send(Message::Text(Utf8Bytes::from(text)))
.await
.map_err(|e| e.to_string())
}
OutgoingMessage::Ping => sink
.send(Message::Ping(Vec::new().into()))
.await
.map_err(|e| e.to_string()),
OutgoingMessage::Pong(payload) => sink
.send(Message::Pong(payload.into()))
.await
.map_err(|e| e.to_string()),
}
}
struct StreamAction {
clear_pong_deadline: bool,
}
async fn send_outgoing(
outgoing_tx: &mpsc::Sender<OutgoingMessage>,
message: OutgoingMessage,
) -> Result<(), String> {
outgoing_tx.send(message).await.map_err(|e| e.to_string())
}
async fn handle_stream_message(
message: Option<Result<Message, tokio_tungstenite::tungstenite::Error>>,
outgoing_tx: &mpsc::Sender<OutgoingMessage>,
max_output_bytes: u64,
) -> Result<StreamAction, String> {
match message {
Some(Ok(Message::Text(text))) => {
handle_server_message(parse_server_message(&text)?, outgoing_tx, max_output_bytes)
.await?;
Ok(StreamAction {
clear_pong_deadline: false,
})
}
Some(Ok(Message::Ping(payload))) => {
send_outgoing(outgoing_tx, OutgoingMessage::Pong(payload.to_vec())).await?;
Ok(StreamAction {
clear_pong_deadline: false,
})
}
Some(Ok(Message::Close(_))) | None => Err("Connection closed".into()),
Some(Err(error)) => Err(error.to_string()),
Some(Ok(Message::Pong(_))) => Ok(StreamAction {
clear_pong_deadline: true,
}),
Some(Ok(_)) => Ok(StreamAction {
clear_pong_deadline: false,
}),
}
}
async fn handle_server_message(
message: ServerMsg,
outgoing_tx: &mpsc::Sender<OutgoingMessage>,
max_output_bytes: u64,
) -> Result<(), String> {
match message {
ServerMsg::Exec { exec_id, command } => {
spawn_command_task(exec_id, command, outgoing_tx.clone(), max_output_bytes);
Ok(())
}
ServerMsg::AuthError { reason } => Err(reason),
ServerMsg::AuthOk => Ok(()),
}
}
fn spawn_command_task(
exec_id: String,
command: String,
outgoing_tx: mpsc::Sender<OutgoingMessage>,
max_output_bytes: u64,
) {
tokio::spawn(async move {
info!("Executing [{exec_id}]: {command}");
let (exit_code, stdout, stderr) = execute_command(&command, max_output_bytes).await;
let result = ClientMsg::ExecResult {
exec_id,
exit_code,
stdout,
stderr,
};
if let Err(error) = send_outgoing(&outgoing_tx, OutgoingMessage::Application(result)).await
{
error!("Failed to send command result: {error}");
}
});
}
async fn execute_command(command: &str, max_output_bytes: u64) -> (i32, String, String) {
let mut child = match Command::new("sh")
.arg("-c")
.arg(command)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
{
Ok(c) => c,
Err(e) => return (-1, String::new(), e.to_string()),
};
let Some(stdout) = child.stdout.take() else {
return (-1, String::new(), "stdout pipe unavailable".to_string());
};
let Some(stderr) = child.stderr.take() else {
return (-1, String::new(), "stderr pipe unavailable".to_string());
};
let (stdout_data, stderr_data) = tokio::join!(
read_limited_output(stdout, max_output_bytes),
read_limited_output(stderr, max_output_bytes),
);
child.kill().await.ok();
let exit_code = child.wait().await.ok().and_then(|s| s.code()).unwrap_or(-1);
(
exit_code,
String::from_utf8_lossy(&stdout_data).into_owned(),
String::from_utf8_lossy(&stderr_data).into_owned(),
)
}
async fn read_limited_output<R>(reader: R, max_output_bytes: u64) -> Vec<u8>
where
R: AsyncRead + Unpin,
{
let mut buffer = Vec::new();
let _ = reader.take(max_output_bytes).read_to_end(&mut buffer).await;
buffer
} }