package corpusingest import ( "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" ) // fakeGateway records the embed + add calls corpusingest fires and // returns canned responses. The whole point of the unit test is to // validate the pipeline shape (request payloads, batching, stats) // without needing live embedd/vectord. type fakeGateway struct { mu sync.Mutex embedCalls int embedTexts [][]string // texts per call addCalls int addItems [][]addItem // items per call createCalled bool deleteCalled bool indexConflict bool // simulate "index already exists" → 409 embedDimension int } type addItem struct { ID string `json:"id"` Vector []float32 `json:"vector"` Metadata json.RawMessage `json:"metadata,omitempty"` } func newFakeGateway(dim int) *fakeGateway { return &fakeGateway{embedDimension: dim} } func (f *fakeGateway) handler() http.Handler { mux := http.NewServeMux() mux.HandleFunc("/v1/vectors/index", func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "wrong method", http.StatusMethodNotAllowed) return } f.mu.Lock() f.createCalled = true conflict := f.indexConflict f.mu.Unlock() if conflict { http.Error(w, "exists", http.StatusConflict) return } w.WriteHeader(http.StatusCreated) }) mux.HandleFunc("/v1/embed", func(w http.ResponseWriter, r *http.Request) { var req struct { Texts []string `json:"texts"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } // Synthesize deterministic vectors: vector[i] = float32(i+1). vecs := make([][]float32, len(req.Texts)) for i := range vecs { v := make([]float32, f.embedDimension) for j := range v { v[j] = float32(i + j + 1) } vecs[i] = v } f.mu.Lock() f.embedCalls++ // Copy because we'll release the slice after returning. texts := append([]string(nil), req.Texts...) f.embedTexts = append(f.embedTexts, texts) f.mu.Unlock() w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]any{ "vectors": vecs, "dimension": f.embedDimension, "model": "fake-embed", }) }) mux.HandleFunc("/v1/vectors/index/", func(w http.ResponseWriter, r *http.Request) { // /v1/vectors/index/{name}/add if !strings.HasSuffix(r.URL.Path, "/add") { if r.Method == http.MethodDelete { f.mu.Lock() f.deleteCalled = true f.mu.Unlock() w.WriteHeader(http.StatusNoContent) return } http.Error(w, "unhandled "+r.URL.Path, http.StatusNotFound) return } var req struct { Items []addItem `json:"items"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } f.mu.Lock() f.addCalls++ f.addItems = append(f.addItems, append([]addItem(nil), req.Items...)) f.mu.Unlock() _, _ = io.WriteString(w, `{"added":`+fmt.Sprint(len(req.Items))+`}`) }) return mux } // staticSource yields a fixed slice of rows. type staticSource struct { rows []Row i int } func (s *staticSource) Next() (Row, error) { if s.i >= len(s.rows) { return Row{}, io.EOF } r := s.rows[s.i] s.i++ return r, nil } func TestRun_PipelineShapeAndStats(t *testing.T) { const dim = 4 fg := newFakeGateway(dim) srv := httptest.NewServer(fg.handler()) defer srv.Close() rows := make([]Row, 50) for i := range rows { rows[i] = Row{ ID: fmt.Sprintf("r-%03d", i), Text: fmt.Sprintf("row %d text", i), Metadata: map[string]any{"i": i, "kind": "test"}, } } stats, err := Run(context.Background(), Config{ GatewayURL: srv.URL, IndexName: "test_corpus", Dimension: dim, Distance: "cosine", EmbedBatch: 16, EmbedWorkers: 4, HTTPClient: srv.Client(), LogProgress: 0, }, &staticSource{rows: rows}) if err != nil { t.Fatalf("Run: %v", err) } if stats.Scanned != 50 { t.Errorf("Scanned: want 50, got %d", stats.Scanned) } if stats.Embedded != 50 { t.Errorf("Embedded: want 50, got %d", stats.Embedded) } if stats.Added != 50 { t.Errorf("Added: want 50, got %d", stats.Added) } if !fg.createCalled { t.Error("expected create-index to be called") } // 50 rows / 16 batch = ceil(50/16) = 4 batches → 4 embed calls + 4 add calls if fg.embedCalls != 4 { t.Errorf("embedCalls: want 4 (50 rows / 16 batch), got %d", fg.embedCalls) } if fg.addCalls != 4 { t.Errorf("addCalls: want 4, got %d", fg.addCalls) } // Sum of texts across embed calls must be 50, and IDs across add // calls must be every r-NNN exactly once. totalTexts := 0 for _, ts := range fg.embedTexts { totalTexts += len(ts) } if totalTexts != 50 { t.Errorf("total embedded texts: want 50, got %d", totalTexts) } seen := make(map[string]bool) for _, items := range fg.addItems { for _, it := range items { if seen[it.ID] { t.Errorf("duplicate id in add stream: %s", it.ID) } seen[it.ID] = true if len(it.Vector) != dim { t.Errorf("vector dim: want %d, got %d", dim, len(it.Vector)) } } } if len(seen) != 50 { t.Errorf("unique ids added: want 50, got %d", len(seen)) } } func TestRun_DropExistingFiresDelete(t *testing.T) { fg := newFakeGateway(4) srv := httptest.NewServer(fg.handler()) defer srv.Close() _, err := Run(context.Background(), Config{ GatewayURL: srv.URL, IndexName: "drops_first", Dimension: 4, DropExisting: true, HTTPClient: srv.Client(), }, &staticSource{rows: []Row{{ID: "x", Text: "y", Metadata: nil}}}) if err != nil { t.Fatalf("Run: %v", err) } if !fg.deleteCalled { t.Error("expected delete-index to fire when DropExisting=true") } } func TestRun_IndexAlreadyExistsIsReused(t *testing.T) { fg := newFakeGateway(4) fg.indexConflict = true // first POST /v1/vectors/index → 409 srv := httptest.NewServer(fg.handler()) defer srv.Close() stats, err := Run(context.Background(), Config{ GatewayURL: srv.URL, IndexName: "exists_already", Dimension: 4, HTTPClient: srv.Client(), EmbedWorkers: 1, }, &staticSource{rows: []Row{{ID: "x", Text: "y", Metadata: nil}}}) if err != nil { t.Fatalf("Run with existing index should succeed: %v", err) } if stats.Added != 1 { t.Errorf("Added: want 1, got %d", stats.Added) } } func TestRun_LimitStopsEarly(t *testing.T) { fg := newFakeGateway(4) srv := httptest.NewServer(fg.handler()) defer srv.Close() rows := make([]Row, 100) for i := range rows { rows[i] = Row{ID: fmt.Sprintf("r-%d", i), Text: "t", Metadata: nil} } stats, err := Run(context.Background(), Config{ GatewayURL: srv.URL, IndexName: "limited", Dimension: 4, Limit: 25, EmbedBatch: 8, EmbedWorkers: 2, HTTPClient: srv.Client(), }, &staticSource{rows: rows}) if err != nil { t.Fatalf("Run: %v", err) } if stats.Scanned != 25 { t.Errorf("Scanned: want 25 (limit), got %d", stats.Scanned) } } func TestRun_EmptyTextSkipped(t *testing.T) { fg := newFakeGateway(4) srv := httptest.NewServer(fg.handler()) defer srv.Close() rows := []Row{ {ID: "a", Text: "real text", Metadata: nil}, {ID: "b", Text: "", Metadata: nil}, // skipped {ID: "c", Text: "more text", Metadata: nil}, } stats, err := Run(context.Background(), Config{ GatewayURL: srv.URL, IndexName: "skip", Dimension: 4, HTTPClient: srv.Client(), }, &staticSource{rows: rows}) if err != nil { t.Fatalf("Run: %v", err) } if stats.Scanned != 3 { t.Errorf("Scanned: want 3 (b is skipped but counted as scanned), got %d", stats.Scanned) } if stats.Added != 2 { t.Errorf("Added: want 2 (b excluded from embed), got %d", stats.Added) } } // TestRun_ProgressLoggerExits guards the bug caught 2026-04-29 in // the candidates e2e: when LogProgress > 0, the progress goroutine's // only exit was ctx.Done(). With context.Background() in the // production driver, Run hung forever after the pipeline finished. // This test bounds Run's wall to a few hundred ms — if it regresses, // the test deadline kicks in. func TestRun_ProgressLoggerExits(t *testing.T) { fg := newFakeGateway(4) srv := httptest.NewServer(fg.handler()) defer srv.Close() rows := []Row{ {ID: "a", Text: "x", Metadata: nil}, {ID: "b", Text: "y", Metadata: nil}, } done := make(chan error, 1) go func() { _, err := Run(context.Background(), Config{ GatewayURL: srv.URL, IndexName: "progress_test", Dimension: 4, HTTPClient: srv.Client(), LogProgress: 50 * time.Millisecond, }, &staticSource{rows: rows}) done <- err }() select { case err := <-done: if err != nil { t.Fatalf("Run: %v", err) } case <-time.After(2 * time.Second): t.Fatal("Run did not return within 2s — progress goroutine likely hanging") } } func TestRun_RequiresIndexName(t *testing.T) { _, err := Run(context.Background(), Config{Dimension: 4}, &staticSource{rows: nil}) if err == nil || !strings.Contains(err.Error(), "IndexName") { t.Errorf("want IndexName-required error, got %v", err) } } func TestRun_RequiresDimension(t *testing.T) { _, err := Run(context.Background(), Config{IndexName: "x"}, &staticSource{rows: nil}) if err == nil || !strings.Contains(err.Error(), "Dimension") { t.Errorf("want Dimension-required error, got %v", err) } } // TestRun_ContextCancel verifies the pipeline drains cleanly when // ctx is cancelled mid-run. Source returns rows fast enough that // without ctx the run would complete; cancelling early should stop // well before all 1000 rows are processed. func TestRun_ContextCancel(t *testing.T) { fg := newFakeGateway(4) // Slow embed handler: each call sleeps 50ms. mux := http.NewServeMux() mux.HandleFunc("/v1/vectors/index", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusCreated) }) mux.HandleFunc("/v1/embed", func(w http.ResponseWriter, r *http.Request) { var req struct { Texts []string `json:"texts"` } _ = json.NewDecoder(r.Body).Decode(&req) // Simulate slow-but-valid backend so we test ctx cancel, not // degraded-payload handling (that's covered in production by // the len-mismatch guard in Run's worker). time.Sleep(50 * time.Millisecond) _ = fg vecs := make([][]float32, len(req.Texts)) for i := range vecs { vecs[i] = []float32{1, 2, 3, 4} } _ = json.NewEncoder(w).Encode(map[string]any{ "vectors": vecs, "dimension": 4, "model": "x", }) }) mux.HandleFunc("/v1/vectors/index/", func(w http.ResponseWriter, r *http.Request) { _, _ = io.WriteString(w, `{}`) }) srv := httptest.NewServer(mux) defer srv.Close() rows := make([]Row, 1000) for i := range rows { rows[i] = Row{ID: fmt.Sprintf("r-%d", i), Text: "t"} } ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() stats, err := Run(ctx, Config{ GatewayURL: srv.URL, IndexName: "cancel_me", Dimension: 4, EmbedBatch: 1, EmbedWorkers: 1, HTTPClient: srv.Client(), }, &staticSource{rows: rows}) // Either an error or a partial stats; the point is "didn't process all 1000." if stats.Scanned >= 1000 { t.Errorf("ctx cancel did not stop early: scanned=%d err=%v", stats.Scanned, err) } }