package embed import ( "context" "encoding/json" "errors" "net/http" "net/http/httptest" "strings" "sync" "testing" ) func TestOllama_EmbedBatch_PreservesOrder(t *testing.T) { var mu sync.Mutex var seenPrompts []string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req ollamaRequest _ = json.NewDecoder(r.Body).Decode(&req) mu.Lock() seenPrompts = append(seenPrompts, req.Prompt) mu.Unlock() // Return a vector that encodes which prompt this was, so // we can assert order at the caller. 4-d vector for cheap. var vec [4]float64 switch req.Prompt { case "alpha": vec = [4]float64{1, 0, 0, 0} case "beta": vec = [4]float64{0, 1, 0, 0} case "gamma": vec = [4]float64{0, 0, 1, 0} } _ = json.NewEncoder(w).Encode(map[string]any{"embedding": vec[:]}) })) defer srv.Close() p := NewOllama(srv.URL, "test-model") res, err := p.Embed(context.Background(), []string{"alpha", "beta", "gamma"}, "") if err != nil { t.Fatal(err) } if res.Model != "test-model" || res.Dimension != 4 || len(res.Vectors) != 3 { t.Fatalf("Result: got %+v", res) } if res.Vectors[0][0] != 1 || res.Vectors[1][1] != 1 || res.Vectors[2][2] != 1 { t.Errorf("vectors out of order: %v", res.Vectors) } // Sanity: all three prompts hit the server. if len(seenPrompts) != 3 { t.Errorf("expected 3 upstream calls, got %d", len(seenPrompts)) } } func TestOllama_EmptyTextsErrors(t *testing.T) { p := NewOllama("http://nope:0", "x") _, err := p.Embed(context.Background(), nil, "") if !errors.Is(err, ErrEmptyTexts) { t.Errorf("expected ErrEmptyTexts, got %v", err) } } func TestOllama_NoModelNoDefault(t *testing.T) { p := NewOllama("http://nope:0", "") // empty default _, err := p.Embed(context.Background(), []string{"hi"}, "") if err == nil || !strings.Contains(err.Error(), "no model") { t.Errorf("expected no-model error, got %v", err) } } func TestOllama_UpstreamErrorPropagates(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "model not loaded", http.StatusInternalServerError) })) defer srv.Close() p := NewOllama(srv.URL, "x") _, err := p.Embed(context.Background(), []string{"hi"}, "") if err == nil || !strings.Contains(err.Error(), "upstream status 500") { t.Errorf("expected wrapped 500 error, got %v", err) } } func TestOllama_DimensionMismatchMidBatch(t *testing.T) { calls := 0 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { calls++ // First call returns 4-d, second returns 8-d → server changed // model under us. Provider should ErrModelMismatch. var v []float64 if calls == 1 { v = []float64{1, 0, 0, 0} } else { v = []float64{1, 0, 0, 0, 0, 0, 0, 0} } _ = json.NewEncoder(w).Encode(map[string]any{"embedding": v}) })) defer srv.Close() p := NewOllama(srv.URL, "x") _, err := p.Embed(context.Background(), []string{"a", "b"}, "") if !errors.Is(err, ErrModelMismatch) { t.Errorf("expected ErrModelMismatch, got %v", err) } } func TestOllama_EmptyEmbeddingErrors(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _ = json.NewEncoder(w).Encode(map[string]any{"embedding": []float64{}}) })) defer srv.Close() p := NewOllama(srv.URL, "x") _, err := p.Embed(context.Background(), []string{"hi"}, "") if err == nil || !strings.Contains(err.Error(), "empty embedding") { t.Errorf("expected empty-embedding error, got %v", err) } }