Rust fixes (HIGH): - symbols.rs: path validation (reject near-root, 50K file limit, symlink filter) - memory.rs: FTS5 query quoting (prevent operator injection), 1000 fragment cap, content length limit, transaction wrapping - budget.rs: atomic check-and-reserve via transaction, input validation, index on budget_log - export.rs: safe UTF-8 truncation via chars().take() - git_context.rs: reject paths starting with '-' (flag injection) - branch_policy.rs: action validation (block|warn only), path validation Rust fixes (MEDIUM): - export.rs: named column access (positional→named) - budget.rs: named column access, negative value guards Svelte fixes: - AccountSwitcher: 2-step confirmation before account switch - ProjectMemory: expand/collapse content, 2-step delete confirm, tags split fix - CodeIntelligence: min 2-char symbol query, CodeSymbol rename, aria-labels - BudgetManager: 10M upper bound, aria-label on input, named constants - SessionExporter: timeout cleanup on destroy, aria-live feedback - AnalyticsDashboard: SVG aria-label, removed unused import, named constant
319 lines
12 KiB
Rust
319 lines
12 KiB
Rust
// SPDX-License-Identifier: LicenseRef-Commercial
|
|
// Codebase Symbol Graph — stub implementation using regex parsing.
|
|
// Full tree-sitter implementation deferred until tree-sitter dep is added.
|
|
|
|
use serde::Serialize;
|
|
use std::collections::HashMap;
|
|
use std::path::{Path, PathBuf};
|
|
use std::sync::Mutex;
|
|
|
|
static SYMBOL_CACHE: std::sync::LazyLock<Mutex<HashMap<String, Vec<Symbol>>>> =
|
|
std::sync::LazyLock::new(|| Mutex::new(HashMap::new()));
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct Symbol {
|
|
pub name: String,
|
|
pub kind: String,
|
|
pub file_path: String,
|
|
pub line_number: usize,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct CallerRef {
|
|
pub file_path: String,
|
|
pub line_number: usize,
|
|
pub context: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct ScanResult {
|
|
pub files_scanned: usize,
|
|
pub symbols_found: usize,
|
|
pub duration_ms: u64,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct IndexStatus {
|
|
pub indexed: bool,
|
|
pub symbols_count: usize,
|
|
pub last_scan: Option<String>,
|
|
}
|
|
|
|
/// Common directories to skip during scan.
|
|
const SKIP_DIRS: &[&str] = &[
|
|
".git", "node_modules", "target", "dist", "build", ".next",
|
|
"__pycache__", ".venv", "venv", ".tox",
|
|
];
|
|
|
|
/// Supported extensions for symbol extraction.
|
|
const SUPPORTED_EXT: &[&str] = &["ts", "rs", "py", "js", "tsx", "jsx"];
|
|
|
|
fn should_skip(name: &str) -> bool {
|
|
SKIP_DIRS.contains(&name)
|
|
}
|
|
|
|
const MAX_FILES: usize = 50_000;
|
|
const MAX_DEPTH: usize = 20;
|
|
|
|
fn walk_files(dir: &Path, files: &mut Vec<PathBuf>) {
|
|
walk_files_bounded(dir, files, 0);
|
|
}
|
|
|
|
fn walk_files_bounded(dir: &Path, files: &mut Vec<PathBuf>, depth: usize) {
|
|
if depth >= MAX_DEPTH || files.len() >= MAX_FILES {
|
|
return;
|
|
}
|
|
let Ok(entries) = std::fs::read_dir(dir) else { return };
|
|
for entry in entries.flatten() {
|
|
if files.len() >= MAX_FILES {
|
|
return;
|
|
}
|
|
let ft = entry.file_type();
|
|
// Skip symlinks
|
|
if ft.as_ref().map_or(false, |ft| ft.is_symlink()) {
|
|
continue;
|
|
}
|
|
let path = entry.path();
|
|
if ft.as_ref().map_or(false, |ft| ft.is_dir()) {
|
|
if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
|
|
if !should_skip(name) {
|
|
walk_files_bounded(&path, files, depth + 1);
|
|
}
|
|
}
|
|
} else if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
|
|
if SUPPORTED_EXT.contains(&ext) {
|
|
files.push(path);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn extract_symbols_from_file(path: &Path) -> Vec<Symbol> {
|
|
let Ok(content) = std::fs::read_to_string(path) else { return vec![] };
|
|
let file_str = path.to_string_lossy().to_string();
|
|
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
|
|
let mut symbols = Vec::new();
|
|
|
|
for (line_idx, line) in content.lines().enumerate() {
|
|
let trimmed = line.trim();
|
|
match ext {
|
|
"rs" => {
|
|
if let Some(name) = extract_after(trimmed, "fn ") {
|
|
symbols.push(Symbol { name, kind: "function".into(), file_path: file_str.clone(), line_number: line_idx + 1 });
|
|
} else if let Some(name) = extract_after(trimmed, "struct ") {
|
|
symbols.push(Symbol { name, kind: "struct".into(), file_path: file_str.clone(), line_number: line_idx + 1 });
|
|
} else if let Some(name) = extract_after(trimmed, "enum ") {
|
|
symbols.push(Symbol { name, kind: "type".into(), file_path: file_str.clone(), line_number: line_idx + 1 });
|
|
} else if let Some(name) = extract_after(trimmed, "const ") {
|
|
symbols.push(Symbol { name, kind: "const".into(), file_path: file_str.clone(), line_number: line_idx + 1 });
|
|
} else if let Some(name) = extract_after(trimmed, "trait ") {
|
|
symbols.push(Symbol { name, kind: "type".into(), file_path: file_str.clone(), line_number: line_idx + 1 });
|
|
}
|
|
}
|
|
"ts" | "tsx" | "js" | "jsx" => {
|
|
if let Some(name) = extract_after(trimmed, "function ") {
|
|
symbols.push(Symbol { name, kind: "function".into(), file_path: file_str.clone(), line_number: line_idx + 1 });
|
|
} else if let Some(name) = extract_after(trimmed, "class ") {
|
|
symbols.push(Symbol { name, kind: "class".into(), file_path: file_str.clone(), line_number: line_idx + 1 });
|
|
} else if let Some(name) = extract_ts_const_fn(trimmed) {
|
|
symbols.push(Symbol { name, kind: "function".into(), file_path: file_str.clone(), line_number: line_idx + 1 });
|
|
} else if let Some(name) = extract_after(trimmed, "interface ") {
|
|
symbols.push(Symbol { name, kind: "type".into(), file_path: file_str.clone(), line_number: line_idx + 1 });
|
|
} else if let Some(name) = extract_after(trimmed, "type ") {
|
|
symbols.push(Symbol { name, kind: "type".into(), file_path: file_str.clone(), line_number: line_idx + 1 });
|
|
}
|
|
}
|
|
"py" => {
|
|
if let Some(name) = extract_after(trimmed, "def ") {
|
|
symbols.push(Symbol { name, kind: "function".into(), file_path: file_str.clone(), line_number: line_idx + 1 });
|
|
} else if let Some(name) = extract_after(trimmed, "class ") {
|
|
symbols.push(Symbol { name, kind: "class".into(), file_path: file_str.clone(), line_number: line_idx + 1 });
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
symbols
|
|
}
|
|
|
|
/// Extract identifier after a keyword (e.g., "fn " -> function name).
|
|
fn extract_after(line: &str, prefix: &str) -> Option<String> {
|
|
if !line.starts_with(prefix) && !line.starts_with(&format!("pub {prefix}"))
|
|
&& !line.starts_with(&format!("export {prefix}"))
|
|
&& !line.starts_with(&format!("pub(crate) {prefix}"))
|
|
&& !line.starts_with(&format!("async {prefix}"))
|
|
&& !line.starts_with(&format!("pub async {prefix}"))
|
|
&& !line.starts_with(&format!("export async {prefix}"))
|
|
&& !line.starts_with(&format!("export default {prefix}"))
|
|
{
|
|
return None;
|
|
}
|
|
let after = line.find(prefix)? + prefix.len();
|
|
let rest = &line[after..];
|
|
let name: String = rest.chars()
|
|
.take_while(|c| c.is_alphanumeric() || *c == '_')
|
|
.collect();
|
|
if name.is_empty() { None } else { Some(name) }
|
|
}
|
|
|
|
/// Extract arrow function / const fn pattern: `const foo = (` or `export const foo = (`
|
|
fn extract_ts_const_fn(line: &str) -> Option<String> {
|
|
let stripped = line.strip_prefix("export ")
|
|
.or(Some(line))?;
|
|
let rest = stripped.strip_prefix("const ")?;
|
|
let name: String = rest.chars()
|
|
.take_while(|c| c.is_alphanumeric() || *c == '_')
|
|
.collect();
|
|
if name.is_empty() { return None; }
|
|
// Check if it looks like a function assignment
|
|
if rest.contains("= (") || rest.contains("= async (") || rest.contains("=> ") {
|
|
Some(name)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
#[tauri::command]
|
|
pub fn pro_symbols_scan(project_path: String) -> Result<ScanResult, String> {
|
|
let start = std::time::Instant::now();
|
|
let root = PathBuf::from(&project_path);
|
|
if !root.is_absolute() || root.components().count() < 3 {
|
|
return Err("Invalid project path: must be an absolute path at least 3 levels deep".into());
|
|
}
|
|
if !root.is_dir() {
|
|
return Err(format!("Not a directory: {project_path}"));
|
|
}
|
|
|
|
let mut files = Vec::new();
|
|
walk_files(&root, &mut files);
|
|
|
|
let mut all_symbols = Vec::new();
|
|
for file in &files {
|
|
all_symbols.extend(extract_symbols_from_file(file));
|
|
}
|
|
|
|
let result = ScanResult {
|
|
files_scanned: files.len(),
|
|
symbols_found: all_symbols.len(),
|
|
duration_ms: start.elapsed().as_millis() as u64,
|
|
};
|
|
|
|
let mut cache = SYMBOL_CACHE.lock().map_err(|e| format!("Lock failed: {e}"))?;
|
|
cache.insert(project_path, all_symbols);
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
#[tauri::command]
|
|
pub fn pro_symbols_search(project_path: String, query: String) -> Result<Vec<Symbol>, String> {
|
|
let cache = SYMBOL_CACHE.lock().map_err(|e| format!("Lock failed: {e}"))?;
|
|
let symbols = cache.get(&project_path).cloned().unwrap_or_default();
|
|
let query_lower = query.to_lowercase();
|
|
|
|
let results: Vec<Symbol> = symbols.into_iter()
|
|
.filter(|s| s.name.to_lowercase().contains(&query_lower))
|
|
.take(50)
|
|
.collect();
|
|
|
|
Ok(results)
|
|
}
|
|
|
|
#[tauri::command]
|
|
pub fn pro_symbols_find_callers(project_path: String, symbol_name: String) -> Result<Vec<CallerRef>, String> {
|
|
let root = PathBuf::from(&project_path);
|
|
if !root.is_absolute() || root.components().count() < 3 {
|
|
return Err("Invalid project path: must be an absolute path at least 3 levels deep".into());
|
|
}
|
|
if !root.is_dir() {
|
|
return Err(format!("Not a directory: {project_path}"));
|
|
}
|
|
|
|
let mut files = Vec::new();
|
|
walk_files(&root, &mut files);
|
|
|
|
let mut callers = Vec::new();
|
|
for file in &files {
|
|
let Ok(content) = std::fs::read_to_string(file) else { continue };
|
|
for (idx, line) in content.lines().enumerate() {
|
|
if line.contains(&symbol_name) {
|
|
callers.push(CallerRef {
|
|
file_path: file.to_string_lossy().to_string(),
|
|
line_number: idx + 1,
|
|
context: line.trim().to_string(),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
// Cap results
|
|
callers.truncate(100);
|
|
Ok(callers)
|
|
}
|
|
|
|
#[tauri::command]
|
|
pub fn pro_symbols_status(project_path: String) -> Result<IndexStatus, String> {
|
|
let cache = SYMBOL_CACHE.lock().map_err(|e| format!("Lock failed: {e}"))?;
|
|
match cache.get(&project_path) {
|
|
Some(symbols) => Ok(IndexStatus {
|
|
indexed: true,
|
|
symbols_count: symbols.len(),
|
|
last_scan: None, // In-memory only, no timestamp tracking
|
|
}),
|
|
None => Ok(IndexStatus {
|
|
indexed: false,
|
|
symbols_count: 0,
|
|
last_scan: None,
|
|
}),
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_symbol_serializes_camel_case() {
|
|
let s = Symbol {
|
|
name: "processEvent".into(),
|
|
kind: "function".into(),
|
|
file_path: "/src/lib.rs".into(),
|
|
line_number: 42,
|
|
};
|
|
let json = serde_json::to_string(&s).unwrap();
|
|
assert!(json.contains("filePath"));
|
|
assert!(json.contains("lineNumber"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_scan_result_serializes_camel_case() {
|
|
let r = ScanResult {
|
|
files_scanned: 10,
|
|
symbols_found: 50,
|
|
duration_ms: 123,
|
|
};
|
|
let json = serde_json::to_string(&r).unwrap();
|
|
assert!(json.contains("filesScanned"));
|
|
assert!(json.contains("symbolsFound"));
|
|
assert!(json.contains("durationMs"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_after_rust_fn() {
|
|
assert_eq!(extract_after("fn hello()", "fn "), Some("hello".into()));
|
|
assert_eq!(extract_after("pub fn world()", "fn "), Some("world".into()));
|
|
assert_eq!(extract_after("pub async fn go()", "fn "), Some("go".into()));
|
|
assert_eq!(extract_after("let x = 5;", "fn "), None);
|
|
}
|
|
|
|
#[test]
|
|
fn test_extract_ts_const_fn() {
|
|
assert_eq!(extract_ts_const_fn("const foo = (x: number) => x"), Some("foo".into()));
|
|
assert_eq!(extract_ts_const_fn("export const bar = async ("), Some("bar".into()));
|
|
assert_eq!(extract_ts_const_fn("const DATA = 42"), None);
|
|
}
|
|
}
|