Was only updating cached dimensions without calling PTY resize. Shell thought terminal was wrong size → double prompts, escape code leaks. - Session stores master PTY handle (Arc<Mutex<Box<dyn MasterPty>>>) - resize() calls master.resize(PtySize) → issues TIOCSWINSZ - Reader task no longer owns master handle (uses cloned reader only)
443 lines
15 KiB
Rust
443 lines
15 KiB
Rust
use std::collections::{HashMap, HashSet};
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
|
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
|
use tokio::net::{UnixListener, UnixStream};
|
|
use tokio::sync::{broadcast, mpsc, Mutex};
|
|
|
|
use crate::auth::AuthToken;
|
|
use crate::protocol::{decode_input, encode_output, ClientMessage, DaemonMessage};
|
|
use crate::session::SessionManager;
|
|
|
|
/// High-water mark for the per-client send queue (in messages, not bytes).
|
|
/// We limit to ~256 KB worth of medium-sized chunks before dropping the client.
|
|
const CLIENT_QUEUE_CAP: usize = 64;
|
|
|
|
/// Shared mutable state accessible from all tasks.
|
|
struct State {
|
|
sessions: SessionManager,
|
|
/// session_id → set of client_ids currently subscribed.
|
|
subscriptions: HashMap<String, HashSet<u64>>,
|
|
/// client_id → channel to push messages back to that client's write task.
|
|
client_txs: HashMap<u64, mpsc::Sender<DaemonMessage>>,
|
|
next_client_id: u64,
|
|
}
|
|
|
|
impl State {
|
|
fn new(default_shell: String) -> Self {
|
|
Self {
|
|
sessions: SessionManager::new(default_shell),
|
|
subscriptions: HashMap::new(),
|
|
client_txs: HashMap::new(),
|
|
next_client_id: 1,
|
|
}
|
|
}
|
|
|
|
fn alloc_client_id(&mut self) -> u64 {
|
|
let id = self.next_client_id;
|
|
self.next_client_id += 1;
|
|
id
|
|
}
|
|
|
|
/// Remove a client from all subscription sets and from the client map.
|
|
fn remove_client(&mut self, cid: u64) {
|
|
self.client_txs.remove(&cid);
|
|
for subs in self.subscriptions.values_mut() {
|
|
subs.remove(&cid);
|
|
}
|
|
}
|
|
|
|
/// Fan-out a message to all subscribers of `session_id`.
|
|
fn fanout(&self, session_id: &str, msg: DaemonMessage) {
|
|
if let Some(subs) = self.subscriptions.get(session_id) {
|
|
for cid in subs {
|
|
if let Some(tx) = self.client_txs.get(cid) {
|
|
// Non-blocking: drop slow clients silently.
|
|
let _ = tx.try_send(msg.clone());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct Daemon {
|
|
socket_path: PathBuf,
|
|
token: AuthToken,
|
|
default_shell: String,
|
|
}
|
|
|
|
impl Daemon {
|
|
pub fn new(socket_path: PathBuf, token: AuthToken, default_shell: String) -> Self {
|
|
Self {
|
|
socket_path,
|
|
token,
|
|
default_shell,
|
|
}
|
|
}
|
|
|
|
/// Run until `shutdown_rx` fires.
|
|
pub async fn run(self, mut shutdown_rx: broadcast::Receiver<()>) -> Result<(), String> {
|
|
// Remove stale socket from a previous run.
|
|
let _ = std::fs::remove_file(&self.socket_path);
|
|
|
|
let listener = UnixListener::bind(&self.socket_path)
|
|
.map_err(|e| format!("bind {:?}: {e}", self.socket_path))?;
|
|
|
|
log::info!("agor-ptyd v0.1.0 listening on {:?}", self.socket_path);
|
|
|
|
let state = Arc::new(Mutex::new(State::new(self.default_shell.clone())));
|
|
let token = Arc::new(self.token);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
accept = listener.accept() => {
|
|
match accept {
|
|
Ok((stream, _addr)) => {
|
|
let state = state.clone();
|
|
let token = token.clone();
|
|
tokio::spawn(handle_client(stream, state, token));
|
|
}
|
|
Err(e) => log::warn!("accept error: {e}"),
|
|
}
|
|
}
|
|
_ = shutdown_rx.recv() => {
|
|
log::info!("shutdown signal received — stopping daemon");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Cleanup socket file.
|
|
let _ = std::fs::remove_file(&self.socket_path);
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
/// Handle a single client connection from handshake to disconnect.
|
|
async fn handle_client(
|
|
stream: UnixStream,
|
|
state: Arc<Mutex<State>>,
|
|
token: Arc<AuthToken>,
|
|
) {
|
|
let (read_half, write_half) = stream.into_split();
|
|
let mut reader = BufReader::new(read_half);
|
|
|
|
// First message must be Auth.
|
|
let mut line = String::new();
|
|
if reader.read_line(&mut line).await.unwrap_or(0) == 0 {
|
|
log::warn!("client disconnected before auth");
|
|
return;
|
|
}
|
|
let auth_msg: ClientMessage = match serde_json::from_str(line.trim()) {
|
|
Ok(m) => m,
|
|
Err(e) => {
|
|
log::warn!("invalid auth message: {e}");
|
|
return;
|
|
}
|
|
};
|
|
let presented_token = match auth_msg {
|
|
ClientMessage::Auth { token: t } => t,
|
|
_ => {
|
|
log::warn!("first message was not Auth — dropping client");
|
|
return;
|
|
}
|
|
};
|
|
if !token.verify(&presented_token) {
|
|
log::warn!("auth failed (token={} redacted)", token.redacted());
|
|
// Send failure then drop the connection.
|
|
let _ = send_line(
|
|
write_half,
|
|
&DaemonMessage::AuthResult { ok: false },
|
|
)
|
|
.await;
|
|
return;
|
|
}
|
|
|
|
// Register client.
|
|
let (out_tx, out_rx) = mpsc::channel::<DaemonMessage>(CLIENT_QUEUE_CAP);
|
|
let cid = {
|
|
let mut st = state.lock().await;
|
|
let cid = st.alloc_client_id();
|
|
st.client_txs.insert(cid, out_tx.clone());
|
|
cid
|
|
};
|
|
log::info!("client {cid} authenticated");
|
|
|
|
// Send auth success.
|
|
if let Err(e) = out_tx.try_send(DaemonMessage::AuthResult { ok: true }) {
|
|
log::warn!("client {cid}: failed to queue AuthResult: {e}");
|
|
state.lock().await.remove_client(cid);
|
|
return;
|
|
}
|
|
|
|
// Spawn a dedicated write task so the reader loop is never blocked by
|
|
// slow writes to the socket.
|
|
let write_task = tokio::spawn(write_loop(write_half, out_rx));
|
|
|
|
// Read loop.
|
|
loop {
|
|
let mut line = String::new();
|
|
match reader.read_line(&mut line).await {
|
|
Ok(0) => break, // EOF
|
|
Ok(_) => {}
|
|
Err(e) => {
|
|
log::debug!("client {cid} read error: {e}");
|
|
break;
|
|
}
|
|
}
|
|
let msg: ClientMessage = match serde_json::from_str(line.trim()) {
|
|
Ok(m) => m,
|
|
Err(e) => {
|
|
log::warn!("client {cid} bad message: {e}");
|
|
let _ = out_tx
|
|
.try_send(DaemonMessage::Error {
|
|
message: format!("parse error: {e}"),
|
|
});
|
|
continue;
|
|
}
|
|
};
|
|
|
|
handle_message(cid, msg, &state, &out_tx).await;
|
|
}
|
|
|
|
// Cleanup on disconnect.
|
|
log::info!("client {cid} disconnected");
|
|
state.lock().await.remove_client(cid);
|
|
write_task.abort();
|
|
}
|
|
|
|
/// Dispatch a single client message to the appropriate handler.
|
|
async fn handle_message(
|
|
cid: u64,
|
|
msg: ClientMessage,
|
|
state: &Arc<Mutex<State>>,
|
|
out_tx: &mpsc::Sender<DaemonMessage>,
|
|
) {
|
|
match msg {
|
|
ClientMessage::Auth { .. } => {
|
|
// Already authenticated — ignore duplicate.
|
|
}
|
|
|
|
ClientMessage::Ping => {
|
|
let _ = out_tx.try_send(DaemonMessage::Pong);
|
|
}
|
|
|
|
ClientMessage::ListSessions => {
|
|
let list = state.lock().await.sessions.list();
|
|
let _ = out_tx.try_send(DaemonMessage::SessionList { sessions: list });
|
|
}
|
|
|
|
ClientMessage::CreateSession { id, shell, cwd, env, cols, rows } => {
|
|
let state_clone = state.clone();
|
|
let out_tx_clone = out_tx.clone();
|
|
let id_clone = id.clone();
|
|
|
|
let result = {
|
|
let mut st = state.lock().await;
|
|
st.sessions.create_session(
|
|
id.clone(),
|
|
shell,
|
|
cwd,
|
|
env,
|
|
cols,
|
|
rows,
|
|
move |sid, code| {
|
|
// Invoked from the blocking reader task when child exits.
|
|
let state_clone = state_clone.clone();
|
|
let _ = &out_tx_clone; // captured for lifetime, not used
|
|
tokio::spawn(async move {
|
|
let st = state_clone.lock().await;
|
|
st.fanout(
|
|
&sid,
|
|
DaemonMessage::SessionClosed {
|
|
session_id: sid.clone(),
|
|
exit_code: code,
|
|
},
|
|
);
|
|
drop(st);
|
|
});
|
|
},
|
|
)
|
|
};
|
|
|
|
match result {
|
|
Ok((pid, output_rx)) => {
|
|
let _ = out_tx.try_send(DaemonMessage::SessionCreated {
|
|
session_id: id_clone.clone(),
|
|
pid,
|
|
});
|
|
// Immediately subscribe the creating client.
|
|
{
|
|
let mut st = state.lock().await;
|
|
st.subscriptions
|
|
.entry(id_clone.clone())
|
|
.or_default()
|
|
.insert(cid);
|
|
}
|
|
// Start a fanout task for this session's output.
|
|
let state_clone = state.clone();
|
|
tokio::spawn(output_fanout_task(id_clone, output_rx, state_clone));
|
|
}
|
|
Err(e) => {
|
|
let _ = out_tx.try_send(DaemonMessage::Error { message: e });
|
|
}
|
|
}
|
|
}
|
|
|
|
ClientMessage::WriteInput { session_id, data } => {
|
|
let bytes = match decode_input(&data) {
|
|
Ok(b) => b,
|
|
Err(e) => {
|
|
let _ = out_tx.try_send(DaemonMessage::Error {
|
|
message: format!("bad input encoding: {e}"),
|
|
});
|
|
return;
|
|
}
|
|
};
|
|
let st = state.lock().await;
|
|
match st.sessions.get(&session_id) {
|
|
Some(sess) => {
|
|
if let Err(e) = sess.write_input(&bytes).await {
|
|
let _ = out_tx.try_send(DaemonMessage::Error { message: e });
|
|
}
|
|
}
|
|
None => {
|
|
let _ = out_tx.try_send(DaemonMessage::Error {
|
|
message: format!("session {session_id} not found"),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
ClientMessage::Resize { session_id, cols, rows } => {
|
|
let mut st = state.lock().await;
|
|
match st.sessions.get_mut(&session_id) {
|
|
Some(sess) => {
|
|
if let Err(e) = sess.resize(cols, rows).await {
|
|
log::warn!("resize {session_id}: {e}");
|
|
}
|
|
}
|
|
None => {
|
|
let _ = out_tx.try_send(DaemonMessage::Error {
|
|
message: format!("session {session_id} not found"),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
ClientMessage::Subscribe { session_id } => {
|
|
let (exists, rx) = {
|
|
let st = state.lock().await;
|
|
let exists = st.sessions.get(&session_id).is_some();
|
|
let rx = st
|
|
.sessions
|
|
.get(&session_id)
|
|
.map(|s| s.subscribe());
|
|
(exists, rx)
|
|
};
|
|
if !exists {
|
|
let _ = out_tx.try_send(DaemonMessage::Error {
|
|
message: format!("session {session_id} not found"),
|
|
});
|
|
return;
|
|
}
|
|
{
|
|
let mut st = state.lock().await;
|
|
st.subscriptions
|
|
.entry(session_id.clone())
|
|
.or_default()
|
|
.insert(cid);
|
|
}
|
|
// If a new rx came back, start a fanout task (handles reconnect case
|
|
// where the original fanout task has gone away after all receivers
|
|
// dropped). We always start one; duplicates are harmless since the
|
|
// broadcast channel keeps all messages.
|
|
if let Some(rx) = rx {
|
|
let state_clone = state.clone();
|
|
tokio::spawn(output_fanout_task(session_id, rx, state_clone));
|
|
}
|
|
}
|
|
|
|
ClientMessage::Unsubscribe { session_id } => {
|
|
let mut st = state.lock().await;
|
|
if let Some(subs) = st.subscriptions.get_mut(&session_id) {
|
|
subs.remove(&cid);
|
|
}
|
|
}
|
|
|
|
ClientMessage::CloseSession { session_id } => {
|
|
let mut st = state.lock().await;
|
|
if let Err(e) = st.sessions.close_session(&session_id) {
|
|
let _ = out_tx.try_send(DaemonMessage::Error { message: e });
|
|
} else {
|
|
st.subscriptions.remove(&session_id);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Reads from a session's broadcast channel and fans output to all subscribed
|
|
/// clients via their individual mpsc queues.
|
|
async fn output_fanout_task(
|
|
session_id: String,
|
|
mut rx: broadcast::Receiver<Vec<u8>>,
|
|
state: Arc<Mutex<State>>,
|
|
) {
|
|
loop {
|
|
match rx.recv().await {
|
|
Ok(chunk) => {
|
|
let encoded = encode_output(&chunk);
|
|
let msg = DaemonMessage::SessionOutput {
|
|
session_id: session_id.clone(),
|
|
data: encoded,
|
|
};
|
|
state.lock().await.fanout(&session_id, msg);
|
|
}
|
|
Err(broadcast::error::RecvError::Lagged(n)) => {
|
|
log::warn!("session {session_id} fanout lagged, dropped {n} messages");
|
|
}
|
|
Err(broadcast::error::RecvError::Closed) => {
|
|
log::debug!("session {session_id} output channel closed");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Drains the per-client mpsc queue and writes newline-delimited JSON to the
|
|
/// socket.
|
|
async fn write_loop(
|
|
mut writer: tokio::net::unix::OwnedWriteHalf,
|
|
mut rx: mpsc::Receiver<DaemonMessage>,
|
|
) {
|
|
while let Some(msg) = rx.recv().await {
|
|
match serde_json::to_string(&msg) {
|
|
Ok(mut json) => {
|
|
json.push('\n');
|
|
if let Err(e) = writer.write_all(json.as_bytes()).await {
|
|
log::debug!("write error: {e}");
|
|
break;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
log::warn!("serialize error: {e}");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// One-shot write for pre-auth messages (write_half not yet consumed by the
|
|
/// write_loop task).
|
|
async fn send_line(
|
|
mut writer: tokio::net::unix::OwnedWriteHalf,
|
|
msg: &DaemonMessage,
|
|
) -> Result<(), String> {
|
|
let mut json = serde_json::to_string(msg)
|
|
.map_err(|e| format!("serialize: {e}"))?;
|
|
json.push('\n');
|
|
writer
|
|
.write_all(json.as_bytes())
|
|
.await
|
|
.map_err(|e| format!("write: {e}"))
|
|
}
|