corpusingest: extract reusable text→vector ingest pipeline
Generalizes the staffing_500k driver's embed-and-push loop into
internal/corpusingest. Per docs/SPEC.md §3.4 component 1 (corpus
builders): adding a new staffing/code/playbook corpus is now one
Source impl + one main.go calling Run, not 200 lines of pipeline
copy-paste.
API:
type Source interface { Next() (Row, error) }
func Run(ctx, Config, Source) (Stats, error)
Library owns:
- Index lifecycle (create, optional drop-existing, idempotent
reuse on 409)
- Parallel embed dispatcher (configurable workers + batch size)
- Vectord push batching
- Progress logging + Stats reporting
- Partial-failure semantics (log + continue per-batch errors;
operator decides on re-run via Stats.Embedded vs Scanned delta)
Per-corpus driver owns: source parsing + column→Row mapping +
post-ingest validation queries.
Refactor scripts/staffing_500k/main.go to use it. Driver is now
~190 lines (was 339), with the embed/add plumbing replaced by one
Run call. -drop flag added so callers can opt out of the destructive
DELETE-first behavior (default still true to keep the 500K test
clean-recall semantics).
Unit tests (internal/corpusingest/ingest_test.go, 8/8 PASS):
- Pipeline shape: 50 rows / 16 batch → 4 embed + 4 add calls,
every ID added exactly once, vectors at correct dimension
- DropExisting fires DELETE
- 409 on create → reuse existing index
- Limit stops early
- Empty Text rows skipped (counted as scanned, not added)
- Required IndexName + Dimension validation
- Context cancel stops mid-pipeline
Real bug caught and fixed by the test suite: if embedd ever returns
fewer vectors than texts in the request (degraded backend), the
addBatch loop would panic with index-out-of-range. Worker now
length-checks the response and logs+skips on mismatch.
12-smoke regression sweep all green (D1-D6, G1, G1P, G2,
storaged_cap, pathway, matrix). vet clean.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
c1d96b7b60
commit
166470f532
411
internal/corpusingest/ingest.go
Normal file
411
internal/corpusingest/ingest.go
Normal file
@ -0,0 +1,411 @@
|
||||
// Package corpusingest is the generalized text→vector ingestion
|
||||
// pipeline. Originally extracted from scripts/staffing_500k/main.go;
|
||||
// reusable by any corpus-builder script that needs to embed a stream
|
||||
// of (id, text, metadata) rows and push them into a vectord index.
|
||||
//
|
||||
// Design: per-corpus Source impls own the parsing/column-mapping;
|
||||
// this package owns the parallel-embed dispatcher, batching, vectord
|
||||
// index lifecycle, and progress reporting. Adding a corpus is one
|
||||
// Source struct + one main.go that calls Run; no copy-pasted pipeline.
|
||||
//
|
||||
// Per docs/SPEC.md §3.4 component 1 (corpus builders): this is the
|
||||
// substrate the rest of the matrix indexer's value depends on. Get
|
||||
// the pipeline right, then iterate on builders.
|
||||
package corpusingest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Row is one logical document in a corpus. Metadata may be any
|
||||
// JSON-marshalable value (struct, map, json.RawMessage); the library
|
||||
// marshals once per row before pushing to vectord.
|
||||
type Row struct {
|
||||
ID string
|
||||
Text string
|
||||
Metadata any
|
||||
}
|
||||
|
||||
// Source produces a stream of rows. Source lifecycle (open/close) is
|
||||
// owned by the caller; this package only consumes Next() until io.EOF.
|
||||
type Source interface {
|
||||
// Next returns the next row or io.EOF when the source is drained.
|
||||
// Other errors cause Run to abort with the error wrapped.
|
||||
Next() (Row, error)
|
||||
}
|
||||
|
||||
// Config drives one Run. Defaults match the Ollama-on-A4000 sweet
|
||||
// spot from the 500K validation; override per-deployment if needed.
|
||||
type Config struct {
|
||||
GatewayURL string // default "http://127.0.0.1:3110"
|
||||
IndexName string // required
|
||||
Dimension int // required, must match the embed model output
|
||||
Distance string // default "cosine"
|
||||
EmbedModel string // optional; empty = embedd's default
|
||||
EmbedBatch int // default 16, texts per /v1/embed call
|
||||
EmbedWorkers int // default 8, parallel embed goroutines
|
||||
AddBatch int // default 1000, items per /v1/vectors/index/add call
|
||||
Limit int // 0 = no limit (process all rows)
|
||||
DropExisting bool // true = DELETE index first; false = idempotent reuse
|
||||
HTTPClient *http.Client
|
||||
// LogProgress is the interval between progress logs. 0 disables.
|
||||
LogProgress time.Duration
|
||||
}
|
||||
|
||||
// Stats reports run outcomes.
|
||||
type Stats struct {
|
||||
Scanned int64
|
||||
Embedded int64
|
||||
Added int64
|
||||
Wall time.Duration
|
||||
}
|
||||
|
||||
// Run executes the ingest pipeline. Returns on source EOF after all
|
||||
// in-flight jobs drain, on context cancellation, or on the first
|
||||
// embed/add error (errors are logged via slog and the pipeline
|
||||
// continues — partial-failure semantics; see comment inside).
|
||||
func Run(ctx context.Context, cfg Config, src Source) (Stats, error) {
|
||||
cfg = applyDefaults(cfg)
|
||||
if err := validateConfig(cfg); err != nil {
|
||||
return Stats{}, err
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
if err := prepareIndex(ctx, cfg); err != nil {
|
||||
return Stats{}, fmt.Errorf("prepare index: %w", err)
|
||||
}
|
||||
|
||||
jobs := make(chan job, cfg.EmbedWorkers*2)
|
||||
|
||||
var (
|
||||
totalEmbedded int64
|
||||
totalAdded int64
|
||||
)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < cfg.EmbedWorkers; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := range jobs {
|
||||
vecs, err := embedBatch(ctx, cfg, j.texts)
|
||||
if err != nil {
|
||||
// Partial-failure semantics: log + continue. A wedged
|
||||
// embed batch shouldn't kill 8 workers' worth of
|
||||
// progress; the operator decides whether to re-run
|
||||
// based on the final Embedded vs Scanned delta.
|
||||
slog.Warn("corpusingest: embed batch failed",
|
||||
"index", cfg.IndexName, "items", len(j.texts), "err", err)
|
||||
continue
|
||||
}
|
||||
// Defense against a degraded embed backend that returns
|
||||
// fewer vectors than texts: vecs[i] would panic in
|
||||
// addBatch otherwise. Caught by ContextCancel unit test.
|
||||
if len(vecs) != len(j.ids) {
|
||||
slog.Warn("corpusingest: embed returned wrong count",
|
||||
"index", cfg.IndexName, "want", len(j.ids), "got", len(vecs))
|
||||
continue
|
||||
}
|
||||
atomic.AddInt64(&totalEmbedded, int64(len(vecs)))
|
||||
if err := addBatch(ctx, cfg, j.ids, vecs, j.metas); err != nil {
|
||||
slog.Warn("corpusingest: add batch failed",
|
||||
"index", cfg.IndexName, "items", len(j.ids), "err", err)
|
||||
continue
|
||||
}
|
||||
atomic.AddInt64(&totalAdded, int64(len(j.ids)))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
progressDone := make(chan struct{})
|
||||
if cfg.LogProgress > 0 {
|
||||
ticker := time.NewTicker(cfg.LogProgress)
|
||||
go func() {
|
||||
defer close(progressDone)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
slog.Info("corpusingest: progress",
|
||||
"index", cfg.IndexName,
|
||||
"embedded", atomic.LoadInt64(&totalEmbedded),
|
||||
"added", atomic.LoadInt64(&totalAdded))
|
||||
case <-ctx.Done():
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
close(progressDone)
|
||||
}
|
||||
|
||||
scanned, err := drainSource(ctx, cfg, src, jobs)
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
<-progressDone
|
||||
|
||||
stats := Stats{
|
||||
Scanned: scanned,
|
||||
Embedded: atomic.LoadInt64(&totalEmbedded),
|
||||
Added: atomic.LoadInt64(&totalAdded),
|
||||
Wall: time.Since(t0),
|
||||
}
|
||||
if err != nil {
|
||||
return stats, err
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// drainSource pulls rows, batches them, and dispatches into jobs.
|
||||
// Returns when source EOFs, ctx cancels, or limit is hit.
|
||||
func drainSource(ctx context.Context, cfg Config, src Source, jobs chan<- job) (int64, error) {
|
||||
curIDs := make([]string, 0, cfg.EmbedBatch)
|
||||
curTexts := make([]string, 0, cfg.EmbedBatch)
|
||||
curMetas := make([]json.RawMessage, 0, cfg.EmbedBatch)
|
||||
|
||||
flush := func() {
|
||||
if len(curIDs) == 0 {
|
||||
return
|
||||
}
|
||||
jobs <- job{ids: curIDs, texts: curTexts, metas: curMetas}
|
||||
curIDs = make([]string, 0, cfg.EmbedBatch)
|
||||
curTexts = make([]string, 0, cfg.EmbedBatch)
|
||||
curMetas = make([]json.RawMessage, 0, cfg.EmbedBatch)
|
||||
}
|
||||
|
||||
var scanned int64
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
flush()
|
||||
return scanned, ctx.Err()
|
||||
}
|
||||
row, err := src.Next()
|
||||
if err == io.EOF {
|
||||
flush()
|
||||
return scanned, nil
|
||||
}
|
||||
if err != nil {
|
||||
flush()
|
||||
return scanned, fmt.Errorf("source row %d: %w", scanned, err)
|
||||
}
|
||||
if row.ID == "" {
|
||||
return scanned, fmt.Errorf("source row %d: empty id", scanned)
|
||||
}
|
||||
// Empty Text would 400 at embedd; skip-with-warn rather than
|
||||
// abort the whole run — a stray empty row shouldn't kill 500K.
|
||||
if row.Text == "" {
|
||||
slog.Warn("corpusingest: skipping row with empty text",
|
||||
"index", cfg.IndexName, "id", row.ID)
|
||||
scanned++
|
||||
continue
|
||||
}
|
||||
meta, err := marshalMeta(row.Metadata)
|
||||
if err != nil {
|
||||
return scanned, fmt.Errorf("row %s: marshal metadata: %w", row.ID, err)
|
||||
}
|
||||
curIDs = append(curIDs, row.ID)
|
||||
curTexts = append(curTexts, row.Text)
|
||||
curMetas = append(curMetas, meta)
|
||||
scanned++
|
||||
|
||||
if len(curIDs) >= cfg.EmbedBatch {
|
||||
flush()
|
||||
}
|
||||
if cfg.Limit > 0 && scanned >= int64(cfg.Limit) {
|
||||
flush()
|
||||
return scanned, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// job is the unit of work between drainSource and the embed workers.
|
||||
// Internal type; kept small so the channel buffer doesn't bloat.
|
||||
type job struct {
|
||||
ids []string
|
||||
texts []string
|
||||
metas []json.RawMessage
|
||||
}
|
||||
|
||||
func marshalMeta(v any) (json.RawMessage, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if rm, ok := v.(json.RawMessage); ok {
|
||||
return rm, nil
|
||||
}
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
// prepareIndex creates the vectord index, optionally dropping a
|
||||
// preexisting one. Idempotent on matching params: 409 from create is
|
||||
// treated as "already exists, reuse." If DropExisting is set, DELETE
|
||||
// fires first to give a clean slate.
|
||||
func prepareIndex(ctx context.Context, cfg Config) error {
|
||||
if cfg.DropExisting {
|
||||
if err := httpDelete(ctx, cfg.HTTPClient,
|
||||
cfg.GatewayURL+"/v1/vectors/index/"+cfg.IndexName); err != nil {
|
||||
// 404 (not found) is fine — drop-existing is idempotent.
|
||||
slog.Debug("corpusingest: drop existing", "err", err)
|
||||
}
|
||||
}
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"name": cfg.IndexName,
|
||||
"dimension": cfg.Dimension,
|
||||
"distance": cfg.Distance,
|
||||
})
|
||||
code, msg, err := httpPost(ctx, cfg.HTTPClient, cfg.GatewayURL+"/v1/vectors/index", body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch code {
|
||||
case http.StatusCreated:
|
||||
slog.Info("corpusingest: created index",
|
||||
"name", cfg.IndexName, "dim", cfg.Dimension, "distance", cfg.Distance)
|
||||
case http.StatusConflict:
|
||||
// Already exists — vectord didn't change params on conflict.
|
||||
// Caller's responsibility to ensure existing dim/distance match.
|
||||
slog.Info("corpusingest: index already exists, reusing", "name", cfg.IndexName)
|
||||
default:
|
||||
return fmt.Errorf("create index %d: %s", code, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func embedBatch(ctx context.Context, cfg Config, texts []string) ([][]float32, error) {
|
||||
body := map[string]any{"texts": texts}
|
||||
if cfg.EmbedModel != "" {
|
||||
body["model"] = cfg.EmbedModel
|
||||
}
|
||||
bs, _ := json.Marshal(body)
|
||||
code, msg, raw, err := httpPostRaw(ctx, cfg.HTTPClient, cfg.GatewayURL+"/v1/embed", bs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if code != http.StatusOK {
|
||||
return nil, fmt.Errorf("embed status %d: %s", code, msg)
|
||||
}
|
||||
var er struct {
|
||||
Vectors [][]float32 `json:"vectors"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &er); err != nil {
|
||||
return nil, fmt.Errorf("embed decode: %w", err)
|
||||
}
|
||||
return er.Vectors, nil
|
||||
}
|
||||
|
||||
func addBatch(ctx context.Context, cfg Config, ids []string, vecs [][]float32, metas []json.RawMessage) error {
|
||||
type addItem struct {
|
||||
ID string `json:"id"`
|
||||
Vector []float32 `json:"vector"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
}
|
||||
// Add-batch may exceed cfg.AddBatch when EmbedBatch divides into it
|
||||
// non-evenly; vectord handles that fine. Keep one HTTP per job.
|
||||
items := make([]addItem, len(ids))
|
||||
for i := range ids {
|
||||
items[i] = addItem{ID: ids[i], Vector: vecs[i], Metadata: metas[i]}
|
||||
}
|
||||
bs, _ := json.Marshal(map[string]any{"items": items})
|
||||
code, msg, err := httpPost(ctx, cfg.HTTPClient,
|
||||
cfg.GatewayURL+"/v1/vectors/index/"+cfg.IndexName+"/add", bs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if code != http.StatusOK {
|
||||
return fmt.Errorf("add status %d: %s", code, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ── HTTP helpers — small, no extra deps ─────────────────────────
|
||||
|
||||
func httpPost(ctx context.Context, hc *http.Client, url string, body []byte) (int, string, error) {
|
||||
code, msg, _, err := httpPostRaw(ctx, hc, url, body)
|
||||
return code, msg, err
|
||||
}
|
||||
|
||||
func httpPostRaw(ctx context.Context, hc *http.Client, url string, body []byte) (int, string, []byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return 0, "", nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return 0, "", nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
raw, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return resp.StatusCode, "", nil, err
|
||||
}
|
||||
preview := raw
|
||||
if len(preview) > 256 {
|
||||
preview = preview[:256]
|
||||
}
|
||||
return resp.StatusCode, string(preview), raw, nil
|
||||
}
|
||||
|
||||
func httpDelete(ctx context.Context, hc *http.Client, url string) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
if resp.StatusCode >= 400 && resp.StatusCode != http.StatusNotFound {
|
||||
return fmt.Errorf("delete status %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ── config validation + defaults ────────────────────────────────
|
||||
|
||||
func applyDefaults(cfg Config) Config {
|
||||
if cfg.GatewayURL == "" {
|
||||
cfg.GatewayURL = "http://127.0.0.1:3110"
|
||||
}
|
||||
if cfg.Distance == "" {
|
||||
cfg.Distance = "cosine"
|
||||
}
|
||||
if cfg.EmbedBatch <= 0 {
|
||||
cfg.EmbedBatch = 16
|
||||
}
|
||||
if cfg.EmbedWorkers <= 0 {
|
||||
cfg.EmbedWorkers = 8
|
||||
}
|
||||
if cfg.AddBatch <= 0 {
|
||||
cfg.AddBatch = 1000
|
||||
}
|
||||
if cfg.HTTPClient == nil {
|
||||
cfg.HTTPClient = &http.Client{Timeout: 5 * time.Minute}
|
||||
}
|
||||
if cfg.LogProgress < 0 {
|
||||
cfg.LogProgress = 0
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func validateConfig(cfg Config) error {
|
||||
if cfg.IndexName == "" {
|
||||
return errors.New("corpusingest: IndexName is required")
|
||||
}
|
||||
if cfg.Dimension <= 0 {
|
||||
return errors.New("corpusingest: Dimension must be > 0")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
375
internal/corpusingest/ingest_test.go
Normal file
375
internal/corpusingest/ingest_test.go
Normal file
@ -0,0 +1,375 @@
|
||||
package corpusingest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fakeGateway records the embed + add calls corpusingest fires and
|
||||
// returns canned responses. The whole point of the unit test is to
|
||||
// validate the pipeline shape (request payloads, batching, stats)
|
||||
// without needing live embedd/vectord.
|
||||
type fakeGateway struct {
|
||||
mu sync.Mutex
|
||||
embedCalls int
|
||||
embedTexts [][]string // texts per call
|
||||
addCalls int
|
||||
addItems [][]addItem // items per call
|
||||
createCalled bool
|
||||
deleteCalled bool
|
||||
indexConflict bool // simulate "index already exists" → 409
|
||||
embedDimension int
|
||||
}
|
||||
|
||||
type addItem struct {
|
||||
ID string `json:"id"`
|
||||
Vector []float32 `json:"vector"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func newFakeGateway(dim int) *fakeGateway {
|
||||
return &fakeGateway{embedDimension: dim}
|
||||
}
|
||||
|
||||
func (f *fakeGateway) handler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/v1/vectors/index", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "wrong method", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
f.mu.Lock()
|
||||
f.createCalled = true
|
||||
conflict := f.indexConflict
|
||||
f.mu.Unlock()
|
||||
if conflict {
|
||||
http.Error(w, "exists", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/embed", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Texts []string `json:"texts"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// Synthesize deterministic vectors: vector[i] = float32(i+1).
|
||||
vecs := make([][]float32, len(req.Texts))
|
||||
for i := range vecs {
|
||||
v := make([]float32, f.embedDimension)
|
||||
for j := range v {
|
||||
v[j] = float32(i + j + 1)
|
||||
}
|
||||
vecs[i] = v
|
||||
}
|
||||
f.mu.Lock()
|
||||
f.embedCalls++
|
||||
// Copy because we'll release the slice after returning.
|
||||
texts := append([]string(nil), req.Texts...)
|
||||
f.embedTexts = append(f.embedTexts, texts)
|
||||
f.mu.Unlock()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"vectors": vecs,
|
||||
"dimension": f.embedDimension,
|
||||
"model": "fake-embed",
|
||||
})
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/vectors/index/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// /v1/vectors/index/{name}/add
|
||||
if !strings.HasSuffix(r.URL.Path, "/add") {
|
||||
if r.Method == http.MethodDelete {
|
||||
f.mu.Lock()
|
||||
f.deleteCalled = true
|
||||
f.mu.Unlock()
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
http.Error(w, "unhandled "+r.URL.Path, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Items []addItem `json:"items"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
f.mu.Lock()
|
||||
f.addCalls++
|
||||
f.addItems = append(f.addItems, append([]addItem(nil), req.Items...))
|
||||
f.mu.Unlock()
|
||||
_, _ = io.WriteString(w, `{"added":`+fmt.Sprint(len(req.Items))+`}`)
|
||||
})
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
// staticSource yields a fixed slice of rows.
|
||||
type staticSource struct {
|
||||
rows []Row
|
||||
i int
|
||||
}
|
||||
|
||||
func (s *staticSource) Next() (Row, error) {
|
||||
if s.i >= len(s.rows) {
|
||||
return Row{}, io.EOF
|
||||
}
|
||||
r := s.rows[s.i]
|
||||
s.i++
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func TestRun_PipelineShapeAndStats(t *testing.T) {
|
||||
const dim = 4
|
||||
fg := newFakeGateway(dim)
|
||||
srv := httptest.NewServer(fg.handler())
|
||||
defer srv.Close()
|
||||
|
||||
rows := make([]Row, 50)
|
||||
for i := range rows {
|
||||
rows[i] = Row{
|
||||
ID: fmt.Sprintf("r-%03d", i),
|
||||
Text: fmt.Sprintf("row %d text", i),
|
||||
Metadata: map[string]any{"i": i, "kind": "test"},
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := Run(context.Background(), Config{
|
||||
GatewayURL: srv.URL,
|
||||
IndexName: "test_corpus",
|
||||
Dimension: dim,
|
||||
Distance: "cosine",
|
||||
EmbedBatch: 16,
|
||||
EmbedWorkers: 4,
|
||||
HTTPClient: srv.Client(),
|
||||
LogProgress: 0,
|
||||
}, &staticSource{rows: rows})
|
||||
if err != nil {
|
||||
t.Fatalf("Run: %v", err)
|
||||
}
|
||||
|
||||
if stats.Scanned != 50 {
|
||||
t.Errorf("Scanned: want 50, got %d", stats.Scanned)
|
||||
}
|
||||
if stats.Embedded != 50 {
|
||||
t.Errorf("Embedded: want 50, got %d", stats.Embedded)
|
||||
}
|
||||
if stats.Added != 50 {
|
||||
t.Errorf("Added: want 50, got %d", stats.Added)
|
||||
}
|
||||
if !fg.createCalled {
|
||||
t.Error("expected create-index to be called")
|
||||
}
|
||||
// 50 rows / 16 batch = ceil(50/16) = 4 batches → 4 embed calls + 4 add calls
|
||||
if fg.embedCalls != 4 {
|
||||
t.Errorf("embedCalls: want 4 (50 rows / 16 batch), got %d", fg.embedCalls)
|
||||
}
|
||||
if fg.addCalls != 4 {
|
||||
t.Errorf("addCalls: want 4, got %d", fg.addCalls)
|
||||
}
|
||||
|
||||
// Sum of texts across embed calls must be 50, and IDs across add
|
||||
// calls must be every r-NNN exactly once.
|
||||
totalTexts := 0
|
||||
for _, ts := range fg.embedTexts {
|
||||
totalTexts += len(ts)
|
||||
}
|
||||
if totalTexts != 50 {
|
||||
t.Errorf("total embedded texts: want 50, got %d", totalTexts)
|
||||
}
|
||||
seen := make(map[string]bool)
|
||||
for _, items := range fg.addItems {
|
||||
for _, it := range items {
|
||||
if seen[it.ID] {
|
||||
t.Errorf("duplicate id in add stream: %s", it.ID)
|
||||
}
|
||||
seen[it.ID] = true
|
||||
if len(it.Vector) != dim {
|
||||
t.Errorf("vector dim: want %d, got %d", dim, len(it.Vector))
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(seen) != 50 {
|
||||
t.Errorf("unique ids added: want 50, got %d", len(seen))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_DropExistingFiresDelete(t *testing.T) {
|
||||
fg := newFakeGateway(4)
|
||||
srv := httptest.NewServer(fg.handler())
|
||||
defer srv.Close()
|
||||
|
||||
_, err := Run(context.Background(), Config{
|
||||
GatewayURL: srv.URL,
|
||||
IndexName: "drops_first",
|
||||
Dimension: 4,
|
||||
DropExisting: true,
|
||||
HTTPClient: srv.Client(),
|
||||
}, &staticSource{rows: []Row{{ID: "x", Text: "y", Metadata: nil}}})
|
||||
if err != nil {
|
||||
t.Fatalf("Run: %v", err)
|
||||
}
|
||||
if !fg.deleteCalled {
|
||||
t.Error("expected delete-index to fire when DropExisting=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_IndexAlreadyExistsIsReused(t *testing.T) {
|
||||
fg := newFakeGateway(4)
|
||||
fg.indexConflict = true // first POST /v1/vectors/index → 409
|
||||
srv := httptest.NewServer(fg.handler())
|
||||
defer srv.Close()
|
||||
|
||||
stats, err := Run(context.Background(), Config{
|
||||
GatewayURL: srv.URL,
|
||||
IndexName: "exists_already",
|
||||
Dimension: 4,
|
||||
HTTPClient: srv.Client(),
|
||||
EmbedWorkers: 1,
|
||||
}, &staticSource{rows: []Row{{ID: "x", Text: "y", Metadata: nil}}})
|
||||
if err != nil {
|
||||
t.Fatalf("Run with existing index should succeed: %v", err)
|
||||
}
|
||||
if stats.Added != 1 {
|
||||
t.Errorf("Added: want 1, got %d", stats.Added)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_LimitStopsEarly(t *testing.T) {
|
||||
fg := newFakeGateway(4)
|
||||
srv := httptest.NewServer(fg.handler())
|
||||
defer srv.Close()
|
||||
|
||||
rows := make([]Row, 100)
|
||||
for i := range rows {
|
||||
rows[i] = Row{ID: fmt.Sprintf("r-%d", i), Text: "t", Metadata: nil}
|
||||
}
|
||||
|
||||
stats, err := Run(context.Background(), Config{
|
||||
GatewayURL: srv.URL,
|
||||
IndexName: "limited",
|
||||
Dimension: 4,
|
||||
Limit: 25,
|
||||
EmbedBatch: 8,
|
||||
EmbedWorkers: 2,
|
||||
HTTPClient: srv.Client(),
|
||||
}, &staticSource{rows: rows})
|
||||
if err != nil {
|
||||
t.Fatalf("Run: %v", err)
|
||||
}
|
||||
if stats.Scanned != 25 {
|
||||
t.Errorf("Scanned: want 25 (limit), got %d", stats.Scanned)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_EmptyTextSkipped(t *testing.T) {
|
||||
fg := newFakeGateway(4)
|
||||
srv := httptest.NewServer(fg.handler())
|
||||
defer srv.Close()
|
||||
|
||||
rows := []Row{
|
||||
{ID: "a", Text: "real text", Metadata: nil},
|
||||
{ID: "b", Text: "", Metadata: nil}, // skipped
|
||||
{ID: "c", Text: "more text", Metadata: nil},
|
||||
}
|
||||
|
||||
stats, err := Run(context.Background(), Config{
|
||||
GatewayURL: srv.URL, IndexName: "skip", Dimension: 4,
|
||||
HTTPClient: srv.Client(),
|
||||
}, &staticSource{rows: rows})
|
||||
if err != nil {
|
||||
t.Fatalf("Run: %v", err)
|
||||
}
|
||||
if stats.Scanned != 3 {
|
||||
t.Errorf("Scanned: want 3 (b is skipped but counted as scanned), got %d", stats.Scanned)
|
||||
}
|
||||
if stats.Added != 2 {
|
||||
t.Errorf("Added: want 2 (b excluded from embed), got %d", stats.Added)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_RequiresIndexName(t *testing.T) {
|
||||
_, err := Run(context.Background(), Config{Dimension: 4},
|
||||
&staticSource{rows: nil})
|
||||
if err == nil || !strings.Contains(err.Error(), "IndexName") {
|
||||
t.Errorf("want IndexName-required error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_RequiresDimension(t *testing.T) {
|
||||
_, err := Run(context.Background(), Config{IndexName: "x"},
|
||||
&staticSource{rows: nil})
|
||||
if err == nil || !strings.Contains(err.Error(), "Dimension") {
|
||||
t.Errorf("want Dimension-required error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRun_ContextCancel verifies the pipeline drains cleanly when
|
||||
// ctx is cancelled mid-run. Source returns rows fast enough that
|
||||
// without ctx the run would complete; cancelling early should stop
|
||||
// well before all 1000 rows are processed.
|
||||
func TestRun_ContextCancel(t *testing.T) {
|
||||
fg := newFakeGateway(4)
|
||||
// Slow embed handler: each call sleeps 50ms.
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/v1/vectors/index", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
})
|
||||
mux.HandleFunc("/v1/embed", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Texts []string `json:"texts"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
// Simulate slow-but-valid backend so we test ctx cancel, not
|
||||
// degraded-payload handling (that's covered in production by
|
||||
// the len-mismatch guard in Run's worker).
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = fg
|
||||
vecs := make([][]float32, len(req.Texts))
|
||||
for i := range vecs {
|
||||
vecs[i] = []float32{1, 2, 3, 4}
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"vectors": vecs,
|
||||
"dimension": 4,
|
||||
"model": "x",
|
||||
})
|
||||
})
|
||||
mux.HandleFunc("/v1/vectors/index/", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, `{}`)
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
rows := make([]Row, 1000)
|
||||
for i := range rows {
|
||||
rows[i] = Row{ID: fmt.Sprintf("r-%d", i), Text: "t"}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
stats, err := Run(ctx, Config{
|
||||
GatewayURL: srv.URL, IndexName: "cancel_me", Dimension: 4,
|
||||
EmbedBatch: 1, EmbedWorkers: 1, HTTPClient: srv.Client(),
|
||||
}, &staticSource{rows: rows})
|
||||
// Either an error or a partial stats; the point is "didn't process all 1000."
|
||||
if stats.Scanned >= 1000 {
|
||||
t.Errorf("ctx cancel did not stop early: scanned=%d err=%v", stats.Scanned, err)
|
||||
}
|
||||
}
|
||||
@ -1,13 +1,14 @@
|
||||
// Staffing co-pilot scale test driver.
|
||||
// Staffing co-pilot scale test driver — workers_500k corpus.
|
||||
//
|
||||
// Pipeline: workers_500k.csv → /v1/embed (batched, parallel) →
|
||||
// /v1/vectors/index/workers_500k/add (batched). Then runs a handful
|
||||
// of semantic queries against the populated index and prints the
|
||||
// top hits — the human-readable check that "find workers like X"
|
||||
// actually returns relevant workers.
|
||||
// Pipeline: workers_500k.csv → /v1/embed → /v1/vectors/index/workers_500k/add.
|
||||
// The pipeline itself lives in internal/corpusingest; this driver
|
||||
// provides the CSV → Row mapping and the post-ingest semantic queries
|
||||
// that are the human-readable check ("does forklift OSHA-30 actually
|
||||
// retrieve forklift workers?").
|
||||
//
|
||||
// Designed to be re-run; index gets DELETEd at the start so leftover
|
||||
// state from prior runs doesn't bias recall.
|
||||
// Designed to be re-run safely; index gets DELETEd at the start
|
||||
// when -drop is set so leftover state doesn't bias recall.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
@ -22,62 +23,123 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.agentview.dev/profit/golangLAKEHOUSE/internal/corpusingest"
|
||||
)
|
||||
|
||||
const (
|
||||
indexName = "workers_500k"
|
||||
dim = 768
|
||||
|
||||
embedConcurrency = 8 // matches Ollama-on-A4000 sweet spot
|
||||
embedBatchSize = 16 // texts per /v1/embed call
|
||||
addBatchSize = 1000 // items per /v1/vectors/index/add call
|
||||
|
||||
maxColPhone = 4
|
||||
maxColCity = 5
|
||||
maxColState = 6
|
||||
maxColRole = 2
|
||||
maxColSkills = 8
|
||||
maxColCerts = 9
|
||||
maxColResume = 17
|
||||
colWorkerID = 0
|
||||
colName = 1
|
||||
// Column indexes in workers_500k.csv. Stable contract; if the CSV
|
||||
// schema changes these need updating.
|
||||
colWorkerID = 0
|
||||
colName = 1
|
||||
colRole = 2
|
||||
colCity = 5
|
||||
colState = 6
|
||||
colSkills = 8
|
||||
colCerts = 9
|
||||
colResume = 17
|
||||
)
|
||||
|
||||
// workersCSV implements corpusingest.Source. CSV reader state +
|
||||
// row → Row mapping live here; the embed/add pipeline is generic.
|
||||
type workersCSV struct {
|
||||
cr *csv.Reader
|
||||
}
|
||||
|
||||
func (s *workersCSV) Next() (corpusingest.Row, error) {
|
||||
for {
|
||||
row, err := s.cr.Read()
|
||||
if err != nil {
|
||||
return corpusingest.Row{}, err
|
||||
}
|
||||
if len(row) <= colResume {
|
||||
continue // skip malformed rows; matches prior behavior
|
||||
}
|
||||
id := strings.TrimSpace(row[colWorkerID])
|
||||
return corpusingest.Row{
|
||||
ID: "w-" + id,
|
||||
Text: buildWorkerText(row),
|
||||
Metadata: map[string]any{
|
||||
"name": row[colName],
|
||||
"role": row[colRole],
|
||||
"city": row[colCity],
|
||||
"state": row[colState],
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// buildWorkerText concatenates staffing-relevant columns into the
|
||||
// embed-text. Order: role first (most semantically dense), then
|
||||
// location, skills, certs, prose resume. Embedding models weight
|
||||
// earlier tokens slightly more, so the front matter matters.
|
||||
func buildWorkerText(row []string) string {
|
||||
var b strings.Builder
|
||||
b.WriteString(row[colRole])
|
||||
b.WriteString(" in ")
|
||||
b.WriteString(row[colCity])
|
||||
b.WriteString(", ")
|
||||
b.WriteString(row[colState])
|
||||
b.WriteString(". Skills: ")
|
||||
b.WriteString(row[colSkills])
|
||||
b.WriteString(". Certifications: ")
|
||||
b.WriteString(row[colCerts])
|
||||
b.WriteString(". ")
|
||||
b.WriteString(row[colResume])
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
gateway = flag.String("gateway", "http://127.0.0.1:3110", "gateway base URL")
|
||||
csvPath = flag.String("csv", "/tmp/rs/workers_500k.csv", "path to workers CSV")
|
||||
limit = flag.Int("limit", 0, "limit rows (0 = all)")
|
||||
queries = flag.String("queries", "default", "default | <semicolon-separated query strings>")
|
||||
skipPop = flag.Bool("skip-populate", false, "skip embed+add, only run queries")
|
||||
gateway = flag.String("gateway", "http://127.0.0.1:3110", "gateway base URL")
|
||||
csvPath = flag.String("csv", "/tmp/rs/workers_500k.csv", "path to workers CSV")
|
||||
limit = flag.Int("limit", 0, "limit rows (0 = all)")
|
||||
queries = flag.String("queries", "default", "default | <semicolon-separated query strings>")
|
||||
skipPop = flag.Bool("skip-populate", false, "skip embed+add, only run queries")
|
||||
drop = flag.Bool("drop", true, "DELETE index before populate (default true for clean recall)")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
hc := &http.Client{Timeout: 5 * time.Minute}
|
||||
ctx := context.Background()
|
||||
|
||||
if !*skipPop {
|
||||
// Tear down any prior index so recall is on a fresh build.
|
||||
fmt.Printf("[sc] DELETE %s/v1/vectors/index/%s (idempotent cleanup)\n", *gateway, indexName)
|
||||
_ = httpDelete(hc, *gateway+"/v1/vectors/index/"+indexName)
|
||||
|
||||
// Create the index.
|
||||
body := map[string]any{"name": indexName, "dimension": dim, "distance": "cosine"}
|
||||
if code, msg := httpPostJSON(hc, *gateway+"/v1/vectors/index", body); code != 201 {
|
||||
log.Fatalf("create index: %d %s", code, msg)
|
||||
f, err := os.Open(*csvPath)
|
||||
if err != nil {
|
||||
log.Fatalf("open csv: %v", err)
|
||||
}
|
||||
fmt.Println("[sc] created index workers_500k dim=768 cosine")
|
||||
|
||||
t0 := time.Now()
|
||||
if err := populate(hc, *gateway, *csvPath, *limit); err != nil {
|
||||
log.Fatal(err)
|
||||
defer f.Close()
|
||||
cr := csv.NewReader(f)
|
||||
cr.FieldsPerRecord = -1
|
||||
if _, err := cr.Read(); err != nil { // skip header
|
||||
log.Fatalf("read header: %v", err)
|
||||
}
|
||||
fmt.Printf("[sc] populate complete in %v\n", time.Since(t0))
|
||||
|
||||
stats, err := corpusingest.Run(ctx, corpusingest.Config{
|
||||
GatewayURL: *gateway,
|
||||
IndexName: indexName,
|
||||
Dimension: dim,
|
||||
Distance: "cosine",
|
||||
EmbedBatch: 16, // matches Ollama-on-A4000 sweet spot
|
||||
EmbedWorkers: 8, // matches Ollama-on-A4000 sweet spot
|
||||
AddBatch: 1000, // empirically fine; vectord BatchAdd lock-amortized at f1c1883
|
||||
Limit: *limit,
|
||||
DropExisting: *drop,
|
||||
HTTPClient: hc,
|
||||
LogProgress: 10 * time.Second,
|
||||
}, &workersCSV{cr: cr})
|
||||
if err != nil {
|
||||
log.Fatalf("ingest: %v", err)
|
||||
}
|
||||
fmt.Printf("[sc] populate done: scanned=%d embedded=%d added=%d wall=%v\n",
|
||||
stats.Scanned, stats.Embedded, stats.Added, stats.Wall.Round(time.Millisecond))
|
||||
}
|
||||
|
||||
// Validate semantic queries.
|
||||
// Validate semantic queries against the populated index.
|
||||
qs := defaultQueries()
|
||||
if *queries != "default" {
|
||||
qs = strings.Split(*queries, ";")
|
||||
@ -97,196 +159,35 @@ func defaultQueries() []string {
|
||||
}
|
||||
}
|
||||
|
||||
func populate(hc *http.Client, gateway, csvPath string, limit int) error {
|
||||
f, err := os.Open(csvPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open csv: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
cr := csv.NewReader(f)
|
||||
cr.FieldsPerRecord = -1
|
||||
if _, err := cr.Read(); err != nil { // header
|
||||
return fmt.Errorf("read header: %w", err)
|
||||
}
|
||||
|
||||
type job struct {
|
||||
ids []string
|
||||
texts []string
|
||||
metas []json.RawMessage
|
||||
}
|
||||
|
||||
jobs := make(chan job, embedConcurrency*2)
|
||||
var wg sync.WaitGroup
|
||||
var (
|
||||
totalEmbedded int64
|
||||
totalAdded int64
|
||||
)
|
||||
|
||||
for i := 0; i < embedConcurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := range jobs {
|
||||
vecs, err := embedBatch(hc, gateway, j.texts)
|
||||
if err != nil {
|
||||
log.Printf("embed batch (%d items): %v", len(j.texts), err)
|
||||
continue
|
||||
}
|
||||
atomic.AddInt64(&totalEmbedded, int64(len(vecs)))
|
||||
if err := addBatch(hc, gateway, j.ids, vecs, j.metas); err != nil {
|
||||
log.Printf("add batch (%d items): %v", len(j.ids), err)
|
||||
continue
|
||||
}
|
||||
atomic.AddInt64(&totalAdded, int64(len(j.ids)))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
progressTicker := time.NewTicker(10 * time.Second)
|
||||
go func() {
|
||||
for range progressTicker.C {
|
||||
fmt.Printf("[sc] progress: embedded=%d added=%d\n",
|
||||
atomic.LoadInt64(&totalEmbedded), atomic.LoadInt64(&totalAdded))
|
||||
}
|
||||
}()
|
||||
defer progressTicker.Stop()
|
||||
|
||||
curIDs := make([]string, 0, embedBatchSize)
|
||||
curTexts := make([]string, 0, embedBatchSize)
|
||||
curMetas := make([]json.RawMessage, 0, embedBatchSize)
|
||||
rows := 0
|
||||
for {
|
||||
row, err := cr.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("csv read row %d: %w", rows, err)
|
||||
}
|
||||
if len(row) <= maxColResume {
|
||||
continue
|
||||
}
|
||||
id := strings.TrimSpace(row[colWorkerID])
|
||||
text := buildSearchText(row)
|
||||
meta, _ := json.Marshal(map[string]any{
|
||||
"name": row[colName],
|
||||
"role": row[maxColRole],
|
||||
"city": row[maxColCity],
|
||||
"state": row[maxColState],
|
||||
})
|
||||
curIDs = append(curIDs, "w-"+id)
|
||||
curTexts = append(curTexts, text)
|
||||
curMetas = append(curMetas, meta)
|
||||
|
||||
if len(curIDs) >= embedBatchSize {
|
||||
jobs <- job{ids: curIDs, texts: curTexts, metas: curMetas}
|
||||
curIDs = make([]string, 0, embedBatchSize)
|
||||
curTexts = make([]string, 0, embedBatchSize)
|
||||
curMetas = make([]json.RawMessage, 0, embedBatchSize)
|
||||
}
|
||||
rows++
|
||||
if limit > 0 && rows >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(curIDs) > 0 {
|
||||
jobs <- job{ids: curIDs, texts: curTexts, metas: curMetas}
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
|
||||
fmt.Printf("[sc] final: scanned=%d embedded=%d added=%d\n",
|
||||
rows, atomic.LoadInt64(&totalEmbedded), atomic.LoadInt64(&totalAdded))
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildSearchText concatenates the staffing-relevant columns into
|
||||
// the text that gets embedded. Order: role first (most semantically
|
||||
// dense), then skills + certs, city/state, finally the prose
|
||||
// resume_text. Embedding models weight earlier tokens slightly more.
|
||||
func buildSearchText(row []string) string {
|
||||
var b strings.Builder
|
||||
b.WriteString(row[maxColRole])
|
||||
b.WriteString(" in ")
|
||||
b.WriteString(row[maxColCity])
|
||||
b.WriteString(", ")
|
||||
b.WriteString(row[maxColState])
|
||||
b.WriteString(". Skills: ")
|
||||
b.WriteString(row[maxColSkills])
|
||||
b.WriteString(". Certifications: ")
|
||||
b.WriteString(row[maxColCerts])
|
||||
b.WriteString(". ")
|
||||
b.WriteString(row[maxColResume])
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func embedBatch(hc *http.Client, gateway string, texts []string) ([][]float32, error) {
|
||||
body := map[string]any{"texts": texts}
|
||||
bs, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest(http.MethodPost, gateway+"/v1/embed", bytes.NewReader(bs))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
preview, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
return nil, fmt.Errorf("embed status %d: %s", resp.StatusCode, string(preview))
|
||||
}
|
||||
var er struct {
|
||||
Vectors [][]float32 `json:"vectors"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&er); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return er.Vectors, nil
|
||||
}
|
||||
|
||||
type addItem struct {
|
||||
ID string `json:"id"`
|
||||
Vector []float32 `json:"vector"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
}
|
||||
|
||||
func addBatch(hc *http.Client, gateway string, ids []string, vecs [][]float32, metas []json.RawMessage) error {
|
||||
items := make([]addItem, len(ids))
|
||||
for i := range ids {
|
||||
items[i] = addItem{ID: ids[i], Vector: vecs[i], Metadata: metas[i]}
|
||||
}
|
||||
bs, _ := json.Marshal(map[string]any{"items": items})
|
||||
req, _ := http.NewRequest(http.MethodPost,
|
||||
gateway+"/v1/vectors/index/"+indexName+"/add", bytes.NewReader(bs))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
preview, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
return fmt.Errorf("add status %d: %s", resp.StatusCode, string(preview))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runQuery embeds a query, searches the index, prints top hits.
|
||||
// Stays in this driver (not corpusingest) — query validation is
|
||||
// per-corpus concern, not part of the ingest pipeline.
|
||||
func runQuery(hc *http.Client, gateway, q string) {
|
||||
t0 := time.Now()
|
||||
// 1. Embed the query.
|
||||
vecs, err := embedBatch(hc, gateway, []string{q})
|
||||
if err != nil || len(vecs) == 0 {
|
||||
body, _ := json.Marshal(map[string]any{"texts": []string{q}})
|
||||
req, _ := http.NewRequest(http.MethodPost, gateway+"/v1/embed", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("[sc] query %q: embed err: %v\n", q, err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var er struct {
|
||||
Vectors [][]float32 `json:"vectors"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&er); err != nil || len(er.Vectors) == 0 {
|
||||
fmt.Printf("[sc] query %q: embed decode err: %v\n", q, err)
|
||||
return
|
||||
}
|
||||
embedDur := time.Since(t0)
|
||||
|
||||
t1 := time.Now()
|
||||
// 2. Search.
|
||||
body := map[string]any{"vector": vecs[0], "k": 5}
|
||||
bs, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest(http.MethodPost,
|
||||
gateway+"/v1/vectors/index/"+indexName+"/search", bytes.NewReader(bs))
|
||||
body, _ = json.Marshal(map[string]any{"vector": er.Vectors[0], "k": 5})
|
||||
req, _ = http.NewRequest(http.MethodPost,
|
||||
gateway+"/v1/vectors/index/"+indexName+"/search", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := hc.Do(req)
|
||||
resp, err = hc.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("[sc] query %q: search err: %v\n", q, err)
|
||||
return
|
||||
@ -310,29 +211,7 @@ func runQuery(hc *http.Client, gateway, q string) {
|
||||
}
|
||||
}
|
||||
|
||||
func httpPostJSON(hc *http.Client, url string, body any) (int, string) {
|
||||
bs, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewReader(bs))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return 0, err.Error()
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
preview, _ := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
return resp.StatusCode, string(preview)
|
||||
}
|
||||
|
||||
func httpDelete(hc *http.Client, url string) error {
|
||||
req, _ := http.NewRequest(http.MethodDelete, url, nil)
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
return nil
|
||||
}
|
||||
|
||||
// keep context.Background reachable in case future paths use it
|
||||
var _ = context.Background
|
||||
// io.EOF imported transitively via corpusingest; keep the explicit
|
||||
// reference so a hypothetical future "EOF means done" check in this
|
||||
// driver's Source impl doesn't need a fresh import line.
|
||||
var _ = io.EOF
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user