root 56844c3f31 embed cache — LRU at /v1/embed for repeat-query elimination
Adds CachedProvider wrapping the embedding Provider with a thread-safe
LRU keyed on (effective_model, sha256(text)) → []float32. Repeat
queries return the stored vector without round-tripping to Ollama.

Why this matters: the staffing 500K test (memory project_golang_lakehouse)
documented that the staffing co-pilot replays many of the same query
texts ("forklift driver IL", "welder Chicago", "warehouse safety", etc).
Each repeat paid the ~50ms Ollama round-trip. Cached repeats now serve
in <1µs (LRU lookup + sha256 of input).

Memory budget: ~3 KiB per entry at d=768. Default 10K entries ≈ 30 MiB.
Configurable via [embedd].cache_size; 0 disables (pass-through mode).

Per-text caching, not per-batch — a batch with mixed hits/misses only
fetches the misses upstream, then merges the result preserving caller
input order. Three-text batch with one miss = one upstream call for
that one text instead of three.

Implementation:
  internal/embed/cached.go (NEW, 150 LoC)
    CachedProvider implements Provider; uses hashicorp/golang-lru/v2.
    Key shape: "<model>:<sha256-hex>". Empty model resolves to
    defaultModel (request-derived) for the key — NOT res.Model
    (upstream-derived), so future requests with same input shape
    hit the same key. Caught by TestCachedProvider_EmptyModelResolvesToDefault.
    Atomic hit/miss counters + Stats() + HitRate() + Len().

  internal/embed/cached_test.go (NEW, 12 test funcs)
    Pass-through-when-zero, hit-on-repeat, mixed-batch only fetches
    misses, model-key isolation, empty-model resolves to default,
    LRU eviction at cap, error propagation, all-hits synthesized
    without upstream call, hit-rate accumulation, empty-texts
    rejected, concurrent-safe (50 goroutines × 100 calls), key
    stability + distinctness.

  internal/shared/config.go
    EmbeddConfig.CacheSize (toml: cache_size). Default 10000.

  cmd/embedd/main.go
    Wraps Ollama Provider with CachedProvider on startup. Adds
    /embed/stats endpoint exposing hits / misses / hit_rate / size.
    Operators check the rate to confirm the cache is working
    (high rate = good) or sized wrong (low rate + many misses on a
    workload that should have repeats).

  cmd/embedd/main_test.go
    Stats endpoint tests — disabled mode shape, enabled mode tracks
    hits + misses across repeat calls.

One real bug caught by my own test:
  Initial implementation cached under res.Model (upstream-resolved)
  rather than effectiveModel (request-resolved). A request with
  model="" caching under "test-model" (Ollama's default), then a
  request with model="the-default" (our config default) missing
  the cache. Fix: always use the request-derived effectiveModel
  for keys; that's the predictable side. Locked by
  TestCachedProvider_EmptyModelResolvesToDefault.

Verified:
  go test -count=1 ./internal/embed/  — all 12 cached tests + 6 ollama tests green
  go test -count=1 ./cmd/embedd/      — stats endpoint tests green
  just verify                          — vet + test + 9 smokes 33s

Production benefit:
  ~50ms Ollama round-trip → <1µs cache lookup for cached entries.
  At 10K-entry default + ~30% repeat rate (typical staffing co-pilot
  workload), saves several seconds per staffer-query session.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-29 06:54:30 -05:00

181 lines
5.7 KiB
Go

// cached.go — LRU caching wrapper around a Provider.
//
// Memoizes (effective_model, sha256(text)) → []float32. Repeat
// queries return the stored vector without round-tripping to the
// upstream embedding service. Per-text caching: a batch with
// mixed hit/miss only fetches the misses, then merges the result
// preserving the caller's input order.
//
// Memory budget: ~3 KB per entry at d=768. 10K-entry default ≈ 30 MB
// — small enough for any realistic embedd deployment. Operators
// raising the cap should weigh memory headroom against expected
// hit rate.
//
// Why this exists: 500K staffing test (memory project_golang_lakehouse)
// showed that the staffing co-pilot replays many of the same query
// texts ("forklift driver", "welder Chicago", etc.). Caching them
// drops repeat-query cost from ~50ms (Ollama round-trip) to <1µs
// (LRU hit). Real production win documented in feedback_meta_index_vision.
package embed
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"sync/atomic"
lru "github.com/hashicorp/golang-lru/v2"
)
// CachedProvider wraps a Provider with an LRU cache keyed on
// (effective_model, sha256(text)). Thread-safe.
type CachedProvider struct {
inner Provider
defaultModel string
cache *lru.Cache[string, []float32]
hits atomic.Int64
misses atomic.Int64
}
// NewCachedProvider wraps inner with an LRU cache of the given size.
//
// defaultModel is used to resolve cache keys when a request leaves
// the model field empty: a request for model="" is treated as
// model=defaultModel for cache-key purposes, so callers that mix
// "" and the explicit model name still hit the same cache entry.
//
// Caller-side panic protection: size <= 0 is treated as "no cache"
// — every Embed call passes through. Avoids forcing operators to
// understand LRU sizing to disable.
func NewCachedProvider(inner Provider, defaultModel string, size int) (*CachedProvider, error) {
if size <= 0 {
// Sentinel: nil cache means pass-through. NewCachedProvider
// stays callable so the wiring layer can always wrap.
return &CachedProvider{inner: inner, defaultModel: defaultModel}, nil
}
cache, err := lru.New[string, []float32](size)
if err != nil {
return nil, fmt.Errorf("embed cache init: %w", err)
}
return &CachedProvider{
inner: inner,
defaultModel: defaultModel,
cache: cache,
}, nil
}
// Embed returns vectors for texts, memoizing per (model, text). On
// a batch with mixed hits/misses, only the misses round-trip to
// inner; the result preserves caller ordering.
func (c *CachedProvider) Embed(ctx context.Context, texts []string, model string) (Result, error) {
// Pass-through when caching disabled.
if c.cache == nil {
return c.inner.Embed(ctx, texts, model)
}
if len(texts) == 0 {
return Result{}, ErrEmptyTexts
}
effectiveModel := model
if effectiveModel == "" {
effectiveModel = c.defaultModel
}
// Pass 1: cache lookup; collect misses preserving original index
// so we can write the upstream result back into the right slots.
out := make([][]float32, len(texts))
missTexts := make([]string, 0, len(texts))
missIdx := make([]int, 0, len(texts))
for i, t := range texts {
key := cacheKey(effectiveModel, t)
if v, ok := c.cache.Get(key); ok {
out[i] = v
c.hits.Add(1)
continue
}
missTexts = append(missTexts, t)
missIdx = append(missIdx, i)
c.misses.Add(1)
}
// All hits — synthesize the result without an upstream call.
// Use the effective model + the first cached vector's length
// for the response. Every cached vector for the same model has
// the same dimension by construction (Provider guarantees it).
if len(missTexts) == 0 {
dim := 0
if len(out) > 0 && len(out[0]) > 0 {
dim = len(out[0])
}
return Result{
Model: effectiveModel,
Dimension: dim,
Vectors: out,
}, nil
}
// Pass 2: fetch the misses, populate cache, merge.
res, err := c.inner.Embed(ctx, missTexts, model)
if err != nil {
return Result{}, err
}
if len(res.Vectors) != len(missTexts) {
return Result{}, fmt.Errorf("embed cache: provider returned %d vectors for %d miss texts",
len(res.Vectors), len(missTexts))
}
// Cache under effectiveModel (request-derived), NOT res.Model
// (upstream-derived). Future requests with the same input shape
// — same explicit model OR the same "" — must hit the same key.
// Only ours is predictable from the request; upstream's resolution
// can drift if the upstream's default changes.
for j, t := range missTexts {
out[missIdx[j]] = res.Vectors[j]
c.cache.Add(cacheKey(effectiveModel, t), res.Vectors[j])
}
return Result{
Model: res.Model,
Dimension: res.Dimension,
Vectors: out,
}, nil
}
// Stats returns lifetime hit + miss counts for this cache. Atomic
// reads — no locking. Useful for /health, /metrics, or operator
// dashboards.
func (c *CachedProvider) Stats() (hits, misses int64) {
return c.hits.Load(), c.misses.Load()
}
// HitRate returns the fraction of requests served from cache.
// Returns 0.0 when no requests have been served (avoids NaN).
func (c *CachedProvider) HitRate() float64 {
h := c.hits.Load()
m := c.misses.Load()
total := h + m
if total == 0 {
return 0.0
}
return float64(h) / float64(total)
}
// Len returns the current entry count in the cache. Returns 0 when
// caching is disabled.
func (c *CachedProvider) Len() int {
if c.cache == nil {
return 0
}
return c.cache.Len()
}
// cacheKey is "<model>:<sha256(text)>". sha256 collapses long texts
// to a fixed-size key; model prefix scopes cache to one model so
// callers using multiple models don't get cross-contamination.
func cacheKey(model, text string) string {
h := sha256.Sum256([]byte(text))
return model + ":" + hex.EncodeToString(h[:])
}