use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use std::sync::Arc; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::net::{UnixListener, UnixStream}; use tokio::sync::{broadcast, mpsc, Mutex}; use crate::auth::AuthToken; use crate::protocol::{decode_input, encode_output, ClientMessage, DaemonMessage}; use crate::session::SessionManager; /// High-water mark for the per-client send queue (in messages, not bytes). /// We limit to ~256 KB worth of medium-sized chunks before dropping the client. const CLIENT_QUEUE_CAP: usize = 64; /// Shared mutable state accessible from all tasks. struct State { sessions: SessionManager, /// session_id → set of client_ids currently subscribed. subscriptions: HashMap>, /// client_id → channel to push messages back to that client's write task. client_txs: HashMap>, next_client_id: u64, } impl State { fn new(default_shell: String) -> Self { Self { sessions: SessionManager::new(default_shell), subscriptions: HashMap::new(), client_txs: HashMap::new(), next_client_id: 1, } } fn alloc_client_id(&mut self) -> u64 { let id = self.next_client_id; self.next_client_id += 1; id } /// Remove a client from all subscription sets and from the client map. fn remove_client(&mut self, cid: u64) { self.client_txs.remove(&cid); for subs in self.subscriptions.values_mut() { subs.remove(&cid); } } /// Fan-out a message to all subscribers of `session_id`. fn fanout(&self, session_id: &str, msg: DaemonMessage) { if let Some(subs) = self.subscriptions.get(session_id) { for cid in subs { if let Some(tx) = self.client_txs.get(cid) { // Non-blocking: drop slow clients silently. let _ = tx.try_send(msg.clone()); } } } } } pub struct Daemon { socket_path: PathBuf, token: AuthToken, default_shell: String, } impl Daemon { pub fn new(socket_path: PathBuf, token: AuthToken, default_shell: String) -> Self { Self { socket_path, token, default_shell, } } /// Run until `shutdown_rx` fires. pub async fn run(self, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), String> { // Remove stale socket from a previous run. let _ = std::fs::remove_file(&self.socket_path); let listener = UnixListener::bind(&self.socket_path) .map_err(|e| format!("bind {:?}: {e}", self.socket_path))?; log::info!("agor-ptyd v0.1.0 listening on {:?}", self.socket_path); let state = Arc::new(Mutex::new(State::new(self.default_shell.clone()))); let token = Arc::new(self.token); loop { tokio::select! { accept = listener.accept() => { match accept { Ok((stream, _addr)) => { let state = state.clone(); let token = token.clone(); tokio::spawn(handle_client(stream, state, token)); } Err(e) => log::warn!("accept error: {e}"), } } _ = shutdown_rx.recv() => { log::info!("shutdown signal received — stopping daemon"); break; } } } // Cleanup socket file. let _ = std::fs::remove_file(&self.socket_path); Ok(()) } } /// Handle a single client connection from handshake to disconnect. async fn handle_client( stream: UnixStream, state: Arc>, token: Arc, ) { let (read_half, write_half) = stream.into_split(); let mut reader = BufReader::new(read_half); // First message must be Auth. let mut line = String::new(); if reader.read_line(&mut line).await.unwrap_or(0) == 0 { log::warn!("client disconnected before auth"); return; } let auth_msg: ClientMessage = match serde_json::from_str(line.trim()) { Ok(m) => m, Err(e) => { log::warn!("invalid auth message: {e}"); return; } }; let presented_token = match auth_msg { ClientMessage::Auth { token: t } => t, _ => { log::warn!("first message was not Auth — dropping client"); return; } }; if !token.verify(&presented_token) { log::warn!("auth failed (token={} redacted)", token.redacted()); // Send failure then drop the connection. let _ = send_line( write_half, &DaemonMessage::AuthResult { ok: false }, ) .await; return; } // Register client. let (out_tx, out_rx) = mpsc::channel::(CLIENT_QUEUE_CAP); let cid = { let mut st = state.lock().await; let cid = st.alloc_client_id(); st.client_txs.insert(cid, out_tx.clone()); cid }; log::info!("client {cid} authenticated"); // Send auth success. if let Err(e) = out_tx.try_send(DaemonMessage::AuthResult { ok: true }) { log::warn!("client {cid}: failed to queue AuthResult: {e}"); state.lock().await.remove_client(cid); return; } // Spawn a dedicated write task so the reader loop is never blocked by // slow writes to the socket. let write_task = tokio::spawn(write_loop(write_half, out_rx)); // Read loop. loop { let mut line = String::new(); match reader.read_line(&mut line).await { Ok(0) => break, // EOF Ok(_) => {} Err(e) => { log::debug!("client {cid} read error: {e}"); break; } } let msg: ClientMessage = match serde_json::from_str(line.trim()) { Ok(m) => m, Err(e) => { log::warn!("client {cid} bad message: {e}"); let _ = out_tx .try_send(DaemonMessage::Error { message: format!("parse error: {e}"), }); continue; } }; handle_message(cid, msg, &state, &out_tx).await; } // Cleanup on disconnect. log::info!("client {cid} disconnected"); state.lock().await.remove_client(cid); write_task.abort(); } /// Dispatch a single client message to the appropriate handler. async fn handle_message( cid: u64, msg: ClientMessage, state: &Arc>, out_tx: &mpsc::Sender, ) { match msg { ClientMessage::Auth { .. } => { // Already authenticated — ignore duplicate. } ClientMessage::Ping => { let _ = out_tx.try_send(DaemonMessage::Pong); } ClientMessage::ListSessions => { let list = state.lock().await.sessions.list(); let _ = out_tx.try_send(DaemonMessage::SessionList { sessions: list }); } ClientMessage::CreateSession { id, shell, cwd, env, cols, rows } => { let state_clone = state.clone(); let out_tx_clone = out_tx.clone(); let id_clone = id.clone(); let result = { let mut st = state.lock().await; st.sessions.create_session( id.clone(), shell, cwd, env, cols, rows, move |sid, code| { // Invoked from the blocking reader task when child exits. let state_clone = state_clone.clone(); let _ = &out_tx_clone; // captured for lifetime, not used tokio::spawn(async move { let st = state_clone.lock().await; st.fanout( &sid, DaemonMessage::SessionClosed { session_id: sid.clone(), exit_code: code, }, ); drop(st); }); }, ) }; match result { Ok((pid, output_rx)) => { let _ = out_tx.try_send(DaemonMessage::SessionCreated { session_id: id_clone.clone(), pid, }); // Immediately subscribe the creating client. { let mut st = state.lock().await; st.subscriptions .entry(id_clone.clone()) .or_default() .insert(cid); } // Start a fanout task for this session's output. let state_clone = state.clone(); tokio::spawn(output_fanout_task(id_clone, output_rx, state_clone)); } Err(e) => { let _ = out_tx.try_send(DaemonMessage::Error { message: e }); } } } ClientMessage::WriteInput { session_id, data } => { let bytes = match decode_input(&data) { Ok(b) => b, Err(e) => { let _ = out_tx.try_send(DaemonMessage::Error { message: format!("bad input encoding: {e}"), }); return; } }; let st = state.lock().await; match st.sessions.get(&session_id) { Some(sess) => { if let Err(e) = sess.write_input(&bytes).await { let _ = out_tx.try_send(DaemonMessage::Error { message: e }); } } None => { let _ = out_tx.try_send(DaemonMessage::Error { message: format!("session {session_id} not found"), }); } } } ClientMessage::Resize { session_id, cols, rows } => { let mut st = state.lock().await; match st.sessions.get_mut(&session_id) { Some(sess) => { if let Err(e) = sess.resize(cols, rows).await { log::warn!("resize {session_id}: {e}"); } } None => { let _ = out_tx.try_send(DaemonMessage::Error { message: format!("session {session_id} not found"), }); } } } ClientMessage::Subscribe { session_id } => { let (exists, rx) = { let st = state.lock().await; let exists = st.sessions.get(&session_id).is_some(); let rx = st .sessions .get(&session_id) .map(|s| s.subscribe()); (exists, rx) }; if !exists { let _ = out_tx.try_send(DaemonMessage::Error { message: format!("session {session_id} not found"), }); return; } { let mut st = state.lock().await; st.subscriptions .entry(session_id.clone()) .or_default() .insert(cid); } // If a new rx came back, start a fanout task (handles reconnect case // where the original fanout task has gone away after all receivers // dropped). We always start one; duplicates are harmless since the // broadcast channel keeps all messages. if let Some(rx) = rx { let state_clone = state.clone(); tokio::spawn(output_fanout_task(session_id, rx, state_clone)); } } ClientMessage::Unsubscribe { session_id } => { let mut st = state.lock().await; if let Some(subs) = st.subscriptions.get_mut(&session_id) { subs.remove(&cid); } } ClientMessage::CloseSession { session_id } => { let mut st = state.lock().await; if let Err(e) = st.sessions.close_session(&session_id) { let _ = out_tx.try_send(DaemonMessage::Error { message: e }); } else { st.subscriptions.remove(&session_id); } } } } /// Reads from a session's broadcast channel and fans output to all subscribed /// clients via their individual mpsc queues. async fn output_fanout_task( session_id: String, mut rx: broadcast::Receiver>, state: Arc>, ) { loop { match rx.recv().await { Ok(chunk) => { let encoded = encode_output(&chunk); let msg = DaemonMessage::SessionOutput { session_id: session_id.clone(), data: encoded, }; state.lock().await.fanout(&session_id, msg); } Err(broadcast::error::RecvError::Lagged(n)) => { log::warn!("session {session_id} fanout lagged, dropped {n} messages"); } Err(broadcast::error::RecvError::Closed) => { log::debug!("session {session_id} output channel closed"); break; } } } } /// Drains the per-client mpsc queue and writes newline-delimited JSON to the /// socket. async fn write_loop( mut writer: tokio::net::unix::OwnedWriteHalf, mut rx: mpsc::Receiver, ) { while let Some(msg) = rx.recv().await { match serde_json::to_string(&msg) { Ok(mut json) => { json.push('\n'); if let Err(e) = writer.write_all(json.as_bytes()).await { log::debug!("write error: {e}"); break; } } Err(e) => { log::warn!("serialize error: {e}"); } } } } /// One-shot write for pre-auth messages (write_half not yet consumed by the /// write_loop task). async fn send_line( mut writer: tokio::net::unix::OwnedWriteHalf, msg: &DaemonMessage, ) -> Result<(), String> { let mut json = serde_json::to_string(msg) .map_err(|e| format!("serialize: {e}"))?; json.push('\n'); writer .write_all(json.as_bytes()) .await .map_err(|e| format!("write: {e}")) }