golangLAKEHOUSE/internal/embed/cached_test.go
root 56844c3f31 embed cache — LRU at /v1/embed for repeat-query elimination
Adds CachedProvider wrapping the embedding Provider with a thread-safe
LRU keyed on (effective_model, sha256(text)) → []float32. Repeat
queries return the stored vector without round-tripping to Ollama.

Why this matters: the staffing 500K test (memory project_golang_lakehouse)
documented that the staffing co-pilot replays many of the same query
texts ("forklift driver IL", "welder Chicago", "warehouse safety", etc).
Each repeat paid the ~50ms Ollama round-trip. Cached repeats now serve
in <1µs (LRU lookup + sha256 of input).

Memory budget: ~3 KiB per entry at d=768. Default 10K entries ≈ 30 MiB.
Configurable via [embedd].cache_size; 0 disables (pass-through mode).

Per-text caching, not per-batch — a batch with mixed hits/misses only
fetches the misses upstream, then merges the result preserving caller
input order. Three-text batch with one miss = one upstream call for
that one text instead of three.

Implementation:
  internal/embed/cached.go (NEW, 150 LoC)
    CachedProvider implements Provider; uses hashicorp/golang-lru/v2.
    Key shape: "<model>:<sha256-hex>". Empty model resolves to
    defaultModel (request-derived) for the key — NOT res.Model
    (upstream-derived), so future requests with same input shape
    hit the same key. Caught by TestCachedProvider_EmptyModelResolvesToDefault.
    Atomic hit/miss counters + Stats() + HitRate() + Len().

  internal/embed/cached_test.go (NEW, 12 test funcs)
    Pass-through-when-zero, hit-on-repeat, mixed-batch only fetches
    misses, model-key isolation, empty-model resolves to default,
    LRU eviction at cap, error propagation, all-hits synthesized
    without upstream call, hit-rate accumulation, empty-texts
    rejected, concurrent-safe (50 goroutines × 100 calls), key
    stability + distinctness.

  internal/shared/config.go
    EmbeddConfig.CacheSize (toml: cache_size). Default 10000.

  cmd/embedd/main.go
    Wraps Ollama Provider with CachedProvider on startup. Adds
    /embed/stats endpoint exposing hits / misses / hit_rate / size.
    Operators check the rate to confirm the cache is working
    (high rate = good) or sized wrong (low rate + many misses on a
    workload that should have repeats).

  cmd/embedd/main_test.go
    Stats endpoint tests — disabled mode shape, enabled mode tracks
    hits + misses across repeat calls.

One real bug caught by my own test:
  Initial implementation cached under res.Model (upstream-resolved)
  rather than effectiveModel (request-resolved). A request with
  model="" caching under "test-model" (Ollama's default), then a
  request with model="the-default" (our config default) missing
  the cache. Fix: always use the request-derived effectiveModel
  for keys; that's the predictable side. Locked by
  TestCachedProvider_EmptyModelResolvesToDefault.

Verified:
  go test -count=1 ./internal/embed/  — all 12 cached tests + 6 ollama tests green
  go test -count=1 ./cmd/embedd/      — stats endpoint tests green
  just verify                          — vet + test + 9 smokes 33s

Production benefit:
  ~50ms Ollama round-trip → <1µs cache lookup for cached entries.
  At 10K-entry default + ~30% repeat rate (typical staffing co-pilot
  workload), saves several seconds per staffer-query session.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-29 06:54:30 -05:00

351 lines
9.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package embed
import (
"context"
"errors"
"sync"
"testing"
)
// stubProvider records every call so tests can verify caching
// behavior (call counts, request texts) without spinning up Ollama.
type stubProvider struct {
mu sync.Mutex
model string
dim int
callCount int
lastTexts []string
err error
// vectorFn returns a deterministic vector for a given text so
// cache hits assert byte-equality, not just shape.
vectorFn func(text string) []float32
}
func (s *stubProvider) Embed(_ context.Context, texts []string, model string) (Result, error) {
s.mu.Lock()
s.callCount++
s.lastTexts = append([]string(nil), texts...)
s.mu.Unlock()
if s.err != nil {
return Result{}, s.err
}
out := make([][]float32, len(texts))
for i, t := range texts {
out[i] = s.vectorFn(t)
}
resolvedModel := model
if resolvedModel == "" {
resolvedModel = s.model
}
return Result{Model: resolvedModel, Dimension: s.dim, Vectors: out}, nil
}
func deterministicVector(t string) []float32 {
v := make([]float32, 4)
for i := range v {
v[i] = float32(int(t[0]) + i)
}
return v
}
func newStub() *stubProvider {
return &stubProvider{
model: "test-model",
dim: 4,
vectorFn: deterministicVector,
}
}
func TestCachedProvider_PassThroughWhenSizeZero(t *testing.T) {
stub := newStub()
c, err := NewCachedProvider(stub, "test-model", 0)
if err != nil {
t.Fatalf("NewCachedProvider: %v", err)
}
for i := 0; i < 3; i++ {
_, err := c.Embed(context.Background(), []string{"hello"}, "")
if err != nil {
t.Fatalf("Embed: %v", err)
}
}
if stub.callCount != 3 {
t.Errorf("size=0 should pass through every call, got callCount=%d", stub.callCount)
}
if c.Len() != 0 {
t.Errorf("Len()=%d, want 0 when caching disabled", c.Len())
}
}
func TestCachedProvider_HitOnRepeatText(t *testing.T) {
stub := newStub()
c, err := NewCachedProvider(stub, "test-model", 100)
if err != nil {
t.Fatalf("NewCachedProvider: %v", err)
}
// First call — miss, hits upstream.
r1, err := c.Embed(context.Background(), []string{"hello"}, "test-model")
if err != nil {
t.Fatalf("first: %v", err)
}
if stub.callCount != 1 {
t.Errorf("first call should hit upstream, callCount=%d", stub.callCount)
}
// Second call — hit, no upstream call.
r2, err := c.Embed(context.Background(), []string{"hello"}, "test-model")
if err != nil {
t.Fatalf("second: %v", err)
}
if stub.callCount != 1 {
t.Errorf("second call should hit cache, callCount=%d (want unchanged 1)", stub.callCount)
}
// Result should be byte-equal.
if len(r1.Vectors[0]) != len(r2.Vectors[0]) {
t.Errorf("vector dim differs: %d vs %d", len(r1.Vectors[0]), len(r2.Vectors[0]))
}
for i := range r1.Vectors[0] {
if r1.Vectors[0][i] != r2.Vectors[0][i] {
t.Errorf("vector[%d]: %v vs %v", i, r1.Vectors[0][i], r2.Vectors[0][i])
}
}
hits, misses := c.Stats()
if hits != 1 || misses != 1 {
t.Errorf("hits/misses = %d/%d, want 1/1", hits, misses)
}
}
func TestCachedProvider_MixedBatchOnlyFetchesMisses(t *testing.T) {
stub := newStub()
c, _ := NewCachedProvider(stub, "test-model", 100)
// Prime cache with one text.
c.Embed(context.Background(), []string{"alpha"}, "test-model")
if stub.callCount != 1 {
t.Fatalf("priming should produce 1 upstream call, got %d", stub.callCount)
}
// Mixed batch: alpha cached, beta + gamma fresh.
r, err := c.Embed(context.Background(),
[]string{"alpha", "beta", "gamma"}, "test-model")
if err != nil {
t.Fatalf("Embed: %v", err)
}
// Provider should have been called only with the misses.
if len(stub.lastTexts) != 2 {
t.Errorf("upstream called with %d texts, want 2 (only misses)", len(stub.lastTexts))
}
for _, lt := range stub.lastTexts {
if lt == "alpha" {
t.Errorf("upstream got 'alpha' but it was cached")
}
}
// Result preserves order: alpha first (cached), beta + gamma in
// their original positions.
if len(r.Vectors) != 3 {
t.Fatalf("got %d vectors, want 3", len(r.Vectors))
}
for i, txt := range []string{"alpha", "beta", "gamma"} {
want := deterministicVector(txt)
for j := range want {
if r.Vectors[i][j] != want[j] {
t.Errorf("vec[%d (%s)][%d] = %v, want %v",
i, txt, j, r.Vectors[i][j], want[j])
}
}
}
}
func TestCachedProvider_ModelKeyIsolation(t *testing.T) {
stub := newStub()
c, _ := NewCachedProvider(stub, "default-model", 100)
// Same text, different models → distinct cache entries.
c.Embed(context.Background(), []string{"shared"}, "model-A")
c.Embed(context.Background(), []string{"shared"}, "model-B")
if stub.callCount != 2 {
t.Errorf("different models should NOT share cache, got callCount=%d", stub.callCount)
}
// Repeat with model-A → hits cache.
c.Embed(context.Background(), []string{"shared"}, "model-A")
if stub.callCount != 2 {
t.Errorf("repeat of model-A should hit cache, got callCount=%d", stub.callCount)
}
}
func TestCachedProvider_EmptyModelResolvesToDefault(t *testing.T) {
stub := newStub()
c, _ := NewCachedProvider(stub, "the-default", 100)
// First call with model="" — caches under "the-default".
c.Embed(context.Background(), []string{"text"}, "")
if stub.callCount != 1 {
t.Fatalf("first call should hit upstream, got %d", stub.callCount)
}
// Second call with model="the-default" — should hit cache because
// the first call resolved "" → "the-default" for the key.
c.Embed(context.Background(), []string{"text"}, "the-default")
if stub.callCount != 1 {
t.Errorf("explicit default-model should hit cache populated by '' request, got callCount=%d", stub.callCount)
}
}
func TestCachedProvider_LRUEviction(t *testing.T) {
stub := newStub()
c, _ := NewCachedProvider(stub, "test-model", 2) // cap=2
// Fill: a, b
c.Embed(context.Background(), []string{"a"}, "")
c.Embed(context.Background(), []string{"b"}, "")
// c arrives — evicts a (LRU).
c.Embed(context.Background(), []string{"c"}, "")
if c.Len() > 2 {
t.Errorf("Len()=%d, want ≤2 with cap=2", c.Len())
}
if stub.callCount != 3 {
t.Errorf("callCount=%d, want 3 after a/b/c distinct", stub.callCount)
}
// Re-request a → miss (was evicted).
c.Embed(context.Background(), []string{"a"}, "")
if stub.callCount != 4 {
t.Errorf("evicted a should miss, callCount=%d want 4", stub.callCount)
}
}
func TestCachedProvider_PropagatesUpstreamError(t *testing.T) {
stub := &stubProvider{
model: "test-model",
dim: 4,
vectorFn: deterministicVector,
err: errors.New("upstream broken"),
}
c, _ := NewCachedProvider(stub, "test-model", 10)
_, err := c.Embed(context.Background(), []string{"hello"}, "")
if err == nil {
t.Fatal("expected upstream error to propagate")
}
// Cache should NOT have stored anything on error.
if c.Len() != 0 {
t.Errorf("error path should not populate cache, Len()=%d", c.Len())
}
}
func TestCachedProvider_AllHitsSynthesized(t *testing.T) {
stub := newStub()
c, _ := NewCachedProvider(stub, "test-model", 100)
// Prime
c.Embed(context.Background(), []string{"x", "y"}, "")
stub.mu.Lock()
stub.callCount = 0 // reset for clarity
stub.mu.Unlock()
// All-hits batch → no upstream call.
r, err := c.Embed(context.Background(), []string{"x", "y"}, "")
if err != nil {
t.Fatalf("Embed: %v", err)
}
if stub.callCount != 0 {
t.Errorf("all-hits batch should not hit upstream, callCount=%d", stub.callCount)
}
if r.Dimension != 4 {
t.Errorf("all-hits Dimension=%d, want 4 (from cached vector len)", r.Dimension)
}
if r.Model != "test-model" {
t.Errorf("all-hits Model=%q, want resolved 'test-model'", r.Model)
}
}
func TestCachedProvider_HitRate(t *testing.T) {
stub := newStub()
c, _ := NewCachedProvider(stub, "test-model", 100)
if c.HitRate() != 0.0 {
t.Errorf("initial HitRate=%v, want 0.0", c.HitRate())
}
c.Embed(context.Background(), []string{"a"}, "") // miss
c.Embed(context.Background(), []string{"a"}, "") // hit
c.Embed(context.Background(), []string{"a"}, "") // hit
got := c.HitRate()
want := 2.0 / 3.0
if got < want-0.01 || got > want+0.01 {
t.Errorf("HitRate=%v, want %v", got, want)
}
}
func TestCachedProvider_EmptyTextsRejected(t *testing.T) {
stub := newStub()
c, _ := NewCachedProvider(stub, "test-model", 100)
_, err := c.Embed(context.Background(), []string{}, "")
if !errors.Is(err, ErrEmptyTexts) {
t.Errorf("empty texts should return ErrEmptyTexts, got %v", err)
}
}
func TestCachedProvider_ConcurrentSafe(t *testing.T) {
stub := newStub()
c, _ := NewCachedProvider(stub, "test-model", 100)
// 50 goroutines × 100 calls each, mostly the same text.
var wg sync.WaitGroup
for g := 0; g < 50; g++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for i := 0; i < 100; i++ {
_, err := c.Embed(context.Background(), []string{"shared"}, "")
if err != nil {
t.Errorf("concurrent embed err: %v", err)
return
}
}
}(g)
}
wg.Wait()
hits, misses := c.Stats()
if hits+misses != 5000 {
t.Errorf("total calls = %d, want 5000", hits+misses)
}
// First call is a guaranteed miss; subsequent should mostly be
// hits. With 50 concurrent goroutines hitting an empty cache,
// up to ~50 first-arrivers may miss before the first one writes.
if misses > 100 {
t.Errorf("misses=%d unreasonably high; cache concurrency broken?", misses)
}
}
func TestCacheKey_StableAndDistinct(t *testing.T) {
a := cacheKey("m1", "text-a")
a2 := cacheKey("m1", "text-a")
b := cacheKey("m1", "text-b")
c := cacheKey("m2", "text-a")
if a != a2 {
t.Errorf("cache key not stable: %q != %q", a, a2)
}
if a == b {
t.Errorf("different texts produced same key: %q", a)
}
if a == c {
t.Errorf("different models produced same key: %q", a)
}
// Key shape: model + ":" + 64-hex-chars sha256.
const sha256HexLen = 64
want := len("m1") + 1 + sha256HexLen
if len(a) != want {
t.Errorf("key len=%d, want %d", len(a), want)
}
}