//! Phase 21 — INPUT-overflow handler. Ports `generateTreeSplit` from //! `tests/multi-agent/agent.ts`. //! //! When the input corpus exceeds the model's window (200 playbooks //! pasted into a T4 strategic prompt, a long retrospective digest, a //! cross-corpus summarization), raising `max_tokens` doesn't help — //! the prompt itself is the problem. The answer is map-reduce: //! //! 1. Caller shards the input at semantic boundaries (records, //! paragraphs, playbook entries). //! 2. For each shard, build a map prompt that includes the running //! scratchpad and run it through `generate_continuable`. //! 3. Append the map output to the scratchpad (oldest-first //! truncation when it outgrows `scratchpad_budget`). //! 4. Build a reduce prompt from the final scratchpad and run it. //! //! Every shard prompt and the reduce prompt go through //! `assert_context_budget` first — if a single shard still overflows //! we bubble the error up rather than silently truncating. That's the //! whole point of Phase 21. use crate::context::{assert_context_budget, BudgetOpts, overflow_message, DEFAULT_MAX_TOKENS, DEFAULT_SAFETY_MARGIN}; use shared::model_matrix::ModelMatrix; use crate::continuation::{generate_continuable, ContinuableOpts, ResponseShape, TextGenerator}; /// Callback signatures — caller supplies closures that stitch the /// scratchpad into each shard's prompt and build the final reduce /// prompt. Kept as `Fn` (not `FnMut`) so the map loop can call them /// by reference. pub type MapPromptFn<'a> = dyn Fn(&str, &str) -> String + Send + Sync + 'a; pub type ReducePromptFn<'a> = dyn Fn(&str) -> String + Send + Sync + 'a; /// Knobs for `generate_tree_split`. #[derive(Debug, Clone)] pub struct TreeSplitOpts { pub model: String, pub system: Option, pub temperature: Option, /// max_tokens for map AND reduce (reduce defaults are usually /// higher; caller overrides for just reduce by calling through /// continuable directly if needed). pub max_tokens: Option, pub reduce_max_tokens: Option, pub think: Option, /// Soft ceiling on scratchpad size (estimated tokens). When it /// grows past this, the oldest shard digest gets dropped. Default /// 6000, matching the TS implementation. pub scratchpad_budget: usize, pub safety_margin: Option, } impl TreeSplitOpts { pub fn new(model: impl Into) -> Self { Self { model: model.into(), system: None, temperature: None, max_tokens: None, reduce_max_tokens: None, think: None, scratchpad_budget: 6_000, safety_margin: None, } } } /// Result — final reduce response plus the accumulated scratchpad so /// the caller can inspect what was kept vs truncated. #[derive(Debug, Clone)] pub struct TreeSplitResult { pub response: String, pub scratchpad: String, pub shards_processed: usize, pub scratchpad_truncations: usize, pub total_continuations: usize, } /// Drop shard-digest blocks from the head of `scratchpad` until its /// estimated-token count fits the budget. Digest blocks are delimited /// by `\n— shard N/M digest —\n` so we can find the first one and /// chop everything before its successor. fn truncate_scratchpad(scratchpad: &mut String, budget_tokens: usize) -> bool { if ModelMatrix::estimate_tokens(scratchpad) <= budget_tokens { return false; } // Find the second delimiter — everything before it gets dropped. const DELIM_PREFIX: &str = "\n— shard "; let mut cursor = 0; let mut truncated = false; while ModelMatrix::estimate_tokens(&scratchpad[cursor..]) > budget_tokens { // Skip past a leading delimiter (if we're sitting on one from // a previous iteration), then find the next. let search_from = cursor + if scratchpad[cursor..].starts_with(DELIM_PREFIX) { DELIM_PREFIX.len() } else { 0 }; let Some(rel_next) = scratchpad[search_from..].find(DELIM_PREFIX) else { break }; cursor = search_from + rel_next; truncated = true; } if cursor > 0 { scratchpad.drain(..cursor); } truncated } /// Phase 21 — map-reduce over shards with a running scratchpad. See /// module docs. pub async fn generate_tree_split( generator: &G, shards: &[String], map_prompt: &MapPromptFn<'_>, reduce_prompt: &ReducePromptFn<'_>, opts: &TreeSplitOpts, ) -> Result { let mut scratchpad = String::new(); let safety = opts.safety_margin.unwrap_or(DEFAULT_SAFETY_MARGIN); let map_max = opts.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS as u32); let reduce_max = opts.reduce_max_tokens.unwrap_or(1_500); let mut truncations = 0usize; let mut total_continuations = 0usize; for (i, shard) in shards.iter().enumerate() { let shard_prompt = map_prompt(shard, &scratchpad); // Loud-fail on per-shard overflow — caller sharded too // coarsely. Silent truncation is exactly the mode J rejected. let budget = BudgetOpts { system: opts.system.as_deref(), max_tokens: Some(map_max as usize), safety_margin: Some(safety), bypass: false, }; let check = assert_context_budget(&opts.model, &shard_prompt, budget) .map_err(|(c, over)| overflow_message(&opts.model, &c, over, safety))?; let _ = check; let mut cont_opts = ContinuableOpts::new(&opts.model); cont_opts.max_tokens = Some(map_max); cont_opts.temperature = opts.temperature; cont_opts.system = opts.system.clone(); cont_opts.shape = ResponseShape::Text; cont_opts.think = opts.think; let outcome = generate_continuable(generator, &shard_prompt, &cont_opts).await?; total_continuations += outcome.continuations; // Append this shard's digest and, if needed, drop oldest. scratchpad.push_str(&format!( "\n— shard {}/{} digest —\n{}", i + 1, shards.len(), outcome.text.trim(), )); if truncate_scratchpad(&mut scratchpad, opts.scratchpad_budget) { truncations += 1; } } // Reduce pass. Budget check first — if the scratchpad is still too // big for the reduce prompt we fail loud with numbers. let reduce_p = reduce_prompt(&scratchpad); let budget = BudgetOpts { system: opts.system.as_deref(), max_tokens: Some(reduce_max as usize), safety_margin: Some(safety), bypass: false, }; assert_context_budget(&opts.model, &reduce_p, budget) .map_err(|(c, over)| overflow_message(&opts.model, &c, over, safety))?; let mut cont_opts = ContinuableOpts::new(&opts.model); cont_opts.max_tokens = Some(reduce_max); cont_opts.temperature = opts.temperature; cont_opts.system = opts.system.clone(); cont_opts.shape = ResponseShape::Text; cont_opts.think = opts.think; let outcome = generate_continuable(generator, &reduce_p, &cont_opts).await?; total_continuations += outcome.continuations; Ok(TreeSplitResult { response: outcome.text, scratchpad, shards_processed: shards.len(), scratchpad_truncations: truncations, total_continuations, }) } #[cfg(test)] mod tests { use super::*; use crate::continuation::ScriptedGenerator; fn simple_map(shard: &str, scratchpad: &str) -> String { format!("SCRATCHPAD:\n{scratchpad}\n---\nSHARD:\n{shard}\n---\nDIGEST:") } fn simple_reduce(scratchpad: &str) -> String { format!("SCRATCHPAD:\n{scratchpad}\n---\nFINAL:") } #[tokio::test] async fn tree_split_runs_map_then_reduce() { // 3 shards → 3 map calls → 1 reduce call = 4 responses scripted. let generator = ScriptedGenerator::new(vec![ Ok("digest-1".to_string()), Ok("digest-2".to_string()), Ok("digest-3".to_string()), Ok("FINAL ANSWER".to_string()), ]); let shards: Vec = ["a", "b", "c"].iter().map(|s| s.to_string()).collect(); let opts = TreeSplitOpts::new("qwen3:latest"); let map_fn: &MapPromptFn = &simple_map; let reduce_fn: &ReducePromptFn = &simple_reduce; let result = generate_tree_split(&generator, &shards, map_fn, reduce_fn, &opts) .await .unwrap(); assert_eq!(result.shards_processed, 3); assert_eq!(result.response, "FINAL ANSWER"); assert_eq!(generator.call_count(), 4); // Scratchpad must carry all three digests in order. assert!(result.scratchpad.contains("digest-1")); assert!(result.scratchpad.contains("digest-2")); assert!(result.scratchpad.contains("digest-3")); } #[tokio::test] async fn tree_split_reduce_prompt_sees_full_scratchpad() { let generator = ScriptedGenerator::new(vec![ Ok("summary-A".to_string()), Ok("summary-B".to_string()), Ok("REDUCED".to_string()), ]); let shards = vec!["input-one".to_string(), "input-two".to_string()]; let opts = TreeSplitOpts::new("qwen3:latest"); let _ = generate_tree_split(&generator, &shards, &simple_map, &simple_reduce, &opts) .await .unwrap(); // Third call = reduce. Its prompt must include both digests. let calls = generator.calls(); let reduce_prompt = &calls[2].prompt; assert!(reduce_prompt.contains("summary-A"), "reduce prompt must see first shard digest"); assert!(reduce_prompt.contains("summary-B"), "reduce prompt must see second shard digest"); } #[tokio::test] async fn tree_split_loud_fails_on_shard_overflow() { let generator = ScriptedGenerator::new(vec![Ok("digest".to_string())]); // One gigantic shard — well over qwen3's 40K window even as a // prompt. The budget check must reject before any generate call. let shards = vec!["x".repeat(200_000)]; let opts = TreeSplitOpts::new("qwen3:latest"); let err = generate_tree_split(&generator, &shards, &simple_map, &simple_reduce, &opts) .await .expect_err("shard-sized overflow must be rejected"); assert!(err.contains("overflow"), "error should mention overflow: {err}"); assert_eq!(generator.call_count(), 0, "generate must not be called on overflow"); } #[tokio::test] async fn tree_split_truncates_scratchpad_when_over_budget() { // Tight budget so each shard trips truncation. qwen3's 40K // window is fine; the budget we care about is the scratchpad // cap, not the model window. let generator = ScriptedGenerator::new(vec![ Ok("A".repeat(2_000)), Ok("B".repeat(2_000)), Ok("C".repeat(2_000)), Ok("D".repeat(2_000)), Ok("FINAL".to_string()), ]); let shards: Vec = (0..4).map(|i| format!("shard{i}")).collect(); let mut opts = TreeSplitOpts::new("qwen3:latest"); opts.scratchpad_budget = 1_000; // ~4000 chars — one digest barely fits let result = generate_tree_split(&generator, &shards, &simple_map, &simple_reduce, &opts) .await .unwrap(); assert!(result.scratchpad_truncations > 0, "tight budget must trigger truncation"); // Scratchpad should still fit roughly within the budget // (post-truncation); the estimator uses chars/4 so the bound // is ~budget*4 chars. Give some slack for the delimiter. let scratchpad_tokens = ModelMatrix::estimate_tokens(&result.scratchpad); assert!(scratchpad_tokens <= opts.scratchpad_budget * 2, "scratchpad {} tokens vs budget {}", scratchpad_tokens, opts.scratchpad_budget); } #[tokio::test] async fn tree_split_reports_continuations_from_map_and_reduce() { // First shard: truncated-then-continued. Reduce: truncated-then-continued. // 1 shard: 2 map calls (initial + continuation), then 2 reduce calls. let generator = ScriptedGenerator::new(vec![ Ok("partial".to_string()), // map shape=text, non-empty → complete on first pass Ok("reduce-out".to_string()), ]); let shards = vec!["only".to_string()]; let opts = TreeSplitOpts::new("qwen3:latest"); let result = generate_tree_split(&generator, &shards, &simple_map, &simple_reduce, &opts) .await .unwrap(); // Text shape treats non-empty as complete → 0 continuations. assert_eq!(result.total_continuations, 0); assert_eq!(result.shards_processed, 1); } #[test] fn truncate_scratchpad_noop_when_under_budget() { let mut s = "\n— shard 1/1 digest —\nshort".to_string(); let truncated = truncate_scratchpad(&mut s, 1_000); assert!(!truncated); assert!(s.contains("short")); } #[test] fn truncate_scratchpad_drops_oldest_first() { let mut s = format!( "\n— shard 1/3 digest —\n{}\n— shard 2/3 digest —\n{}\n— shard 3/3 digest —\nshort", "x".repeat(4_000), // ~1000 tokens "y".repeat(4_000), // ~1000 tokens ); let truncated = truncate_scratchpad(&mut s, 500); // ~2000 chars assert!(truncated); assert!(!s.contains(&"x".repeat(4_000)), "oldest digest should be dropped"); assert!(s.contains("short"), "newest digest should survive"); } }