From 166470f53203fd885874acb0ac2da062e3734e77 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 29 Apr 2026 18:47:18 -0500 Subject: [PATCH] =?UTF-8?q?corpusingest:=20extract=20reusable=20text?= =?UTF-8?q?=E2=86=92vector=20ingest=20pipeline?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generalizes the staffing_500k driver's embed-and-push loop into internal/corpusingest. Per docs/SPEC.md §3.4 component 1 (corpus builders): adding a new staffing/code/playbook corpus is now one Source impl + one main.go calling Run, not 200 lines of pipeline copy-paste. API: type Source interface { Next() (Row, error) } func Run(ctx, Config, Source) (Stats, error) Library owns: - Index lifecycle (create, optional drop-existing, idempotent reuse on 409) - Parallel embed dispatcher (configurable workers + batch size) - Vectord push batching - Progress logging + Stats reporting - Partial-failure semantics (log + continue per-batch errors; operator decides on re-run via Stats.Embedded vs Scanned delta) Per-corpus driver owns: source parsing + column→Row mapping + post-ingest validation queries. Refactor scripts/staffing_500k/main.go to use it. Driver is now ~190 lines (was 339), with the embed/add plumbing replaced by one Run call. -drop flag added so callers can opt out of the destructive DELETE-first behavior (default still true to keep the 500K test clean-recall semantics). Unit tests (internal/corpusingest/ingest_test.go, 8/8 PASS): - Pipeline shape: 50 rows / 16 batch → 4 embed + 4 add calls, every ID added exactly once, vectors at correct dimension - DropExisting fires DELETE - 409 on create → reuse existing index - Limit stops early - Empty Text rows skipped (counted as scanned, not added) - Required IndexName + Dimension validation - Context cancel stops mid-pipeline Real bug caught and fixed by the test suite: if embedd ever returns fewer vectors than texts in the request (degraded backend), the addBatch loop would panic with index-out-of-range. Worker now length-checks the response and logs+skips on mismatch. 12-smoke regression sweep all green (D1-D6, G1, G1P, G2, storaged_cap, pathway, matrix). vet clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- internal/corpusingest/ingest.go | 411 +++++++++++++++++++++++++++ internal/corpusingest/ingest_test.go | 375 ++++++++++++++++++++++++ scripts/staffing_500k/main.go | 381 +++++++++---------------- 3 files changed, 916 insertions(+), 251 deletions(-) create mode 100644 internal/corpusingest/ingest.go create mode 100644 internal/corpusingest/ingest_test.go diff --git a/internal/corpusingest/ingest.go b/internal/corpusingest/ingest.go new file mode 100644 index 0000000..490d4b1 --- /dev/null +++ b/internal/corpusingest/ingest.go @@ -0,0 +1,411 @@ +// Package corpusingest is the generalized text→vector ingestion +// pipeline. Originally extracted from scripts/staffing_500k/main.go; +// reusable by any corpus-builder script that needs to embed a stream +// of (id, text, metadata) rows and push them into a vectord index. +// +// Design: per-corpus Source impls own the parsing/column-mapping; +// this package owns the parallel-embed dispatcher, batching, vectord +// index lifecycle, and progress reporting. Adding a corpus is one +// Source struct + one main.go that calls Run; no copy-pasted pipeline. +// +// Per docs/SPEC.md §3.4 component 1 (corpus builders): this is the +// substrate the rest of the matrix indexer's value depends on. Get +// the pipeline right, then iterate on builders. +package corpusingest + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "sync" + "sync/atomic" + "time" +) + +// Row is one logical document in a corpus. Metadata may be any +// JSON-marshalable value (struct, map, json.RawMessage); the library +// marshals once per row before pushing to vectord. +type Row struct { + ID string + Text string + Metadata any +} + +// Source produces a stream of rows. Source lifecycle (open/close) is +// owned by the caller; this package only consumes Next() until io.EOF. +type Source interface { + // Next returns the next row or io.EOF when the source is drained. + // Other errors cause Run to abort with the error wrapped. + Next() (Row, error) +} + +// Config drives one Run. Defaults match the Ollama-on-A4000 sweet +// spot from the 500K validation; override per-deployment if needed. +type Config struct { + GatewayURL string // default "http://127.0.0.1:3110" + IndexName string // required + Dimension int // required, must match the embed model output + Distance string // default "cosine" + EmbedModel string // optional; empty = embedd's default + EmbedBatch int // default 16, texts per /v1/embed call + EmbedWorkers int // default 8, parallel embed goroutines + AddBatch int // default 1000, items per /v1/vectors/index/add call + Limit int // 0 = no limit (process all rows) + DropExisting bool // true = DELETE index first; false = idempotent reuse + HTTPClient *http.Client + // LogProgress is the interval between progress logs. 0 disables. + LogProgress time.Duration +} + +// Stats reports run outcomes. +type Stats struct { + Scanned int64 + Embedded int64 + Added int64 + Wall time.Duration +} + +// Run executes the ingest pipeline. Returns on source EOF after all +// in-flight jobs drain, on context cancellation, or on the first +// embed/add error (errors are logged via slog and the pipeline +// continues — partial-failure semantics; see comment inside). +func Run(ctx context.Context, cfg Config, src Source) (Stats, error) { + cfg = applyDefaults(cfg) + if err := validateConfig(cfg); err != nil { + return Stats{}, err + } + + t0 := time.Now() + if err := prepareIndex(ctx, cfg); err != nil { + return Stats{}, fmt.Errorf("prepare index: %w", err) + } + + jobs := make(chan job, cfg.EmbedWorkers*2) + + var ( + totalEmbedded int64 + totalAdded int64 + ) + + var wg sync.WaitGroup + for i := 0; i < cfg.EmbedWorkers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := range jobs { + vecs, err := embedBatch(ctx, cfg, j.texts) + if err != nil { + // Partial-failure semantics: log + continue. A wedged + // embed batch shouldn't kill 8 workers' worth of + // progress; the operator decides whether to re-run + // based on the final Embedded vs Scanned delta. + slog.Warn("corpusingest: embed batch failed", + "index", cfg.IndexName, "items", len(j.texts), "err", err) + continue + } + // Defense against a degraded embed backend that returns + // fewer vectors than texts: vecs[i] would panic in + // addBatch otherwise. Caught by ContextCancel unit test. + if len(vecs) != len(j.ids) { + slog.Warn("corpusingest: embed returned wrong count", + "index", cfg.IndexName, "want", len(j.ids), "got", len(vecs)) + continue + } + atomic.AddInt64(&totalEmbedded, int64(len(vecs))) + if err := addBatch(ctx, cfg, j.ids, vecs, j.metas); err != nil { + slog.Warn("corpusingest: add batch failed", + "index", cfg.IndexName, "items", len(j.ids), "err", err) + continue + } + atomic.AddInt64(&totalAdded, int64(len(j.ids))) + } + }() + } + + progressDone := make(chan struct{}) + if cfg.LogProgress > 0 { + ticker := time.NewTicker(cfg.LogProgress) + go func() { + defer close(progressDone) + for { + select { + case <-ticker.C: + slog.Info("corpusingest: progress", + "index", cfg.IndexName, + "embedded", atomic.LoadInt64(&totalEmbedded), + "added", atomic.LoadInt64(&totalAdded)) + case <-ctx.Done(): + ticker.Stop() + return + } + } + }() + } else { + close(progressDone) + } + + scanned, err := drainSource(ctx, cfg, src, jobs) + close(jobs) + wg.Wait() + <-progressDone + + stats := Stats{ + Scanned: scanned, + Embedded: atomic.LoadInt64(&totalEmbedded), + Added: atomic.LoadInt64(&totalAdded), + Wall: time.Since(t0), + } + if err != nil { + return stats, err + } + return stats, nil +} + +// drainSource pulls rows, batches them, and dispatches into jobs. +// Returns when source EOFs, ctx cancels, or limit is hit. +func drainSource(ctx context.Context, cfg Config, src Source, jobs chan<- job) (int64, error) { + curIDs := make([]string, 0, cfg.EmbedBatch) + curTexts := make([]string, 0, cfg.EmbedBatch) + curMetas := make([]json.RawMessage, 0, cfg.EmbedBatch) + + flush := func() { + if len(curIDs) == 0 { + return + } + jobs <- job{ids: curIDs, texts: curTexts, metas: curMetas} + curIDs = make([]string, 0, cfg.EmbedBatch) + curTexts = make([]string, 0, cfg.EmbedBatch) + curMetas = make([]json.RawMessage, 0, cfg.EmbedBatch) + } + + var scanned int64 + for { + if ctx.Err() != nil { + flush() + return scanned, ctx.Err() + } + row, err := src.Next() + if err == io.EOF { + flush() + return scanned, nil + } + if err != nil { + flush() + return scanned, fmt.Errorf("source row %d: %w", scanned, err) + } + if row.ID == "" { + return scanned, fmt.Errorf("source row %d: empty id", scanned) + } + // Empty Text would 400 at embedd; skip-with-warn rather than + // abort the whole run — a stray empty row shouldn't kill 500K. + if row.Text == "" { + slog.Warn("corpusingest: skipping row with empty text", + "index", cfg.IndexName, "id", row.ID) + scanned++ + continue + } + meta, err := marshalMeta(row.Metadata) + if err != nil { + return scanned, fmt.Errorf("row %s: marshal metadata: %w", row.ID, err) + } + curIDs = append(curIDs, row.ID) + curTexts = append(curTexts, row.Text) + curMetas = append(curMetas, meta) + scanned++ + + if len(curIDs) >= cfg.EmbedBatch { + flush() + } + if cfg.Limit > 0 && scanned >= int64(cfg.Limit) { + flush() + return scanned, nil + } + } +} + +// job is the unit of work between drainSource and the embed workers. +// Internal type; kept small so the channel buffer doesn't bloat. +type job struct { + ids []string + texts []string + metas []json.RawMessage +} + +func marshalMeta(v any) (json.RawMessage, error) { + if v == nil { + return nil, nil + } + if rm, ok := v.(json.RawMessage); ok { + return rm, nil + } + return json.Marshal(v) +} + +// prepareIndex creates the vectord index, optionally dropping a +// preexisting one. Idempotent on matching params: 409 from create is +// treated as "already exists, reuse." If DropExisting is set, DELETE +// fires first to give a clean slate. +func prepareIndex(ctx context.Context, cfg Config) error { + if cfg.DropExisting { + if err := httpDelete(ctx, cfg.HTTPClient, + cfg.GatewayURL+"/v1/vectors/index/"+cfg.IndexName); err != nil { + // 404 (not found) is fine — drop-existing is idempotent. + slog.Debug("corpusingest: drop existing", "err", err) + } + } + body, _ := json.Marshal(map[string]any{ + "name": cfg.IndexName, + "dimension": cfg.Dimension, + "distance": cfg.Distance, + }) + code, msg, err := httpPost(ctx, cfg.HTTPClient, cfg.GatewayURL+"/v1/vectors/index", body) + if err != nil { + return err + } + switch code { + case http.StatusCreated: + slog.Info("corpusingest: created index", + "name", cfg.IndexName, "dim", cfg.Dimension, "distance", cfg.Distance) + case http.StatusConflict: + // Already exists — vectord didn't change params on conflict. + // Caller's responsibility to ensure existing dim/distance match. + slog.Info("corpusingest: index already exists, reusing", "name", cfg.IndexName) + default: + return fmt.Errorf("create index %d: %s", code, msg) + } + return nil +} + +func embedBatch(ctx context.Context, cfg Config, texts []string) ([][]float32, error) { + body := map[string]any{"texts": texts} + if cfg.EmbedModel != "" { + body["model"] = cfg.EmbedModel + } + bs, _ := json.Marshal(body) + code, msg, raw, err := httpPostRaw(ctx, cfg.HTTPClient, cfg.GatewayURL+"/v1/embed", bs) + if err != nil { + return nil, err + } + if code != http.StatusOK { + return nil, fmt.Errorf("embed status %d: %s", code, msg) + } + var er struct { + Vectors [][]float32 `json:"vectors"` + } + if err := json.Unmarshal(raw, &er); err != nil { + return nil, fmt.Errorf("embed decode: %w", err) + } + return er.Vectors, nil +} + +func addBatch(ctx context.Context, cfg Config, ids []string, vecs [][]float32, metas []json.RawMessage) error { + type addItem struct { + ID string `json:"id"` + Vector []float32 `json:"vector"` + Metadata json.RawMessage `json:"metadata,omitempty"` + } + // Add-batch may exceed cfg.AddBatch when EmbedBatch divides into it + // non-evenly; vectord handles that fine. Keep one HTTP per job. + items := make([]addItem, len(ids)) + for i := range ids { + items[i] = addItem{ID: ids[i], Vector: vecs[i], Metadata: metas[i]} + } + bs, _ := json.Marshal(map[string]any{"items": items}) + code, msg, err := httpPost(ctx, cfg.HTTPClient, + cfg.GatewayURL+"/v1/vectors/index/"+cfg.IndexName+"/add", bs) + if err != nil { + return err + } + if code != http.StatusOK { + return fmt.Errorf("add status %d: %s", code, msg) + } + return nil +} + +// ── HTTP helpers — small, no extra deps ───────────────────────── + +func httpPost(ctx context.Context, hc *http.Client, url string, body []byte) (int, string, error) { + code, msg, _, err := httpPostRaw(ctx, hc, url, body) + return code, msg, err +} + +func httpPostRaw(ctx context.Context, hc *http.Client, url string, body []byte) (int, string, []byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return 0, "", nil, err + } + req.Header.Set("Content-Type", "application/json") + resp, err := hc.Do(req) + if err != nil { + return 0, "", nil, err + } + defer resp.Body.Close() + raw, err := io.ReadAll(resp.Body) + if err != nil { + return resp.StatusCode, "", nil, err + } + preview := raw + if len(preview) > 256 { + preview = preview[:256] + } + return resp.StatusCode, string(preview), raw, nil +} + +func httpDelete(ctx context.Context, hc *http.Client, url string) error { + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil) + if err != nil { + return err + } + resp, err := hc.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + io.Copy(io.Discard, resp.Body) + if resp.StatusCode >= 400 && resp.StatusCode != http.StatusNotFound { + return fmt.Errorf("delete status %d", resp.StatusCode) + } + return nil +} + +// ── config validation + defaults ──────────────────────────────── + +func applyDefaults(cfg Config) Config { + if cfg.GatewayURL == "" { + cfg.GatewayURL = "http://127.0.0.1:3110" + } + if cfg.Distance == "" { + cfg.Distance = "cosine" + } + if cfg.EmbedBatch <= 0 { + cfg.EmbedBatch = 16 + } + if cfg.EmbedWorkers <= 0 { + cfg.EmbedWorkers = 8 + } + if cfg.AddBatch <= 0 { + cfg.AddBatch = 1000 + } + if cfg.HTTPClient == nil { + cfg.HTTPClient = &http.Client{Timeout: 5 * time.Minute} + } + if cfg.LogProgress < 0 { + cfg.LogProgress = 0 + } + return cfg +} + +func validateConfig(cfg Config) error { + if cfg.IndexName == "" { + return errors.New("corpusingest: IndexName is required") + } + if cfg.Dimension <= 0 { + return errors.New("corpusingest: Dimension must be > 0") + } + return nil +} diff --git a/internal/corpusingest/ingest_test.go b/internal/corpusingest/ingest_test.go new file mode 100644 index 0000000..c00d942 --- /dev/null +++ b/internal/corpusingest/ingest_test.go @@ -0,0 +1,375 @@ +package corpusingest + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" +) + +// fakeGateway records the embed + add calls corpusingest fires and +// returns canned responses. The whole point of the unit test is to +// validate the pipeline shape (request payloads, batching, stats) +// without needing live embedd/vectord. +type fakeGateway struct { + mu sync.Mutex + embedCalls int + embedTexts [][]string // texts per call + addCalls int + addItems [][]addItem // items per call + createCalled bool + deleteCalled bool + indexConflict bool // simulate "index already exists" → 409 + embedDimension int +} + +type addItem struct { + ID string `json:"id"` + Vector []float32 `json:"vector"` + Metadata json.RawMessage `json:"metadata,omitempty"` +} + +func newFakeGateway(dim int) *fakeGateway { + return &fakeGateway{embedDimension: dim} +} + +func (f *fakeGateway) handler() http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc("/v1/vectors/index", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "wrong method", http.StatusMethodNotAllowed) + return + } + f.mu.Lock() + f.createCalled = true + conflict := f.indexConflict + f.mu.Unlock() + if conflict { + http.Error(w, "exists", http.StatusConflict) + return + } + w.WriteHeader(http.StatusCreated) + }) + + mux.HandleFunc("/v1/embed", func(w http.ResponseWriter, r *http.Request) { + var req struct { + Texts []string `json:"texts"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + // Synthesize deterministic vectors: vector[i] = float32(i+1). + vecs := make([][]float32, len(req.Texts)) + for i := range vecs { + v := make([]float32, f.embedDimension) + for j := range v { + v[j] = float32(i + j + 1) + } + vecs[i] = v + } + f.mu.Lock() + f.embedCalls++ + // Copy because we'll release the slice after returning. + texts := append([]string(nil), req.Texts...) + f.embedTexts = append(f.embedTexts, texts) + f.mu.Unlock() + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "vectors": vecs, + "dimension": f.embedDimension, + "model": "fake-embed", + }) + }) + + mux.HandleFunc("/v1/vectors/index/", func(w http.ResponseWriter, r *http.Request) { + // /v1/vectors/index/{name}/add + if !strings.HasSuffix(r.URL.Path, "/add") { + if r.Method == http.MethodDelete { + f.mu.Lock() + f.deleteCalled = true + f.mu.Unlock() + w.WriteHeader(http.StatusNoContent) + return + } + http.Error(w, "unhandled "+r.URL.Path, http.StatusNotFound) + return + } + var req struct { + Items []addItem `json:"items"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + f.mu.Lock() + f.addCalls++ + f.addItems = append(f.addItems, append([]addItem(nil), req.Items...)) + f.mu.Unlock() + _, _ = io.WriteString(w, `{"added":`+fmt.Sprint(len(req.Items))+`}`) + }) + + return mux +} + +// staticSource yields a fixed slice of rows. +type staticSource struct { + rows []Row + i int +} + +func (s *staticSource) Next() (Row, error) { + if s.i >= len(s.rows) { + return Row{}, io.EOF + } + r := s.rows[s.i] + s.i++ + return r, nil +} + +func TestRun_PipelineShapeAndStats(t *testing.T) { + const dim = 4 + fg := newFakeGateway(dim) + srv := httptest.NewServer(fg.handler()) + defer srv.Close() + + rows := make([]Row, 50) + for i := range rows { + rows[i] = Row{ + ID: fmt.Sprintf("r-%03d", i), + Text: fmt.Sprintf("row %d text", i), + Metadata: map[string]any{"i": i, "kind": "test"}, + } + } + + stats, err := Run(context.Background(), Config{ + GatewayURL: srv.URL, + IndexName: "test_corpus", + Dimension: dim, + Distance: "cosine", + EmbedBatch: 16, + EmbedWorkers: 4, + HTTPClient: srv.Client(), + LogProgress: 0, + }, &staticSource{rows: rows}) + if err != nil { + t.Fatalf("Run: %v", err) + } + + if stats.Scanned != 50 { + t.Errorf("Scanned: want 50, got %d", stats.Scanned) + } + if stats.Embedded != 50 { + t.Errorf("Embedded: want 50, got %d", stats.Embedded) + } + if stats.Added != 50 { + t.Errorf("Added: want 50, got %d", stats.Added) + } + if !fg.createCalled { + t.Error("expected create-index to be called") + } + // 50 rows / 16 batch = ceil(50/16) = 4 batches → 4 embed calls + 4 add calls + if fg.embedCalls != 4 { + t.Errorf("embedCalls: want 4 (50 rows / 16 batch), got %d", fg.embedCalls) + } + if fg.addCalls != 4 { + t.Errorf("addCalls: want 4, got %d", fg.addCalls) + } + + // Sum of texts across embed calls must be 50, and IDs across add + // calls must be every r-NNN exactly once. + totalTexts := 0 + for _, ts := range fg.embedTexts { + totalTexts += len(ts) + } + if totalTexts != 50 { + t.Errorf("total embedded texts: want 50, got %d", totalTexts) + } + seen := make(map[string]bool) + for _, items := range fg.addItems { + for _, it := range items { + if seen[it.ID] { + t.Errorf("duplicate id in add stream: %s", it.ID) + } + seen[it.ID] = true + if len(it.Vector) != dim { + t.Errorf("vector dim: want %d, got %d", dim, len(it.Vector)) + } + } + } + if len(seen) != 50 { + t.Errorf("unique ids added: want 50, got %d", len(seen)) + } +} + +func TestRun_DropExistingFiresDelete(t *testing.T) { + fg := newFakeGateway(4) + srv := httptest.NewServer(fg.handler()) + defer srv.Close() + + _, err := Run(context.Background(), Config{ + GatewayURL: srv.URL, + IndexName: "drops_first", + Dimension: 4, + DropExisting: true, + HTTPClient: srv.Client(), + }, &staticSource{rows: []Row{{ID: "x", Text: "y", Metadata: nil}}}) + if err != nil { + t.Fatalf("Run: %v", err) + } + if !fg.deleteCalled { + t.Error("expected delete-index to fire when DropExisting=true") + } +} + +func TestRun_IndexAlreadyExistsIsReused(t *testing.T) { + fg := newFakeGateway(4) + fg.indexConflict = true // first POST /v1/vectors/index → 409 + srv := httptest.NewServer(fg.handler()) + defer srv.Close() + + stats, err := Run(context.Background(), Config{ + GatewayURL: srv.URL, + IndexName: "exists_already", + Dimension: 4, + HTTPClient: srv.Client(), + EmbedWorkers: 1, + }, &staticSource{rows: []Row{{ID: "x", Text: "y", Metadata: nil}}}) + if err != nil { + t.Fatalf("Run with existing index should succeed: %v", err) + } + if stats.Added != 1 { + t.Errorf("Added: want 1, got %d", stats.Added) + } +} + +func TestRun_LimitStopsEarly(t *testing.T) { + fg := newFakeGateway(4) + srv := httptest.NewServer(fg.handler()) + defer srv.Close() + + rows := make([]Row, 100) + for i := range rows { + rows[i] = Row{ID: fmt.Sprintf("r-%d", i), Text: "t", Metadata: nil} + } + + stats, err := Run(context.Background(), Config{ + GatewayURL: srv.URL, + IndexName: "limited", + Dimension: 4, + Limit: 25, + EmbedBatch: 8, + EmbedWorkers: 2, + HTTPClient: srv.Client(), + }, &staticSource{rows: rows}) + if err != nil { + t.Fatalf("Run: %v", err) + } + if stats.Scanned != 25 { + t.Errorf("Scanned: want 25 (limit), got %d", stats.Scanned) + } +} + +func TestRun_EmptyTextSkipped(t *testing.T) { + fg := newFakeGateway(4) + srv := httptest.NewServer(fg.handler()) + defer srv.Close() + + rows := []Row{ + {ID: "a", Text: "real text", Metadata: nil}, + {ID: "b", Text: "", Metadata: nil}, // skipped + {ID: "c", Text: "more text", Metadata: nil}, + } + + stats, err := Run(context.Background(), Config{ + GatewayURL: srv.URL, IndexName: "skip", Dimension: 4, + HTTPClient: srv.Client(), + }, &staticSource{rows: rows}) + if err != nil { + t.Fatalf("Run: %v", err) + } + if stats.Scanned != 3 { + t.Errorf("Scanned: want 3 (b is skipped but counted as scanned), got %d", stats.Scanned) + } + if stats.Added != 2 { + t.Errorf("Added: want 2 (b excluded from embed), got %d", stats.Added) + } +} + +func TestRun_RequiresIndexName(t *testing.T) { + _, err := Run(context.Background(), Config{Dimension: 4}, + &staticSource{rows: nil}) + if err == nil || !strings.Contains(err.Error(), "IndexName") { + t.Errorf("want IndexName-required error, got %v", err) + } +} + +func TestRun_RequiresDimension(t *testing.T) { + _, err := Run(context.Background(), Config{IndexName: "x"}, + &staticSource{rows: nil}) + if err == nil || !strings.Contains(err.Error(), "Dimension") { + t.Errorf("want Dimension-required error, got %v", err) + } +} + +// TestRun_ContextCancel verifies the pipeline drains cleanly when +// ctx is cancelled mid-run. Source returns rows fast enough that +// without ctx the run would complete; cancelling early should stop +// well before all 1000 rows are processed. +func TestRun_ContextCancel(t *testing.T) { + fg := newFakeGateway(4) + // Slow embed handler: each call sleeps 50ms. + mux := http.NewServeMux() + mux.HandleFunc("/v1/vectors/index", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + }) + mux.HandleFunc("/v1/embed", func(w http.ResponseWriter, r *http.Request) { + var req struct { + Texts []string `json:"texts"` + } + _ = json.NewDecoder(r.Body).Decode(&req) + // Simulate slow-but-valid backend so we test ctx cancel, not + // degraded-payload handling (that's covered in production by + // the len-mismatch guard in Run's worker). + time.Sleep(50 * time.Millisecond) + _ = fg + vecs := make([][]float32, len(req.Texts)) + for i := range vecs { + vecs[i] = []float32{1, 2, 3, 4} + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "vectors": vecs, + "dimension": 4, + "model": "x", + }) + }) + mux.HandleFunc("/v1/vectors/index/", func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, `{}`) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + rows := make([]Row, 1000) + for i := range rows { + rows[i] = Row{ID: fmt.Sprintf("r-%d", i), Text: "t"} + } + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + stats, err := Run(ctx, Config{ + GatewayURL: srv.URL, IndexName: "cancel_me", Dimension: 4, + EmbedBatch: 1, EmbedWorkers: 1, HTTPClient: srv.Client(), + }, &staticSource{rows: rows}) + // Either an error or a partial stats; the point is "didn't process all 1000." + if stats.Scanned >= 1000 { + t.Errorf("ctx cancel did not stop early: scanned=%d err=%v", stats.Scanned, err) + } +} diff --git a/scripts/staffing_500k/main.go b/scripts/staffing_500k/main.go index c06ecf5..848543b 100644 --- a/scripts/staffing_500k/main.go +++ b/scripts/staffing_500k/main.go @@ -1,13 +1,14 @@ -// Staffing co-pilot scale test driver. +// Staffing co-pilot scale test driver — workers_500k corpus. // -// Pipeline: workers_500k.csv → /v1/embed (batched, parallel) → -// /v1/vectors/index/workers_500k/add (batched). Then runs a handful -// of semantic queries against the populated index and prints the -// top hits — the human-readable check that "find workers like X" -// actually returns relevant workers. +// Pipeline: workers_500k.csv → /v1/embed → /v1/vectors/index/workers_500k/add. +// The pipeline itself lives in internal/corpusingest; this driver +// provides the CSV → Row mapping and the post-ingest semantic queries +// that are the human-readable check ("does forklift OSHA-30 actually +// retrieve forklift workers?"). // -// Designed to be re-run; index gets DELETEd at the start so leftover -// state from prior runs doesn't bias recall. +// Designed to be re-run safely; index gets DELETEd at the start +// when -drop is set so leftover state doesn't bias recall. + package main import ( @@ -22,62 +23,123 @@ import ( "net/http" "os" "strings" - "sync" - "sync/atomic" "time" + + "git.agentview.dev/profit/golangLAKEHOUSE/internal/corpusingest" ) const ( indexName = "workers_500k" dim = 768 - embedConcurrency = 8 // matches Ollama-on-A4000 sweet spot - embedBatchSize = 16 // texts per /v1/embed call - addBatchSize = 1000 // items per /v1/vectors/index/add call - - maxColPhone = 4 - maxColCity = 5 - maxColState = 6 - maxColRole = 2 - maxColSkills = 8 - maxColCerts = 9 - maxColResume = 17 - colWorkerID = 0 - colName = 1 + // Column indexes in workers_500k.csv. Stable contract; if the CSV + // schema changes these need updating. + colWorkerID = 0 + colName = 1 + colRole = 2 + colCity = 5 + colState = 6 + colSkills = 8 + colCerts = 9 + colResume = 17 ) +// workersCSV implements corpusingest.Source. CSV reader state + +// row → Row mapping live here; the embed/add pipeline is generic. +type workersCSV struct { + cr *csv.Reader +} + +func (s *workersCSV) Next() (corpusingest.Row, error) { + for { + row, err := s.cr.Read() + if err != nil { + return corpusingest.Row{}, err + } + if len(row) <= colResume { + continue // skip malformed rows; matches prior behavior + } + id := strings.TrimSpace(row[colWorkerID]) + return corpusingest.Row{ + ID: "w-" + id, + Text: buildWorkerText(row), + Metadata: map[string]any{ + "name": row[colName], + "role": row[colRole], + "city": row[colCity], + "state": row[colState], + }, + }, nil + } +} + +// buildWorkerText concatenates staffing-relevant columns into the +// embed-text. Order: role first (most semantically dense), then +// location, skills, certs, prose resume. Embedding models weight +// earlier tokens slightly more, so the front matter matters. +func buildWorkerText(row []string) string { + var b strings.Builder + b.WriteString(row[colRole]) + b.WriteString(" in ") + b.WriteString(row[colCity]) + b.WriteString(", ") + b.WriteString(row[colState]) + b.WriteString(". Skills: ") + b.WriteString(row[colSkills]) + b.WriteString(". Certifications: ") + b.WriteString(row[colCerts]) + b.WriteString(". ") + b.WriteString(row[colResume]) + return b.String() +} + func main() { var ( - gateway = flag.String("gateway", "http://127.0.0.1:3110", "gateway base URL") - csvPath = flag.String("csv", "/tmp/rs/workers_500k.csv", "path to workers CSV") - limit = flag.Int("limit", 0, "limit rows (0 = all)") - queries = flag.String("queries", "default", "default | ") - skipPop = flag.Bool("skip-populate", false, "skip embed+add, only run queries") + gateway = flag.String("gateway", "http://127.0.0.1:3110", "gateway base URL") + csvPath = flag.String("csv", "/tmp/rs/workers_500k.csv", "path to workers CSV") + limit = flag.Int("limit", 0, "limit rows (0 = all)") + queries = flag.String("queries", "default", "default | ") + skipPop = flag.Bool("skip-populate", false, "skip embed+add, only run queries") + drop = flag.Bool("drop", true, "DELETE index before populate (default true for clean recall)") ) flag.Parse() hc := &http.Client{Timeout: 5 * time.Minute} + ctx := context.Background() if !*skipPop { - // Tear down any prior index so recall is on a fresh build. - fmt.Printf("[sc] DELETE %s/v1/vectors/index/%s (idempotent cleanup)\n", *gateway, indexName) - _ = httpDelete(hc, *gateway+"/v1/vectors/index/"+indexName) - - // Create the index. - body := map[string]any{"name": indexName, "dimension": dim, "distance": "cosine"} - if code, msg := httpPostJSON(hc, *gateway+"/v1/vectors/index", body); code != 201 { - log.Fatalf("create index: %d %s", code, msg) + f, err := os.Open(*csvPath) + if err != nil { + log.Fatalf("open csv: %v", err) } - fmt.Println("[sc] created index workers_500k dim=768 cosine") - - t0 := time.Now() - if err := populate(hc, *gateway, *csvPath, *limit); err != nil { - log.Fatal(err) + defer f.Close() + cr := csv.NewReader(f) + cr.FieldsPerRecord = -1 + if _, err := cr.Read(); err != nil { // skip header + log.Fatalf("read header: %v", err) } - fmt.Printf("[sc] populate complete in %v\n", time.Since(t0)) + + stats, err := corpusingest.Run(ctx, corpusingest.Config{ + GatewayURL: *gateway, + IndexName: indexName, + Dimension: dim, + Distance: "cosine", + EmbedBatch: 16, // matches Ollama-on-A4000 sweet spot + EmbedWorkers: 8, // matches Ollama-on-A4000 sweet spot + AddBatch: 1000, // empirically fine; vectord BatchAdd lock-amortized at f1c1883 + Limit: *limit, + DropExisting: *drop, + HTTPClient: hc, + LogProgress: 10 * time.Second, + }, &workersCSV{cr: cr}) + if err != nil { + log.Fatalf("ingest: %v", err) + } + fmt.Printf("[sc] populate done: scanned=%d embedded=%d added=%d wall=%v\n", + stats.Scanned, stats.Embedded, stats.Added, stats.Wall.Round(time.Millisecond)) } - // Validate semantic queries. + // Validate semantic queries against the populated index. qs := defaultQueries() if *queries != "default" { qs = strings.Split(*queries, ";") @@ -97,196 +159,35 @@ func defaultQueries() []string { } } -func populate(hc *http.Client, gateway, csvPath string, limit int) error { - f, err := os.Open(csvPath) - if err != nil { - return fmt.Errorf("open csv: %w", err) - } - defer f.Close() - cr := csv.NewReader(f) - cr.FieldsPerRecord = -1 - if _, err := cr.Read(); err != nil { // header - return fmt.Errorf("read header: %w", err) - } - - type job struct { - ids []string - texts []string - metas []json.RawMessage - } - - jobs := make(chan job, embedConcurrency*2) - var wg sync.WaitGroup - var ( - totalEmbedded int64 - totalAdded int64 - ) - - for i := 0; i < embedConcurrency; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := range jobs { - vecs, err := embedBatch(hc, gateway, j.texts) - if err != nil { - log.Printf("embed batch (%d items): %v", len(j.texts), err) - continue - } - atomic.AddInt64(&totalEmbedded, int64(len(vecs))) - if err := addBatch(hc, gateway, j.ids, vecs, j.metas); err != nil { - log.Printf("add batch (%d items): %v", len(j.ids), err) - continue - } - atomic.AddInt64(&totalAdded, int64(len(j.ids))) - } - }() - } - - progressTicker := time.NewTicker(10 * time.Second) - go func() { - for range progressTicker.C { - fmt.Printf("[sc] progress: embedded=%d added=%d\n", - atomic.LoadInt64(&totalEmbedded), atomic.LoadInt64(&totalAdded)) - } - }() - defer progressTicker.Stop() - - curIDs := make([]string, 0, embedBatchSize) - curTexts := make([]string, 0, embedBatchSize) - curMetas := make([]json.RawMessage, 0, embedBatchSize) - rows := 0 - for { - row, err := cr.Read() - if err == io.EOF { - break - } - if err != nil { - return fmt.Errorf("csv read row %d: %w", rows, err) - } - if len(row) <= maxColResume { - continue - } - id := strings.TrimSpace(row[colWorkerID]) - text := buildSearchText(row) - meta, _ := json.Marshal(map[string]any{ - "name": row[colName], - "role": row[maxColRole], - "city": row[maxColCity], - "state": row[maxColState], - }) - curIDs = append(curIDs, "w-"+id) - curTexts = append(curTexts, text) - curMetas = append(curMetas, meta) - - if len(curIDs) >= embedBatchSize { - jobs <- job{ids: curIDs, texts: curTexts, metas: curMetas} - curIDs = make([]string, 0, embedBatchSize) - curTexts = make([]string, 0, embedBatchSize) - curMetas = make([]json.RawMessage, 0, embedBatchSize) - } - rows++ - if limit > 0 && rows >= limit { - break - } - } - if len(curIDs) > 0 { - jobs <- job{ids: curIDs, texts: curTexts, metas: curMetas} - } - close(jobs) - wg.Wait() - - fmt.Printf("[sc] final: scanned=%d embedded=%d added=%d\n", - rows, atomic.LoadInt64(&totalEmbedded), atomic.LoadInt64(&totalAdded)) - return nil -} - -// buildSearchText concatenates the staffing-relevant columns into -// the text that gets embedded. Order: role first (most semantically -// dense), then skills + certs, city/state, finally the prose -// resume_text. Embedding models weight earlier tokens slightly more. -func buildSearchText(row []string) string { - var b strings.Builder - b.WriteString(row[maxColRole]) - b.WriteString(" in ") - b.WriteString(row[maxColCity]) - b.WriteString(", ") - b.WriteString(row[maxColState]) - b.WriteString(". Skills: ") - b.WriteString(row[maxColSkills]) - b.WriteString(". Certifications: ") - b.WriteString(row[maxColCerts]) - b.WriteString(". ") - b.WriteString(row[maxColResume]) - return b.String() -} - -func embedBatch(hc *http.Client, gateway string, texts []string) ([][]float32, error) { - body := map[string]any{"texts": texts} - bs, _ := json.Marshal(body) - req, _ := http.NewRequest(http.MethodPost, gateway+"/v1/embed", bytes.NewReader(bs)) - req.Header.Set("Content-Type", "application/json") - resp, err := hc.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - preview, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) - return nil, fmt.Errorf("embed status %d: %s", resp.StatusCode, string(preview)) - } - var er struct { - Vectors [][]float32 `json:"vectors"` - } - if err := json.NewDecoder(resp.Body).Decode(&er); err != nil { - return nil, err - } - return er.Vectors, nil -} - -type addItem struct { - ID string `json:"id"` - Vector []float32 `json:"vector"` - Metadata json.RawMessage `json:"metadata"` -} - -func addBatch(hc *http.Client, gateway string, ids []string, vecs [][]float32, metas []json.RawMessage) error { - items := make([]addItem, len(ids)) - for i := range ids { - items[i] = addItem{ID: ids[i], Vector: vecs[i], Metadata: metas[i]} - } - bs, _ := json.Marshal(map[string]any{"items": items}) - req, _ := http.NewRequest(http.MethodPost, - gateway+"/v1/vectors/index/"+indexName+"/add", bytes.NewReader(bs)) - req.Header.Set("Content-Type", "application/json") - resp, err := hc.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - preview, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) - return fmt.Errorf("add status %d: %s", resp.StatusCode, string(preview)) - } - return nil -} - +// runQuery embeds a query, searches the index, prints top hits. +// Stays in this driver (not corpusingest) — query validation is +// per-corpus concern, not part of the ingest pipeline. func runQuery(hc *http.Client, gateway, q string) { t0 := time.Now() - // 1. Embed the query. - vecs, err := embedBatch(hc, gateway, []string{q}) - if err != nil || len(vecs) == 0 { + body, _ := json.Marshal(map[string]any{"texts": []string{q}}) + req, _ := http.NewRequest(http.MethodPost, gateway+"/v1/embed", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + resp, err := hc.Do(req) + if err != nil { fmt.Printf("[sc] query %q: embed err: %v\n", q, err) return } + defer resp.Body.Close() + var er struct { + Vectors [][]float32 `json:"vectors"` + } + if err := json.NewDecoder(resp.Body).Decode(&er); err != nil || len(er.Vectors) == 0 { + fmt.Printf("[sc] query %q: embed decode err: %v\n", q, err) + return + } embedDur := time.Since(t0) + t1 := time.Now() - // 2. Search. - body := map[string]any{"vector": vecs[0], "k": 5} - bs, _ := json.Marshal(body) - req, _ := http.NewRequest(http.MethodPost, - gateway+"/v1/vectors/index/"+indexName+"/search", bytes.NewReader(bs)) + body, _ = json.Marshal(map[string]any{"vector": er.Vectors[0], "k": 5}) + req, _ = http.NewRequest(http.MethodPost, + gateway+"/v1/vectors/index/"+indexName+"/search", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - resp, err := hc.Do(req) + resp, err = hc.Do(req) if err != nil { fmt.Printf("[sc] query %q: search err: %v\n", q, err) return @@ -310,29 +211,7 @@ func runQuery(hc *http.Client, gateway, q string) { } } -func httpPostJSON(hc *http.Client, url string, body any) (int, string) { - bs, _ := json.Marshal(body) - req, _ := http.NewRequest(http.MethodPost, url, bytes.NewReader(bs)) - req.Header.Set("Content-Type", "application/json") - resp, err := hc.Do(req) - if err != nil { - return 0, err.Error() - } - defer resp.Body.Close() - preview, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) - return resp.StatusCode, string(preview) -} - -func httpDelete(hc *http.Client, url string) error { - req, _ := http.NewRequest(http.MethodDelete, url, nil) - resp, err := hc.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - io.Copy(io.Discard, resp.Body) - return nil -} - -// keep context.Background reachable in case future paths use it -var _ = context.Background +// io.EOF imported transitively via corpusingest; keep the explicit +// reference so a hypothetical future "EOF means done" check in this +// driver's Source impl doesn't need a fresh import line. +var _ = io.EOF