Adds CachedProvider wrapping the embedding Provider with a thread-safe
LRU keyed on (effective_model, sha256(text)) → []float32. Repeat
queries return the stored vector without round-tripping to Ollama.
Why this matters: the staffing 500K test (memory project_golang_lakehouse)
documented that the staffing co-pilot replays many of the same query
texts ("forklift driver IL", "welder Chicago", "warehouse safety", etc).
Each repeat paid the ~50ms Ollama round-trip. Cached repeats now serve
in <1µs (LRU lookup + sha256 of input).
Memory budget: ~3 KiB per entry at d=768. Default 10K entries ≈ 30 MiB.
Configurable via [embedd].cache_size; 0 disables (pass-through mode).
Per-text caching, not per-batch — a batch with mixed hits/misses only
fetches the misses upstream, then merges the result preserving caller
input order. Three-text batch with one miss = one upstream call for
that one text instead of three.
Implementation:
internal/embed/cached.go (NEW, 150 LoC)
CachedProvider implements Provider; uses hashicorp/golang-lru/v2.
Key shape: "<model>:<sha256-hex>". Empty model resolves to
defaultModel (request-derived) for the key — NOT res.Model
(upstream-derived), so future requests with same input shape
hit the same key. Caught by TestCachedProvider_EmptyModelResolvesToDefault.
Atomic hit/miss counters + Stats() + HitRate() + Len().
internal/embed/cached_test.go (NEW, 12 test funcs)
Pass-through-when-zero, hit-on-repeat, mixed-batch only fetches
misses, model-key isolation, empty-model resolves to default,
LRU eviction at cap, error propagation, all-hits synthesized
without upstream call, hit-rate accumulation, empty-texts
rejected, concurrent-safe (50 goroutines × 100 calls), key
stability + distinctness.
internal/shared/config.go
EmbeddConfig.CacheSize (toml: cache_size). Default 10000.
cmd/embedd/main.go
Wraps Ollama Provider with CachedProvider on startup. Adds
/embed/stats endpoint exposing hits / misses / hit_rate / size.
Operators check the rate to confirm the cache is working
(high rate = good) or sized wrong (low rate + many misses on a
workload that should have repeats).
cmd/embedd/main_test.go
Stats endpoint tests — disabled mode shape, enabled mode tracks
hits + misses across repeat calls.
One real bug caught by my own test:
Initial implementation cached under res.Model (upstream-resolved)
rather than effectiveModel (request-resolved). A request with
model="" caching under "test-model" (Ollama's default), then a
request with model="the-default" (our config default) missing
the cache. Fix: always use the request-derived effectiveModel
for keys; that's the predictable side. Locked by
TestCachedProvider_EmptyModelResolvesToDefault.
Verified:
go test -count=1 ./internal/embed/ — all 12 cached tests + 6 ollama tests green
go test -count=1 ./cmd/embedd/ — stats endpoint tests green
just verify — vet + test + 9 smokes 33s
Production benefit:
~50ms Ollama round-trip → <1µs cache lookup for cached entries.
At 10K-entry default + ~30% repeat rate (typical staffing co-pilot
workload), saves several seconds per staffer-query session.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
181 lines
5.7 KiB
Go
181 lines
5.7 KiB
Go
// cached.go — LRU caching wrapper around a Provider.
|
|
//
|
|
// Memoizes (effective_model, sha256(text)) → []float32. Repeat
|
|
// queries return the stored vector without round-tripping to the
|
|
// upstream embedding service. Per-text caching: a batch with
|
|
// mixed hit/miss only fetches the misses, then merges the result
|
|
// preserving the caller's input order.
|
|
//
|
|
// Memory budget: ~3 KB per entry at d=768. 10K-entry default ≈ 30 MB
|
|
// — small enough for any realistic embedd deployment. Operators
|
|
// raising the cap should weigh memory headroom against expected
|
|
// hit rate.
|
|
//
|
|
// Why this exists: 500K staffing test (memory project_golang_lakehouse)
|
|
// showed that the staffing co-pilot replays many of the same query
|
|
// texts ("forklift driver", "welder Chicago", etc.). Caching them
|
|
// drops repeat-query cost from ~50ms (Ollama round-trip) to <1µs
|
|
// (LRU hit). Real production win documented in feedback_meta_index_vision.
|
|
|
|
package embed
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"sync/atomic"
|
|
|
|
lru "github.com/hashicorp/golang-lru/v2"
|
|
)
|
|
|
|
// CachedProvider wraps a Provider with an LRU cache keyed on
|
|
// (effective_model, sha256(text)). Thread-safe.
|
|
type CachedProvider struct {
|
|
inner Provider
|
|
defaultModel string
|
|
cache *lru.Cache[string, []float32]
|
|
hits atomic.Int64
|
|
misses atomic.Int64
|
|
}
|
|
|
|
// NewCachedProvider wraps inner with an LRU cache of the given size.
|
|
//
|
|
// defaultModel is used to resolve cache keys when a request leaves
|
|
// the model field empty: a request for model="" is treated as
|
|
// model=defaultModel for cache-key purposes, so callers that mix
|
|
// "" and the explicit model name still hit the same cache entry.
|
|
//
|
|
// Caller-side panic protection: size <= 0 is treated as "no cache"
|
|
// — every Embed call passes through. Avoids forcing operators to
|
|
// understand LRU sizing to disable.
|
|
func NewCachedProvider(inner Provider, defaultModel string, size int) (*CachedProvider, error) {
|
|
if size <= 0 {
|
|
// Sentinel: nil cache means pass-through. NewCachedProvider
|
|
// stays callable so the wiring layer can always wrap.
|
|
return &CachedProvider{inner: inner, defaultModel: defaultModel}, nil
|
|
}
|
|
cache, err := lru.New[string, []float32](size)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("embed cache init: %w", err)
|
|
}
|
|
return &CachedProvider{
|
|
inner: inner,
|
|
defaultModel: defaultModel,
|
|
cache: cache,
|
|
}, nil
|
|
}
|
|
|
|
// Embed returns vectors for texts, memoizing per (model, text). On
|
|
// a batch with mixed hits/misses, only the misses round-trip to
|
|
// inner; the result preserves caller ordering.
|
|
func (c *CachedProvider) Embed(ctx context.Context, texts []string, model string) (Result, error) {
|
|
// Pass-through when caching disabled.
|
|
if c.cache == nil {
|
|
return c.inner.Embed(ctx, texts, model)
|
|
}
|
|
|
|
if len(texts) == 0 {
|
|
return Result{}, ErrEmptyTexts
|
|
}
|
|
|
|
effectiveModel := model
|
|
if effectiveModel == "" {
|
|
effectiveModel = c.defaultModel
|
|
}
|
|
|
|
// Pass 1: cache lookup; collect misses preserving original index
|
|
// so we can write the upstream result back into the right slots.
|
|
out := make([][]float32, len(texts))
|
|
missTexts := make([]string, 0, len(texts))
|
|
missIdx := make([]int, 0, len(texts))
|
|
for i, t := range texts {
|
|
key := cacheKey(effectiveModel, t)
|
|
if v, ok := c.cache.Get(key); ok {
|
|
out[i] = v
|
|
c.hits.Add(1)
|
|
continue
|
|
}
|
|
missTexts = append(missTexts, t)
|
|
missIdx = append(missIdx, i)
|
|
c.misses.Add(1)
|
|
}
|
|
|
|
// All hits — synthesize the result without an upstream call.
|
|
// Use the effective model + the first cached vector's length
|
|
// for the response. Every cached vector for the same model has
|
|
// the same dimension by construction (Provider guarantees it).
|
|
if len(missTexts) == 0 {
|
|
dim := 0
|
|
if len(out) > 0 && len(out[0]) > 0 {
|
|
dim = len(out[0])
|
|
}
|
|
return Result{
|
|
Model: effectiveModel,
|
|
Dimension: dim,
|
|
Vectors: out,
|
|
}, nil
|
|
}
|
|
|
|
// Pass 2: fetch the misses, populate cache, merge.
|
|
res, err := c.inner.Embed(ctx, missTexts, model)
|
|
if err != nil {
|
|
return Result{}, err
|
|
}
|
|
if len(res.Vectors) != len(missTexts) {
|
|
return Result{}, fmt.Errorf("embed cache: provider returned %d vectors for %d miss texts",
|
|
len(res.Vectors), len(missTexts))
|
|
}
|
|
// Cache under effectiveModel (request-derived), NOT res.Model
|
|
// (upstream-derived). Future requests with the same input shape
|
|
// — same explicit model OR the same "" — must hit the same key.
|
|
// Only ours is predictable from the request; upstream's resolution
|
|
// can drift if the upstream's default changes.
|
|
for j, t := range missTexts {
|
|
out[missIdx[j]] = res.Vectors[j]
|
|
c.cache.Add(cacheKey(effectiveModel, t), res.Vectors[j])
|
|
}
|
|
|
|
return Result{
|
|
Model: res.Model,
|
|
Dimension: res.Dimension,
|
|
Vectors: out,
|
|
}, nil
|
|
}
|
|
|
|
// Stats returns lifetime hit + miss counts for this cache. Atomic
|
|
// reads — no locking. Useful for /health, /metrics, or operator
|
|
// dashboards.
|
|
func (c *CachedProvider) Stats() (hits, misses int64) {
|
|
return c.hits.Load(), c.misses.Load()
|
|
}
|
|
|
|
// HitRate returns the fraction of requests served from cache.
|
|
// Returns 0.0 when no requests have been served (avoids NaN).
|
|
func (c *CachedProvider) HitRate() float64 {
|
|
h := c.hits.Load()
|
|
m := c.misses.Load()
|
|
total := h + m
|
|
if total == 0 {
|
|
return 0.0
|
|
}
|
|
return float64(h) / float64(total)
|
|
}
|
|
|
|
// Len returns the current entry count in the cache. Returns 0 when
|
|
// caching is disabled.
|
|
func (c *CachedProvider) Len() int {
|
|
if c.cache == nil {
|
|
return 0
|
|
}
|
|
return c.cache.Len()
|
|
}
|
|
|
|
// cacheKey is "<model>:<sha256(text)>". sha256 collapses long texts
|
|
// to a fixed-size key; model prefix scopes cache to one model so
|
|
// callers using multiple models don't get cross-contamination.
|
|
func cacheKey(model, text string) string {
|
|
h := sha256.Sum256([]byte(text))
|
|
return model + ":" + hex.EncodeToString(h[:])
|
|
}
|