package embed import ( "context" "errors" "sync" "testing" ) // stubProvider records every call so tests can verify caching // behavior (call counts, request texts) without spinning up Ollama. type stubProvider struct { mu sync.Mutex model string dim int callCount int lastTexts []string err error // vectorFn returns a deterministic vector for a given text so // cache hits assert byte-equality, not just shape. vectorFn func(text string) []float32 } func (s *stubProvider) Embed(_ context.Context, texts []string, model string) (Result, error) { s.mu.Lock() s.callCount++ s.lastTexts = append([]string(nil), texts...) s.mu.Unlock() if s.err != nil { return Result{}, s.err } out := make([][]float32, len(texts)) for i, t := range texts { out[i] = s.vectorFn(t) } resolvedModel := model if resolvedModel == "" { resolvedModel = s.model } return Result{Model: resolvedModel, Dimension: s.dim, Vectors: out}, nil } func deterministicVector(t string) []float32 { v := make([]float32, 4) for i := range v { v[i] = float32(int(t[0]) + i) } return v } func newStub() *stubProvider { return &stubProvider{ model: "test-model", dim: 4, vectorFn: deterministicVector, } } func TestCachedProvider_PassThroughWhenSizeZero(t *testing.T) { stub := newStub() c, err := NewCachedProvider(stub, "test-model", 0) if err != nil { t.Fatalf("NewCachedProvider: %v", err) } for i := 0; i < 3; i++ { _, err := c.Embed(context.Background(), []string{"hello"}, "") if err != nil { t.Fatalf("Embed: %v", err) } } if stub.callCount != 3 { t.Errorf("size=0 should pass through every call, got callCount=%d", stub.callCount) } if c.Len() != 0 { t.Errorf("Len()=%d, want 0 when caching disabled", c.Len()) } } func TestCachedProvider_HitOnRepeatText(t *testing.T) { stub := newStub() c, err := NewCachedProvider(stub, "test-model", 100) if err != nil { t.Fatalf("NewCachedProvider: %v", err) } // First call — miss, hits upstream. r1, err := c.Embed(context.Background(), []string{"hello"}, "test-model") if err != nil { t.Fatalf("first: %v", err) } if stub.callCount != 1 { t.Errorf("first call should hit upstream, callCount=%d", stub.callCount) } // Second call — hit, no upstream call. r2, err := c.Embed(context.Background(), []string{"hello"}, "test-model") if err != nil { t.Fatalf("second: %v", err) } if stub.callCount != 1 { t.Errorf("second call should hit cache, callCount=%d (want unchanged 1)", stub.callCount) } // Result should be byte-equal. if len(r1.Vectors[0]) != len(r2.Vectors[0]) { t.Errorf("vector dim differs: %d vs %d", len(r1.Vectors[0]), len(r2.Vectors[0])) } for i := range r1.Vectors[0] { if r1.Vectors[0][i] != r2.Vectors[0][i] { t.Errorf("vector[%d]: %v vs %v", i, r1.Vectors[0][i], r2.Vectors[0][i]) } } hits, misses := c.Stats() if hits != 1 || misses != 1 { t.Errorf("hits/misses = %d/%d, want 1/1", hits, misses) } } func TestCachedProvider_MixedBatchOnlyFetchesMisses(t *testing.T) { stub := newStub() c, _ := NewCachedProvider(stub, "test-model", 100) // Prime cache with one text. c.Embed(context.Background(), []string{"alpha"}, "test-model") if stub.callCount != 1 { t.Fatalf("priming should produce 1 upstream call, got %d", stub.callCount) } // Mixed batch: alpha cached, beta + gamma fresh. r, err := c.Embed(context.Background(), []string{"alpha", "beta", "gamma"}, "test-model") if err != nil { t.Fatalf("Embed: %v", err) } // Provider should have been called only with the misses. if len(stub.lastTexts) != 2 { t.Errorf("upstream called with %d texts, want 2 (only misses)", len(stub.lastTexts)) } for _, lt := range stub.lastTexts { if lt == "alpha" { t.Errorf("upstream got 'alpha' but it was cached") } } // Result preserves order: alpha first (cached), beta + gamma in // their original positions. if len(r.Vectors) != 3 { t.Fatalf("got %d vectors, want 3", len(r.Vectors)) } for i, txt := range []string{"alpha", "beta", "gamma"} { want := deterministicVector(txt) for j := range want { if r.Vectors[i][j] != want[j] { t.Errorf("vec[%d (%s)][%d] = %v, want %v", i, txt, j, r.Vectors[i][j], want[j]) } } } } func TestCachedProvider_ModelKeyIsolation(t *testing.T) { stub := newStub() c, _ := NewCachedProvider(stub, "default-model", 100) // Same text, different models → distinct cache entries. c.Embed(context.Background(), []string{"shared"}, "model-A") c.Embed(context.Background(), []string{"shared"}, "model-B") if stub.callCount != 2 { t.Errorf("different models should NOT share cache, got callCount=%d", stub.callCount) } // Repeat with model-A → hits cache. c.Embed(context.Background(), []string{"shared"}, "model-A") if stub.callCount != 2 { t.Errorf("repeat of model-A should hit cache, got callCount=%d", stub.callCount) } } func TestCachedProvider_EmptyModelResolvesToDefault(t *testing.T) { stub := newStub() c, _ := NewCachedProvider(stub, "the-default", 100) // First call with model="" — caches under "the-default". c.Embed(context.Background(), []string{"text"}, "") if stub.callCount != 1 { t.Fatalf("first call should hit upstream, got %d", stub.callCount) } // Second call with model="the-default" — should hit cache because // the first call resolved "" → "the-default" for the key. c.Embed(context.Background(), []string{"text"}, "the-default") if stub.callCount != 1 { t.Errorf("explicit default-model should hit cache populated by '' request, got callCount=%d", stub.callCount) } } func TestCachedProvider_LRUEviction(t *testing.T) { stub := newStub() c, _ := NewCachedProvider(stub, "test-model", 2) // cap=2 // Fill: a, b c.Embed(context.Background(), []string{"a"}, "") c.Embed(context.Background(), []string{"b"}, "") // c arrives — evicts a (LRU). c.Embed(context.Background(), []string{"c"}, "") if c.Len() > 2 { t.Errorf("Len()=%d, want ≤2 with cap=2", c.Len()) } if stub.callCount != 3 { t.Errorf("callCount=%d, want 3 after a/b/c distinct", stub.callCount) } // Re-request a → miss (was evicted). c.Embed(context.Background(), []string{"a"}, "") if stub.callCount != 4 { t.Errorf("evicted a should miss, callCount=%d want 4", stub.callCount) } } func TestCachedProvider_PropagatesUpstreamError(t *testing.T) { stub := &stubProvider{ model: "test-model", dim: 4, vectorFn: deterministicVector, err: errors.New("upstream broken"), } c, _ := NewCachedProvider(stub, "test-model", 10) _, err := c.Embed(context.Background(), []string{"hello"}, "") if err == nil { t.Fatal("expected upstream error to propagate") } // Cache should NOT have stored anything on error. if c.Len() != 0 { t.Errorf("error path should not populate cache, Len()=%d", c.Len()) } } func TestCachedProvider_AllHitsSynthesized(t *testing.T) { stub := newStub() c, _ := NewCachedProvider(stub, "test-model", 100) // Prime c.Embed(context.Background(), []string{"x", "y"}, "") stub.mu.Lock() stub.callCount = 0 // reset for clarity stub.mu.Unlock() // All-hits batch → no upstream call. r, err := c.Embed(context.Background(), []string{"x", "y"}, "") if err != nil { t.Fatalf("Embed: %v", err) } if stub.callCount != 0 { t.Errorf("all-hits batch should not hit upstream, callCount=%d", stub.callCount) } if r.Dimension != 4 { t.Errorf("all-hits Dimension=%d, want 4 (from cached vector len)", r.Dimension) } if r.Model != "test-model" { t.Errorf("all-hits Model=%q, want resolved 'test-model'", r.Model) } } func TestCachedProvider_HitRate(t *testing.T) { stub := newStub() c, _ := NewCachedProvider(stub, "test-model", 100) if c.HitRate() != 0.0 { t.Errorf("initial HitRate=%v, want 0.0", c.HitRate()) } c.Embed(context.Background(), []string{"a"}, "") // miss c.Embed(context.Background(), []string{"a"}, "") // hit c.Embed(context.Background(), []string{"a"}, "") // hit got := c.HitRate() want := 2.0 / 3.0 if got < want-0.01 || got > want+0.01 { t.Errorf("HitRate=%v, want %v", got, want) } } func TestCachedProvider_EmptyTextsRejected(t *testing.T) { stub := newStub() c, _ := NewCachedProvider(stub, "test-model", 100) _, err := c.Embed(context.Background(), []string{}, "") if !errors.Is(err, ErrEmptyTexts) { t.Errorf("empty texts should return ErrEmptyTexts, got %v", err) } } func TestCachedProvider_ConcurrentSafe(t *testing.T) { stub := newStub() c, _ := NewCachedProvider(stub, "test-model", 100) // 50 goroutines × 100 calls each, mostly the same text. var wg sync.WaitGroup for g := 0; g < 50; g++ { wg.Add(1) go func(id int) { defer wg.Done() for i := 0; i < 100; i++ { _, err := c.Embed(context.Background(), []string{"shared"}, "") if err != nil { t.Errorf("concurrent embed err: %v", err) return } } }(g) } wg.Wait() hits, misses := c.Stats() if hits+misses != 5000 { t.Errorf("total calls = %d, want 5000", hits+misses) } // First call is a guaranteed miss; subsequent should mostly be // hits. With 50 concurrent goroutines hitting an empty cache, // up to ~50 first-arrivers may miss before the first one writes. if misses > 100 { t.Errorf("misses=%d unreasonably high; cache concurrency broken?", misses) } } func TestCacheKey_StableAndDistinct(t *testing.T) { a := cacheKey("m1", "text-a") a2 := cacheKey("m1", "text-a") b := cacheKey("m1", "text-b") c := cacheKey("m2", "text-a") if a != a2 { t.Errorf("cache key not stable: %q != %q", a, a2) } if a == b { t.Errorf("different texts produced same key: %q", a) } if a == c { t.Errorf("different models produced same key: %q", a) } // Key shape: model + ":" + 64-hex-chars sha256. const sha256HexLen = 64 want := len("m1") + 1 + sha256HexLen if len(a) != want { t.Errorf("key len=%d, want %d", len(a), want) } }