root 166470f532 corpusingest: extract reusable text→vector ingest pipeline
Generalizes the staffing_500k driver's embed-and-push loop into
internal/corpusingest. Per docs/SPEC.md §3.4 component 1 (corpus
builders): adding a new staffing/code/playbook corpus is now one
Source impl + one main.go calling Run, not 200 lines of pipeline
copy-paste.

API:
  type Source interface { Next() (Row, error) }
  func Run(ctx, Config, Source) (Stats, error)

Library owns:
  - Index lifecycle (create, optional drop-existing, idempotent
    reuse on 409)
  - Parallel embed dispatcher (configurable workers + batch size)
  - Vectord push batching
  - Progress logging + Stats reporting
  - Partial-failure semantics (log + continue per-batch errors;
    operator decides on re-run via Stats.Embedded vs Scanned delta)

Per-corpus driver owns: source parsing + column→Row mapping +
post-ingest validation queries.

Refactor scripts/staffing_500k/main.go to use it. Driver is now
~190 lines (was 339), with the embed/add plumbing replaced by one
Run call. -drop flag added so callers can opt out of the destructive
DELETE-first behavior (default still true to keep the 500K test
clean-recall semantics).

Unit tests (internal/corpusingest/ingest_test.go, 8/8 PASS):
  - Pipeline shape: 50 rows / 16 batch → 4 embed + 4 add calls,
    every ID added exactly once, vectors at correct dimension
  - DropExisting fires DELETE
  - 409 on create → reuse existing index
  - Limit stops early
  - Empty Text rows skipped (counted as scanned, not added)
  - Required IndexName + Dimension validation
  - Context cancel stops mid-pipeline

Real bug caught and fixed by the test suite: if embedd ever returns
fewer vectors than texts in the request (degraded backend), the
addBatch loop would panic with index-out-of-range. Worker now
length-checks the response and logs+skips on mismatch.

12-smoke regression sweep all green (D1-D6, G1, G1P, G2,
storaged_cap, pathway, matrix). vet clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-29 18:47:18 -05:00

376 lines
10 KiB
Go

package corpusingest
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
)
// fakeGateway records the embed + add calls corpusingest fires and
// returns canned responses. The whole point of the unit test is to
// validate the pipeline shape (request payloads, batching, stats)
// without needing live embedd/vectord.
type fakeGateway struct {
mu sync.Mutex
embedCalls int
embedTexts [][]string // texts per call
addCalls int
addItems [][]addItem // items per call
createCalled bool
deleteCalled bool
indexConflict bool // simulate "index already exists" → 409
embedDimension int
}
type addItem struct {
ID string `json:"id"`
Vector []float32 `json:"vector"`
Metadata json.RawMessage `json:"metadata,omitempty"`
}
func newFakeGateway(dim int) *fakeGateway {
return &fakeGateway{embedDimension: dim}
}
func (f *fakeGateway) handler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/v1/vectors/index", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "wrong method", http.StatusMethodNotAllowed)
return
}
f.mu.Lock()
f.createCalled = true
conflict := f.indexConflict
f.mu.Unlock()
if conflict {
http.Error(w, "exists", http.StatusConflict)
return
}
w.WriteHeader(http.StatusCreated)
})
mux.HandleFunc("/v1/embed", func(w http.ResponseWriter, r *http.Request) {
var req struct {
Texts []string `json:"texts"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Synthesize deterministic vectors: vector[i] = float32(i+1).
vecs := make([][]float32, len(req.Texts))
for i := range vecs {
v := make([]float32, f.embedDimension)
for j := range v {
v[j] = float32(i + j + 1)
}
vecs[i] = v
}
f.mu.Lock()
f.embedCalls++
// Copy because we'll release the slice after returning.
texts := append([]string(nil), req.Texts...)
f.embedTexts = append(f.embedTexts, texts)
f.mu.Unlock()
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"vectors": vecs,
"dimension": f.embedDimension,
"model": "fake-embed",
})
})
mux.HandleFunc("/v1/vectors/index/", func(w http.ResponseWriter, r *http.Request) {
// /v1/vectors/index/{name}/add
if !strings.HasSuffix(r.URL.Path, "/add") {
if r.Method == http.MethodDelete {
f.mu.Lock()
f.deleteCalled = true
f.mu.Unlock()
w.WriteHeader(http.StatusNoContent)
return
}
http.Error(w, "unhandled "+r.URL.Path, http.StatusNotFound)
return
}
var req struct {
Items []addItem `json:"items"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
f.mu.Lock()
f.addCalls++
f.addItems = append(f.addItems, append([]addItem(nil), req.Items...))
f.mu.Unlock()
_, _ = io.WriteString(w, `{"added":`+fmt.Sprint(len(req.Items))+`}`)
})
return mux
}
// staticSource yields a fixed slice of rows.
type staticSource struct {
rows []Row
i int
}
func (s *staticSource) Next() (Row, error) {
if s.i >= len(s.rows) {
return Row{}, io.EOF
}
r := s.rows[s.i]
s.i++
return r, nil
}
func TestRun_PipelineShapeAndStats(t *testing.T) {
const dim = 4
fg := newFakeGateway(dim)
srv := httptest.NewServer(fg.handler())
defer srv.Close()
rows := make([]Row, 50)
for i := range rows {
rows[i] = Row{
ID: fmt.Sprintf("r-%03d", i),
Text: fmt.Sprintf("row %d text", i),
Metadata: map[string]any{"i": i, "kind": "test"},
}
}
stats, err := Run(context.Background(), Config{
GatewayURL: srv.URL,
IndexName: "test_corpus",
Dimension: dim,
Distance: "cosine",
EmbedBatch: 16,
EmbedWorkers: 4,
HTTPClient: srv.Client(),
LogProgress: 0,
}, &staticSource{rows: rows})
if err != nil {
t.Fatalf("Run: %v", err)
}
if stats.Scanned != 50 {
t.Errorf("Scanned: want 50, got %d", stats.Scanned)
}
if stats.Embedded != 50 {
t.Errorf("Embedded: want 50, got %d", stats.Embedded)
}
if stats.Added != 50 {
t.Errorf("Added: want 50, got %d", stats.Added)
}
if !fg.createCalled {
t.Error("expected create-index to be called")
}
// 50 rows / 16 batch = ceil(50/16) = 4 batches → 4 embed calls + 4 add calls
if fg.embedCalls != 4 {
t.Errorf("embedCalls: want 4 (50 rows / 16 batch), got %d", fg.embedCalls)
}
if fg.addCalls != 4 {
t.Errorf("addCalls: want 4, got %d", fg.addCalls)
}
// Sum of texts across embed calls must be 50, and IDs across add
// calls must be every r-NNN exactly once.
totalTexts := 0
for _, ts := range fg.embedTexts {
totalTexts += len(ts)
}
if totalTexts != 50 {
t.Errorf("total embedded texts: want 50, got %d", totalTexts)
}
seen := make(map[string]bool)
for _, items := range fg.addItems {
for _, it := range items {
if seen[it.ID] {
t.Errorf("duplicate id in add stream: %s", it.ID)
}
seen[it.ID] = true
if len(it.Vector) != dim {
t.Errorf("vector dim: want %d, got %d", dim, len(it.Vector))
}
}
}
if len(seen) != 50 {
t.Errorf("unique ids added: want 50, got %d", len(seen))
}
}
func TestRun_DropExistingFiresDelete(t *testing.T) {
fg := newFakeGateway(4)
srv := httptest.NewServer(fg.handler())
defer srv.Close()
_, err := Run(context.Background(), Config{
GatewayURL: srv.URL,
IndexName: "drops_first",
Dimension: 4,
DropExisting: true,
HTTPClient: srv.Client(),
}, &staticSource{rows: []Row{{ID: "x", Text: "y", Metadata: nil}}})
if err != nil {
t.Fatalf("Run: %v", err)
}
if !fg.deleteCalled {
t.Error("expected delete-index to fire when DropExisting=true")
}
}
func TestRun_IndexAlreadyExistsIsReused(t *testing.T) {
fg := newFakeGateway(4)
fg.indexConflict = true // first POST /v1/vectors/index → 409
srv := httptest.NewServer(fg.handler())
defer srv.Close()
stats, err := Run(context.Background(), Config{
GatewayURL: srv.URL,
IndexName: "exists_already",
Dimension: 4,
HTTPClient: srv.Client(),
EmbedWorkers: 1,
}, &staticSource{rows: []Row{{ID: "x", Text: "y", Metadata: nil}}})
if err != nil {
t.Fatalf("Run with existing index should succeed: %v", err)
}
if stats.Added != 1 {
t.Errorf("Added: want 1, got %d", stats.Added)
}
}
func TestRun_LimitStopsEarly(t *testing.T) {
fg := newFakeGateway(4)
srv := httptest.NewServer(fg.handler())
defer srv.Close()
rows := make([]Row, 100)
for i := range rows {
rows[i] = Row{ID: fmt.Sprintf("r-%d", i), Text: "t", Metadata: nil}
}
stats, err := Run(context.Background(), Config{
GatewayURL: srv.URL,
IndexName: "limited",
Dimension: 4,
Limit: 25,
EmbedBatch: 8,
EmbedWorkers: 2,
HTTPClient: srv.Client(),
}, &staticSource{rows: rows})
if err != nil {
t.Fatalf("Run: %v", err)
}
if stats.Scanned != 25 {
t.Errorf("Scanned: want 25 (limit), got %d", stats.Scanned)
}
}
func TestRun_EmptyTextSkipped(t *testing.T) {
fg := newFakeGateway(4)
srv := httptest.NewServer(fg.handler())
defer srv.Close()
rows := []Row{
{ID: "a", Text: "real text", Metadata: nil},
{ID: "b", Text: "", Metadata: nil}, // skipped
{ID: "c", Text: "more text", Metadata: nil},
}
stats, err := Run(context.Background(), Config{
GatewayURL: srv.URL, IndexName: "skip", Dimension: 4,
HTTPClient: srv.Client(),
}, &staticSource{rows: rows})
if err != nil {
t.Fatalf("Run: %v", err)
}
if stats.Scanned != 3 {
t.Errorf("Scanned: want 3 (b is skipped but counted as scanned), got %d", stats.Scanned)
}
if stats.Added != 2 {
t.Errorf("Added: want 2 (b excluded from embed), got %d", stats.Added)
}
}
func TestRun_RequiresIndexName(t *testing.T) {
_, err := Run(context.Background(), Config{Dimension: 4},
&staticSource{rows: nil})
if err == nil || !strings.Contains(err.Error(), "IndexName") {
t.Errorf("want IndexName-required error, got %v", err)
}
}
func TestRun_RequiresDimension(t *testing.T) {
_, err := Run(context.Background(), Config{IndexName: "x"},
&staticSource{rows: nil})
if err == nil || !strings.Contains(err.Error(), "Dimension") {
t.Errorf("want Dimension-required error, got %v", err)
}
}
// TestRun_ContextCancel verifies the pipeline drains cleanly when
// ctx is cancelled mid-run. Source returns rows fast enough that
// without ctx the run would complete; cancelling early should stop
// well before all 1000 rows are processed.
func TestRun_ContextCancel(t *testing.T) {
fg := newFakeGateway(4)
// Slow embed handler: each call sleeps 50ms.
mux := http.NewServeMux()
mux.HandleFunc("/v1/vectors/index", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
})
mux.HandleFunc("/v1/embed", func(w http.ResponseWriter, r *http.Request) {
var req struct {
Texts []string `json:"texts"`
}
_ = json.NewDecoder(r.Body).Decode(&req)
// Simulate slow-but-valid backend so we test ctx cancel, not
// degraded-payload handling (that's covered in production by
// the len-mismatch guard in Run's worker).
time.Sleep(50 * time.Millisecond)
_ = fg
vecs := make([][]float32, len(req.Texts))
for i := range vecs {
vecs[i] = []float32{1, 2, 3, 4}
}
_ = json.NewEncoder(w).Encode(map[string]any{
"vectors": vecs,
"dimension": 4,
"model": "x",
})
})
mux.HandleFunc("/v1/vectors/index/", func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, `{}`)
})
srv := httptest.NewServer(mux)
defer srv.Close()
rows := make([]Row, 1000)
for i := range rows {
rows[i] = Row{ID: fmt.Sprintf("r-%d", i), Text: "t"}
}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
stats, err := Run(ctx, Config{
GatewayURL: srv.URL, IndexName: "cancel_me", Dimension: 4,
EmbedBatch: 1, EmbedWorkers: 1, HTTPClient: srv.Client(),
}, &staticSource{rows: rows})
// Either an error or a partial stats; the point is "didn't process all 1000."
if stats.Scanned >= 1000 {
t.Errorf("ctx cancel did not stop early: scanned=%d err=%v", stats.Scanned, err)
}
}