// cached.go — LRU caching wrapper around a Provider. // // Memoizes (effective_model, sha256(text)) → []float32. Repeat // queries return the stored vector without round-tripping to the // upstream embedding service. Per-text caching: a batch with // mixed hit/miss only fetches the misses, then merges the result // preserving the caller's input order. // // Memory budget: ~3 KB per entry at d=768. 10K-entry default ≈ 30 MB // — small enough for any realistic embedd deployment. Operators // raising the cap should weigh memory headroom against expected // hit rate. // // Why this exists: 500K staffing test (memory project_golang_lakehouse) // showed that the staffing co-pilot replays many of the same query // texts ("forklift driver", "welder Chicago", etc.). Caching them // drops repeat-query cost from ~50ms (Ollama round-trip) to <1µs // (LRU hit). Real production win documented in feedback_meta_index_vision. package embed import ( "context" "crypto/sha256" "encoding/hex" "fmt" "sync/atomic" lru "github.com/hashicorp/golang-lru/v2" ) // CachedProvider wraps a Provider with an LRU cache keyed on // (effective_model, sha256(text)). Thread-safe. type CachedProvider struct { inner Provider defaultModel string cache *lru.Cache[string, []float32] hits atomic.Int64 misses atomic.Int64 } // NewCachedProvider wraps inner with an LRU cache of the given size. // // defaultModel is used to resolve cache keys when a request leaves // the model field empty: a request for model="" is treated as // model=defaultModel for cache-key purposes, so callers that mix // "" and the explicit model name still hit the same cache entry. // // Caller-side panic protection: size <= 0 is treated as "no cache" // — every Embed call passes through. Avoids forcing operators to // understand LRU sizing to disable. func NewCachedProvider(inner Provider, defaultModel string, size int) (*CachedProvider, error) { if size <= 0 { // Sentinel: nil cache means pass-through. NewCachedProvider // stays callable so the wiring layer can always wrap. return &CachedProvider{inner: inner, defaultModel: defaultModel}, nil } cache, err := lru.New[string, []float32](size) if err != nil { return nil, fmt.Errorf("embed cache init: %w", err) } return &CachedProvider{ inner: inner, defaultModel: defaultModel, cache: cache, }, nil } // Embed returns vectors for texts, memoizing per (model, text). On // a batch with mixed hits/misses, only the misses round-trip to // inner; the result preserves caller ordering. func (c *CachedProvider) Embed(ctx context.Context, texts []string, model string) (Result, error) { // Pass-through when caching disabled. if c.cache == nil { return c.inner.Embed(ctx, texts, model) } if len(texts) == 0 { return Result{}, ErrEmptyTexts } effectiveModel := model if effectiveModel == "" { effectiveModel = c.defaultModel } // Pass 1: cache lookup; collect misses preserving original index // so we can write the upstream result back into the right slots. out := make([][]float32, len(texts)) missTexts := make([]string, 0, len(texts)) missIdx := make([]int, 0, len(texts)) for i, t := range texts { key := cacheKey(effectiveModel, t) if v, ok := c.cache.Get(key); ok { out[i] = v c.hits.Add(1) continue } missTexts = append(missTexts, t) missIdx = append(missIdx, i) c.misses.Add(1) } // All hits — synthesize the result without an upstream call. // Use the effective model + the first cached vector's length // for the response. Every cached vector for the same model has // the same dimension by construction (Provider guarantees it). if len(missTexts) == 0 { dim := 0 if len(out) > 0 && len(out[0]) > 0 { dim = len(out[0]) } return Result{ Model: effectiveModel, Dimension: dim, Vectors: out, }, nil } // Pass 2: fetch the misses, populate cache, merge. res, err := c.inner.Embed(ctx, missTexts, model) if err != nil { return Result{}, err } if len(res.Vectors) != len(missTexts) { return Result{}, fmt.Errorf("embed cache: provider returned %d vectors for %d miss texts", len(res.Vectors), len(missTexts)) } // Cache under effectiveModel (request-derived), NOT res.Model // (upstream-derived). Future requests with the same input shape // — same explicit model OR the same "" — must hit the same key. // Only ours is predictable from the request; upstream's resolution // can drift if the upstream's default changes. for j, t := range missTexts { out[missIdx[j]] = res.Vectors[j] c.cache.Add(cacheKey(effectiveModel, t), res.Vectors[j]) } return Result{ Model: res.Model, Dimension: res.Dimension, Vectors: out, }, nil } // Stats returns lifetime hit + miss counts for this cache. Atomic // reads — no locking. Useful for /health, /metrics, or operator // dashboards. func (c *CachedProvider) Stats() (hits, misses int64) { return c.hits.Load(), c.misses.Load() } // HitRate returns the fraction of requests served from cache. // Returns 0.0 when no requests have been served (avoids NaN). func (c *CachedProvider) HitRate() float64 { h := c.hits.Load() m := c.misses.Load() total := h + m if total == 0 { return 0.0 } return float64(h) / float64(total) } // Len returns the current entry count in the cache. Returns 0 when // caching is disabled. func (c *CachedProvider) Len() int { if c.cache == nil { return 0 } return c.cache.Len() } // cacheKey is ":". sha256 collapses long texts // to a fixed-size key; model prefix scopes cache to one model so // callers using multiple models don't get cross-contamination. func cacheKey(model, text string) string { h := sha256.Sum256([]byte(text)) return model + ":" + hex.EncodeToString(h[:]) }