feat: add TLS support to bterminal-relay
Add optional --tls-cert and --tls-key CLI args. When provided, the relay wraps TCP streams with native-tls before WebSocket upgrade. Refactored to generic accept_ws_with_auth<S> and run_ws_session<S> to avoid code duplication between plain and TLS paths. Client side already supports wss:// URLs via connect_async with native-tls feature.
This commit is contained in:
parent
cd774ab4bd
commit
83c6711cd6
3 changed files with 108 additions and 9 deletions
2
v2/Cargo.lock
generated
2
v2/Cargo.lock
generated
|
|
@ -416,9 +416,11 @@ dependencies = [
|
||||||
"env_logger",
|
"env_logger",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"log",
|
"log",
|
||||||
|
"native-tls",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-native-tls",
|
||||||
"tokio-tungstenite",
|
"tokio-tungstenite",
|
||||||
"uuid",
|
"uuid",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,8 @@ log = "0.4"
|
||||||
env_logger = "0.11"
|
env_logger = "0.11"
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
tokio-tungstenite = { version = "0.21", features = ["native-tls"] }
|
tokio-tungstenite = { version = "0.21", features = ["native-tls"] }
|
||||||
|
tokio-native-tls = "0.3"
|
||||||
|
native-tls = "0.2"
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
clap = { version = "4", features = ["derive"] }
|
clap = { version = "4", features = ["derive"] }
|
||||||
uuid = { version = "1", features = ["v4"] }
|
uuid = { version = "1", features = ["v4"] }
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,14 @@ struct Cli {
|
||||||
#[arg(long, default_value = "false")]
|
#[arg(long, default_value = "false")]
|
||||||
insecure: bool,
|
insecure: bool,
|
||||||
|
|
||||||
|
/// TLS certificate file (PEM format). Enables wss:// when provided with --tls-key.
|
||||||
|
#[arg(long)]
|
||||||
|
tls_cert: Option<String>,
|
||||||
|
|
||||||
|
/// TLS private key file (PEM format). Required when --tls-cert is provided.
|
||||||
|
#[arg(long)]
|
||||||
|
tls_key: Option<String>,
|
||||||
|
|
||||||
/// Additional sidecar search paths
|
/// Additional sidecar search paths
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
sidecar_path: Vec<String>,
|
sidecar_path: Vec<String>,
|
||||||
|
|
@ -75,14 +83,52 @@ impl EventSink for WsEventSink {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build a native-tls TLS acceptor from PEM cert and key files.
|
||||||
|
fn build_tls_acceptor(cert_path: &str, key_path: &str) -> Result<tokio_native_tls::TlsAcceptor, String> {
|
||||||
|
let cert_pem = std::fs::read(cert_path)
|
||||||
|
.map_err(|e| format!("Failed to read TLS cert '{}': {}", cert_path, e))?;
|
||||||
|
let key_pem = std::fs::read(key_path)
|
||||||
|
.map_err(|e| format!("Failed to read TLS key '{}': {}", key_path, e))?;
|
||||||
|
|
||||||
|
let identity = native_tls::Identity::from_pkcs8(&cert_pem, &key_pem)
|
||||||
|
.map_err(|e| format!("Failed to parse TLS identity (cert+key): {e}"))?;
|
||||||
|
|
||||||
|
let tls_acceptor = native_tls::TlsAcceptor::builder(identity)
|
||||||
|
.min_protocol_version(Some(native_tls::Protocol::Tlsv12))
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to build TLS acceptor: {e}"))?;
|
||||||
|
|
||||||
|
Ok(tokio_native_tls::TlsAcceptor::from(tls_acceptor))
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
env_logger::init();
|
env_logger::init();
|
||||||
let cli = Cli::parse();
|
let cli = Cli::parse();
|
||||||
|
|
||||||
|
// Validate TLS args
|
||||||
|
let tls_acceptor = match (&cli.tls_cert, &cli.tls_key) {
|
||||||
|
(Some(cert), Some(key)) => {
|
||||||
|
let acceptor = build_tls_acceptor(cert, key).expect("TLS setup failed");
|
||||||
|
log::info!("TLS enabled (cert: {cert}, key: {key})");
|
||||||
|
Some(Arc::new(acceptor))
|
||||||
|
}
|
||||||
|
(Some(_), None) | (None, Some(_)) => {
|
||||||
|
eprintln!("Error: --tls-cert and --tls-key must both be provided");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
(None, None) => {
|
||||||
|
if !cli.insecure {
|
||||||
|
log::warn!("Running without TLS. Use --tls-cert/--tls-key for encrypted connections, or --insecure to suppress this warning.");
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let addr = SocketAddr::from(([0, 0, 0, 0], cli.port));
|
let addr = SocketAddr::from(([0, 0, 0, 0], cli.port));
|
||||||
let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
|
let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
|
||||||
log::info!("bterminal-relay listening on {addr}");
|
let protocol = if tls_acceptor.is_some() { "wss" } else { "ws" };
|
||||||
|
log::info!("bterminal-relay listening on {protocol}://{addr}");
|
||||||
|
|
||||||
// Build sidecar config
|
// Build sidecar config
|
||||||
let mut search_paths: Vec<std::path::PathBuf> = cli
|
let mut search_paths: Vec<std::path::PathBuf> = cli
|
||||||
|
|
@ -111,6 +157,7 @@ async fn main() {
|
||||||
let token = token.clone();
|
let token = token.clone();
|
||||||
let sidecar_config = sidecar_config.clone();
|
let sidecar_config = sidecar_config.clone();
|
||||||
let auth_failures = auth_failures.clone();
|
let auth_failures = auth_failures.clone();
|
||||||
|
let tls = tls_acceptor.clone();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
// Check rate limit
|
// Check rate limit
|
||||||
|
|
@ -128,9 +175,24 @@ async fn main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(tls_acceptor) = tls {
|
||||||
|
// TLS path: wrap TCP stream with TLS, then upgrade to WebSocket
|
||||||
|
match tls_acceptor.accept(stream).await {
|
||||||
|
Ok(tls_stream) => {
|
||||||
|
if let Err(e) = handle_tls_connection(tls_stream, peer, &token, &sidecar_config, &auth_failures).await {
|
||||||
|
log::error!("TLS connection error from {peer}: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("TLS handshake failed from {peer}: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Plain WebSocket path
|
||||||
if let Err(e) = handle_connection(stream, peer, &token, &sidecar_config, &auth_failures).await {
|
if let Err(e) = handle_connection(stream, peer, &token, &sidecar_config, &auth_failures).await {
|
||||||
log::error!("Connection error from {peer}: {e}");
|
log::error!("Connection error from {peer}: {e}");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -142,12 +204,36 @@ async fn handle_connection(
|
||||||
sidecar_config: &SidecarConfig,
|
sidecar_config: &SidecarConfig,
|
||||||
auth_failures: &tokio::sync::Mutex<std::collections::HashMap<SocketAddr, (u32, std::time::Instant)>>,
|
auth_failures: &tokio::sync::Mutex<std::collections::HashMap<SocketAddr, (u32, std::time::Instant)>>,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
// Accept WebSocket with auth validation
|
let ws_stream = accept_ws_with_auth(stream, expected_token, peer, auth_failures).await?;
|
||||||
let ws_stream = tokio_tungstenite::accept_hdr_async(stream, |req: &http::Request<()>, response: http::Response<()>| {
|
run_ws_session(ws_stream, peer, sidecar_config).await
|
||||||
// Validate auth token from headers
|
}
|
||||||
|
|
||||||
|
async fn handle_tls_connection(
|
||||||
|
stream: tokio_native_tls::TlsStream<TcpStream>,
|
||||||
|
peer: SocketAddr,
|
||||||
|
expected_token: &str,
|
||||||
|
sidecar_config: &SidecarConfig,
|
||||||
|
auth_failures: &tokio::sync::Mutex<std::collections::HashMap<SocketAddr, (u32, std::time::Instant)>>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let ws_stream = accept_ws_with_auth(stream, expected_token, peer, auth_failures).await?;
|
||||||
|
run_ws_session(ws_stream, peer, sidecar_config).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Accept a WebSocket connection with Bearer token auth validation.
|
||||||
|
async fn accept_ws_with_auth<S>(
|
||||||
|
stream: S,
|
||||||
|
expected_token: &str,
|
||||||
|
peer: SocketAddr,
|
||||||
|
auth_failures: &tokio::sync::Mutex<std::collections::HashMap<SocketAddr, (u32, std::time::Instant)>>,
|
||||||
|
) -> Result<tokio_tungstenite::WebSocketStream<S>, String>
|
||||||
|
where
|
||||||
|
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
let expected = format!("Bearer {expected_token}");
|
||||||
|
tokio_tungstenite::accept_hdr_async(stream, |req: &http::Request<()>, response: http::Response<()>| {
|
||||||
let auth = req.headers().get("authorization").and_then(|v| v.to_str().ok());
|
let auth = req.headers().get("authorization").and_then(|v| v.to_str().ok());
|
||||||
match auth {
|
match auth {
|
||||||
Some(value) if value == format!("Bearer {expected_token}") => Ok(response),
|
Some(value) if value == expected => Ok(response),
|
||||||
_ => {
|
_ => {
|
||||||
Err(http::Response::builder()
|
Err(http::Response::builder()
|
||||||
.status(http::StatusCode::UNAUTHORIZED)
|
.status(http::StatusCode::UNAUTHORIZED)
|
||||||
|
|
@ -158,15 +244,24 @@ async fn handle_connection(
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
// Record auth failure
|
|
||||||
let _ = auth_failures.try_lock().map(|mut f| {
|
let _ = auth_failures.try_lock().map(|mut f| {
|
||||||
let entry = f.entry(peer).or_insert((0, std::time::Instant::now()));
|
let entry = f.entry(peer).or_insert((0, std::time::Instant::now()));
|
||||||
entry.0 += 1;
|
entry.0 += 1;
|
||||||
entry.1 = std::time::Instant::now();
|
entry.1 = std::time::Instant::now();
|
||||||
});
|
});
|
||||||
format!("WebSocket handshake failed: {e}")
|
format!("WebSocket handshake failed: {e}")
|
||||||
})?;
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the WebSocket session (managers, event forwarding, command processing).
|
||||||
|
async fn run_ws_session<S>(
|
||||||
|
ws_stream: tokio_tungstenite::WebSocketStream<S>,
|
||||||
|
peer: SocketAddr,
|
||||||
|
sidecar_config: &SidecarConfig,
|
||||||
|
) -> Result<(), String>
|
||||||
|
where
|
||||||
|
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
log::info!("Client connected: {peer}");
|
log::info!("Client connected: {peer}");
|
||||||
|
|
||||||
// Set up event channel — shared between EventSink and command response sender
|
// Set up event channel — shared between EventSink and command response sender
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue