diff --git a/src/app/client.rs b/src/app/client.rs new file mode 100644 index 0000000..c9b9356 --- /dev/null +++ b/src/app/client.rs @@ -0,0 +1,163 @@ +use std::{sync::Arc, time::Duration}; + +use futures::StreamExt; +use log::{error, info}; +use tokio::{ + sync::{mpsc, oneshot}, + time::Instant, +}; +use tokio_tungstenite::tungstenite::Message; + +use crate::app::{ + command, + config::Config, + connection::{self, OutgoingMessage, SocketReader}, + protocol::{ClientMsg, ServerMsg}, +}; + +pub async fn run_forever() { + let config = match Config::from_env() { + Ok(config) => Arc::new(config), + Err(error) => { + error!("Configuration error: {error}"); + std::process::exit(1); + } + }; + + loop { + match run_session(&config).await { + Ok(()) => info!("Disconnected. Reconnecting in 5s..."), + Err(error) => error!("Error: {error}. Reconnecting in 5s..."), + } + + tokio::time::sleep(Duration::from_secs(5)).await; + } +} + +async fn run_session(config: &Arc) -> Result<(), String> { + let mut socket = connection::connect(config).await?; + authenticate(&mut socket).await?; + + let (sink, stream) = connection::split(socket); + let (outgoing_tx, writer_closed) = + connection::spawn_writer(sink, config.writer_channel_capacity); + + process_messages(stream, outgoing_tx, writer_closed, config).await +} + +async fn authenticate(socket: &mut connection::Socket) -> Result<(), String> { + match connection::read_server_message(socket).await? { + ServerMsg::AuthOk => { + info!("Authenticated."); + Ok(()) + } + ServerMsg::AuthError { reason } => Err(format!("Auth failed: {reason}")), + ServerMsg::Exec { .. } => Err("Unexpected message during auth".into()), + } +} + +async fn process_messages( + mut stream: SocketReader, + outgoing_tx: mpsc::Sender, + mut writer_closed: oneshot::Receiver<()>, + config: &Arc, +) -> Result<(), String> { + let mut ping_interval = + tokio::time::interval_at(Instant::now() + config.ping_interval, config.ping_interval); + + let pong_timer = tokio::time::sleep(Duration::from_secs(0)); + tokio::pin!(pong_timer); + let mut awaiting_pong = false; + + loop { + tokio::select! { + _ = &mut writer_closed => return Err("Write half closed".into()), + _ = ping_interval.tick() => { + if awaiting_pong { + return Err("Server did not respond to ping".into()); + } + + outgoing_tx + .send(OutgoingMessage::Ping) + .await + .map_err(|error| error.to_string())?; + awaiting_pong = true; + pong_timer.as_mut().reset(Instant::now() + config.ping_timeout); + } + _ = &mut pong_timer, if awaiting_pong => { + return Err("Server did not respond to ping".into()); + } + next_message = stream.next() => { + if handle_stream_message(next_message, &outgoing_tx, config.max_output_bytes).await? { + awaiting_pong = false; + } + } + } + } +} + +async fn handle_stream_message( + message: Option>, + outgoing_tx: &mpsc::Sender, + max_output_bytes: u64, +) -> Result { + match message { + Some(Ok(Message::Text(text))) => { + handle_server_message( + serde_json::from_str(&text).map_err(|error| error.to_string())?, + outgoing_tx, + max_output_bytes, + ) + .await?; + Ok(false) + } + Some(Ok(Message::Ping(payload))) => { + outgoing_tx + .send(OutgoingMessage::Pong(payload.to_vec())) + .await + .map_err(|error| error.to_string())?; + Ok(false) + } + Some(Ok(Message::Pong(_))) => Ok(true), + Some(Ok(Message::Close(_))) | None => Err("Connection closed".into()), + Some(Err(error)) => Err(error.to_string()), + Some(Ok(_)) => Ok(false), + } +} + +async fn handle_server_message( + message: ServerMsg, + outgoing_tx: &mpsc::Sender, + max_output_bytes: u64, +) -> Result<(), String> { + match message { + ServerMsg::Exec { exec_id, command } => { + spawn_command_task(exec_id, command, outgoing_tx.clone(), max_output_bytes); + Ok(()) + } + ServerMsg::AuthError { reason } => Err(reason), + ServerMsg::AuthOk => Ok(()), + } +} + +fn spawn_command_task( + exec_id: String, + command: String, + outgoing_tx: mpsc::Sender, + max_output_bytes: u64, +) { + tokio::spawn(async move { + info!("Executing [{exec_id}]: {command}"); + let (exit_code, stdout, stderr) = command::execute(&command, max_output_bytes).await; + let result = ClientMsg::ExecResult { + exec_id, + exit_code, + stdout, + stderr, + }; + + if let Err(error) = outgoing_tx.send(OutgoingMessage::Application(result)).await { + error!("Failed to send command result: {error}"); + } + }); +} diff --git a/src/app/command.rs b/src/app/command.rs new file mode 100644 index 0000000..b37ecd4 --- /dev/null +++ b/src/app/command.rs @@ -0,0 +1,54 @@ +use std::process::Stdio; + +use tokio::{ + io::{AsyncRead, AsyncReadExt}, + process::Command, +}; + +pub async fn execute(command: &str, max_output_bytes: u64) -> (i32, String, String) { + let mut child = match Command::new("sh") + .arg("-c") + .arg(command) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + { + Ok(child) => child, + Err(error) => return (-1, String::new(), error.to_string()), + }; + + let Some(stdout) = child.stdout.take() else { + return (-1, String::new(), "stdout pipe unavailable".to_string()); + }; + let Some(stderr) = child.stderr.take() else { + return (-1, String::new(), "stderr pipe unavailable".to_string()); + }; + + let (stdout_data, stderr_data) = tokio::join!( + read_limited_output(stdout, max_output_bytes), + read_limited_output(stderr, max_output_bytes), + ); + + child.kill().await.ok(); + let exit_code = child + .wait() + .await + .ok() + .and_then(|status| status.code()) + .unwrap_or(-1); + + ( + exit_code, + String::from_utf8_lossy(&stdout_data).into_owned(), + String::from_utf8_lossy(&stderr_data).into_owned(), + ) +} + +async fn read_limited_output(reader: R, max_output_bytes: u64) -> Vec +where + R: AsyncRead + Unpin, +{ + let mut buffer = Vec::new(); + let _ = reader.take(max_output_bytes).read_to_end(&mut buffer).await; + buffer +} diff --git a/src/app/connection.rs b/src/app/connection.rs new file mode 100644 index 0000000..359cbb8 --- /dev/null +++ b/src/app/connection.rs @@ -0,0 +1,111 @@ +use futures::{ + SinkExt, StreamExt, + stream::{SplitSink, SplitStream}, +}; +use log::info; +use tokio::sync::{mpsc, oneshot}; +use tokio_tungstenite::{ + connect_async_with_config, + tungstenite::{Message, Utf8Bytes, client::IntoClientRequest, http::HeaderValue}, +}; + +use crate::app::{config::Config, protocol::ClientMsg}; + +pub type Socket = + tokio_tungstenite::WebSocketStream>; +pub type SocketReader = SplitStream; +type SocketWriter = SplitSink; + +pub enum OutgoingMessage { + Application(ClientMsg), + Ping, + Pong(Vec), +} + +pub async fn connect(config: &Config) -> Result { + let mut request = format!("{}/connect", config.base_url) + .into_client_request() + .map_err(|error| error.to_string())?; + + request.headers_mut().insert( + "Authorization", + HeaderValue::from_str(&format!("Bearer {}", config.token)) + .map_err(|error| error.to_string())?, + ); + request.headers_mut().insert( + "X-Device-ID", + HeaderValue::from_str(&config.device_id).map_err(|error| error.to_string())?, + ); + + let (socket, _) = connect_async_with_config(request, None, false) + .await + .map_err(|error| error.to_string())?; + + info!("Connected."); + Ok(socket) +} + +pub fn split(socket: Socket) -> (SocketWriter, SocketReader) { + socket.split() +} + +pub fn spawn_writer( + sink: SocketWriter, + capacity: usize, +) -> (mpsc::Sender, oneshot::Receiver<()>) { + let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(capacity); + let (closed_tx, closed_rx) = oneshot::channel::<()>(); + + tokio::spawn(async move { + let mut sink = sink; + let _closed = closed_tx; + + while let Some(message) = outgoing_rx.recv().await { + if write_message(&mut sink, message).await.is_err() { + break; + } + } + }); + + (outgoing_tx, closed_rx) +} + +pub async fn read_server_message( + socket: &mut Socket, +) -> Result { + loop { + match socket.next().await { + Some(Ok(Message::Text(text))) => { + return serde_json::from_str(&text).map_err(|error| error.to_string()); + } + Some(Ok(Message::Ping(payload))) => { + socket + .send(Message::Pong(payload)) + .await + .map_err(|error| error.to_string())?; + } + Some(Ok(Message::Close(_))) | None => return Err("Connection closed".into()), + Some(Err(error)) => return Err(error.to_string()), + Some(Ok(_)) => {} + } + } +} + +async fn write_message(sink: &mut SocketWriter, message: OutgoingMessage) -> Result<(), String> { + match message { + OutgoingMessage::Application(message) => sink + .send(Message::Text(Utf8Bytes::from( + serde_json::to_string(&message).map_err(|error| error.to_string())?, + ))) + .await + .map_err(|error| error.to_string()), + OutgoingMessage::Ping => sink + .send(Message::Ping(Vec::new().into())) + .await + .map_err(|error| error.to_string()), + OutgoingMessage::Pong(payload) => sink + .send(Message::Pong(payload.into())) + .await + .map_err(|error| error.to_string()), + } +} diff --git a/src/app/mod.rs b/src/app/mod.rs new file mode 100644 index 0000000..b7f4742 --- /dev/null +++ b/src/app/mod.rs @@ -0,0 +1,5 @@ +pub mod client; +pub mod command; +pub mod config; +pub mod connection; +pub mod protocol; diff --git a/src/main.rs b/src/main.rs index e05e357..8d59ac0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,322 +1,7 @@ -mod app { - pub mod config; - pub mod protocol; -} - -use std::{process::Stdio, sync::Arc, time::Duration}; - -use app::config::Config; -use app::protocol::{ClientMsg, ServerMsg}; -use futures::{ - SinkExt, StreamExt, - stream::{SplitSink, SplitStream}, -}; -use log::{error, info}; -use tokio::{ - io::AsyncRead, - io::AsyncReadExt, - process::Command, - sync::{mpsc, oneshot}, -}; -use tokio_tungstenite::{ - connect_async_with_config, - tungstenite::{Message, Utf8Bytes, client::IntoClientRequest, http::HeaderValue}, -}; - -type Sock = - tokio_tungstenite::WebSocketStream>; -type Writer = SplitSink; - -enum OutgoingMessage { - Application(ClientMsg), - Ping, - Pong(Vec), -} +mod app; #[tokio::main] async fn main() { env_logger::init(); - - let config = match Config::from_env() { - Ok(config) => Arc::new(config), - Err(e) => { - error!("Configuration error: {e}"); - std::process::exit(1); - } - }; - - loop { - match run_client(&config).await { - Ok(()) => info!("Disconnected. Reconnecting in 5s..."), - Err(e) => error!("Error: {e}. Reconnecting in 5s..."), - } - tokio::time::sleep(Duration::from_secs(5)).await; - } -} - -async fn run_client(config: &Arc) -> Result<(), String> { - let mut socket = connect_to_server(config).await?; - authenticate(&mut socket).await?; - - let (sink, stream) = socket.split(); - let (outgoing_tx, writer_closed) = spawn_writer_task(sink, config.writer_channel_capacity); - - process_messages(stream, outgoing_tx, writer_closed, config).await -} - -async fn connect_to_server(config: &Config) -> Result { - let url = format!("{}/connect", config.base_url); - let mut request = url.into_client_request().map_err(|e| e.to_string())?; - let headers = request.headers_mut(); - headers.insert( - "Authorization", - HeaderValue::from_str(&format!("Bearer {}", config.token)).map_err(|e| e.to_string())?, - ); - headers.insert( - "X-Device-ID", - HeaderValue::from_str(&config.device_id).map_err(|e| e.to_string())?, - ); - - let (socket, _) = connect_async_with_config(request, None, false) - .await - .map_err(|e| e.to_string())?; - info!("Connected."); - Ok(socket) -} - -async fn authenticate(socket: &mut Sock) -> Result<(), String> { - match read_protocol_message(socket).await? { - ServerMsg::AuthOk => { - info!("Authenticated."); - Ok(()) - } - ServerMsg::AuthError { reason } => Err(format!("Auth failed: {reason}")), - ServerMsg::Exec { .. } => Err("Unexpected message during auth".into()), - } -} - -async fn read_protocol_message(socket: &mut Sock) -> Result { - loop { - match socket.next().await { - Some(Ok(Message::Text(text))) => return parse_server_message(&text), - Some(Ok(Message::Ping(data))) => { - socket - .send(Message::Pong(data)) - .await - .map_err(|e| e.to_string())?; - } - Some(Ok(Message::Close(_))) | None => return Err("Connection closed".into()), - Some(Err(error)) => return Err(error.to_string()), - _ => {} - } - } -} - -fn parse_server_message(text: &str) -> Result { - serde_json::from_str(text).map_err(|e| e.to_string()) -} - -async fn process_messages( - mut stream: SplitStream, - outgoing_tx: mpsc::Sender, - mut writer_closed: oneshot::Receiver<()>, - config: &Arc, -) -> Result<(), String> { - let mut ping_interval = tokio::time::interval_at( - tokio::time::Instant::now() + config.ping_interval, - config.ping_interval, - ); - - let pong_timer = tokio::time::sleep(Duration::from_secs(0)); - tokio::pin!(pong_timer); - let mut awaiting_pong = false; - - loop { - tokio::select! { - _ = &mut writer_closed => return Err("Write half closed".into()), - _ = ping_interval.tick() => { - if awaiting_pong { - return Err("Server did not respond to ping".into()); - } - - send_outgoing(&outgoing_tx, OutgoingMessage::Ping).await?; - awaiting_pong = true; - pong_timer.as_mut().reset( - tokio::time::Instant::now() + config.ping_timeout, - ); - } - _ = &mut pong_timer, if awaiting_pong => { - return Err("Server did not respond to ping".into()); - } - next_message = stream.next() => { - let action = handle_stream_message(next_message, &outgoing_tx, config.max_output_bytes).await?; - - if action.clear_pong_deadline { - awaiting_pong = false; - } - } - } - } -} - -fn spawn_writer_task( - sink: Writer, - capacity: usize, -) -> (mpsc::Sender, oneshot::Receiver<()>) { - let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(capacity); - let (closed_tx, closed_rx) = oneshot::channel::<()>(); - - tokio::spawn(async move { - let mut sink = sink; - let _closed = closed_tx; - - while let Some(message) = outgoing_rx.recv().await { - if write_message(&mut sink, message).await.is_err() { - break; - } - } - }); - - (outgoing_tx, closed_rx) -} - -async fn write_message(sink: &mut Writer, message: OutgoingMessage) -> Result<(), String> { - match message { - OutgoingMessage::Application(message) => { - let text = serde_json::to_string(&message).map_err(|e| e.to_string())?; - sink.send(Message::Text(Utf8Bytes::from(text))) - .await - .map_err(|e| e.to_string()) - } - OutgoingMessage::Ping => sink - .send(Message::Ping(Vec::new().into())) - .await - .map_err(|e| e.to_string()), - OutgoingMessage::Pong(payload) => sink - .send(Message::Pong(payload.into())) - .await - .map_err(|e| e.to_string()), - } -} - -struct StreamAction { - clear_pong_deadline: bool, -} - -async fn send_outgoing( - outgoing_tx: &mpsc::Sender, - message: OutgoingMessage, -) -> Result<(), String> { - outgoing_tx.send(message).await.map_err(|e| e.to_string()) -} - -async fn handle_stream_message( - message: Option>, - outgoing_tx: &mpsc::Sender, - max_output_bytes: u64, -) -> Result { - match message { - Some(Ok(Message::Text(text))) => { - handle_server_message(parse_server_message(&text)?, outgoing_tx, max_output_bytes) - .await?; - Ok(StreamAction { - clear_pong_deadline: false, - }) - } - Some(Ok(Message::Ping(payload))) => { - send_outgoing(outgoing_tx, OutgoingMessage::Pong(payload.to_vec())).await?; - Ok(StreamAction { - clear_pong_deadline: false, - }) - } - Some(Ok(Message::Close(_))) | None => Err("Connection closed".into()), - Some(Err(error)) => Err(error.to_string()), - Some(Ok(Message::Pong(_))) => Ok(StreamAction { - clear_pong_deadline: true, - }), - Some(Ok(_)) => Ok(StreamAction { - clear_pong_deadline: false, - }), - } -} - -async fn handle_server_message( - message: ServerMsg, - outgoing_tx: &mpsc::Sender, - max_output_bytes: u64, -) -> Result<(), String> { - match message { - ServerMsg::Exec { exec_id, command } => { - spawn_command_task(exec_id, command, outgoing_tx.clone(), max_output_bytes); - Ok(()) - } - ServerMsg::AuthError { reason } => Err(reason), - ServerMsg::AuthOk => Ok(()), - } -} - -fn spawn_command_task( - exec_id: String, - command: String, - outgoing_tx: mpsc::Sender, - max_output_bytes: u64, -) { - tokio::spawn(async move { - info!("Executing [{exec_id}]: {command}"); - let (exit_code, stdout, stderr) = execute_command(&command, max_output_bytes).await; - let result = ClientMsg::ExecResult { - exec_id, - exit_code, - stdout, - stderr, - }; - - if let Err(error) = send_outgoing(&outgoing_tx, OutgoingMessage::Application(result)).await - { - error!("Failed to send command result: {error}"); - } - }); -} - -async fn execute_command(command: &str, max_output_bytes: u64) -> (i32, String, String) { - let mut child = match Command::new("sh") - .arg("-c") - .arg(command) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - { - Ok(c) => c, - Err(e) => return (-1, String::new(), e.to_string()), - }; - - let Some(stdout) = child.stdout.take() else { - return (-1, String::new(), "stdout pipe unavailable".to_string()); - }; - let Some(stderr) = child.stderr.take() else { - return (-1, String::new(), "stderr pipe unavailable".to_string()); - }; - - let (stdout_data, stderr_data) = tokio::join!( - read_limited_output(stdout, max_output_bytes), - read_limited_output(stderr, max_output_bytes), - ); - - child.kill().await.ok(); - let exit_code = child.wait().await.ok().and_then(|s| s.code()).unwrap_or(-1); - - ( - exit_code, - String::from_utf8_lossy(&stdout_data).into_owned(), - String::from_utf8_lossy(&stderr_data).into_owned(), - ) -} - -async fn read_limited_output(reader: R, max_output_bytes: u64) -> Vec -where - R: AsyncRead + Unpin, -{ - let mut buffer = Vec::new(); - let _ = reader.take(max_output_bytes).read_to_end(&mut buffer).await; - buffer + app::client::run_forever().await; }