From 191b869b431f0bf06e5e7c0361add078ee245dc2 Mon Sep 17 00:00:00 2001 From: Hibryda Date: Tue, 17 Mar 2026 03:27:40 +0100 Subject: [PATCH] feat(pro): implement all 3 commercial phases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 1 — Cost Intelligence: - budget.rs: per-project token budgets, soft/hard limits, usage logging - router.rs: 3 preset profiles (CostSaver/QualityFirst/Balanced) Phase 2 — Knowledge Base: - memory.rs: persistent agent memory with FTS5, auto-extraction, TTL - symbols.rs: regex-based symbol graph (tree-sitter stub) Phase 3 — Git Integration: - git_context.rs: branch/commit/modified file context injection - branch_policy.rs: session-level branch protection 6 modules, 32 cargo tests, 22+ Tauri plugin commands. --- agor-pro/src/branch_policy.rs | 208 +++++++++++++++++++++ agor-pro/src/budget.rs | 235 ++++++++++++++++++++++++ agor-pro/src/git_context.rs | 209 +++++++++++++++++++++ agor-pro/src/lib.rs | 34 ++++ agor-pro/src/memory.rs | 334 ++++++++++++++++++++++++++++++++++ agor-pro/src/router.rs | 194 ++++++++++++++++++++ agor-pro/src/symbols.rs | 295 ++++++++++++++++++++++++++++++ 7 files changed, 1509 insertions(+) create mode 100644 agor-pro/src/branch_policy.rs create mode 100644 agor-pro/src/budget.rs create mode 100644 agor-pro/src/git_context.rs create mode 100644 agor-pro/src/memory.rs create mode 100644 agor-pro/src/router.rs create mode 100644 agor-pro/src/symbols.rs diff --git a/agor-pro/src/branch_policy.rs b/agor-pro/src/branch_policy.rs new file mode 100644 index 0000000..5e46cdd --- /dev/null +++ b/agor-pro/src/branch_policy.rs @@ -0,0 +1,208 @@ +// SPDX-License-Identifier: LicenseRef-Commercial +// Branch Policy Enforcement — block agent sessions on protected branches. + +use rusqlite::params; +use serde::Serialize; +use std::process::Command; + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct BranchPolicy { + pub id: i64, + pub pattern: String, + pub action: String, + pub reason: String, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PolicyDecision { + pub allowed: bool, + pub branch: String, + pub matched_policy: Option, + pub reason: String, +} + +fn ensure_tables(conn: &rusqlite::Connection) -> Result<(), String> { + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS pro_branch_policies ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + pattern TEXT NOT NULL, + action TEXT NOT NULL DEFAULT 'block', + reason TEXT NOT NULL DEFAULT '' + );" + ).map_err(|e| format!("Failed to create branch_policies table: {e}"))?; + + // Seed default policies if table is empty + let count: i64 = conn.query_row( + "SELECT COUNT(*) FROM pro_branch_policies", [], |row| row.get(0) + ).unwrap_or(0); + + if count == 0 { + conn.execute_batch( + "INSERT INTO pro_branch_policies (pattern, action, reason) VALUES + ('main', 'block', 'Protected branch: direct work on main is not allowed'), + ('master', 'block', 'Protected branch: direct work on master is not allowed'), + ('release/*', 'block', 'Protected branch: release branches require PRs');" + ).map_err(|e| format!("Failed to seed default policies: {e}"))?; + } + + Ok(()) +} + +/// Simple glob matching: supports `*` at the end of a pattern (e.g., `release/*`). +fn glob_match(pattern: &str, value: &str) -> bool { + if pattern == value { + return true; + } + if let Some(prefix) = pattern.strip_suffix('*') { + return value.starts_with(prefix); + } + false +} + +fn get_current_branch(project_path: &str) -> Result { + let output = Command::new("git") + .args(["-C", project_path, "branch", "--show-current"]) + .output() + .map_err(|e| format!("Failed to run git: {e}"))?; + + if output.status.success() { + Ok(String::from_utf8_lossy(&output.stdout).trim().to_string()) + } else { + Err("Not a git repository or git not available".into()) + } +} + +#[tauri::command] +pub fn pro_branch_check(project_path: String) -> Result { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + + let branch = get_current_branch(&project_path)?; + + let mut stmt = conn.prepare( + "SELECT id, pattern, action, reason FROM pro_branch_policies" + ).map_err(|e| format!("Query failed: {e}"))?; + + let policies: Vec = stmt.query_map([], |row| { + Ok(BranchPolicy { + id: row.get(0)?, + pattern: row.get(1)?, + action: row.get(2)?, + reason: row.get(3)?, + }) + }).map_err(|e| format!("Query failed: {e}"))? + .collect::, _>>() + .map_err(|e| format!("Row read failed: {e}"))?; + + for policy in &policies { + if glob_match(&policy.pattern, &branch) { + let allowed = policy.action != "block"; + return Ok(PolicyDecision { + allowed, + branch: branch.clone(), + matched_policy: Some(policy.clone()), + reason: policy.reason.clone(), + }); + } + } + + Ok(PolicyDecision { + allowed: true, + branch, + matched_policy: None, + reason: "No matching policy".into(), + }) +} + +#[tauri::command] +pub fn pro_branch_policy_list() -> Result, String> { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + + let mut stmt = conn.prepare( + "SELECT id, pattern, action, reason FROM pro_branch_policies ORDER BY id" + ).map_err(|e| format!("Query failed: {e}"))?; + + let rows = stmt.query_map([], |row| { + Ok(BranchPolicy { + id: row.get(0)?, + pattern: row.get(1)?, + action: row.get(2)?, + reason: row.get(3)?, + }) + }).map_err(|e| format!("Query failed: {e}"))? + .collect::, _>>() + .map_err(|e| format!("Row read failed: {e}"))?; + + Ok(rows) +} + +#[tauri::command] +pub fn pro_branch_policy_add(pattern: String, action: Option, reason: Option) -> Result { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + let act = action.unwrap_or_else(|| "block".into()); + let rsn = reason.unwrap_or_default(); + conn.execute( + "INSERT INTO pro_branch_policies (pattern, action, reason) VALUES (?1, ?2, ?3)", + params![pattern, act, rsn], + ).map_err(|e| format!("Failed to add policy: {e}"))?; + Ok(conn.last_insert_rowid()) +} + +#[tauri::command] +pub fn pro_branch_policy_remove(id: i64) -> Result<(), String> { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + conn.execute("DELETE FROM pro_branch_policies WHERE id = ?1", params![id]) + .map_err(|e| format!("Failed to remove policy: {e}"))?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_policy_decision_serializes_camel_case() { + let d = PolicyDecision { + allowed: false, + branch: "main".into(), + matched_policy: Some(BranchPolicy { + id: 1, + pattern: "main".into(), + action: "block".into(), + reason: "Protected".into(), + }), + reason: "Protected".into(), + }; + let json = serde_json::to_string(&d).unwrap(); + assert!(json.contains("matchedPolicy")); + assert!(json.contains("\"allowed\":false")); + } + + #[test] + fn test_branch_policy_serializes_camel_case() { + let p = BranchPolicy { + id: 1, + pattern: "release/*".into(), + action: "block".into(), + reason: "No direct commits".into(), + }; + let json = serde_json::to_string(&p).unwrap(); + assert!(json.contains("\"pattern\":\"release/*\"")); + assert!(json.contains("\"action\":\"block\"")); + } + + #[test] + fn test_glob_match() { + assert!(glob_match("main", "main")); + assert!(!glob_match("main", "main2")); + assert!(glob_match("release/*", "release/v1.0")); + assert!(glob_match("release/*", "release/hotfix")); + assert!(!glob_match("release/*", "feature/test")); + assert!(!glob_match("master", "main")); + } +} diff --git a/agor-pro/src/budget.rs b/agor-pro/src/budget.rs new file mode 100644 index 0000000..2466f3a --- /dev/null +++ b/agor-pro/src/budget.rs @@ -0,0 +1,235 @@ +// SPDX-License-Identifier: LicenseRef-Commercial +// Budget Governor — per-project monthly token budgets with soft/hard limits. + +use rusqlite::params; +use serde::Serialize; + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct BudgetStatus { + pub project_id: String, + pub limit: i64, + pub used: i64, + pub remaining: i64, + pub percent: f64, + pub reset_date: i64, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct BudgetDecision { + pub allowed: bool, + pub reason: String, + pub remaining: i64, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct BudgetEntry { + pub project_id: String, + pub monthly_limit_tokens: i64, + pub used_tokens: i64, + pub reset_date: i64, +} + +fn ensure_tables(conn: &rusqlite::Connection) -> Result<(), String> { + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS pro_budgets ( + project_id TEXT PRIMARY KEY, + monthly_limit_tokens INTEGER NOT NULL, + used_tokens INTEGER NOT NULL DEFAULT 0, + reset_date INTEGER NOT NULL + ); + CREATE TABLE IF NOT EXISTS pro_budget_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id TEXT NOT NULL, + session_id TEXT NOT NULL, + tokens_used INTEGER NOT NULL, + timestamp INTEGER NOT NULL + );" + ).map_err(|e| format!("Failed to create budget tables: {e}")) +} + +fn now_epoch() -> i64 { + super::analytics::now_epoch() +} + +/// Calculate reset date: first day of next month as epoch. +fn next_month_epoch() -> i64 { + let now = now_epoch(); + // Approximate: 30 days from now + now + 30 * 86400 +} + +#[tauri::command] +pub fn pro_budget_set(project_id: String, monthly_limit_tokens: i64) -> Result<(), String> { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + let reset = next_month_epoch(); + conn.execute( + "INSERT INTO pro_budgets (project_id, monthly_limit_tokens, used_tokens, reset_date) + VALUES (?1, ?2, 0, ?3) + ON CONFLICT(project_id) DO UPDATE SET monthly_limit_tokens = ?2", + params![project_id, monthly_limit_tokens, reset], + ).map_err(|e| format!("Failed to set budget: {e}"))?; + Ok(()) +} + +#[tauri::command] +pub fn pro_budget_get(project_id: String) -> Result { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + auto_reset_if_expired(&conn, &project_id)?; + + let mut stmt = conn.prepare( + "SELECT monthly_limit_tokens, used_tokens, reset_date FROM pro_budgets WHERE project_id = ?1" + ).map_err(|e| format!("Query failed: {e}"))?; + + stmt.query_row(params![project_id], |row| { + let limit: i64 = row.get(0)?; + let used: i64 = row.get(1)?; + let reset_date: i64 = row.get(2)?; + let remaining = (limit - used).max(0); + let percent = if limit > 0 { (used as f64 / limit as f64) * 100.0 } else { 0.0 }; + Ok(BudgetStatus { project_id: project_id.clone(), limit, used, remaining, percent, reset_date }) + }).map_err(|e| format!("Budget not found for project '{}': {e}", project_id)) +} + +#[tauri::command] +pub fn pro_budget_check(project_id: String, estimated_tokens: i64) -> Result { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + auto_reset_if_expired(&conn, &project_id)?; + + let result = conn.prepare( + "SELECT monthly_limit_tokens, used_tokens FROM pro_budgets WHERE project_id = ?1" + ).map_err(|e| format!("Query failed: {e}"))? + .query_row(params![project_id], |row| { + Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)) + }); + + match result { + Ok((limit, used)) => { + let remaining = (limit - used).max(0); + if used + estimated_tokens > limit { + Ok(BudgetDecision { + allowed: false, + reason: format!("Would exceed budget: {} remaining, {} requested", remaining, estimated_tokens), + remaining, + }) + } else { + Ok(BudgetDecision { allowed: true, reason: "Within budget".into(), remaining }) + } + } + Err(_) => { + // No budget set — allow by default + Ok(BudgetDecision { allowed: true, reason: "No budget configured".into(), remaining: i64::MAX }) + } + } +} + +#[tauri::command] +pub fn pro_budget_log_usage(project_id: String, session_id: String, tokens_used: i64) -> Result<(), String> { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + let ts = now_epoch(); + conn.execute( + "INSERT INTO pro_budget_log (project_id, session_id, tokens_used, timestamp) VALUES (?1, ?2, ?3, ?4)", + params![project_id, session_id, tokens_used, ts], + ).map_err(|e| format!("Failed to log usage: {e}"))?; + conn.execute( + "UPDATE pro_budgets SET used_tokens = used_tokens + ?2 WHERE project_id = ?1", + params![project_id, tokens_used], + ).map_err(|e| format!("Failed to update used tokens: {e}"))?; + Ok(()) +} + +#[tauri::command] +pub fn pro_budget_reset(project_id: String) -> Result<(), String> { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + let reset = next_month_epoch(); + conn.execute( + "UPDATE pro_budgets SET used_tokens = 0, reset_date = ?2 WHERE project_id = ?1", + params![project_id, reset], + ).map_err(|e| format!("Failed to reset budget: {e}"))?; + Ok(()) +} + +#[tauri::command] +pub fn pro_budget_list() -> Result, String> { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + let mut stmt = conn.prepare( + "SELECT project_id, monthly_limit_tokens, used_tokens, reset_date FROM pro_budgets ORDER BY project_id" + ).map_err(|e| format!("Query failed: {e}"))?; + + let rows = stmt.query_map([], |row| { + Ok(BudgetEntry { + project_id: row.get(0)?, + monthly_limit_tokens: row.get(1)?, + used_tokens: row.get(2)?, + reset_date: row.get(3)?, + }) + }).map_err(|e| format!("Query failed: {e}"))? + .collect::, _>>() + .map_err(|e| format!("Row read failed: {e}"))?; + + Ok(rows) +} + +fn auto_reset_if_expired(conn: &rusqlite::Connection, project_id: &str) -> Result<(), String> { + let now = now_epoch(); + conn.execute( + "UPDATE pro_budgets SET used_tokens = 0, reset_date = ?3 + WHERE project_id = ?1 AND reset_date < ?2", + params![project_id, now, now + 30 * 86400], + ).map_err(|e| format!("Auto-reset failed: {e}"))?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_budget_status_serializes_camel_case() { + let s = BudgetStatus { + project_id: "proj1".into(), + limit: 100_000, + used: 25_000, + remaining: 75_000, + percent: 25.0, + reset_date: 1710000000, + }; + let json = serde_json::to_string(&s).unwrap(); + assert!(json.contains("projectId")); + assert!(json.contains("resetDate")); + assert!(json.contains("\"remaining\":75000")); + } + + #[test] + fn test_budget_decision_serializes_camel_case() { + let d = BudgetDecision { + allowed: true, + reason: "Within budget".into(), + remaining: 50_000, + }; + let json = serde_json::to_string(&d).unwrap(); + assert!(json.contains("\"allowed\":true")); + assert!(json.contains("\"remaining\":50000")); + } + + #[test] + fn test_budget_entry_serializes_camel_case() { + let e = BudgetEntry { + project_id: "p".into(), + monthly_limit_tokens: 200_000, + used_tokens: 10_000, + reset_date: 1710000000, + }; + let json = serde_json::to_string(&e).unwrap(); + assert!(json.contains("monthlyLimitTokens")); + assert!(json.contains("usedTokens")); + } +} diff --git a/agor-pro/src/git_context.rs b/agor-pro/src/git_context.rs new file mode 100644 index 0000000..e89de86 --- /dev/null +++ b/agor-pro/src/git_context.rs @@ -0,0 +1,209 @@ +// SPDX-License-Identifier: LicenseRef-Commercial +// Git Context Injection — lightweight git CLI wrapper for agent session context. +// Full git2/libgit2 implementation deferred until git2 dep is added. + +use serde::Serialize; +use std::process::Command; + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct GitContext { + pub branch: String, + pub last_commits: Vec, + pub modified_files: Vec, + pub has_unstaged: bool, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CommitSummary { + pub hash: String, + pub message: String, + pub author: String, + pub timestamp: i64, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct BranchInfo { + pub name: String, + pub is_protected: bool, + pub upstream: Option, + pub ahead: i64, + pub behind: i64, +} + +fn git_cmd(project_path: &str, args: &[&str]) -> Result { + let output = Command::new("git") + .args(["-C", project_path]) + .args(args) + .output() + .map_err(|e| format!("Failed to run git: {e}"))?; + + if output.status.success() { + Ok(String::from_utf8_lossy(&output.stdout).trim().to_string()) + } else { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + Err(format!("git error: {stderr}")) + } +} + +fn parse_log_line(line: &str) -> Option { + // Format: hash|author|timestamp|message + let parts: Vec<&str> = line.splitn(4, '|').collect(); + if parts.len() < 4 { return None; } + Some(CommitSummary { + hash: parts[0].to_string(), + author: parts[1].to_string(), + timestamp: parts[2].parse().unwrap_or(0), + message: parts[3].to_string(), + }) +} + +#[tauri::command] +pub fn pro_git_context(project_path: String) -> Result { + let branch = git_cmd(&project_path, &["branch", "--show-current"]) + .unwrap_or_else(|_| "unknown".into()); + + let log_output = git_cmd( + &project_path, + &["log", "--format=%H|%an|%at|%s", "-10"], + ).unwrap_or_default(); + + let last_commits: Vec = log_output + .lines() + .filter_map(parse_log_line) + .collect(); + + let status_output = git_cmd(&project_path, &["status", "--porcelain"]) + .unwrap_or_default(); + + let modified_files: Vec = status_output + .lines() + .filter(|l| !l.is_empty()) + .map(|l| { + // Format: XY filename (first 3 chars are status + space) + if l.len() > 3 { l[3..].to_string() } else { l.to_string() } + }) + .collect(); + + let has_unstaged = status_output.lines().any(|l| { + l.len() >= 2 && !l[1..2].eq(" ") && !l[1..2].eq("?") + }); + + Ok(GitContext { branch, last_commits, modified_files, has_unstaged }) +} + +#[tauri::command] +pub fn pro_git_inject(project_path: String, max_tokens: Option) -> Result { + let ctx = pro_git_context(project_path)?; + let max_chars = (max_tokens.unwrap_or(1000) * 3) as usize; + + let mut md = String::new(); + md.push_str(&format!("## Git Context\n\n**Branch:** {}\n\n", ctx.branch)); + + if !ctx.last_commits.is_empty() { + md.push_str("**Recent commits:**\n"); + for c in &ctx.last_commits { + let short_hash = if c.hash.len() >= 7 { &c.hash[..7] } else { &c.hash }; + let line = format!("- {} {}\n", short_hash, c.message); + if md.len() + line.len() > max_chars { break; } + md.push_str(&line); + } + md.push('\n'); + } + + if !ctx.modified_files.is_empty() { + md.push_str("**Modified files:**\n"); + for f in &ctx.modified_files { + let line = format!("- {f}\n"); + if md.len() + line.len() > max_chars { break; } + md.push_str(&line); + } + } + + Ok(md) +} + +#[tauri::command] +pub fn pro_git_branch_info(project_path: String) -> Result { + let name = git_cmd(&project_path, &["branch", "--show-current"]) + .unwrap_or_else(|_| "unknown".into()); + + let upstream = git_cmd( + &project_path, + &["rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"], + ).ok(); + + let (ahead, behind) = if upstream.is_some() { + let counts = git_cmd( + &project_path, + &["rev-list", "--left-right", "--count", "HEAD...@{u}"], + ).unwrap_or_else(|_| "0\t0".into()); + let parts: Vec<&str> = counts.split('\t').collect(); + let a = parts.first().and_then(|s| s.parse().ok()).unwrap_or(0); + let b = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0); + (a, b) + } else { + (0, 0) + }; + + let is_protected = matches!(name.as_str(), "main" | "master") + || name.starts_with("release/"); + + Ok(BranchInfo { name, is_protected, upstream, ahead, behind }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_git_context_serializes_camel_case() { + let ctx = GitContext { + branch: "main".into(), + last_commits: vec![], + modified_files: vec!["src/lib.rs".into()], + has_unstaged: true, + }; + let json = serde_json::to_string(&ctx).unwrap(); + assert!(json.contains("lastCommits")); + assert!(json.contains("modifiedFiles")); + assert!(json.contains("hasUnstaged")); + } + + #[test] + fn test_commit_summary_serializes_camel_case() { + let c = CommitSummary { + hash: "abc1234".into(), + message: "feat: add router".into(), + author: "dev".into(), + timestamp: 1710000000, + }; + let json = serde_json::to_string(&c).unwrap(); + assert!(json.contains("\"hash\":\"abc1234\"")); + assert!(json.contains("\"timestamp\":1710000000")); + } + + #[test] + fn test_branch_info_serializes_camel_case() { + let b = BranchInfo { + name: "feature/test".into(), + is_protected: false, + upstream: Some("origin/feature/test".into()), + ahead: 2, + behind: 0, + }; + let json = serde_json::to_string(&b).unwrap(); + assert!(json.contains("isProtected")); + } + + #[test] + fn test_parse_log_line() { + let line = "abc123|Author Name|1710000000|feat: test commit"; + let c = parse_log_line(line).unwrap(); + assert_eq!(c.hash, "abc123"); + assert_eq!(c.author, "Author Name"); + assert_eq!(c.message, "feat: test commit"); + } +} diff --git a/agor-pro/src/lib.rs b/agor-pro/src/lib.rs index c5fe4a0..9a78aec 100644 --- a/agor-pro/src/lib.rs +++ b/agor-pro/src/lib.rs @@ -5,9 +5,15 @@ // agents-orchestrator/agents-orchestrator private repository. mod analytics; +mod branch_policy; +mod budget; mod export; +mod git_context; mod marketplace; +mod memory; mod profiles; +mod router; +mod symbols; use tauri::{ plugin::{Builder, TauriPlugin}, @@ -32,6 +38,34 @@ pub fn init() -> TauriPlugin { marketplace::pro_marketplace_uninstall, marketplace::pro_marketplace_check_updates, marketplace::pro_marketplace_update, + budget::pro_budget_set, + budget::pro_budget_get, + budget::pro_budget_check, + budget::pro_budget_log_usage, + budget::pro_budget_reset, + budget::pro_budget_list, + router::pro_router_recommend, + router::pro_router_set_profile, + router::pro_router_get_profile, + router::pro_router_list_profiles, + memory::pro_memory_add, + memory::pro_memory_list, + memory::pro_memory_search, + memory::pro_memory_update, + memory::pro_memory_delete, + memory::pro_memory_inject, + memory::pro_memory_extract_from_session, + symbols::pro_symbols_scan, + symbols::pro_symbols_search, + symbols::pro_symbols_find_callers, + symbols::pro_symbols_status, + git_context::pro_git_context, + git_context::pro_git_inject, + git_context::pro_git_branch_info, + branch_policy::pro_branch_check, + branch_policy::pro_branch_policy_list, + branch_policy::pro_branch_policy_add, + branch_policy::pro_branch_policy_remove, ]) .build() } diff --git a/agor-pro/src/memory.rs b/agor-pro/src/memory.rs new file mode 100644 index 0000000..cda83a5 --- /dev/null +++ b/agor-pro/src/memory.rs @@ -0,0 +1,334 @@ +// SPDX-License-Identifier: LicenseRef-Commercial +// Persistent Agent Memory — project-scoped structured fragments that survive sessions. + +use rusqlite::params; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct MemoryFragment { + pub id: i64, + pub project_id: String, + pub content: String, + pub source: String, + pub trust: String, + pub confidence: f64, + pub created_at: i64, + pub ttl_days: i64, + pub tags: String, +} + +fn ensure_tables(conn: &rusqlite::Connection) -> Result<(), String> { + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS pro_memories ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id TEXT NOT NULL, + content TEXT NOT NULL, + source TEXT NOT NULL DEFAULT '', + trust TEXT NOT NULL DEFAULT 'agent', + confidence REAL NOT NULL DEFAULT 1.0, + created_at INTEGER NOT NULL, + ttl_days INTEGER NOT NULL DEFAULT 90, + tags TEXT NOT NULL DEFAULT '' + ); + CREATE VIRTUAL TABLE IF NOT EXISTS pro_memories_fts USING fts5( + content, tags, content=pro_memories, content_rowid=id + ); + CREATE TRIGGER IF NOT EXISTS pro_memories_ai AFTER INSERT ON pro_memories BEGIN + INSERT INTO pro_memories_fts(rowid, content, tags) VALUES (new.id, new.content, new.tags); + END; + CREATE TRIGGER IF NOT EXISTS pro_memories_ad AFTER DELETE ON pro_memories BEGIN + INSERT INTO pro_memories_fts(pro_memories_fts, rowid, content, tags) + VALUES ('delete', old.id, old.content, old.tags); + END; + CREATE TRIGGER IF NOT EXISTS pro_memories_au AFTER UPDATE ON pro_memories BEGIN + INSERT INTO pro_memories_fts(pro_memories_fts, rowid, content, tags) + VALUES ('delete', old.id, old.content, old.tags); + INSERT INTO pro_memories_fts(rowid, content, tags) VALUES (new.id, new.content, new.tags); + END;" + ).map_err(|e| format!("Failed to create memory tables: {e}")) +} + +fn now_epoch() -> i64 { + super::analytics::now_epoch() +} + +fn prune_expired(conn: &rusqlite::Connection) -> Result<(), String> { + let now = now_epoch(); + conn.execute( + "DELETE FROM pro_memories WHERE created_at + (ttl_days * 86400) < ?1", + params![now], + ).map_err(|e| format!("Prune failed: {e}"))?; + Ok(()) +} + +fn row_to_fragment(row: &rusqlite::Row) -> rusqlite::Result { + Ok(MemoryFragment { + id: row.get(0)?, + project_id: row.get(1)?, + content: row.get(2)?, + source: row.get(3)?, + trust: row.get(4)?, + confidence: row.get(5)?, + created_at: row.get(6)?, + ttl_days: row.get(7)?, + tags: row.get(8)?, + }) +} + +#[tauri::command] +pub fn pro_memory_add( + project_id: String, + content: String, + source: Option, + tags: Option, +) -> Result { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + let ts = now_epoch(); + let src = source.unwrap_or_default(); + let tgs = tags.unwrap_or_default(); + conn.execute( + "INSERT INTO pro_memories (project_id, content, source, created_at, tags) VALUES (?1, ?2, ?3, ?4, ?5)", + params![project_id, content, src, ts, tgs], + ).map_err(|e| format!("Failed to add memory: {e}"))?; + Ok(conn.last_insert_rowid()) +} + +#[tauri::command] +pub fn pro_memory_list(project_id: String, limit: Option) -> Result, String> { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + prune_expired(&conn)?; + let lim = limit.unwrap_or(50); + + let mut stmt = conn.prepare( + "SELECT id, project_id, content, source, trust, confidence, created_at, ttl_days, tags + FROM pro_memories WHERE project_id = ?1 ORDER BY created_at DESC LIMIT ?2" + ).map_err(|e| format!("Query failed: {e}"))?; + + let rows = stmt.query_map(params![project_id, lim], row_to_fragment) + .map_err(|e| format!("Query failed: {e}"))? + .collect::, _>>() + .map_err(|e| format!("Row read failed: {e}"))?; + + Ok(rows) +} + +#[tauri::command] +pub fn pro_memory_search(project_id: String, query: String) -> Result, String> { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + prune_expired(&conn)?; + + let mut stmt = conn.prepare( + "SELECT m.id, m.project_id, m.content, m.source, m.trust, m.confidence, m.created_at, m.ttl_days, m.tags + FROM pro_memories m + JOIN pro_memories_fts f ON m.id = f.rowid + WHERE f.pro_memories_fts MATCH ?1 AND m.project_id = ?2 + ORDER BY rank LIMIT 20" + ).map_err(|e| format!("Search query failed: {e}"))?; + + let rows = stmt.query_map(params![query, project_id], row_to_fragment) + .map_err(|e| format!("Search failed: {e}"))? + .collect::, _>>() + .map_err(|e| format!("Row read failed: {e}"))?; + + Ok(rows) +} + +#[tauri::command] +pub fn pro_memory_update( + id: i64, + content: Option, + trust: Option, + confidence: Option, +) -> Result<(), String> { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + + if let Some(c) = content { + conn.execute("UPDATE pro_memories SET content = ?2 WHERE id = ?1", params![id, c]) + .map_err(|e| format!("Update content failed: {e}"))?; + } + if let Some(t) = trust { + conn.execute("UPDATE pro_memories SET trust = ?2 WHERE id = ?1", params![id, t]) + .map_err(|e| format!("Update trust failed: {e}"))?; + } + if let Some(c) = confidence { + conn.execute("UPDATE pro_memories SET confidence = ?2 WHERE id = ?1", params![id, c]) + .map_err(|e| format!("Update confidence failed: {e}"))?; + } + Ok(()) +} + +#[tauri::command] +pub fn pro_memory_delete(id: i64) -> Result<(), String> { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + conn.execute("DELETE FROM pro_memories WHERE id = ?1", params![id]) + .map_err(|e| format!("Delete failed: {e}"))?; + Ok(()) +} + +#[tauri::command] +pub fn pro_memory_inject(project_id: String, max_tokens: Option) -> Result { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + prune_expired(&conn)?; + + let max_chars = (max_tokens.unwrap_or(2000) * 3) as usize; // ~3 chars per token heuristic + + let mut stmt = conn.prepare( + "SELECT content, trust, confidence FROM pro_memories + WHERE project_id = ?1 ORDER BY confidence DESC, created_at DESC" + ).map_err(|e| format!("Query failed: {e}"))?; + + let entries: Vec<(String, String, f64)> = stmt + .query_map(params![project_id], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?))) + .map_err(|e| format!("Query failed: {e}"))? + .collect::, _>>() + .map_err(|e| format!("Row read failed: {e}"))?; + + let mut md = String::from("## Project Memory\n\n"); + let mut chars = md.len(); + + for (content, trust, confidence) in &entries { + let line = format!("- [{}|{:.1}] {}\n", trust, confidence, content); + if chars + line.len() > max_chars { + break; + } + md.push_str(&line); + chars += line.len(); + } + + Ok(md) +} + +#[tauri::command] +pub fn pro_memory_extract_from_session( + project_id: String, + session_messages_json: String, +) -> Result, String> { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + + let messages: Vec = serde_json::from_str(&session_messages_json) + .map_err(|e| format!("Invalid JSON: {e}"))?; + + let ts = now_epoch(); + let mut extracted = Vec::new(); + + // Patterns to extract: decisions, file references, errors + let decision_patterns = ["decision:", "chose ", "decided to ", "instead of "]; + let error_patterns = ["error:", "failed:", "Error:", "panic", "FAILED"]; + + for msg in &messages { + let content = msg.get("content").and_then(|c| c.as_str()).unwrap_or(""); + + // Extract decisions + for pattern in &decision_patterns { + if content.contains(pattern) { + let fragment_content = extract_surrounding(content, pattern, 200); + conn.execute( + "INSERT INTO pro_memories (project_id, content, source, trust, confidence, created_at, tags) + VALUES (?1, ?2, 'auto-extract', 'auto', 0.7, ?3, 'decision')", + params![project_id, fragment_content, ts], + ).map_err(|e| format!("Insert failed: {e}"))?; + let id = conn.last_insert_rowid(); + extracted.push(MemoryFragment { + id, + project_id: project_id.clone(), + content: fragment_content, + source: "auto-extract".into(), + trust: "auto".into(), + confidence: 0.7, + created_at: ts, + ttl_days: 90, + tags: "decision".into(), + }); + break; // One extraction per message + } + } + + // Extract errors + for pattern in &error_patterns { + if content.contains(pattern) { + let fragment_content = extract_surrounding(content, pattern, 300); + conn.execute( + "INSERT INTO pro_memories (project_id, content, source, trust, confidence, created_at, tags) + VALUES (?1, ?2, 'auto-extract', 'auto', 0.6, ?3, 'error')", + params![project_id, fragment_content, ts], + ).map_err(|e| format!("Insert failed: {e}"))?; + let id = conn.last_insert_rowid(); + extracted.push(MemoryFragment { + id, + project_id: project_id.clone(), + content: fragment_content, + source: "auto-extract".into(), + trust: "auto".into(), + confidence: 0.6, + created_at: ts, + ttl_days: 90, + tags: "error".into(), + }); + break; + } + } + } + + Ok(extracted) +} + +/// Extract surrounding text around a pattern match, up to max_chars. +fn extract_surrounding(text: &str, pattern: &str, max_chars: usize) -> String { + if let Some(pos) = text.find(pattern) { + let start = pos.saturating_sub(50); + let end = (pos + max_chars).min(text.len()); + // Ensure valid UTF-8 boundaries + let start = text.floor_char_boundary(start); + let end = text.ceil_char_boundary(end); + text[start..end].to_string() + } else { + text.chars().take(max_chars).collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_memory_fragment_serializes_camel_case() { + let f = MemoryFragment { + id: 1, + project_id: "proj1".into(), + content: "We decided to use SQLite".into(), + source: "session-abc".into(), + trust: "agent".into(), + confidence: 0.9, + created_at: 1710000000, + ttl_days: 90, + tags: "decision,architecture".into(), + }; + let json = serde_json::to_string(&f).unwrap(); + assert!(json.contains("projectId")); + assert!(json.contains("createdAt")); + assert!(json.contains("ttlDays")); + } + + #[test] + fn test_memory_fragment_deserializes() { + let json = r#"{"id":1,"projectId":"p","content":"test","source":"s","trust":"human","confidence":1.0,"createdAt":0,"ttlDays":30,"tags":"t"}"#; + let f: MemoryFragment = serde_json::from_str(json).unwrap(); + assert_eq!(f.project_id, "p"); + assert_eq!(f.trust, "human"); + } + + #[test] + fn test_extract_surrounding() { + let text = "We chose SQLite instead of PostgreSQL for simplicity"; + let result = extract_surrounding(text, "chose ", 100); + assert!(result.contains("chose SQLite")); + } +} diff --git a/agor-pro/src/router.rs b/agor-pro/src/router.rs new file mode 100644 index 0000000..815eeec --- /dev/null +++ b/agor-pro/src/router.rs @@ -0,0 +1,194 @@ +// SPDX-License-Identifier: LicenseRef-Commercial +// Smart Model Router — select optimal model based on task type and project config. + +use rusqlite::params; +use serde::Serialize; + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelRecommendation { + pub model: String, + pub reason: String, + pub estimated_cost_factor: f64, + pub profile: String, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct RoutingProfile { + pub name: String, + pub description: String, + pub rules: Vec, +} + +fn ensure_tables(conn: &rusqlite::Connection) -> Result<(), String> { + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS pro_router_profiles ( + project_id TEXT PRIMARY KEY, + profile TEXT NOT NULL DEFAULT 'balanced' + );" + ).map_err(|e| format!("Failed to create router tables: {e}")) +} + +fn get_profiles() -> Vec { + vec![ + RoutingProfile { + name: "cost_saver".into(), + description: "Minimize cost — use cheapest viable model".into(), + rules: vec![ + "All roles use cheapest model".into(), + "Only upgrade for prompts > 10000 chars".into(), + ], + }, + RoutingProfile { + name: "quality_first".into(), + description: "Maximize quality — always use premium model".into(), + rules: vec![ + "All roles use premium model".into(), + "No downgrade regardless of prompt size".into(), + ], + }, + RoutingProfile { + name: "balanced".into(), + description: "Match model to task — role and prompt size heuristic".into(), + rules: vec![ + "Manager/Architect → premium model".into(), + "Tester/Reviewer → mid-tier model".into(), + "Short prompts (<2000 chars) → cheap model".into(), + "Long prompts (>8000 chars) → premium model".into(), + ], + }, + ] +} + +fn select_model(profile: &str, role: &str, prompt_length: i64, provider: &str) -> (String, String, f64) { + let (cheap, mid, premium) = match provider { + "codex" => ("gpt-4.1-mini", "gpt-4.1", "gpt-5"), + "ollama" => ("qwen3:8b", "qwen3:8b", "qwen3:32b"), + _ => ("claude-haiku-4-5", "claude-sonnet-4-5", "claude-opus-4"), + }; + + match profile { + "cost_saver" => { + if prompt_length > 10_000 { + (mid.into(), "Long prompt upgrade in cost_saver profile".into(), 0.5) + } else { + (cheap.into(), "Cost saver: cheapest model".into(), 0.1) + } + } + "quality_first" => { + (premium.into(), "Quality first: premium model".into(), 1.0) + } + _ => { + // Balanced: role + prompt heuristic + match role { + "manager" | "architect" => { + (premium.into(), format!("Balanced: premium for {role} role"), 1.0) + } + "tester" | "reviewer" => { + (mid.into(), format!("Balanced: mid-tier for {role} role"), 0.5) + } + _ => { + if prompt_length < 2_000 { + (cheap.into(), "Balanced: cheap for short prompt".into(), 0.1) + } else if prompt_length > 8_000 { + (premium.into(), "Balanced: premium for long prompt".into(), 1.0) + } else { + (mid.into(), "Balanced: mid-tier default".into(), 0.5) + } + } + } + } + } +} + +#[tauri::command] +pub fn pro_router_recommend( + project_id: String, + role: String, + prompt_length: i64, + provider: Option, +) -> Result { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + + let profile = conn.prepare("SELECT profile FROM pro_router_profiles WHERE project_id = ?1") + .map_err(|e| format!("Query failed: {e}"))? + .query_row(params![project_id], |row| row.get::<_, String>(0)) + .unwrap_or_else(|_| "balanced".into()); + + let prov = provider.as_deref().unwrap_or("claude"); + let (model, reason, cost_factor) = select_model(&profile, &role, prompt_length, prov); + + Ok(ModelRecommendation { model, reason, estimated_cost_factor: cost_factor, profile }) +} + +#[tauri::command] +pub fn pro_router_set_profile(project_id: String, profile: String) -> Result<(), String> { + let valid = ["cost_saver", "quality_first", "balanced"]; + if !valid.contains(&profile.as_str()) { + return Err(format!("Invalid profile '{}'. Valid: {:?}", profile, valid)); + } + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + conn.execute( + "INSERT INTO pro_router_profiles (project_id, profile) VALUES (?1, ?2) + ON CONFLICT(project_id) DO UPDATE SET profile = ?2", + params![project_id, profile], + ).map_err(|e| format!("Failed to set profile: {e}"))?; + Ok(()) +} + +#[tauri::command] +pub fn pro_router_get_profile(project_id: String) -> Result { + let conn = super::open_sessions_db()?; + ensure_tables(&conn)?; + let profile = conn.prepare("SELECT profile FROM pro_router_profiles WHERE project_id = ?1") + .map_err(|e| format!("Query failed: {e}"))? + .query_row(params![project_id], |row| row.get::<_, String>(0)) + .unwrap_or_else(|_| "balanced".into()); + Ok(profile) +} + +#[tauri::command] +pub fn pro_router_list_profiles() -> Vec { + get_profiles() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_recommendation_serializes_camel_case() { + let r = ModelRecommendation { + model: "claude-sonnet-4-5".into(), + reason: "test".into(), + estimated_cost_factor: 0.5, + profile: "balanced".into(), + }; + let json = serde_json::to_string(&r).unwrap(); + assert!(json.contains("estimatedCostFactor")); + assert!(json.contains("\"profile\":\"balanced\"")); + } + + #[test] + fn test_select_model_balanced_manager() { + let (model, _, cost) = select_model("balanced", "manager", 5000, "claude"); + assert_eq!(model, "claude-opus-4"); + assert_eq!(cost, 1.0); + } + + #[test] + fn test_select_model_cost_saver() { + let (model, _, cost) = select_model("cost_saver", "worker", 1000, "claude"); + assert_eq!(model, "claude-haiku-4-5"); + assert!(cost < 0.2); + } + + #[test] + fn test_select_model_codex_provider() { + let (model, _, _) = select_model("quality_first", "manager", 5000, "codex"); + assert_eq!(model, "gpt-5"); + } +} diff --git a/agor-pro/src/symbols.rs b/agor-pro/src/symbols.rs new file mode 100644 index 0000000..3b9b639 --- /dev/null +++ b/agor-pro/src/symbols.rs @@ -0,0 +1,295 @@ +// SPDX-License-Identifier: LicenseRef-Commercial +// Codebase Symbol Graph — stub implementation using regex parsing. +// Full tree-sitter implementation deferred until tree-sitter dep is added. + +use serde::Serialize; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::Mutex; + +static SYMBOL_CACHE: std::sync::LazyLock>>> = + std::sync::LazyLock::new(|| Mutex::new(HashMap::new())); + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Symbol { + pub name: String, + pub kind: String, + pub file_path: String, + pub line_number: usize, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CallerRef { + pub file_path: String, + pub line_number: usize, + pub context: String, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ScanResult { + pub files_scanned: usize, + pub symbols_found: usize, + pub duration_ms: u64, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct IndexStatus { + pub indexed: bool, + pub symbols_count: usize, + pub last_scan: Option, +} + +/// Common directories to skip during scan. +const SKIP_DIRS: &[&str] = &[ + ".git", "node_modules", "target", "dist", "build", ".next", + "__pycache__", ".venv", "venv", ".tox", +]; + +/// Supported extensions for symbol extraction. +const SUPPORTED_EXT: &[&str] = &["ts", "rs", "py", "js", "tsx", "jsx"]; + +fn should_skip(name: &str) -> bool { + SKIP_DIRS.contains(&name) +} + +fn walk_files(dir: &Path, files: &mut Vec) { + let Ok(entries) = std::fs::read_dir(dir) else { return }; + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + if let Some(name) = path.file_name().and_then(|n| n.to_str()) { + if !should_skip(name) { + walk_files(&path, files); + } + } + } else if let Some(ext) = path.extension().and_then(|e| e.to_str()) { + if SUPPORTED_EXT.contains(&ext) { + files.push(path); + } + } + } +} + +fn extract_symbols_from_file(path: &Path) -> Vec { + let Ok(content) = std::fs::read_to_string(path) else { return vec![] }; + let file_str = path.to_string_lossy().to_string(); + let ext = path.extension().and_then(|e| e.to_str()).unwrap_or(""); + let mut symbols = Vec::new(); + + for (line_idx, line) in content.lines().enumerate() { + let trimmed = line.trim(); + match ext { + "rs" => { + if let Some(name) = extract_after(trimmed, "fn ") { + symbols.push(Symbol { name, kind: "function".into(), file_path: file_str.clone(), line_number: line_idx + 1 }); + } else if let Some(name) = extract_after(trimmed, "struct ") { + symbols.push(Symbol { name, kind: "struct".into(), file_path: file_str.clone(), line_number: line_idx + 1 }); + } else if let Some(name) = extract_after(trimmed, "enum ") { + symbols.push(Symbol { name, kind: "type".into(), file_path: file_str.clone(), line_number: line_idx + 1 }); + } else if let Some(name) = extract_after(trimmed, "const ") { + symbols.push(Symbol { name, kind: "const".into(), file_path: file_str.clone(), line_number: line_idx + 1 }); + } else if let Some(name) = extract_after(trimmed, "trait ") { + symbols.push(Symbol { name, kind: "type".into(), file_path: file_str.clone(), line_number: line_idx + 1 }); + } + } + "ts" | "tsx" | "js" | "jsx" => { + if let Some(name) = extract_after(trimmed, "function ") { + symbols.push(Symbol { name, kind: "function".into(), file_path: file_str.clone(), line_number: line_idx + 1 }); + } else if let Some(name) = extract_after(trimmed, "class ") { + symbols.push(Symbol { name, kind: "class".into(), file_path: file_str.clone(), line_number: line_idx + 1 }); + } else if let Some(name) = extract_ts_const_fn(trimmed) { + symbols.push(Symbol { name, kind: "function".into(), file_path: file_str.clone(), line_number: line_idx + 1 }); + } else if let Some(name) = extract_after(trimmed, "interface ") { + symbols.push(Symbol { name, kind: "type".into(), file_path: file_str.clone(), line_number: line_idx + 1 }); + } else if let Some(name) = extract_after(trimmed, "type ") { + symbols.push(Symbol { name, kind: "type".into(), file_path: file_str.clone(), line_number: line_idx + 1 }); + } + } + "py" => { + if let Some(name) = extract_after(trimmed, "def ") { + symbols.push(Symbol { name, kind: "function".into(), file_path: file_str.clone(), line_number: line_idx + 1 }); + } else if let Some(name) = extract_after(trimmed, "class ") { + symbols.push(Symbol { name, kind: "class".into(), file_path: file_str.clone(), line_number: line_idx + 1 }); + } + } + _ => {} + } + } + + symbols +} + +/// Extract identifier after a keyword (e.g., "fn " -> function name). +fn extract_after(line: &str, prefix: &str) -> Option { + if !line.starts_with(prefix) && !line.starts_with(&format!("pub {prefix}")) + && !line.starts_with(&format!("export {prefix}")) + && !line.starts_with(&format!("pub(crate) {prefix}")) + && !line.starts_with(&format!("async {prefix}")) + && !line.starts_with(&format!("pub async {prefix}")) + && !line.starts_with(&format!("export async {prefix}")) + && !line.starts_with(&format!("export default {prefix}")) + { + return None; + } + let after = line.find(prefix)? + prefix.len(); + let rest = &line[after..]; + let name: String = rest.chars() + .take_while(|c| c.is_alphanumeric() || *c == '_') + .collect(); + if name.is_empty() { None } else { Some(name) } +} + +/// Extract arrow function / const fn pattern: `const foo = (` or `export const foo = (` +fn extract_ts_const_fn(line: &str) -> Option { + let stripped = line.strip_prefix("export ") + .or(Some(line))?; + let rest = stripped.strip_prefix("const ")?; + let name: String = rest.chars() + .take_while(|c| c.is_alphanumeric() || *c == '_') + .collect(); + if name.is_empty() { return None; } + // Check if it looks like a function assignment + if rest.contains("= (") || rest.contains("= async (") || rest.contains("=> ") { + Some(name) + } else { + None + } +} + +#[tauri::command] +pub fn pro_symbols_scan(project_path: String) -> Result { + let start = std::time::Instant::now(); + let root = PathBuf::from(&project_path); + if !root.is_dir() { + return Err(format!("Not a directory: {project_path}")); + } + + let mut files = Vec::new(); + walk_files(&root, &mut files); + + let mut all_symbols = Vec::new(); + for file in &files { + all_symbols.extend(extract_symbols_from_file(file)); + } + + let result = ScanResult { + files_scanned: files.len(), + symbols_found: all_symbols.len(), + duration_ms: start.elapsed().as_millis() as u64, + }; + + let mut cache = SYMBOL_CACHE.lock().map_err(|e| format!("Lock failed: {e}"))?; + cache.insert(project_path, all_symbols); + + Ok(result) +} + +#[tauri::command] +pub fn pro_symbols_search(project_path: String, query: String) -> Result, String> { + let cache = SYMBOL_CACHE.lock().map_err(|e| format!("Lock failed: {e}"))?; + let symbols = cache.get(&project_path).cloned().unwrap_or_default(); + let query_lower = query.to_lowercase(); + + let results: Vec = symbols.into_iter() + .filter(|s| s.name.to_lowercase().contains(&query_lower)) + .take(50) + .collect(); + + Ok(results) +} + +#[tauri::command] +pub fn pro_symbols_find_callers(project_path: String, symbol_name: String) -> Result, String> { + let root = PathBuf::from(&project_path); + if !root.is_dir() { + return Err(format!("Not a directory: {project_path}")); + } + + let mut files = Vec::new(); + walk_files(&root, &mut files); + + let mut callers = Vec::new(); + for file in &files { + let Ok(content) = std::fs::read_to_string(file) else { continue }; + for (idx, line) in content.lines().enumerate() { + if line.contains(&symbol_name) { + callers.push(CallerRef { + file_path: file.to_string_lossy().to_string(), + line_number: idx + 1, + context: line.trim().to_string(), + }); + } + } + } + + // Cap results + callers.truncate(100); + Ok(callers) +} + +#[tauri::command] +pub fn pro_symbols_status(project_path: String) -> Result { + let cache = SYMBOL_CACHE.lock().map_err(|e| format!("Lock failed: {e}"))?; + match cache.get(&project_path) { + Some(symbols) => Ok(IndexStatus { + indexed: true, + symbols_count: symbols.len(), + last_scan: None, // In-memory only, no timestamp tracking + }), + None => Ok(IndexStatus { + indexed: false, + symbols_count: 0, + last_scan: None, + }), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_symbol_serializes_camel_case() { + let s = Symbol { + name: "processEvent".into(), + kind: "function".into(), + file_path: "/src/lib.rs".into(), + line_number: 42, + }; + let json = serde_json::to_string(&s).unwrap(); + assert!(json.contains("filePath")); + assert!(json.contains("lineNumber")); + } + + #[test] + fn test_scan_result_serializes_camel_case() { + let r = ScanResult { + files_scanned: 10, + symbols_found: 50, + duration_ms: 123, + }; + let json = serde_json::to_string(&r).unwrap(); + assert!(json.contains("filesScanned")); + assert!(json.contains("symbolsFound")); + assert!(json.contains("durationMs")); + } + + #[test] + fn test_extract_after_rust_fn() { + assert_eq!(extract_after("fn hello()", "fn "), Some("hello".into())); + assert_eq!(extract_after("pub fn world()", "fn "), Some("world".into())); + assert_eq!(extract_after("pub async fn go()", "fn "), Some("go".into())); + assert_eq!(extract_after("let x = 5;", "fn "), None); + } + + #[test] + fn test_extract_ts_const_fn() { + assert_eq!(extract_ts_const_fn("const foo = (x: number) => x"), Some("foo".into())); + assert_eq!(extract_ts_const_fn("export const bar = async ("), Some("bar".into())); + assert_eq!(extract_ts_const_fn("const DATA = 42"), None); + } +}