diff --git a/v2/Cargo.lock b/v2/Cargo.lock index 90821cd..8a9aa2e 100644 --- a/v2/Cargo.lock +++ b/v2/Cargo.lock @@ -416,9 +416,11 @@ dependencies = [ "env_logger", "futures-util", "log", + "native-tls", "serde", "serde_json", "tokio", + "tokio-native-tls", "tokio-tungstenite", "uuid", ] diff --git a/v2/bterminal-relay/Cargo.toml b/v2/bterminal-relay/Cargo.toml index ee77a34..3e95c7e 100644 --- a/v2/bterminal-relay/Cargo.toml +++ b/v2/bterminal-relay/Cargo.toml @@ -17,6 +17,8 @@ log = "0.4" env_logger = "0.11" tokio = { version = "1", features = ["full"] } tokio-tungstenite = { version = "0.21", features = ["native-tls"] } +tokio-native-tls = "0.3" +native-tls = "0.2" futures-util = "0.3" clap = { version = "4", features = ["derive"] } uuid = { version = "1", features = ["v4"] } diff --git a/v2/bterminal-relay/src/main.rs b/v2/bterminal-relay/src/main.rs index d87e80e..7e98edf 100644 --- a/v2/bterminal-relay/src/main.rs +++ b/v2/bterminal-relay/src/main.rs @@ -28,6 +28,14 @@ struct Cli { #[arg(long, default_value = "false")] insecure: bool, + /// TLS certificate file (PEM format). Enables wss:// when provided with --tls-key. + #[arg(long)] + tls_cert: Option, + + /// TLS private key file (PEM format). Required when --tls-cert is provided. + #[arg(long)] + tls_key: Option, + /// Additional sidecar search paths #[arg(long)] sidecar_path: Vec, @@ -75,14 +83,52 @@ impl EventSink for WsEventSink { } } +/// Build a native-tls TLS acceptor from PEM cert and key files. +fn build_tls_acceptor(cert_path: &str, key_path: &str) -> Result { + let cert_pem = std::fs::read(cert_path) + .map_err(|e| format!("Failed to read TLS cert '{}': {}", cert_path, e))?; + let key_pem = std::fs::read(key_path) + .map_err(|e| format!("Failed to read TLS key '{}': {}", key_path, e))?; + + let identity = native_tls::Identity::from_pkcs8(&cert_pem, &key_pem) + .map_err(|e| format!("Failed to parse TLS identity (cert+key): {e}"))?; + + let tls_acceptor = native_tls::TlsAcceptor::builder(identity) + .min_protocol_version(Some(native_tls::Protocol::Tlsv12)) + .build() + .map_err(|e| format!("Failed to build TLS acceptor: {e}"))?; + + Ok(tokio_native_tls::TlsAcceptor::from(tls_acceptor)) +} + #[tokio::main] async fn main() { env_logger::init(); let cli = Cli::parse(); + // Validate TLS args + let tls_acceptor = match (&cli.tls_cert, &cli.tls_key) { + (Some(cert), Some(key)) => { + let acceptor = build_tls_acceptor(cert, key).expect("TLS setup failed"); + log::info!("TLS enabled (cert: {cert}, key: {key})"); + Some(Arc::new(acceptor)) + } + (Some(_), None) | (None, Some(_)) => { + eprintln!("Error: --tls-cert and --tls-key must both be provided"); + std::process::exit(1); + } + (None, None) => { + if !cli.insecure { + log::warn!("Running without TLS. Use --tls-cert/--tls-key for encrypted connections, or --insecure to suppress this warning."); + } + None + } + }; + let addr = SocketAddr::from(([0, 0, 0, 0], cli.port)); let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); - log::info!("bterminal-relay listening on {addr}"); + let protocol = if tls_acceptor.is_some() { "wss" } else { "ws" }; + log::info!("bterminal-relay listening on {protocol}://{addr}"); // Build sidecar config let mut search_paths: Vec = cli @@ -111,6 +157,7 @@ async fn main() { let token = token.clone(); let sidecar_config = sidecar_config.clone(); let auth_failures = auth_failures.clone(); + let tls = tls_acceptor.clone(); tokio::spawn(async move { // Check rate limit @@ -128,8 +175,23 @@ async fn main() { } } - if let Err(e) = handle_connection(stream, peer, &token, &sidecar_config, &auth_failures).await { - log::error!("Connection error from {peer}: {e}"); + if let Some(tls_acceptor) = tls { + // TLS path: wrap TCP stream with TLS, then upgrade to WebSocket + match tls_acceptor.accept(stream).await { + Ok(tls_stream) => { + if let Err(e) = handle_tls_connection(tls_stream, peer, &token, &sidecar_config, &auth_failures).await { + log::error!("TLS connection error from {peer}: {e}"); + } + } + Err(e) => { + log::error!("TLS handshake failed from {peer}: {e}"); + } + } + } else { + // Plain WebSocket path + if let Err(e) = handle_connection(stream, peer, &token, &sidecar_config, &auth_failures).await { + log::error!("Connection error from {peer}: {e}"); + } } }); } @@ -142,12 +204,36 @@ async fn handle_connection( sidecar_config: &SidecarConfig, auth_failures: &tokio::sync::Mutex>, ) -> Result<(), String> { - // Accept WebSocket with auth validation - let ws_stream = tokio_tungstenite::accept_hdr_async(stream, |req: &http::Request<()>, response: http::Response<()>| { - // Validate auth token from headers + let ws_stream = accept_ws_with_auth(stream, expected_token, peer, auth_failures).await?; + run_ws_session(ws_stream, peer, sidecar_config).await +} + +async fn handle_tls_connection( + stream: tokio_native_tls::TlsStream, + peer: SocketAddr, + expected_token: &str, + sidecar_config: &SidecarConfig, + auth_failures: &tokio::sync::Mutex>, +) -> Result<(), String> { + let ws_stream = accept_ws_with_auth(stream, expected_token, peer, auth_failures).await?; + run_ws_session(ws_stream, peer, sidecar_config).await +} + +/// Accept a WebSocket connection with Bearer token auth validation. +async fn accept_ws_with_auth( + stream: S, + expected_token: &str, + peer: SocketAddr, + auth_failures: &tokio::sync::Mutex>, +) -> Result, String> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + let expected = format!("Bearer {expected_token}"); + tokio_tungstenite::accept_hdr_async(stream, |req: &http::Request<()>, response: http::Response<()>| { let auth = req.headers().get("authorization").and_then(|v| v.to_str().ok()); match auth { - Some(value) if value == format!("Bearer {expected_token}") => Ok(response), + Some(value) if value == expected => Ok(response), _ => { Err(http::Response::builder() .status(http::StatusCode::UNAUTHORIZED) @@ -158,15 +244,24 @@ async fn handle_connection( }) .await .map_err(|e| { - // Record auth failure let _ = auth_failures.try_lock().map(|mut f| { let entry = f.entry(peer).or_insert((0, std::time::Instant::now())); entry.0 += 1; entry.1 = std::time::Instant::now(); }); format!("WebSocket handshake failed: {e}") - })?; + }) +} +/// Run the WebSocket session (managers, event forwarding, command processing). +async fn run_ws_session( + ws_stream: tokio_tungstenite::WebSocketStream, + peer: SocketAddr, + sidecar_config: &SidecarConfig, +) -> Result<(), String> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, +{ log::info!("Client connected: {peer}"); // Set up event channel — shared between EventSink and command response sender