feat(agor-pty): complete PTY daemon — auth, sessions, output fanout

This commit is contained in:
Hibryda 2026-03-20 03:10:49 +01:00
parent 4b5583430d
commit f3456bd09d
6 changed files with 1853 additions and 65 deletions

441
agor-pty/src/daemon.rs Normal file
View file

@ -0,0 +1,441 @@
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::{UnixListener, UnixStream};
use tokio::sync::{broadcast, mpsc, Mutex};
use crate::auth::AuthToken;
use crate::protocol::{decode_input, encode_output, ClientMessage, DaemonMessage};
use crate::session::SessionManager;
/// High-water mark for the per-client send queue (in messages, not bytes).
/// We limit to ~256 KB worth of medium-sized chunks before dropping the client.
const CLIENT_QUEUE_CAP: usize = 64;
/// Shared mutable state accessible from all tasks.
struct State {
sessions: SessionManager,
/// session_id → set of client_ids currently subscribed.
subscriptions: HashMap<String, HashSet<u64>>,
/// client_id → channel to push messages back to that client's write task.
client_txs: HashMap<u64, mpsc::Sender<DaemonMessage>>,
next_client_id: u64,
}
impl State {
fn new(default_shell: String) -> Self {
Self {
sessions: SessionManager::new(default_shell),
subscriptions: HashMap::new(),
client_txs: HashMap::new(),
next_client_id: 1,
}
}
fn alloc_client_id(&mut self) -> u64 {
let id = self.next_client_id;
self.next_client_id += 1;
id
}
/// Remove a client from all subscription sets and from the client map.
fn remove_client(&mut self, cid: u64) {
self.client_txs.remove(&cid);
for subs in self.subscriptions.values_mut() {
subs.remove(&cid);
}
}
/// Fan-out a message to all subscribers of `session_id`.
fn fanout(&self, session_id: &str, msg: DaemonMessage) {
if let Some(subs) = self.subscriptions.get(session_id) {
for cid in subs {
if let Some(tx) = self.client_txs.get(cid) {
// Non-blocking: drop slow clients silently.
let _ = tx.try_send(msg.clone());
}
}
}
}
}
pub struct Daemon {
socket_path: PathBuf,
token: AuthToken,
default_shell: String,
}
impl Daemon {
pub fn new(socket_path: PathBuf, token: AuthToken, default_shell: String) -> Self {
Self {
socket_path,
token,
default_shell,
}
}
/// Run until `shutdown_rx` fires.
pub async fn run(self, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), String> {
// Remove stale socket from a previous run.
let _ = std::fs::remove_file(&self.socket_path);
let listener = UnixListener::bind(&self.socket_path)
.map_err(|e| format!("bind {:?}: {e}", self.socket_path))?;
log::info!("agor-ptyd v0.1.0 listening on {:?}", self.socket_path);
let state = Arc::new(Mutex::new(State::new(self.default_shell.clone())));
let token = Arc::new(self.token);
loop {
tokio::select! {
accept = listener.accept() => {
match accept {
Ok((stream, _addr)) => {
let state = state.clone();
let token = token.clone();
tokio::spawn(handle_client(stream, state, token));
}
Err(e) => log::warn!("accept error: {e}"),
}
}
_ = shutdown_rx.recv() => {
log::info!("shutdown signal received — stopping daemon");
break;
}
}
}
// Cleanup socket file.
let _ = std::fs::remove_file(&self.socket_path);
Ok(())
}
}
/// Handle a single client connection from handshake to disconnect.
async fn handle_client(
stream: UnixStream,
state: Arc<Mutex<State>>,
token: Arc<AuthToken>,
) {
let (read_half, write_half) = stream.into_split();
let mut reader = BufReader::new(read_half);
// First message must be Auth.
let mut line = String::new();
if reader.read_line(&mut line).await.unwrap_or(0) == 0 {
log::warn!("client disconnected before auth");
return;
}
let auth_msg: ClientMessage = match serde_json::from_str(line.trim()) {
Ok(m) => m,
Err(e) => {
log::warn!("invalid auth message: {e}");
return;
}
};
let presented_token = match auth_msg {
ClientMessage::Auth { token: t } => t,
_ => {
log::warn!("first message was not Auth — dropping client");
return;
}
};
if !token.verify(&presented_token) {
log::warn!("auth failed (token={} redacted)", token.redacted());
// Send failure then drop the connection.
let _ = send_line(
write_half,
&DaemonMessage::AuthResult { ok: false },
)
.await;
return;
}
// Register client.
let (out_tx, out_rx) = mpsc::channel::<DaemonMessage>(CLIENT_QUEUE_CAP);
let cid = {
let mut st = state.lock().await;
let cid = st.alloc_client_id();
st.client_txs.insert(cid, out_tx.clone());
cid
};
log::info!("client {cid} authenticated");
// Send auth success.
if let Err(e) = out_tx.try_send(DaemonMessage::AuthResult { ok: true }) {
log::warn!("client {cid}: failed to queue AuthResult: {e}");
state.lock().await.remove_client(cid);
return;
}
// Spawn a dedicated write task so the reader loop is never blocked by
// slow writes to the socket.
let write_task = tokio::spawn(write_loop(write_half, out_rx));
// Read loop.
loop {
let mut line = String::new();
match reader.read_line(&mut line).await {
Ok(0) => break, // EOF
Ok(_) => {}
Err(e) => {
log::debug!("client {cid} read error: {e}");
break;
}
}
let msg: ClientMessage = match serde_json::from_str(line.trim()) {
Ok(m) => m,
Err(e) => {
log::warn!("client {cid} bad message: {e}");
let _ = out_tx
.try_send(DaemonMessage::Error {
message: format!("parse error: {e}"),
});
continue;
}
};
handle_message(cid, msg, &state, &out_tx).await;
}
// Cleanup on disconnect.
log::info!("client {cid} disconnected");
state.lock().await.remove_client(cid);
write_task.abort();
}
/// Dispatch a single client message to the appropriate handler.
async fn handle_message(
cid: u64,
msg: ClientMessage,
state: &Arc<Mutex<State>>,
out_tx: &mpsc::Sender<DaemonMessage>,
) {
match msg {
ClientMessage::Auth { .. } => {
// Already authenticated — ignore duplicate.
}
ClientMessage::Ping => {
let _ = out_tx.try_send(DaemonMessage::Pong);
}
ClientMessage::ListSessions => {
let list = state.lock().await.sessions.list();
let _ = out_tx.try_send(DaemonMessage::SessionList { sessions: list });
}
ClientMessage::CreateSession { id, shell, cwd, env, cols, rows } => {
let state_clone = state.clone();
let out_tx_clone = out_tx.clone();
let id_clone = id.clone();
let result = {
let mut st = state.lock().await;
st.sessions.create_session(
id.clone(),
shell,
cwd,
env,
cols,
rows,
move |sid, code| {
// Invoked from the blocking reader task when child exits.
let state_clone = state_clone.clone();
let _ = &out_tx_clone; // captured for lifetime, not used
tokio::spawn(async move {
let st = state_clone.lock().await;
st.fanout(
&sid,
DaemonMessage::SessionClosed {
session_id: sid.clone(),
exit_code: code,
},
);
drop(st);
});
},
)
};
match result {
Ok((pid, output_rx)) => {
let _ = out_tx.try_send(DaemonMessage::SessionCreated {
session_id: id_clone.clone(),
pid,
});
// Immediately subscribe the creating client.
{
let mut st = state.lock().await;
st.subscriptions
.entry(id_clone.clone())
.or_default()
.insert(cid);
}
// Start a fanout task for this session's output.
let state_clone = state.clone();
tokio::spawn(output_fanout_task(id_clone, output_rx, state_clone));
}
Err(e) => {
let _ = out_tx.try_send(DaemonMessage::Error { message: e });
}
}
}
ClientMessage::WriteInput { session_id, data } => {
let bytes = match decode_input(&data) {
Ok(b) => b,
Err(e) => {
let _ = out_tx.try_send(DaemonMessage::Error {
message: format!("bad input encoding: {e}"),
});
return;
}
};
let st = state.lock().await;
match st.sessions.get(&session_id) {
Some(sess) => {
if let Err(e) = sess.write_input(&bytes).await {
let _ = out_tx.try_send(DaemonMessage::Error { message: e });
}
}
None => {
let _ = out_tx.try_send(DaemonMessage::Error {
message: format!("session {session_id} not found"),
});
}
}
}
ClientMessage::Resize { session_id, cols, rows } => {
let mut st = state.lock().await;
match st.sessions.get_mut(&session_id) {
Some(sess) => {
sess.note_resize(cols, rows);
}
None => {
let _ = out_tx.try_send(DaemonMessage::Error {
message: format!("session {session_id} not found"),
});
}
}
}
ClientMessage::Subscribe { session_id } => {
let (exists, rx) = {
let st = state.lock().await;
let exists = st.sessions.get(&session_id).is_some();
let rx = st
.sessions
.get(&session_id)
.map(|s| s.subscribe());
(exists, rx)
};
if !exists {
let _ = out_tx.try_send(DaemonMessage::Error {
message: format!("session {session_id} not found"),
});
return;
}
{
let mut st = state.lock().await;
st.subscriptions
.entry(session_id.clone())
.or_default()
.insert(cid);
}
// If a new rx came back, start a fanout task (handles reconnect case
// where the original fanout task has gone away after all receivers
// dropped). We always start one; duplicates are harmless since the
// broadcast channel keeps all messages.
if let Some(rx) = rx {
let state_clone = state.clone();
tokio::spawn(output_fanout_task(session_id, rx, state_clone));
}
}
ClientMessage::Unsubscribe { session_id } => {
let mut st = state.lock().await;
if let Some(subs) = st.subscriptions.get_mut(&session_id) {
subs.remove(&cid);
}
}
ClientMessage::CloseSession { session_id } => {
let mut st = state.lock().await;
if let Err(e) = st.sessions.close_session(&session_id) {
let _ = out_tx.try_send(DaemonMessage::Error { message: e });
} else {
st.subscriptions.remove(&session_id);
}
}
}
}
/// Reads from a session's broadcast channel and fans output to all subscribed
/// clients via their individual mpsc queues.
async fn output_fanout_task(
session_id: String,
mut rx: broadcast::Receiver<Vec<u8>>,
state: Arc<Mutex<State>>,
) {
loop {
match rx.recv().await {
Ok(chunk) => {
let encoded = encode_output(&chunk);
let msg = DaemonMessage::SessionOutput {
session_id: session_id.clone(),
data: encoded,
};
state.lock().await.fanout(&session_id, msg);
}
Err(broadcast::error::RecvError::Lagged(n)) => {
log::warn!("session {session_id} fanout lagged, dropped {n} messages");
}
Err(broadcast::error::RecvError::Closed) => {
log::debug!("session {session_id} output channel closed");
break;
}
}
}
}
/// Drains the per-client mpsc queue and writes newline-delimited JSON to the
/// socket.
async fn write_loop(
mut writer: tokio::net::unix::OwnedWriteHalf,
mut rx: mpsc::Receiver<DaemonMessage>,
) {
while let Some(msg) = rx.recv().await {
match serde_json::to_string(&msg) {
Ok(mut json) => {
json.push('\n');
if let Err(e) = writer.write_all(json.as_bytes()).await {
log::debug!("write error: {e}");
break;
}
}
Err(e) => {
log::warn!("serialize error: {e}");
}
}
}
}
/// One-shot write for pre-auth messages (write_half not yet consumed by the
/// write_loop task).
async fn send_line(
mut writer: tokio::net::unix::OwnedWriteHalf,
msg: &DaemonMessage,
) -> Result<(), String> {
let mut json = serde_json::to_string(msg)
.map_err(|e| format!("serialize: {e}"))?;
json.push('\n');
writer
.write_all(json.as_bytes())
.await
.map_err(|e| format!("write: {e}"))
}