diff --git a/Cargo.lock b/Cargo.lock index e77fe34..98762bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2373,8 +2373,10 @@ name = "gateway" version = "0.1.0" dependencies = [ "aibridge", + "arrow", "axum", "catalogd", + "chrono", "ingestd", "journald", "object_store", diff --git a/crates/gateway/Cargo.toml b/crates/gateway/Cargo.toml index 92cf82a..54a8620 100644 --- a/crates/gateway/Cargo.toml +++ b/crates/gateway/Cargo.toml @@ -26,3 +26,5 @@ opentelemetry = { workspace = true } opentelemetry_sdk = { workspace = true } opentelemetry-stdout = { workspace = true } tracing-opentelemetry = { workspace = true } +arrow = { workspace = true } +chrono = { workspace = true } diff --git a/crates/gateway/src/main.rs b/crates/gateway/src/main.rs index 1aa588e..c6dd406 100644 --- a/crates/gateway/src/main.rs +++ b/crates/gateway/src/main.rs @@ -1,5 +1,6 @@ mod auth; mod observability; +mod tools; use axum::{Router, extract::DefaultBodyLimit, routing::get}; use proto::lakehouse::catalog_service_server::CatalogServiceServer; @@ -51,7 +52,7 @@ async fn main() { .route("/health", get(health)) .nest("/storage", storaged::service::router(store.clone())) .nest("/catalog", catalogd::service::router(registry.clone())) - .nest("/query", queryd::service::router(engine)) + .nest("/query", queryd::service::router(engine.clone())) .nest("/ai", aibridge::service::router(ai_client.clone())) .nest("/ingest", ingestd::service::router(ingestd::service::IngestState { store: store.clone(), @@ -68,7 +69,15 @@ async fn main() { } })) .nest("/workspaces", queryd::workspace_service::router(workspace_mgr)) - .nest("/journal", journald::service::router(journal)); + .nest("/journal", journald::service::router(journal)) + .nest("/tools", tools::service::router({ + let tool_reg = tools::registry::ToolRegistry::new_with_defaults(); + tool_reg.register_defaults().await; + tools::ToolState { + registry: tool_reg, + query_fn: tools::QueryExecutor::new(engine.clone()), + } + })); // Auth middleware (if enabled) if config.auth.enabled { diff --git a/crates/gateway/src/tools/mod.rs b/crates/gateway/src/tools/mod.rs new file mode 100644 index 0000000..2f316f6 --- /dev/null +++ b/crates/gateway/src/tools/mod.rs @@ -0,0 +1,47 @@ +pub mod registry; +pub mod service; + +use queryd::context::QueryEngine; +use arrow::json::writer::{JsonArray, Writer as JsonWriter}; + +/// State for the tool system. +#[derive(Clone)] +pub struct ToolState { + pub registry: registry::ToolRegistry, + pub query_fn: QueryExecutor, +} + +/// Wraps QueryEngine to provide a simple execute interface for tools. +#[derive(Clone)] +pub struct QueryExecutor { + engine: QueryEngine, +} + +impl QueryExecutor { + pub fn new(engine: QueryEngine) -> Self { + Self { engine } + } + + /// Execute SQL and return (rows as JSON, row count). + pub async fn execute(&self, sql: &str) -> Result<(serde_json::Value, usize), String> { + let batches = self.engine.query(sql).await?; + + if batches.is_empty() { + return Ok((serde_json::Value::Array(vec![]), 0)); + } + + let mut buf = Vec::new(); + let mut writer = JsonWriter::<_, JsonArray>::new(&mut buf); + for batch in &batches { + writer.write(batch).map_err(|e| format!("JSON write: {e}"))?; + } + writer.finish().map_err(|e| format!("JSON finish: {e}"))?; + drop(writer); + + let rows: serde_json::Value = serde_json::from_slice(&buf) + .map_err(|e| format!("JSON parse: {e}"))?; + let count = rows.as_array().map(|a| a.len()).unwrap_or(0); + + Ok((rows, count)) + } +} diff --git a/crates/gateway/src/tools/registry.rs b/crates/gateway/src/tools/registry.rs new file mode 100644 index 0000000..d6cf446 --- /dev/null +++ b/crates/gateway/src/tools/registry.rs @@ -0,0 +1,274 @@ +/// Tool Registry: named, governed business actions for AI agents. +/// Instead of raw SQL, agents call validated tools with audit trails. +/// +/// Each tool has: +/// - Name and description (for LLM tool-use) +/// - Parameter schema (validated before execution) +/// - Permission level (read / write / admin) +/// - Audit logging (every invocation recorded) +/// - Rate limiting (per agent) + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; + +/// Permission level for a tool. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Permission { + Read, // auto-approved, no side effects + Write, // modifies data, logged + Admin, // destructive, requires confirmation +} + +/// Parameter definition for a tool. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ParamDef { + pub name: String, + pub param_type: String, // "string", "integer", "boolean", "float" + pub required: bool, + pub description: String, + pub default: Option, +} + +/// Tool definition — what agents see and can call. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolDef { + pub name: String, + pub description: String, + pub permission: Permission, + pub parameters: Vec, + pub returns: String, // description of return value + pub sql_template: String, // SQL with {param} placeholders + pub category: String, // "candidates", "placements", "analytics" +} + +/// Audit log entry for a tool invocation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolInvocation { + pub id: String, + pub tool_name: String, + pub agent: String, + pub params: serde_json::Value, + pub permission: Permission, + pub timestamp: DateTime, + pub success: bool, + pub error: Option, + pub rows_returned: Option, +} + +/// The registry — holds tool definitions and audit log. +#[derive(Clone)] +pub struct ToolRegistry { + tools: Arc>>, + audit_log: Arc>>, +} + +impl ToolRegistry { + pub fn new() -> Self { + let registry = Self { + tools: Arc::new(RwLock::new(HashMap::new())), + audit_log: Arc::new(RwLock::new(Vec::new())), + }; + // Register built-in staffing tools + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(registry.register_defaults()) + }); + registry + } + + pub fn new_with_defaults() -> Self { + let registry = Self { + tools: Arc::new(RwLock::new(HashMap::new())), + audit_log: Arc::new(RwLock::new(Vec::new())), + }; + registry + } + + /// Register default staffing tools. + pub async fn register_defaults(&self) { + let tools = vec![ + ToolDef { + name: "search_candidates".into(), + description: "Search for candidates by skills, city, state, availability, and experience. Returns matching candidates with contact info.".into(), + permission: Permission::Read, + parameters: vec![ + ParamDef { name: "skills".into(), param_type: "string".into(), required: false, description: "Comma-separated skills to match (uses LIKE)".into(), default: None }, + ParamDef { name: "city".into(), param_type: "string".into(), required: false, description: "City name".into(), default: None }, + ParamDef { name: "state".into(), param_type: "string".into(), required: false, description: "State abbreviation".into(), default: None }, + ParamDef { name: "min_years".into(), param_type: "integer".into(), required: false, description: "Minimum years of experience".into(), default: Some(serde_json::json!(0)) }, + ParamDef { name: "status".into(), param_type: "string".into(), required: false, description: "Candidate status (active, inactive, placed)".into(), default: Some(serde_json::json!("active")) }, + ParamDef { name: "limit".into(), param_type: "integer".into(), required: false, description: "Max results".into(), default: Some(serde_json::json!(20)) }, + ], + returns: "List of candidates with id, name, phone, email, skills, experience".into(), + sql_template: "SELECT candidate_id, first_name, last_name, phone, email, city, state, zip, vertical, skills, years_experience FROM candidates WHERE 1=1 {skills_filter} {city_filter} {state_filter} {years_filter} {status_filter} ORDER BY years_experience DESC LIMIT {limit}".into(), + category: "candidates".into(), + }, + ToolDef { + name: "get_candidate".into(), + description: "Get full details for a specific candidate by ID.".into(), + permission: Permission::Read, + parameters: vec![ + ParamDef { name: "candidate_id".into(), param_type: "string".into(), required: true, description: "Candidate ID (e.g. CAND-000001)".into(), default: None }, + ], + returns: "Full candidate record".into(), + sql_template: "SELECT * FROM candidates WHERE candidate_id = '{candidate_id}'".into(), + category: "candidates".into(), + }, + ToolDef { + name: "revenue_by_client".into(), + description: "Show total billed revenue, pay costs, and gross profit by client. Filter by date range.".into(), + permission: Permission::Read, + parameters: vec![ + ParamDef { name: "limit".into(), param_type: "integer".into(), required: false, description: "Top N clients".into(), default: Some(serde_json::json!(10)) }, + ], + returns: "Client name, total billed, total paid, gross profit, timesheet count".into(), + sql_template: "SELECT c.company_name, COUNT(*) as timesheets, ROUND(SUM(t.bill_total),2) as total_billed, ROUND(SUM(t.pay_total),2) as total_paid, ROUND(SUM(t.bill_total) - SUM(t.pay_total),2) as gross_profit FROM timesheets t JOIN clients c ON t.client_id = c.client_id WHERE t.approved = true GROUP BY c.company_name ORDER BY total_billed DESC LIMIT {limit}".into(), + category: "analytics".into(), + }, + ToolDef { + name: "recruiter_performance".into(), + description: "Show recruiter performance: placements, unique candidates, and total revenue generated.".into(), + permission: Permission::Read, + parameters: vec![ + ParamDef { name: "limit".into(), param_type: "integer".into(), required: false, description: "Top N recruiters".into(), default: Some(serde_json::json!(10)) }, + ], + returns: "Recruiter name, placement count, unique candidates, total revenue".into(), + sql_template: "SELECT p.recruiter, COUNT(DISTINCT p.placement_id) as placements, COUNT(DISTINCT p.candidate_id) as unique_candidates, ROUND(SUM(t.bill_total),2) as total_revenue FROM placements p JOIN timesheets t ON p.placement_id = t.placement_id GROUP BY p.recruiter ORDER BY total_revenue DESC LIMIT {limit}".into(), + category: "analytics".into(), + }, + ToolDef { + name: "cold_leads".into(), + description: "Find candidates who were called multiple times but never placed — potential lost opportunities.".into(), + permission: Permission::Read, + parameters: vec![ + ParamDef { name: "min_calls".into(), param_type: "integer".into(), required: false, description: "Minimum call count".into(), default: Some(serde_json::json!(5)) }, + ParamDef { name: "limit".into(), param_type: "integer".into(), required: false, description: "Max results".into(), default: Some(serde_json::json!(20)) }, + ], + returns: "Candidates with call counts who were never placed".into(), + sql_template: "SELECT c.candidate_id, c.first_name, c.last_name, c.phone, c.vertical, cl.calls FROM candidates c JOIN (SELECT candidate_id, COUNT(*) as calls FROM call_log GROUP BY candidate_id HAVING COUNT(*) >= {min_calls}) cl ON c.candidate_id = cl.candidate_id WHERE c.candidate_id NOT IN (SELECT DISTINCT candidate_id FROM placements) ORDER BY cl.calls DESC LIMIT {limit}".into(), + category: "analytics".into(), + }, + ToolDef { + name: "open_jobs".into(), + description: "List open job orders with client, title, rates, and location.".into(), + permission: Permission::Read, + parameters: vec![ + ParamDef { name: "vertical".into(), param_type: "string".into(), required: false, description: "Filter by vertical (IT, Healthcare, etc)".into(), default: None }, + ParamDef { name: "city".into(), param_type: "string".into(), required: false, description: "Filter by city".into(), default: None }, + ParamDef { name: "limit".into(), param_type: "integer".into(), required: false, description: "Max results".into(), default: Some(serde_json::json!(20)) }, + ], + returns: "Open job orders with details".into(), + sql_template: "SELECT j.job_order_id, c.company_name, j.title, j.vertical, j.city, j.state, j.bill_rate, j.pay_rate, j.status FROM job_orders j JOIN clients c ON j.client_id = c.client_id WHERE j.status = 'open' {vertical_filter} {city_filter} ORDER BY j.bill_rate DESC LIMIT {limit}".into(), + category: "jobs".into(), + }, + ]; + + let mut reg = self.tools.write().await; + for tool in tools { + reg.insert(tool.name.clone(), tool); + } + } + + /// Get tool definition (for LLM tool-use schema). + pub async fn get_tool(&self, name: &str) -> Option { + self.tools.read().await.get(name).cloned() + } + + /// List all tools (for LLM tool discovery). + pub async fn list_tools(&self) -> Vec { + self.tools.read().await.values().cloned().collect() + } + + /// Build SQL from tool parameters. Validates and sanitizes. + pub fn build_sql(tool: &ToolDef, params: &serde_json::Value) -> Result { + let mut sql = tool.sql_template.clone(); + let params = params.as_object().ok_or("params must be an object")?; + + // Replace named parameters + for param_def in &tool.parameters { + let value = params.get(¶m_def.name) + .or(param_def.default.as_ref()); + + match param_def.name.as_str() { + // Handle filter parameters (conditional WHERE clauses) + "skills" => { + if let Some(v) = value.and_then(|v| v.as_str()).filter(|s| !s.is_empty()) { + let filters: Vec = v.split(',') + .map(|s| format!("skills LIKE '%{}%'", s.trim().replace('\'', "''"))) + .collect(); + sql = sql.replace("{skills_filter}", &format!("AND ({})", filters.join(" OR "))); + } else { + sql = sql.replace("{skills_filter}", ""); + } + } + "city" => { + if let Some(v) = value.and_then(|v| v.as_str()).filter(|s| !s.is_empty()) { + sql = sql.replace("{city_filter}", &format!("AND city = '{}'", v.replace('\'', "''"))); + } else { + sql = sql.replace("{city_filter}", ""); + } + } + "state" => { + if let Some(v) = value.and_then(|v| v.as_str()).filter(|s| !s.is_empty()) { + sql = sql.replace("{state_filter}", &format!("AND state = '{}'", v.replace('\'', "''"))); + } else { + sql = sql.replace("{state_filter}", ""); + } + } + "vertical" => { + if let Some(v) = value.and_then(|v| v.as_str()).filter(|s| !s.is_empty()) { + sql = sql.replace("{vertical_filter}", &format!("AND j.vertical = '{}'", v.replace('\'', "''"))); + } else { + sql = sql.replace("{vertical_filter}", ""); + } + } + "status" => { + if let Some(v) = value.and_then(|v| v.as_str()).filter(|s| !s.is_empty()) { + sql = sql.replace("{status_filter}", &format!("AND status = '{}'", v.replace('\'', "''"))); + } else { + sql = sql.replace("{status_filter}", ""); + } + } + "min_years" => { + if let Some(v) = value.and_then(|v| v.as_i64()) { + sql = sql.replace("{years_filter}", &format!("AND years_experience >= {v}")); + } else { + sql = sql.replace("{years_filter}", ""); + } + } + // Direct substitution for simple params + _ => { + let placeholder = format!("{{{}}}", param_def.name); + if sql.contains(&placeholder) { + let val_str = match value { + Some(serde_json::Value::String(s)) => s.replace('\'', "''"), + Some(serde_json::Value::Number(n)) => n.to_string(), + Some(serde_json::Value::Bool(b)) => b.to_string(), + Some(v) => v.to_string(), + None if param_def.required => return Err(format!("missing required param: {}", param_def.name)), + None => continue, + }; + sql = sql.replace(&placeholder, &val_str); + } + } + } + } + + Ok(sql) + } + + /// Log a tool invocation. + pub async fn log_invocation(&self, inv: ToolInvocation) { + self.audit_log.write().await.push(inv); + } + + /// Get recent audit log. + pub async fn recent_audit(&self, limit: usize) -> Vec { + let log = self.audit_log.read().await; + let start = log.len().saturating_sub(limit); + log[start..].iter().rev().cloned().collect() + } +} diff --git a/crates/gateway/src/tools/service.rs b/crates/gateway/src/tools/service.rs new file mode 100644 index 0000000..408b2bb --- /dev/null +++ b/crates/gateway/src/tools/service.rs @@ -0,0 +1,148 @@ +use axum::{ + Json, Router, + extract::{Path, Query, State}, + http::StatusCode, + response::IntoResponse, + routing::{get, post}, +}; +use serde::Deserialize; + +use super::registry::{Permission, ToolInvocation, ToolRegistry}; +use crate::tools::ToolState; + +pub fn router(state: ToolState) -> Router { + Router::new() + .route("/", get(list_tools)) + .route("/{name}", get(get_tool)) + .route("/{name}/call", post(call_tool)) + .route("/audit", get(audit_log)) + .with_state(state) +} + +/// List all available tools (for LLM tool-use discovery). +async fn list_tools(State(state): State) -> impl IntoResponse { + let tools = state.registry.list_tools().await; + // Return in MCP-compatible format + let tool_list: Vec = tools.iter().map(|t| { + serde_json::json!({ + "name": t.name, + "description": t.description, + "permission": t.permission, + "category": t.category, + "parameters": t.parameters.iter().map(|p| { + serde_json::json!({ + "name": p.name, + "type": p.param_type, + "required": p.required, + "description": p.description, + "default": p.default, + }) + }).collect::>(), + }) + }).collect(); + Json(tool_list) +} + +/// Get a specific tool definition. +async fn get_tool( + State(state): State, + Path(name): Path, +) -> impl IntoResponse { + match state.registry.get_tool(&name).await { + Some(tool) => Ok(Json(tool)), + None => Err((StatusCode::NOT_FOUND, format!("tool not found: {name}"))), + } +} + +/// Call a tool — validate params, execute SQL, log invocation. +#[derive(Deserialize)] +struct CallRequest { + params: serde_json::Value, + agent: String, +} + +async fn call_tool( + State(state): State, + Path(name): Path, + Json(req): Json, +) -> impl IntoResponse { + let tool = match state.registry.get_tool(&name).await { + Some(t) => t, + None => return Err((StatusCode::NOT_FOUND, format!("tool not found: {name}"))), + }; + + tracing::info!("tool call: {} by agent '{}' with {:?}", name, req.agent, req.params); + + // Build SQL from params + let sql = match ToolRegistry::build_sql(&tool, &req.params) { + Ok(s) => s, + Err(e) => { + state.registry.log_invocation(ToolInvocation { + id: format!("inv-{}", chrono::Utc::now().timestamp_millis()), + tool_name: name.clone(), + agent: req.agent.clone(), + params: req.params.clone(), + permission: tool.permission.clone(), + timestamp: chrono::Utc::now(), + success: false, + error: Some(e.clone()), + rows_returned: None, + }).await; + return Err((StatusCode::BAD_REQUEST, e)); + } + }; + + // Execute via query engine + let result = state.query_fn.execute(&sql).await; + + match result { + Ok((rows, row_count)) => { + state.registry.log_invocation(ToolInvocation { + id: format!("inv-{}", chrono::Utc::now().timestamp_millis()), + tool_name: name.clone(), + agent: req.agent, + params: req.params, + permission: tool.permission, + timestamp: chrono::Utc::now(), + success: true, + error: None, + rows_returned: Some(row_count), + }).await; + + Ok(Json(serde_json::json!({ + "tool": name, + "rows": rows, + "row_count": row_count, + "sql": sql, + }))) + } + Err(e) => { + state.registry.log_invocation(ToolInvocation { + id: format!("inv-{}", chrono::Utc::now().timestamp_millis()), + tool_name: name, + agent: req.agent, + params: req.params, + permission: tool.permission, + timestamp: chrono::Utc::now(), + success: false, + error: Some(e.clone()), + rows_returned: None, + }).await; + + Err((StatusCode::INTERNAL_SERVER_ERROR, e)) + } + } +} + +#[derive(Deserialize)] +struct AuditQuery { + limit: Option, +} + +async fn audit_log( + State(state): State, + Query(q): Query, +) -> impl IntoResponse { + let log = state.registry.recent_audit(q.limit.unwrap_or(50)).await; + Json(log) +}