feat(agor-pty): complete PTY daemon — auth, sessions, output fanout
This commit is contained in:
parent
4b5583430d
commit
f3456bd09d
6 changed files with 1853 additions and 65 deletions
441
agor-pty/src/daemon.rs
Normal file
441
agor-pty/src/daemon.rs
Normal file
|
|
@ -0,0 +1,441 @@
|
|||
use std::collections::{HashMap, HashSet};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
use tokio::sync::{broadcast, mpsc, Mutex};
|
||||
|
||||
use crate::auth::AuthToken;
|
||||
use crate::protocol::{decode_input, encode_output, ClientMessage, DaemonMessage};
|
||||
use crate::session::SessionManager;
|
||||
|
||||
/// High-water mark for the per-client send queue (in messages, not bytes).
|
||||
/// We limit to ~256 KB worth of medium-sized chunks before dropping the client.
|
||||
const CLIENT_QUEUE_CAP: usize = 64;
|
||||
|
||||
/// Shared mutable state accessible from all tasks.
|
||||
struct State {
|
||||
sessions: SessionManager,
|
||||
/// session_id → set of client_ids currently subscribed.
|
||||
subscriptions: HashMap<String, HashSet<u64>>,
|
||||
/// client_id → channel to push messages back to that client's write task.
|
||||
client_txs: HashMap<u64, mpsc::Sender<DaemonMessage>>,
|
||||
next_client_id: u64,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new(default_shell: String) -> Self {
|
||||
Self {
|
||||
sessions: SessionManager::new(default_shell),
|
||||
subscriptions: HashMap::new(),
|
||||
client_txs: HashMap::new(),
|
||||
next_client_id: 1,
|
||||
}
|
||||
}
|
||||
|
||||
fn alloc_client_id(&mut self) -> u64 {
|
||||
let id = self.next_client_id;
|
||||
self.next_client_id += 1;
|
||||
id
|
||||
}
|
||||
|
||||
/// Remove a client from all subscription sets and from the client map.
|
||||
fn remove_client(&mut self, cid: u64) {
|
||||
self.client_txs.remove(&cid);
|
||||
for subs in self.subscriptions.values_mut() {
|
||||
subs.remove(&cid);
|
||||
}
|
||||
}
|
||||
|
||||
/// Fan-out a message to all subscribers of `session_id`.
|
||||
fn fanout(&self, session_id: &str, msg: DaemonMessage) {
|
||||
if let Some(subs) = self.subscriptions.get(session_id) {
|
||||
for cid in subs {
|
||||
if let Some(tx) = self.client_txs.get(cid) {
|
||||
// Non-blocking: drop slow clients silently.
|
||||
let _ = tx.try_send(msg.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Daemon {
|
||||
socket_path: PathBuf,
|
||||
token: AuthToken,
|
||||
default_shell: String,
|
||||
}
|
||||
|
||||
impl Daemon {
|
||||
pub fn new(socket_path: PathBuf, token: AuthToken, default_shell: String) -> Self {
|
||||
Self {
|
||||
socket_path,
|
||||
token,
|
||||
default_shell,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run until `shutdown_rx` fires.
|
||||
pub async fn run(self, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), String> {
|
||||
// Remove stale socket from a previous run.
|
||||
let _ = std::fs::remove_file(&self.socket_path);
|
||||
|
||||
let listener = UnixListener::bind(&self.socket_path)
|
||||
.map_err(|e| format!("bind {:?}: {e}", self.socket_path))?;
|
||||
|
||||
log::info!("agor-ptyd v0.1.0 listening on {:?}", self.socket_path);
|
||||
|
||||
let state = Arc::new(Mutex::new(State::new(self.default_shell.clone())));
|
||||
let token = Arc::new(self.token);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
accept = listener.accept() => {
|
||||
match accept {
|
||||
Ok((stream, _addr)) => {
|
||||
let state = state.clone();
|
||||
let token = token.clone();
|
||||
tokio::spawn(handle_client(stream, state, token));
|
||||
}
|
||||
Err(e) => log::warn!("accept error: {e}"),
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
log::info!("shutdown signal received — stopping daemon");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup socket file.
|
||||
let _ = std::fs::remove_file(&self.socket_path);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a single client connection from handshake to disconnect.
|
||||
async fn handle_client(
|
||||
stream: UnixStream,
|
||||
state: Arc<Mutex<State>>,
|
||||
token: Arc<AuthToken>,
|
||||
) {
|
||||
let (read_half, write_half) = stream.into_split();
|
||||
let mut reader = BufReader::new(read_half);
|
||||
|
||||
// First message must be Auth.
|
||||
let mut line = String::new();
|
||||
if reader.read_line(&mut line).await.unwrap_or(0) == 0 {
|
||||
log::warn!("client disconnected before auth");
|
||||
return;
|
||||
}
|
||||
let auth_msg: ClientMessage = match serde_json::from_str(line.trim()) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
log::warn!("invalid auth message: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let presented_token = match auth_msg {
|
||||
ClientMessage::Auth { token: t } => t,
|
||||
_ => {
|
||||
log::warn!("first message was not Auth — dropping client");
|
||||
return;
|
||||
}
|
||||
};
|
||||
if !token.verify(&presented_token) {
|
||||
log::warn!("auth failed (token={} redacted)", token.redacted());
|
||||
// Send failure then drop the connection.
|
||||
let _ = send_line(
|
||||
write_half,
|
||||
&DaemonMessage::AuthResult { ok: false },
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Register client.
|
||||
let (out_tx, out_rx) = mpsc::channel::<DaemonMessage>(CLIENT_QUEUE_CAP);
|
||||
let cid = {
|
||||
let mut st = state.lock().await;
|
||||
let cid = st.alloc_client_id();
|
||||
st.client_txs.insert(cid, out_tx.clone());
|
||||
cid
|
||||
};
|
||||
log::info!("client {cid} authenticated");
|
||||
|
||||
// Send auth success.
|
||||
if let Err(e) = out_tx.try_send(DaemonMessage::AuthResult { ok: true }) {
|
||||
log::warn!("client {cid}: failed to queue AuthResult: {e}");
|
||||
state.lock().await.remove_client(cid);
|
||||
return;
|
||||
}
|
||||
|
||||
// Spawn a dedicated write task so the reader loop is never blocked by
|
||||
// slow writes to the socket.
|
||||
let write_task = tokio::spawn(write_loop(write_half, out_rx));
|
||||
|
||||
// Read loop.
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
match reader.read_line(&mut line).await {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
log::debug!("client {cid} read error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
let msg: ClientMessage = match serde_json::from_str(line.trim()) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
log::warn!("client {cid} bad message: {e}");
|
||||
let _ = out_tx
|
||||
.try_send(DaemonMessage::Error {
|
||||
message: format!("parse error: {e}"),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
handle_message(cid, msg, &state, &out_tx).await;
|
||||
}
|
||||
|
||||
// Cleanup on disconnect.
|
||||
log::info!("client {cid} disconnected");
|
||||
state.lock().await.remove_client(cid);
|
||||
write_task.abort();
|
||||
}
|
||||
|
||||
/// Dispatch a single client message to the appropriate handler.
|
||||
async fn handle_message(
|
||||
cid: u64,
|
||||
msg: ClientMessage,
|
||||
state: &Arc<Mutex<State>>,
|
||||
out_tx: &mpsc::Sender<DaemonMessage>,
|
||||
) {
|
||||
match msg {
|
||||
ClientMessage::Auth { .. } => {
|
||||
// Already authenticated — ignore duplicate.
|
||||
}
|
||||
|
||||
ClientMessage::Ping => {
|
||||
let _ = out_tx.try_send(DaemonMessage::Pong);
|
||||
}
|
||||
|
||||
ClientMessage::ListSessions => {
|
||||
let list = state.lock().await.sessions.list();
|
||||
let _ = out_tx.try_send(DaemonMessage::SessionList { sessions: list });
|
||||
}
|
||||
|
||||
ClientMessage::CreateSession { id, shell, cwd, env, cols, rows } => {
|
||||
let state_clone = state.clone();
|
||||
let out_tx_clone = out_tx.clone();
|
||||
let id_clone = id.clone();
|
||||
|
||||
let result = {
|
||||
let mut st = state.lock().await;
|
||||
st.sessions.create_session(
|
||||
id.clone(),
|
||||
shell,
|
||||
cwd,
|
||||
env,
|
||||
cols,
|
||||
rows,
|
||||
move |sid, code| {
|
||||
// Invoked from the blocking reader task when child exits.
|
||||
let state_clone = state_clone.clone();
|
||||
let _ = &out_tx_clone; // captured for lifetime, not used
|
||||
tokio::spawn(async move {
|
||||
let st = state_clone.lock().await;
|
||||
st.fanout(
|
||||
&sid,
|
||||
DaemonMessage::SessionClosed {
|
||||
session_id: sid.clone(),
|
||||
exit_code: code,
|
||||
},
|
||||
);
|
||||
drop(st);
|
||||
});
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok((pid, output_rx)) => {
|
||||
let _ = out_tx.try_send(DaemonMessage::SessionCreated {
|
||||
session_id: id_clone.clone(),
|
||||
pid,
|
||||
});
|
||||
// Immediately subscribe the creating client.
|
||||
{
|
||||
let mut st = state.lock().await;
|
||||
st.subscriptions
|
||||
.entry(id_clone.clone())
|
||||
.or_default()
|
||||
.insert(cid);
|
||||
}
|
||||
// Start a fanout task for this session's output.
|
||||
let state_clone = state.clone();
|
||||
tokio::spawn(output_fanout_task(id_clone, output_rx, state_clone));
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = out_tx.try_send(DaemonMessage::Error { message: e });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::WriteInput { session_id, data } => {
|
||||
let bytes = match decode_input(&data) {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
let _ = out_tx.try_send(DaemonMessage::Error {
|
||||
message: format!("bad input encoding: {e}"),
|
||||
});
|
||||
return;
|
||||
}
|
||||
};
|
||||
let st = state.lock().await;
|
||||
match st.sessions.get(&session_id) {
|
||||
Some(sess) => {
|
||||
if let Err(e) = sess.write_input(&bytes).await {
|
||||
let _ = out_tx.try_send(DaemonMessage::Error { message: e });
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let _ = out_tx.try_send(DaemonMessage::Error {
|
||||
message: format!("session {session_id} not found"),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::Resize { session_id, cols, rows } => {
|
||||
let mut st = state.lock().await;
|
||||
match st.sessions.get_mut(&session_id) {
|
||||
Some(sess) => {
|
||||
sess.note_resize(cols, rows);
|
||||
}
|
||||
None => {
|
||||
let _ = out_tx.try_send(DaemonMessage::Error {
|
||||
message: format!("session {session_id} not found"),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::Subscribe { session_id } => {
|
||||
let (exists, rx) = {
|
||||
let st = state.lock().await;
|
||||
let exists = st.sessions.get(&session_id).is_some();
|
||||
let rx = st
|
||||
.sessions
|
||||
.get(&session_id)
|
||||
.map(|s| s.subscribe());
|
||||
(exists, rx)
|
||||
};
|
||||
if !exists {
|
||||
let _ = out_tx.try_send(DaemonMessage::Error {
|
||||
message: format!("session {session_id} not found"),
|
||||
});
|
||||
return;
|
||||
}
|
||||
{
|
||||
let mut st = state.lock().await;
|
||||
st.subscriptions
|
||||
.entry(session_id.clone())
|
||||
.or_default()
|
||||
.insert(cid);
|
||||
}
|
||||
// If a new rx came back, start a fanout task (handles reconnect case
|
||||
// where the original fanout task has gone away after all receivers
|
||||
// dropped). We always start one; duplicates are harmless since the
|
||||
// broadcast channel keeps all messages.
|
||||
if let Some(rx) = rx {
|
||||
let state_clone = state.clone();
|
||||
tokio::spawn(output_fanout_task(session_id, rx, state_clone));
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::Unsubscribe { session_id } => {
|
||||
let mut st = state.lock().await;
|
||||
if let Some(subs) = st.subscriptions.get_mut(&session_id) {
|
||||
subs.remove(&cid);
|
||||
}
|
||||
}
|
||||
|
||||
ClientMessage::CloseSession { session_id } => {
|
||||
let mut st = state.lock().await;
|
||||
if let Err(e) = st.sessions.close_session(&session_id) {
|
||||
let _ = out_tx.try_send(DaemonMessage::Error { message: e });
|
||||
} else {
|
||||
st.subscriptions.remove(&session_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads from a session's broadcast channel and fans output to all subscribed
|
||||
/// clients via their individual mpsc queues.
|
||||
async fn output_fanout_task(
|
||||
session_id: String,
|
||||
mut rx: broadcast::Receiver<Vec<u8>>,
|
||||
state: Arc<Mutex<State>>,
|
||||
) {
|
||||
loop {
|
||||
match rx.recv().await {
|
||||
Ok(chunk) => {
|
||||
let encoded = encode_output(&chunk);
|
||||
let msg = DaemonMessage::SessionOutput {
|
||||
session_id: session_id.clone(),
|
||||
data: encoded,
|
||||
};
|
||||
state.lock().await.fanout(&session_id, msg);
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||
log::warn!("session {session_id} fanout lagged, dropped {n} messages");
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => {
|
||||
log::debug!("session {session_id} output channel closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Drains the per-client mpsc queue and writes newline-delimited JSON to the
|
||||
/// socket.
|
||||
async fn write_loop(
|
||||
mut writer: tokio::net::unix::OwnedWriteHalf,
|
||||
mut rx: mpsc::Receiver<DaemonMessage>,
|
||||
) {
|
||||
while let Some(msg) = rx.recv().await {
|
||||
match serde_json::to_string(&msg) {
|
||||
Ok(mut json) => {
|
||||
json.push('\n');
|
||||
if let Err(e) = writer.write_all(json.as_bytes()).await {
|
||||
log::debug!("write error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("serialize error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// One-shot write for pre-auth messages (write_half not yet consumed by the
|
||||
/// write_loop task).
|
||||
async fn send_line(
|
||||
mut writer: tokio::net::unix::OwnedWriteHalf,
|
||||
msg: &DaemonMessage,
|
||||
) -> Result<(), String> {
|
||||
let mut json = serde_json::to_string(msg)
|
||||
.map_err(|e| format!("serialize: {e}"))?;
|
||||
json.push('\n');
|
||||
writer
|
||||
.write_all(json.as_bytes())
|
||||
.await
|
||||
.map_err(|e| format!("write: {e}"))
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue