// SPDX-License-Identifier: LicenseRef-Commercial // Smart Model Router — select optimal model based on task type and project config. use rusqlite::params; use serde::Serialize; #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub struct ModelRecommendation { pub model: String, pub reason: String, pub estimated_cost_factor: f64, pub profile: String, } #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] pub struct RoutingProfile { pub name: String, pub description: String, pub rules: Vec, } fn ensure_tables(conn: &rusqlite::Connection) -> Result<(), String> { conn.execute_batch( "CREATE TABLE IF NOT EXISTS pro_router_profiles ( project_id TEXT PRIMARY KEY, profile TEXT NOT NULL DEFAULT 'balanced' );" ).map_err(|e| format!("Failed to create router tables: {e}")) } fn get_profiles() -> Vec { vec![ RoutingProfile { name: "cost_saver".into(), description: "Minimize cost — use cheapest viable model".into(), rules: vec![ "All roles use cheapest model".into(), "Only upgrade for prompts > 10000 chars".into(), ], }, RoutingProfile { name: "quality_first".into(), description: "Maximize quality — always use premium model".into(), rules: vec![ "All roles use premium model".into(), "No downgrade regardless of prompt size".into(), ], }, RoutingProfile { name: "balanced".into(), description: "Match model to task — role and prompt size heuristic".into(), rules: vec![ "Manager/Architect → premium model".into(), "Tester/Reviewer → mid-tier model".into(), "Short prompts (<2000 chars) → cheap model".into(), "Long prompts (>8000 chars) → premium model".into(), ], }, ] } fn select_model(profile: &str, role: &str, prompt_length: i64, provider: &str) -> (String, String, f64) { let (cheap, mid, premium) = match provider { "codex" => ("gpt-4.1-mini", "gpt-4.1", "gpt-5"), "ollama" => ("qwen3:8b", "qwen3:8b", "qwen3:32b"), _ => ("claude-haiku-4-5", "claude-sonnet-4-5", "claude-opus-4"), }; match profile { "cost_saver" => { if prompt_length > 10_000 { (mid.into(), "Long prompt upgrade in cost_saver profile".into(), 0.5) } else { (cheap.into(), "Cost saver: cheapest model".into(), 0.1) } } "quality_first" => { (premium.into(), "Quality first: premium model".into(), 1.0) } _ => { // Balanced: role + prompt heuristic match role { "manager" | "architect" => { (premium.into(), format!("Balanced: premium for {role} role"), 1.0) } "tester" | "reviewer" => { (mid.into(), format!("Balanced: mid-tier for {role} role"), 0.5) } _ => { if prompt_length < 2_000 { (cheap.into(), "Balanced: cheap for short prompt".into(), 0.1) } else if prompt_length > 8_000 { (premium.into(), "Balanced: premium for long prompt".into(), 1.0) } else { (mid.into(), "Balanced: mid-tier default".into(), 0.5) } } } } } } #[tauri::command] pub fn pro_router_recommend( project_id: String, role: String, prompt_length: i64, provider: Option, ) -> Result { let conn = super::open_sessions_db()?; ensure_tables(&conn)?; let profile = conn.prepare("SELECT profile FROM pro_router_profiles WHERE project_id = ?1") .map_err(|e| format!("Query failed: {e}"))? .query_row(params![project_id], |row| row.get::<_, String>(0)) .unwrap_or_else(|_| "balanced".into()); let prov = provider.as_deref().unwrap_or("claude"); let (model, reason, cost_factor) = select_model(&profile, &role, prompt_length, prov); Ok(ModelRecommendation { model, reason, estimated_cost_factor: cost_factor, profile }) } #[tauri::command] pub fn pro_router_set_profile(project_id: String, profile: String) -> Result<(), String> { let valid = ["cost_saver", "quality_first", "balanced"]; if !valid.contains(&profile.as_str()) { return Err(format!("Invalid profile '{}'. Valid: {:?}", profile, valid)); } let conn = super::open_sessions_db()?; ensure_tables(&conn)?; conn.execute( "INSERT INTO pro_router_profiles (project_id, profile) VALUES (?1, ?2) ON CONFLICT(project_id) DO UPDATE SET profile = ?2", params![project_id, profile], ).map_err(|e| format!("Failed to set profile: {e}"))?; Ok(()) } #[tauri::command] pub fn pro_router_get_profile(project_id: String) -> Result { let conn = super::open_sessions_db()?; ensure_tables(&conn)?; let profile = conn.prepare("SELECT profile FROM pro_router_profiles WHERE project_id = ?1") .map_err(|e| format!("Query failed: {e}"))? .query_row(params![project_id], |row| row.get::<_, String>(0)) .unwrap_or_else(|_| "balanced".into()); Ok(profile) } #[tauri::command] pub fn pro_router_list_profiles() -> Vec { get_profiles() } #[cfg(test)] mod tests { use super::*; #[test] fn test_recommendation_serializes_camel_case() { let r = ModelRecommendation { model: "claude-sonnet-4-5".into(), reason: "test".into(), estimated_cost_factor: 0.5, profile: "balanced".into(), }; let json = serde_json::to_string(&r).unwrap(); assert!(json.contains("estimatedCostFactor")); assert!(json.contains("\"profile\":\"balanced\"")); } #[test] fn test_select_model_balanced_manager() { let (model, _, cost) = select_model("balanced", "manager", 5000, "claude"); assert_eq!(model, "claude-opus-4"); assert_eq!(cost, 1.0); } #[test] fn test_select_model_cost_saver() { let (model, _, cost) = select_model("cost_saver", "worker", 1000, "claude"); assert_eq!(model, "claude-haiku-4-5"); assert!(cost < 0.2); } #[test] fn test_select_model_codex_provider() { let (model, _, _) = select_model("quality_first", "manager", 5000, "codex"); assert_eq!(model, "gpt-5"); } }