diff --git a/crates/ui/src/main.rs b/crates/ui/src/main.rs index 3bbd6dc..4786c58 100644 --- a/crates/ui/src/main.rs +++ b/crates/ui/src/main.rs @@ -111,18 +111,17 @@ async fn fetch_health(path: &str) -> Result { /// Get schema context for all datasets (used for AI SQL generation) async fn get_schema_context(datasets: &[Dataset]) -> String { - let mut ctx = String::from("Available tables:\n\n"); + let mut ctx = String::from("DATABASE SCHEMA:\n\n"); for ds in datasets { let desc = run_sql(&format!("DESCRIBE {}", ds.name)).await; match desc { Ok(resp) => { - ctx.push_str(&format!("TABLE: {}\n", ds.name)); + ctx.push_str(&format!("TABLE: {}\n Columns:\n", ds.name)); if let Some(rows) = resp.rows.as_array() { for row in rows { let col = row.get("column_name").and_then(|v| v.as_str()).unwrap_or("?"); let dt = row.get("data_type").and_then(|v| v.as_str()).unwrap_or("?"); - let nullable = row.get("is_nullable").and_then(|v| v.as_str()).unwrap_or("?"); - ctx.push_str(&format!(" - {} ({}, nullable={})\n", col, dt, nullable)); + ctx.push_str(&format!(" {}.{} ({})\n", ds.name, col, dt)); } } ctx.push('\n'); @@ -132,6 +131,18 @@ async fn get_schema_context(datasets: &[Dataset]) -> String { } } } + // Add relationship hints so the model knows how to JOIN + ctx.push_str("RELATIONSHIPS:\n"); + ctx.push_str(" candidates.candidate_id = placements.candidate_id\n"); + ctx.push_str(" candidates.candidate_id = timesheets.candidate_id\n"); + ctx.push_str(" candidates.candidate_id = call_log.candidate_id\n"); + ctx.push_str(" candidates.candidate_id = email_log.candidate_id\n"); + ctx.push_str(" clients.client_id = job_orders.client_id\n"); + ctx.push_str(" clients.client_id = placements.client_id\n"); + ctx.push_str(" clients.client_id = timesheets.client_id\n"); + ctx.push_str(" job_orders.job_order_id = placements.job_order_id\n"); + ctx.push_str(" placements.placement_id = timesheets.placement_id\n"); + ctx.push_str("\nNOTE: 'vertical' is only in candidates, clients, and job_orders. To get vertical for timesheets or placements, JOIN to those tables.\n"); ctx } @@ -272,9 +283,7 @@ fn AskPanel(datasets: Vec) -> Element { match ai_generate(&prompt, 512).await { Ok(resp) => { - let sql = resp.text.trim().to_string(); - // Clean markdown backticks if model adds them - let sql = sql.trim_start_matches("```sql").trim_start_matches("```").trim_end_matches("```").trim().to_string(); + let sql = clean_sql(&resp.text); generated_sql.set(Some(sql.clone())); // Step 3: Execute @@ -292,9 +301,7 @@ fn AskPanel(datasets: Vec) -> Element { Write a CORRECTED SQL query using ONLY the columns listed in the schema. Output ONLY SQL." ); if let Ok(retry_resp) = ai_generate(&retry_prompt, 512).await { - let retry_sql = retry_resp.text.trim() - .trim_start_matches("```sql").trim_start_matches("```") - .trim_end_matches("```").trim().to_string(); + let retry_sql = clean_sql(&retry_resp.text); generated_sql.set(Some(retry_sql.clone())); step.set("running corrected query...".into()); let retry_result = run_sql(&retry_sql).await; @@ -356,7 +363,7 @@ fn AskPanel(datasets: Vec) -> Element { ); match ai_generate(&prompt, 512).await { Ok(resp) => { - let sql = resp.text.trim().trim_start_matches("```sql").trim_start_matches("```").trim_end_matches("```").trim().to_string(); + let sql = clean_sql(&resp.text); generated_sql.set(Some(sql.clone())); step.set("running query...".into()); let query_result = run_sql(&sql).await; @@ -367,7 +374,7 @@ fn AskPanel(datasets: Vec) -> Element { "The SQL you wrote had an error:\n{err}\n\n{schema_ctx}\n\nOriginal question: {q}\n\nWrite a CORRECTED SQL query using ONLY the columns listed. Output ONLY SQL." ); if let Ok(rr) = ai_generate(&retry_prompt, 512).await { - let rsql = rr.text.trim().trim_start_matches("```sql").trim_start_matches("```").trim_end_matches("```").trim().to_string(); + let rsql = clean_sql(&rr.text); generated_sql.set(Some(rsql.clone())); step.set("running corrected query...".into()); result.set(Some(run_sql(&rsql).await)); @@ -1018,6 +1025,25 @@ fn ResultsTable(response: QueryResponse) -> Element { } } +/// Clean AI-generated SQL: strip markdown fences, leading "sql" keyword, explanations. +fn clean_sql(raw: &str) -> String { + let mut s = raw.trim().to_string(); + // Remove markdown code fences + s = s.trim_start_matches("```sql").trim_start_matches("```").trim_end_matches("```").trim().to_string(); + // Remove leading "sql" keyword on its own line + let lines: Vec<&str> = s.lines().collect(); + if let Some(first) = lines.first() { + if first.trim().eq_ignore_ascii_case("sql") || first.trim().eq_ignore_ascii_case("sql;") { + s = lines[1..].join("\n").trim().to_string(); + } + } + // If the model added explanation after the SQL, keep only up to the first semicolon line + if let Some(pos) = s.find(";\n\n") { + s = s[..pos + 1].to_string(); + } + s +} + fn format_cell(val: Option<&serde_json::Value>) -> String { match val { None | Some(serde_json::Value::Null) => "—".to_string(),