lakehouse/crates/aibridge/src/tree_split.rs
root cdc24d8bd0
Some checks failed
lakehouse/auditor 1 blocking issue: todo!() macro call in tests/real-world/scrum_master_pipeline.ts
shared: build ModelMatrix — migrate 5 call sites off deprecated estimate_tokens
The `aibridge::context::estimate_tokens` deprecation has been pointing
at `shared::model_matrix::ModelMatrix::estimate_tokens` for a while,
but that module didn't exist — so the deprecation was aspirational
noise, not actionable guidance.

Built the minimal target: `shared::model_matrix::ModelMatrix` with
an associated `estimate_tokens(text: &str) -> usize` method. Same
chars/4 ceiling heuristic as the deprecated helper. 6 tests cover
empty/3/4/5-char cases, multi-byte UTF-8 (emoji count as 1 char each),
and linear scaling to 400-char inputs.

Migrated 5 call sites:
  - aibridge/context.rs:88 — opts.system token count
  - aibridge/context.rs:89 — prompt token count
  - aibridge/tree_split.rs:22 — import (now uses ModelMatrix)
  - aibridge/tree_split.rs:84, 89 — truncate_scratchpad budget loop
  - aibridge/tree_split.rs:282 — scratchpad post-truncation assertion
  - aibridge/context.rs:183 — system-prompt budget test

Also cleaned up two parallel test warnings:
  - aibridge/context.rs legacy estimate_tokens_ceiling_divides_by_four
    test deleted (ModelMatrix's tests cover the same behavior now).
  - vectord/playbook_memory.rs:1650 unused_mut on e_alive.

Net workspace warning count: 11 → 0 (including --tests build).

The deprecated `estimate_tokens` wrapper stays in aibridge/context.rs
for external callers. Future commits can remove it entirely once no
public API surface still references it.

The applier's warning-count gate now has a floor of 0 — any future
patch that introduces a single warning trips the gate automatically.
Previously a floor of 11 tolerated noise.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-24 06:32:16 -05:00

328 lines
14 KiB
Rust

//! Phase 21 — INPUT-overflow handler. Ports `generateTreeSplit` from
//! `tests/multi-agent/agent.ts`.
//!
//! When the input corpus exceeds the model's window (200 playbooks
//! pasted into a T4 strategic prompt, a long retrospective digest, a
//! cross-corpus summarization), raising `max_tokens` doesn't help —
//! the prompt itself is the problem. The answer is map-reduce:
//!
//! 1. Caller shards the input at semantic boundaries (records,
//! paragraphs, playbook entries).
//! 2. For each shard, build a map prompt that includes the running
//! scratchpad and run it through `generate_continuable`.
//! 3. Append the map output to the scratchpad (oldest-first
//! truncation when it outgrows `scratchpad_budget`).
//! 4. Build a reduce prompt from the final scratchpad and run it.
//!
//! Every shard prompt and the reduce prompt go through
//! `assert_context_budget` first — if a single shard still overflows
//! we bubble the error up rather than silently truncating. That's the
//! whole point of Phase 21.
use crate::context::{assert_context_budget, BudgetOpts, overflow_message,
DEFAULT_MAX_TOKENS, DEFAULT_SAFETY_MARGIN};
use shared::model_matrix::ModelMatrix;
use crate::continuation::{generate_continuable, ContinuableOpts, ResponseShape, TextGenerator};
/// Callback signatures — caller supplies closures that stitch the
/// scratchpad into each shard's prompt and build the final reduce
/// prompt. Kept as `Fn` (not `FnMut`) so the map loop can call them
/// by reference.
pub type MapPromptFn<'a> = dyn Fn(&str, &str) -> String + Send + Sync + 'a;
pub type ReducePromptFn<'a> = dyn Fn(&str) -> String + Send + Sync + 'a;
/// Knobs for `generate_tree_split`.
#[derive(Debug, Clone)]
pub struct TreeSplitOpts {
pub model: String,
pub system: Option<String>,
pub temperature: Option<f64>,
/// max_tokens for map AND reduce (reduce defaults are usually
/// higher; caller overrides for just reduce by calling through
/// continuable directly if needed).
pub max_tokens: Option<u32>,
pub reduce_max_tokens: Option<u32>,
pub think: Option<bool>,
/// Soft ceiling on scratchpad size (estimated tokens). When it
/// grows past this, the oldest shard digest gets dropped. Default
/// 6000, matching the TS implementation.
pub scratchpad_budget: usize,
pub safety_margin: Option<usize>,
}
impl TreeSplitOpts {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
system: None,
temperature: None,
max_tokens: None,
reduce_max_tokens: None,
think: None,
scratchpad_budget: 6_000,
safety_margin: None,
}
}
}
/// Result — final reduce response plus the accumulated scratchpad so
/// the caller can inspect what was kept vs truncated.
#[derive(Debug, Clone)]
pub struct TreeSplitResult {
pub response: String,
pub scratchpad: String,
pub shards_processed: usize,
pub scratchpad_truncations: usize,
pub total_continuations: usize,
}
/// Drop shard-digest blocks from the head of `scratchpad` until its
/// estimated-token count fits the budget. Digest blocks are delimited
/// by `\n— shard N/M digest —\n` so we can find the first one and
/// chop everything before its successor.
fn truncate_scratchpad(scratchpad: &mut String, budget_tokens: usize) -> bool {
if ModelMatrix::estimate_tokens(scratchpad) <= budget_tokens { return false; }
// Find the second delimiter — everything before it gets dropped.
const DELIM_PREFIX: &str = "\n— shard ";
let mut cursor = 0;
let mut truncated = false;
while ModelMatrix::estimate_tokens(&scratchpad[cursor..]) > budget_tokens {
// Skip past a leading delimiter (if we're sitting on one from
// a previous iteration), then find the next.
let search_from = cursor + if scratchpad[cursor..].starts_with(DELIM_PREFIX) {
DELIM_PREFIX.len()
} else { 0 };
let Some(rel_next) = scratchpad[search_from..].find(DELIM_PREFIX) else { break };
cursor = search_from + rel_next;
truncated = true;
}
if cursor > 0 {
scratchpad.drain(..cursor);
}
truncated
}
/// Phase 21 — map-reduce over shards with a running scratchpad. See
/// module docs.
pub async fn generate_tree_split<G: TextGenerator>(
generator: &G,
shards: &[String],
map_prompt: &MapPromptFn<'_>,
reduce_prompt: &ReducePromptFn<'_>,
opts: &TreeSplitOpts,
) -> Result<TreeSplitResult, String> {
let mut scratchpad = String::new();
let safety = opts.safety_margin.unwrap_or(DEFAULT_SAFETY_MARGIN);
let map_max = opts.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS as u32);
let reduce_max = opts.reduce_max_tokens.unwrap_or(1_500);
let mut truncations = 0usize;
let mut total_continuations = 0usize;
for (i, shard) in shards.iter().enumerate() {
let shard_prompt = map_prompt(shard, &scratchpad);
// Loud-fail on per-shard overflow — caller sharded too
// coarsely. Silent truncation is exactly the mode J rejected.
let budget = BudgetOpts {
system: opts.system.as_deref(),
max_tokens: Some(map_max as usize),
safety_margin: Some(safety),
bypass: false,
};
let check = assert_context_budget(&opts.model, &shard_prompt, budget)
.map_err(|(c, over)| overflow_message(&opts.model, &c, over, safety))?;
let _ = check;
let mut cont_opts = ContinuableOpts::new(&opts.model);
cont_opts.max_tokens = Some(map_max);
cont_opts.temperature = opts.temperature;
cont_opts.system = opts.system.clone();
cont_opts.shape = ResponseShape::Text;
cont_opts.think = opts.think;
let outcome = generate_continuable(generator, &shard_prompt, &cont_opts).await?;
total_continuations += outcome.continuations;
// Append this shard's digest and, if needed, drop oldest.
scratchpad.push_str(&format!(
"\n— shard {}/{} digest —\n{}",
i + 1, shards.len(), outcome.text.trim(),
));
if truncate_scratchpad(&mut scratchpad, opts.scratchpad_budget) {
truncations += 1;
}
}
// Reduce pass. Budget check first — if the scratchpad is still too
// big for the reduce prompt we fail loud with numbers.
let reduce_p = reduce_prompt(&scratchpad);
let budget = BudgetOpts {
system: opts.system.as_deref(),
max_tokens: Some(reduce_max as usize),
safety_margin: Some(safety),
bypass: false,
};
assert_context_budget(&opts.model, &reduce_p, budget)
.map_err(|(c, over)| overflow_message(&opts.model, &c, over, safety))?;
let mut cont_opts = ContinuableOpts::new(&opts.model);
cont_opts.max_tokens = Some(reduce_max);
cont_opts.temperature = opts.temperature;
cont_opts.system = opts.system.clone();
cont_opts.shape = ResponseShape::Text;
cont_opts.think = opts.think;
let outcome = generate_continuable(generator, &reduce_p, &cont_opts).await?;
total_continuations += outcome.continuations;
Ok(TreeSplitResult {
response: outcome.text,
scratchpad,
shards_processed: shards.len(),
scratchpad_truncations: truncations,
total_continuations,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::continuation::ScriptedGenerator;
fn simple_map(shard: &str, scratchpad: &str) -> String {
format!("SCRATCHPAD:\n{scratchpad}\n---\nSHARD:\n{shard}\n---\nDIGEST:")
}
fn simple_reduce(scratchpad: &str) -> String {
format!("SCRATCHPAD:\n{scratchpad}\n---\nFINAL:")
}
#[tokio::test]
async fn tree_split_runs_map_then_reduce() {
// 3 shards → 3 map calls → 1 reduce call = 4 responses scripted.
let generator = ScriptedGenerator::new(vec![
Ok("digest-1".to_string()),
Ok("digest-2".to_string()),
Ok("digest-3".to_string()),
Ok("FINAL ANSWER".to_string()),
]);
let shards: Vec<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
let opts = TreeSplitOpts::new("qwen3:latest");
let map_fn: &MapPromptFn = &simple_map;
let reduce_fn: &ReducePromptFn = &simple_reduce;
let result = generate_tree_split(&generator, &shards, map_fn, reduce_fn, &opts)
.await
.unwrap();
assert_eq!(result.shards_processed, 3);
assert_eq!(result.response, "FINAL ANSWER");
assert_eq!(generator.call_count(), 4);
// Scratchpad must carry all three digests in order.
assert!(result.scratchpad.contains("digest-1"));
assert!(result.scratchpad.contains("digest-2"));
assert!(result.scratchpad.contains("digest-3"));
}
#[tokio::test]
async fn tree_split_reduce_prompt_sees_full_scratchpad() {
let generator = ScriptedGenerator::new(vec![
Ok("summary-A".to_string()),
Ok("summary-B".to_string()),
Ok("REDUCED".to_string()),
]);
let shards = vec!["input-one".to_string(), "input-two".to_string()];
let opts = TreeSplitOpts::new("qwen3:latest");
let _ = generate_tree_split(&generator, &shards, &simple_map, &simple_reduce, &opts)
.await
.unwrap();
// Third call = reduce. Its prompt must include both digests.
let calls = generator.calls();
let reduce_prompt = &calls[2].prompt;
assert!(reduce_prompt.contains("summary-A"),
"reduce prompt must see first shard digest");
assert!(reduce_prompt.contains("summary-B"),
"reduce prompt must see second shard digest");
}
#[tokio::test]
async fn tree_split_loud_fails_on_shard_overflow() {
let generator = ScriptedGenerator::new(vec![Ok("digest".to_string())]);
// One gigantic shard — well over qwen3's 40K window even as a
// prompt. The budget check must reject before any generate call.
let shards = vec!["x".repeat(200_000)];
let opts = TreeSplitOpts::new("qwen3:latest");
let err = generate_tree_split(&generator, &shards, &simple_map, &simple_reduce, &opts)
.await
.expect_err("shard-sized overflow must be rejected");
assert!(err.contains("overflow"), "error should mention overflow: {err}");
assert_eq!(generator.call_count(), 0, "generate must not be called on overflow");
}
#[tokio::test]
async fn tree_split_truncates_scratchpad_when_over_budget() {
// Tight budget so each shard trips truncation. qwen3's 40K
// window is fine; the budget we care about is the scratchpad
// cap, not the model window.
let generator = ScriptedGenerator::new(vec![
Ok("A".repeat(2_000)),
Ok("B".repeat(2_000)),
Ok("C".repeat(2_000)),
Ok("D".repeat(2_000)),
Ok("FINAL".to_string()),
]);
let shards: Vec<String> = (0..4).map(|i| format!("shard{i}")).collect();
let mut opts = TreeSplitOpts::new("qwen3:latest");
opts.scratchpad_budget = 1_000; // ~4000 chars — one digest barely fits
let result = generate_tree_split(&generator, &shards, &simple_map, &simple_reduce, &opts)
.await
.unwrap();
assert!(result.scratchpad_truncations > 0,
"tight budget must trigger truncation");
// Scratchpad should still fit roughly within the budget
// (post-truncation); the estimator uses chars/4 so the bound
// is ~budget*4 chars. Give some slack for the delimiter.
let scratchpad_tokens = ModelMatrix::estimate_tokens(&result.scratchpad);
assert!(scratchpad_tokens <= opts.scratchpad_budget * 2,
"scratchpad {} tokens vs budget {}", scratchpad_tokens, opts.scratchpad_budget);
}
#[tokio::test]
async fn tree_split_reports_continuations_from_map_and_reduce() {
// First shard: truncated-then-continued. Reduce: truncated-then-continued.
// 1 shard: 2 map calls (initial + continuation), then 2 reduce calls.
let generator = ScriptedGenerator::new(vec![
Ok("partial".to_string()), // map shape=text, non-empty → complete on first pass
Ok("reduce-out".to_string()),
]);
let shards = vec!["only".to_string()];
let opts = TreeSplitOpts::new("qwen3:latest");
let result = generate_tree_split(&generator, &shards, &simple_map, &simple_reduce, &opts)
.await
.unwrap();
// Text shape treats non-empty as complete → 0 continuations.
assert_eq!(result.total_continuations, 0);
assert_eq!(result.shards_processed, 1);
}
#[test]
fn truncate_scratchpad_noop_when_under_budget() {
let mut s = "\n— shard 1/1 digest —\nshort".to_string();
let truncated = truncate_scratchpad(&mut s, 1_000);
assert!(!truncated);
assert!(s.contains("short"));
}
#[test]
fn truncate_scratchpad_drops_oldest_first() {
let mut s = format!(
"\n— shard 1/3 digest —\n{}\n— shard 2/3 digest —\n{}\n— shard 3/3 digest —\nshort",
"x".repeat(4_000), // ~1000 tokens
"y".repeat(4_000), // ~1000 tokens
);
let truncated = truncate_scratchpad(&mut s, 500); // ~2000 chars
assert!(truncated);
assert!(!s.contains(&"x".repeat(4_000)),
"oldest digest should be dropped");
assert!(s.contains("short"),
"newest digest should survive");
}
}