Compare commits
2 Commits
fb08232f58
...
8f4c16fab1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f4c16fab1 | ||
|
|
56844c3f31 |
22
README.md
22
README.md
@ -38,6 +38,28 @@ real-scale validation + G1/G1P/G2 pointer at the bottom).
|
||||
| `queryd` | 3214 | DuckDB SELECT over registered Parquets via httpfs |
|
||||
| `vectord` | 3215 | HNSW vector search (+ optional persistence to storaged) |
|
||||
| `embedd` | 3216 | Text → vector via Ollama (default `nomic-embed-text` 768-d) |
|
||||
| `mcpd` | stdio | Model Context Protocol server (Claude Desktop / Code consumers) |
|
||||
|
||||
## MCP server
|
||||
|
||||
`bin/mcpd` exposes Lakehouse capabilities as MCP tools over stdio:
|
||||
`list_datasets`, `get_manifest`, `query_sql`, `embed_text`, `search_vectors`.
|
||||
All tools proxy to the gateway, so the gateway must be up first.
|
||||
|
||||
Wire into Claude Desktop / Claude Code by adding to the MCP config:
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"lakehouse": {
|
||||
"command": "/path/to/golangLAKEHOUSE/bin/mcpd",
|
||||
"args": ["--gateway", "http://127.0.0.1:3110"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Replaces the Bun `mcp-server.ts` MCP-tool surface from the Rust system.
|
||||
HTTP demo routes (the staffing co-pilot UI) stay Bun until G5.
|
||||
|
||||
## Acceptance smokes
|
||||
|
||||
|
||||
@ -41,9 +41,20 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
h := &handlers{
|
||||
provider: embed.NewOllama(cfg.Embedd.ProviderURL, cfg.Embedd.DefaultModel),
|
||||
// Wrap the upstream provider in an LRU cache so repeat queries
|
||||
// (the staffing co-pilot replays many of the same texts) bypass
|
||||
// the ~50ms Ollama round-trip. Cache size 0 = pass-through.
|
||||
base := embed.NewOllama(cfg.Embedd.ProviderURL, cfg.Embedd.DefaultModel)
|
||||
cached, err := embed.NewCachedProvider(base, cfg.Embedd.DefaultModel, cfg.Embedd.CacheSize)
|
||||
if err != nil {
|
||||
slog.Error("embed cache", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
slog.Info("embed cache",
|
||||
"size", cfg.Embedd.CacheSize,
|
||||
"default_model", cfg.Embedd.DefaultModel,
|
||||
"enabled", cfg.Embedd.CacheSize > 0)
|
||||
h := &handlers{provider: cached, cache: cached}
|
||||
|
||||
if err := shared.Run("embedd", cfg.Embedd.Bind, h.register); err != nil {
|
||||
slog.Error("server", "err", err)
|
||||
@ -53,10 +64,36 @@ func main() {
|
||||
|
||||
type handlers struct {
|
||||
provider embed.Provider
|
||||
// cache is the same instance as provider when caching is enabled,
|
||||
// kept as a typed pointer so /v1/embed/stats can expose hit-rate
|
||||
// without type-asserting through the Provider interface. nil when
|
||||
// CacheSize=0 (pass-through mode).
|
||||
cache *embed.CachedProvider
|
||||
}
|
||||
|
||||
func (h *handlers) register(r chi.Router) {
|
||||
r.Post("/embed", h.handleEmbed)
|
||||
r.Get("/embed/stats", h.handleStats)
|
||||
}
|
||||
|
||||
// handleStats reports cache hits/misses + hit rate + size. Operators
|
||||
// use this to confirm the cache is doing its job (high hit rate) or
|
||||
// is sized wrong (low hit rate + many misses on a workload that
|
||||
// should have repeats).
|
||||
func (h *handlers) handleStats(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if h.cache == nil {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"enabled": false})
|
||||
return
|
||||
}
|
||||
hits, misses := h.cache.Stats()
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"enabled": true,
|
||||
"hits": hits,
|
||||
"misses": misses,
|
||||
"hit_rate": h.cache.HitRate(),
|
||||
"size": h.cache.Len(),
|
||||
})
|
||||
}
|
||||
|
||||
// embedRequest is the POST /embed body. Texts is the list to
|
||||
|
||||
@ -143,6 +143,71 @@ func TestHandleEmbed_HappyPath_ProviderEcho(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatsEndpoint_CacheDisabled(t *testing.T) {
|
||||
r := mountWithProvider(&stubProvider{})
|
||||
srv := httptest.NewServer(r)
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := http.Get(srv.URL + "/embed/stats")
|
||||
if err != nil {
|
||||
t.Fatalf("GET: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", resp.StatusCode)
|
||||
}
|
||||
if !strings.Contains(getBody(t, resp), `"enabled":false`) {
|
||||
t.Errorf("expected enabled:false when no cache wired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatsEndpoint_CacheEnabled_TracksHitsAndMisses(t *testing.T) {
|
||||
stub := &stubProvider{result: embed.Result{
|
||||
Model: "test-model",
|
||||
Dimension: 3,
|
||||
Vectors: [][]float32{{0.1, 0.2, 0.3}},
|
||||
}}
|
||||
cache, err := embed.NewCachedProvider(stub, "test-model", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCachedProvider: %v", err)
|
||||
}
|
||||
h := &handlers{provider: cache, cache: cache}
|
||||
r := chi.NewRouter()
|
||||
h.register(r)
|
||||
srv := httptest.NewServer(r)
|
||||
defer srv.Close()
|
||||
|
||||
// First call → miss.
|
||||
http.Post(srv.URL+"/embed", "application/json",
|
||||
strings.NewReader(`{"texts":["hello"],"model":"test-model"}`))
|
||||
// Second call same text → hit.
|
||||
http.Post(srv.URL+"/embed", "application/json",
|
||||
strings.NewReader(`{"texts":["hello"],"model":"test-model"}`))
|
||||
|
||||
resp, err := http.Get(srv.URL + "/embed/stats")
|
||||
if err != nil {
|
||||
t.Fatalf("stats GET: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body := getBody(t, resp)
|
||||
if !strings.Contains(body, `"enabled":true`) {
|
||||
t.Errorf("expected enabled:true, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, `"hits":1`) {
|
||||
t.Errorf("expected hits:1, body=%s", body)
|
||||
}
|
||||
if !strings.Contains(body, `"misses":1`) {
|
||||
t.Errorf("expected misses:1, body=%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func getBody(t *testing.T, resp *http.Response) string {
|
||||
t.Helper()
|
||||
buf := make([]byte, 4096)
|
||||
n, _ := resp.Body.Read(buf)
|
||||
return string(buf[:n])
|
||||
}
|
||||
|
||||
func TestItoa(t *testing.T) {
|
||||
cases := []struct {
|
||||
in int
|
||||
|
||||
268
cmd/mcpd/main.go
Normal file
268
cmd/mcpd/main.go
Normal file
@ -0,0 +1,268 @@
|
||||
// mcpd is the Model Context Protocol server that exposes Lakehouse
|
||||
// capabilities as MCP tools. Replaces the Bun mcp-server.ts surface
|
||||
// (the MCP-tool-only subset; HTTP demo routes stay Bun until G5).
|
||||
//
|
||||
// Tools exposed:
|
||||
// list_datasets GET /v1/catalog/list
|
||||
// get_manifest GET /v1/catalog/manifest/<name>
|
||||
// query_sql POST /v1/sql
|
||||
// embed_text POST /v1/embed
|
||||
// search_vectors POST /v1/vectors/index/<name>/search
|
||||
//
|
||||
// Transport: StdioTransport (the universal MCP transport — Claude
|
||||
// Desktop, Claude Code, MCP CLI all speak this). Other transports
|
||||
// (SSE, HTTP) can be added later by changing the Run call.
|
||||
//
|
||||
// Setup for Claude Desktop / Claude Code:
|
||||
// bin/mcpd --gateway http://127.0.0.1:3110
|
||||
// (configure your client to spawn this binary as an MCP server)
|
||||
//
|
||||
// Why not in cmd/gateway: separation of concerns. Gateway is HTTP
|
||||
// for direct-API callers; mcpd is stdio for MCP consumers. Keeping
|
||||
// them separate means each can be deployed / restarted / monitored
|
||||
// without affecting the other.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
func main() {
|
||||
gatewayURL := flag.String("gateway", "http://127.0.0.1:3110",
|
||||
"Gateway URL (where mcpd proxies all tool calls)")
|
||||
flag.Parse()
|
||||
|
||||
srv := buildServer(*gatewayURL)
|
||||
|
||||
if err := srv.Run(context.Background(), &mcp.StdioTransport{}); err != nil {
|
||||
log.Fatalf("mcpd: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// buildServer assembles the MCP server with all tools wired against
|
||||
// the given gateway URL. Extracted from main() so tests can build
|
||||
// the server with a test gateway URL.
|
||||
func buildServer(gatewayURL string) *mcp.Server {
|
||||
srv := mcp.NewServer(&mcp.Implementation{
|
||||
Name: "lakehouse",
|
||||
Version: "v0.1.0",
|
||||
}, nil)
|
||||
|
||||
gw := newGatewayClient(gatewayURL)
|
||||
|
||||
mcp.AddTool(srv, &mcp.Tool{
|
||||
Name: "list_datasets",
|
||||
Description: "List all datasets registered in the catalog. " +
|
||||
"Returns dataset_id, name, schema_fingerprint, row_count " +
|
||||
"per dataset.",
|
||||
}, gw.listDatasets)
|
||||
|
||||
mcp.AddTool(srv, &mcp.Tool{
|
||||
Name: "get_manifest",
|
||||
Description: "Fetch the manifest for a single dataset by name. " +
|
||||
"Includes schema fingerprint, parquet object keys, row count, " +
|
||||
"created_at_unix_ns.",
|
||||
}, gw.getManifest)
|
||||
|
||||
mcp.AddTool(srv, &mcp.Tool{
|
||||
Name: "query_sql",
|
||||
Description: "Execute a SQL query against the registered datasets. " +
|
||||
"Returns columns + rows. SQL is interpreted by DuckDB; standard " +
|
||||
"SQL plus DuckDB-specific functions (read_parquet, etc.) work.",
|
||||
}, gw.querySQL)
|
||||
|
||||
mcp.AddTool(srv, &mcp.Tool{
|
||||
Name: "embed_text",
|
||||
Description: "Embed one or more texts via the configured embedding " +
|
||||
"model (default: nomic-embed-text). Returns one vector per text " +
|
||||
"in the same order as the input.",
|
||||
}, gw.embedText)
|
||||
|
||||
mcp.AddTool(srv, &mcp.Tool{
|
||||
Name: "search_vectors",
|
||||
Description: "Find the top-K nearest neighbors of a query vector " +
|
||||
"in the named HNSW index. K defaults to 10 if omitted.",
|
||||
}, gw.searchVectors)
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
// gatewayClient holds the HTTP client + base URL for proxying tool
|
||||
// calls to the Go gateway. Per-tool latency is on the order of a
|
||||
// gateway round-trip; the 30s timeout accommodates the slowest
|
||||
// expected SQL query without holding stdio sessions indefinitely.
|
||||
type gatewayClient struct {
|
||||
baseURL string
|
||||
hc *http.Client
|
||||
}
|
||||
|
||||
func newGatewayClient(baseURL string) *gatewayClient {
|
||||
return &gatewayClient{
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
hc: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// ── tool argument structs (jsonschema tags drive schema generation) ──
|
||||
|
||||
type listDatasetsArgs struct{}
|
||||
|
||||
type getManifestArgs struct {
|
||||
Name string `json:"name" jsonschema:"the dataset name to fetch"`
|
||||
}
|
||||
|
||||
type querySQLArgs struct {
|
||||
SQL string `json:"sql" jsonschema:"the SQL query to execute"`
|
||||
}
|
||||
|
||||
type embedTextArgs struct {
|
||||
Texts []string `json:"texts" jsonschema:"the texts to embed"`
|
||||
Model string `json:"model,omitempty" jsonschema:"optional model name (defaults to embedd's configured default)"`
|
||||
}
|
||||
|
||||
type searchVectorsArgs struct {
|
||||
IndexName string `json:"index_name" jsonschema:"the index to search"`
|
||||
Vector []float32 `json:"vector" jsonschema:"the query vector"`
|
||||
K int `json:"k,omitempty" jsonschema:"top-K to return (default 10)"`
|
||||
}
|
||||
|
||||
// ── tool handlers ──
|
||||
|
||||
func (g *gatewayClient) listDatasets(ctx context.Context, _ *mcp.CallToolRequest, _ listDatasetsArgs) (*mcp.CallToolResult, any, error) {
|
||||
body, err := g.proxy(ctx, http.MethodGet, "/v1/catalog/list", nil)
|
||||
if err != nil {
|
||||
return errorResult(err), nil, nil
|
||||
}
|
||||
return jsonResult(body), nil, nil
|
||||
}
|
||||
|
||||
func (g *gatewayClient) getManifest(ctx context.Context, _ *mcp.CallToolRequest, args getManifestArgs) (*mcp.CallToolResult, any, error) {
|
||||
if args.Name == "" {
|
||||
return errorResult(fmt.Errorf("name is required")), nil, nil
|
||||
}
|
||||
path := "/v1/catalog/manifest/" + url.PathEscape(args.Name)
|
||||
body, err := g.proxy(ctx, http.MethodGet, path, nil)
|
||||
if err != nil {
|
||||
return errorResult(err), nil, nil
|
||||
}
|
||||
return jsonResult(body), nil, nil
|
||||
}
|
||||
|
||||
func (g *gatewayClient) querySQL(ctx context.Context, _ *mcp.CallToolRequest, args querySQLArgs) (*mcp.CallToolResult, any, error) {
|
||||
if strings.TrimSpace(args.SQL) == "" {
|
||||
return errorResult(fmt.Errorf("sql is required")), nil, nil
|
||||
}
|
||||
reqBody, _ := json.Marshal(map[string]string{"sql": args.SQL})
|
||||
body, err := g.proxy(ctx, http.MethodPost, "/v1/sql", reqBody)
|
||||
if err != nil {
|
||||
return errorResult(err), nil, nil
|
||||
}
|
||||
return jsonResult(body), nil, nil
|
||||
}
|
||||
|
||||
func (g *gatewayClient) embedText(ctx context.Context, _ *mcp.CallToolRequest, args embedTextArgs) (*mcp.CallToolResult, any, error) {
|
||||
if len(args.Texts) == 0 {
|
||||
return errorResult(fmt.Errorf("texts is required")), nil, nil
|
||||
}
|
||||
payload := map[string]any{"texts": args.Texts}
|
||||
if args.Model != "" {
|
||||
payload["model"] = args.Model
|
||||
}
|
||||
reqBody, _ := json.Marshal(payload)
|
||||
body, err := g.proxy(ctx, http.MethodPost, "/v1/embed", reqBody)
|
||||
if err != nil {
|
||||
return errorResult(err), nil, nil
|
||||
}
|
||||
return jsonResult(body), nil, nil
|
||||
}
|
||||
|
||||
func (g *gatewayClient) searchVectors(ctx context.Context, _ *mcp.CallToolRequest, args searchVectorsArgs) (*mcp.CallToolResult, any, error) {
|
||||
if args.IndexName == "" {
|
||||
return errorResult(fmt.Errorf("index_name is required")), nil, nil
|
||||
}
|
||||
if len(args.Vector) == 0 {
|
||||
return errorResult(fmt.Errorf("vector is required")), nil, nil
|
||||
}
|
||||
payload := map[string]any{"vector": args.Vector}
|
||||
if args.K > 0 {
|
||||
payload["k"] = args.K
|
||||
}
|
||||
reqBody, _ := json.Marshal(payload)
|
||||
path := "/v1/vectors/index/" + url.PathEscape(args.IndexName) + "/search"
|
||||
body, err := g.proxy(ctx, http.MethodPost, path, reqBody)
|
||||
if err != nil {
|
||||
return errorResult(err), nil, nil
|
||||
}
|
||||
return jsonResult(body), nil, nil
|
||||
}
|
||||
|
||||
// proxy makes a request to the gateway and returns the response body
|
||||
// on success. Non-2xx status codes return an error with the body
|
||||
// preview in the message — surfaced to the MCP client as a tool error
|
||||
// rather than a transport-level failure.
|
||||
func (g *gatewayClient) proxy(ctx context.Context, method, path string, body []byte) ([]byte, error) {
|
||||
var rdr io.Reader
|
||||
if body != nil {
|
||||
rdr = bytes.NewReader(body)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, g.baseURL+path, rdr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build request: %w", err)
|
||||
}
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
resp, err := g.hc.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("call gateway: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 16<<20)) // 16 MiB tool-response cap
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
preview := respBody
|
||||
if len(preview) > 512 {
|
||||
preview = preview[:512]
|
||||
}
|
||||
return nil, fmt.Errorf("gateway %s %s: status %d: %s",
|
||||
method, path, resp.StatusCode, string(preview))
|
||||
}
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
// errorResult wraps an error as an MCP tool error result. The MCP
|
||||
// protocol distinguishes "tool ran but reported failure" (returned
|
||||
// in CallToolResult.IsError + content) from "tool threw" (returned
|
||||
// as the third return value). We use the former so the LLM caller
|
||||
// sees the error text and can decide how to react, rather than
|
||||
// surfacing the error as transport noise.
|
||||
func errorResult(err error) *mcp.CallToolResult {
|
||||
return &mcp.CallToolResult{
|
||||
IsError: true,
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: err.Error()},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// jsonResult wraps a JSON byte slice as a successful tool result.
|
||||
// The content is text — MCP clients render it; LLMs parse it as
|
||||
// JSON when their tool config indicates so.
|
||||
func jsonResult(body []byte) *mcp.CallToolResult {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: string(body)},
|
||||
},
|
||||
}
|
||||
}
|
||||
282
cmd/mcpd/main_test.go
Normal file
282
cmd/mcpd/main_test.go
Normal file
@ -0,0 +1,282 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// Tests the MCP tool surface end-to-end without a subprocess: spin
|
||||
// up a fake gateway via httptest, build the MCP server pointed at
|
||||
// it, connect a client via in-memory transports, call each tool.
|
||||
|
||||
// fakeGateway returns an httptest.Server that responds to the routes
|
||||
// mcpd proxies. Each route's handler is configurable via the routes map.
|
||||
func fakeGateway(t *testing.T, routes map[string]http.HandlerFunc) *httptest.Server {
|
||||
t.Helper()
|
||||
mux := http.NewServeMux()
|
||||
for path, h := range routes {
|
||||
mux.HandleFunc(path, h)
|
||||
}
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Errorf("fakeGateway: unexpected route %s %s", r.Method, r.URL.Path)
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
t.Cleanup(srv.Close)
|
||||
return srv
|
||||
}
|
||||
|
||||
// connect builds the mcpd server pointed at gatewayURL, connects an
|
||||
// in-memory client, and returns a ready-to-use ClientSession.
|
||||
func connect(t *testing.T, gatewayURL string) *mcp.ClientSession {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
srv := buildServer(gatewayURL)
|
||||
client := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "v0.0.1"}, nil)
|
||||
|
||||
t1, t2 := mcp.NewInMemoryTransports()
|
||||
if _, err := srv.Connect(ctx, t1, nil); err != nil {
|
||||
t.Fatalf("server connect: %v", err)
|
||||
}
|
||||
session, err := client.Connect(ctx, t2, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("client connect: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = session.Close() })
|
||||
return session
|
||||
}
|
||||
|
||||
func callTool(t *testing.T, session *mcp.ClientSession, name string, args any) *mcp.CallToolResult {
|
||||
t.Helper()
|
||||
res, err := session.CallTool(context.Background(), &mcp.CallToolParams{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CallTool %s: %v", name, err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func resultText(t *testing.T, res *mcp.CallToolResult) string {
|
||||
t.Helper()
|
||||
if len(res.Content) == 0 {
|
||||
t.Fatal("result has no content")
|
||||
}
|
||||
tc, ok := res.Content[0].(*mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatalf("first content is %T, want *mcp.TextContent", res.Content[0])
|
||||
}
|
||||
return tc.Text
|
||||
}
|
||||
|
||||
func TestListTools(t *testing.T) {
|
||||
gw := fakeGateway(t, nil)
|
||||
session := connect(t, gw.URL)
|
||||
|
||||
res, err := session.ListTools(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ListTools: %v", err)
|
||||
}
|
||||
want := map[string]bool{
|
||||
"list_datasets": false,
|
||||
"get_manifest": false,
|
||||
"query_sql": false,
|
||||
"embed_text": false,
|
||||
"search_vectors": false,
|
||||
}
|
||||
for _, tool := range res.Tools {
|
||||
if _, ok := want[tool.Name]; ok {
|
||||
want[tool.Name] = true
|
||||
}
|
||||
}
|
||||
for name, found := range want {
|
||||
if !found {
|
||||
t.Errorf("expected tool %q exposed by mcpd, not in ListTools", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDatasets_HappyPath(t *testing.T) {
|
||||
gw := fakeGateway(t, map[string]http.HandlerFunc{
|
||||
"/v1/catalog/list": func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"manifests":[{"name":"workers","row_count":500000}],"count":1}`))
|
||||
},
|
||||
})
|
||||
session := connect(t, gw.URL)
|
||||
|
||||
res := callTool(t, session, "list_datasets", listDatasetsArgs{})
|
||||
if res.IsError {
|
||||
t.Fatalf("expected success, got IsError with: %s", resultText(t, res))
|
||||
}
|
||||
body := resultText(t, res)
|
||||
if !strings.Contains(body, "workers") || !strings.Contains(body, "500000") {
|
||||
t.Errorf("response missing expected fields: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetManifest_HappyPath(t *testing.T) {
|
||||
gw := fakeGateway(t, map[string]http.HandlerFunc{
|
||||
"/v1/catalog/manifest/workers": func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"name":"workers","row_count":500000,"schema_fingerprint":"sha256:abc"}`))
|
||||
},
|
||||
})
|
||||
session := connect(t, gw.URL)
|
||||
|
||||
res := callTool(t, session, "get_manifest", getManifestArgs{Name: "workers"})
|
||||
if res.IsError {
|
||||
t.Fatalf("expected success, got: %s", resultText(t, res))
|
||||
}
|
||||
body := resultText(t, res)
|
||||
if !strings.Contains(body, "workers") {
|
||||
t.Errorf("missing manifest fields in response: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetManifest_EmptyName_IsError(t *testing.T) {
|
||||
gw := fakeGateway(t, nil) // no handlers — tool should error before hitting gateway
|
||||
session := connect(t, gw.URL)
|
||||
|
||||
res := callTool(t, session, "get_manifest", getManifestArgs{Name: ""})
|
||||
if !res.IsError {
|
||||
t.Fatal("expected IsError on empty name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuerySQL_HappyPath(t *testing.T) {
|
||||
gw := fakeGateway(t, map[string]http.HandlerFunc{
|
||||
"/v1/sql": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("/v1/sql got %s, want POST", r.Method)
|
||||
}
|
||||
_, _ = w.Write([]byte(`{"columns":[{"name":"n","type":"BIGINT"}],"rows":[[5]],"row_count":1}`))
|
||||
},
|
||||
})
|
||||
session := connect(t, gw.URL)
|
||||
|
||||
res := callTool(t, session, "query_sql", querySQLArgs{SQL: "SELECT count(*) FROM workers"})
|
||||
if res.IsError {
|
||||
t.Fatalf("expected success, got: %s", resultText(t, res))
|
||||
}
|
||||
if !strings.Contains(resultText(t, res), `"row_count":1`) {
|
||||
t.Errorf("response missing row_count field: %s", resultText(t, res))
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuerySQL_EmptySQL_IsError(t *testing.T) {
|
||||
gw := fakeGateway(t, nil)
|
||||
session := connect(t, gw.URL)
|
||||
|
||||
res := callTool(t, session, "query_sql", querySQLArgs{SQL: " "})
|
||||
if !res.IsError {
|
||||
t.Fatal("expected IsError on whitespace SQL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbedText_HappyPath(t *testing.T) {
|
||||
gw := fakeGateway(t, map[string]http.HandlerFunc{
|
||||
"/v1/embed": func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"model":"nomic-embed-text","dimension":768,"vectors":[[0.1,0.2]]}`))
|
||||
},
|
||||
})
|
||||
session := connect(t, gw.URL)
|
||||
|
||||
res := callTool(t, session, "embed_text", embedTextArgs{Texts: []string{"hello"}})
|
||||
if res.IsError {
|
||||
t.Fatalf("expected success, got: %s", resultText(t, res))
|
||||
}
|
||||
if !strings.Contains(resultText(t, res), `"dimension":768`) {
|
||||
t.Errorf("missing dimension in response: %s", resultText(t, res))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbedText_EmptyTexts_IsError(t *testing.T) {
|
||||
gw := fakeGateway(t, nil)
|
||||
session := connect(t, gw.URL)
|
||||
|
||||
res := callTool(t, session, "embed_text", embedTextArgs{Texts: nil})
|
||||
if !res.IsError {
|
||||
t.Fatal("expected IsError on empty texts")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchVectors_HappyPath(t *testing.T) {
|
||||
gw := fakeGateway(t, map[string]http.HandlerFunc{
|
||||
"/v1/vectors/index/test_idx/search": func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"results":[{"id":"v1","distance":0.001}]}`))
|
||||
},
|
||||
})
|
||||
session := connect(t, gw.URL)
|
||||
|
||||
res := callTool(t, session, "search_vectors", searchVectorsArgs{
|
||||
IndexName: "test_idx",
|
||||
Vector: []float32{1, 0, 0, 0},
|
||||
K: 5,
|
||||
})
|
||||
if res.IsError {
|
||||
t.Fatalf("expected success, got: %s", resultText(t, res))
|
||||
}
|
||||
if !strings.Contains(resultText(t, res), `"id":"v1"`) {
|
||||
t.Errorf("missing top-1 in response: %s", resultText(t, res))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchVectors_MissingIndex_IsError(t *testing.T) {
|
||||
gw := fakeGateway(t, nil)
|
||||
session := connect(t, gw.URL)
|
||||
|
||||
res := callTool(t, session, "search_vectors", searchVectorsArgs{
|
||||
Vector: []float32{1, 0, 0, 0},
|
||||
})
|
||||
if !res.IsError {
|
||||
t.Fatal("expected IsError on missing index_name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchVectors_MissingVector_IsError(t *testing.T) {
|
||||
gw := fakeGateway(t, nil)
|
||||
session := connect(t, gw.URL)
|
||||
|
||||
res := callTool(t, session, "search_vectors", searchVectorsArgs{
|
||||
IndexName: "test_idx",
|
||||
})
|
||||
if !res.IsError {
|
||||
t.Fatal("expected IsError on missing vector")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayError_PropagatesAsIsError(t *testing.T) {
|
||||
gw := fakeGateway(t, map[string]http.HandlerFunc{
|
||||
"/v1/sql": func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("syntax error: no such table"))
|
||||
},
|
||||
})
|
||||
session := connect(t, gw.URL)
|
||||
|
||||
res := callTool(t, session, "query_sql", querySQLArgs{SQL: "SELECT * FROM nope"})
|
||||
if !res.IsError {
|
||||
t.Fatal("expected IsError on gateway 4xx")
|
||||
}
|
||||
body := resultText(t, res)
|
||||
if !strings.Contains(body, "400") {
|
||||
t.Errorf("error result should mention status 400, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayUnreachable_PropagatesAsIsError(t *testing.T) {
|
||||
// Point at a non-listening port — connect will fail.
|
||||
session := connect(t, "http://127.0.0.1:1") // reserved port
|
||||
|
||||
res := callTool(t, session, "list_datasets", listDatasetsArgs{})
|
||||
if !res.IsError {
|
||||
t.Fatal("expected IsError on unreachable gateway")
|
||||
}
|
||||
}
|
||||
11
go.mod
11
go.mod
@ -10,8 +10,12 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.22.16
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.100.0
|
||||
github.com/aws/smithy-go v1.25.0
|
||||
github.com/coder/hnsw v0.6.1
|
||||
github.com/duckdb/duckdb-go/v2 v2.10502.0
|
||||
github.com/go-chi/chi/v5 v5.2.5
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||
github.com/modelcontextprotocol/go-sdk v1.5.0
|
||||
github.com/pelletier/go-toml/v2 v2.3.0
|
||||
)
|
||||
|
||||
@ -33,26 +37,29 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/chewxy/math32 v1.10.1 // indirect
|
||||
github.com/coder/hnsw v0.6.1 // indirect
|
||||
github.com/duckdb/duckdb-go-bindings v0.10502.0 // indirect
|
||||
github.com/duckdb/duckdb-go-bindings/lib/darwin-amd64 v0.10502.0 // indirect
|
||||
github.com/duckdb/duckdb-go-bindings/lib/darwin-arm64 v0.10502.0 // indirect
|
||||
github.com/duckdb/duckdb-go-bindings/lib/linux-amd64 v0.10502.0 // indirect
|
||||
github.com/duckdb/duckdb-go-bindings/lib/linux-arm64 v0.10502.0 // indirect
|
||||
github.com/duckdb/duckdb-go-bindings/lib/windows-amd64 v0.10502.0 // indirect
|
||||
github.com/duckdb/duckdb-go/v2 v2.10502.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/goccy/go-json v0.10.6 // indirect
|
||||
github.com/google/flatbuffers v25.12.19+incompatible // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/google/renameio v1.0.1 // indirect
|
||||
github.com/klauspost/compress v1.18.5 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.26 // indirect
|
||||
github.com/segmentio/asm v1.1.3 // indirect
|
||||
github.com/segmentio/encoding v0.5.4 // indirect
|
||||
github.com/viterin/partial v1.1.0 // indirect
|
||||
github.com/viterin/vek v0.4.2 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/zeebo/xxh3 v1.1.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/oauth2 v0.35.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
|
||||
18
go.sum
18
go.sum
@ -74,26 +74,38 @@ github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPE
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/goccy/go-json v0.10.6 h1:p8HrPJzOakx/mn/bQtjgNjdTcN+/S6FcG2CTtQOrHVU=
|
||||
github.com/goccy/go-json v0.10.6/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/flatbuffers v25.12.19+incompatible h1:haMV2JRRJCe1998HeW/p0X9UaMTK6SDo0ffLn2+DbLs=
|
||||
github.com/google/flatbuffers v25.12.19+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/renameio v1.0.1 h1:Lh/jXZmvZxb0BBeSY5VKEfidcbcbenKjZFzM/q0fSeU=
|
||||
github.com/google/renameio v1.0.1/go.mod h1:t/HQoYBZSsWSNK35C6CO/TpPLDVWvxOHboWUAweKUpk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
|
||||
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/modelcontextprotocol/go-sdk v1.5.0 h1:CHU0FIX9kpueNkxuYtfYQn1Z0slhFzBZuq+x6IiblIU=
|
||||
github.com/modelcontextprotocol/go-sdk v1.5.0/go.mod h1:gggDIhoemhWs3BGkGwd1umzEXCEMMvAnhTrnbXJKKKA=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pierrec/lz4/v4 v4.1.26 h1:GrpZw1gZttORinvzBdXPUXATeqlJjqUG/D87TKMnhjY=
|
||||
github.com/pierrec/lz4/v4 v4.1.26/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc=
|
||||
github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg=
|
||||
github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0=
|
||||
github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
@ -104,6 +116,8 @@ github.com/viterin/vek v0.4.2 h1:Vyv04UjQT6gcjEFX82AS9ocgNbAJqsHviheIBdPlv5U=
|
||||
github.com/viterin/vek v0.4.2/go.mod h1:A4JRAe8OvbhdzBL5ofzjBS0J29FyUrf95tQogvtHHUc=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
|
||||
github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
|
||||
github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs=
|
||||
@ -124,12 +138,16 @@ golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7
|
||||
golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ=
|
||||
golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
|
||||
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260120221211-b8f7ae30c516 h1:sNrWoksmOyF5bvJUcnmbeAmQi8baNhqg5IWaI3llQqU=
|
||||
|
||||
180
internal/embed/cached.go
Normal file
180
internal/embed/cached.go
Normal file
@ -0,0 +1,180 @@
|
||||
// cached.go — LRU caching wrapper around a Provider.
|
||||
//
|
||||
// Memoizes (effective_model, sha256(text)) → []float32. Repeat
|
||||
// queries return the stored vector without round-tripping to the
|
||||
// upstream embedding service. Per-text caching: a batch with
|
||||
// mixed hit/miss only fetches the misses, then merges the result
|
||||
// preserving the caller's input order.
|
||||
//
|
||||
// Memory budget: ~3 KB per entry at d=768. 10K-entry default ≈ 30 MB
|
||||
// — small enough for any realistic embedd deployment. Operators
|
||||
// raising the cap should weigh memory headroom against expected
|
||||
// hit rate.
|
||||
//
|
||||
// Why this exists: 500K staffing test (memory project_golang_lakehouse)
|
||||
// showed that the staffing co-pilot replays many of the same query
|
||||
// texts ("forklift driver", "welder Chicago", etc.). Caching them
|
||||
// drops repeat-query cost from ~50ms (Ollama round-trip) to <1µs
|
||||
// (LRU hit). Real production win documented in feedback_meta_index_vision.
|
||||
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
)
|
||||
|
||||
// CachedProvider wraps a Provider with an LRU cache keyed on
|
||||
// (effective_model, sha256(text)). Thread-safe.
|
||||
type CachedProvider struct {
|
||||
inner Provider
|
||||
defaultModel string
|
||||
cache *lru.Cache[string, []float32]
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
}
|
||||
|
||||
// NewCachedProvider wraps inner with an LRU cache of the given size.
|
||||
//
|
||||
// defaultModel is used to resolve cache keys when a request leaves
|
||||
// the model field empty: a request for model="" is treated as
|
||||
// model=defaultModel for cache-key purposes, so callers that mix
|
||||
// "" and the explicit model name still hit the same cache entry.
|
||||
//
|
||||
// Caller-side panic protection: size <= 0 is treated as "no cache"
|
||||
// — every Embed call passes through. Avoids forcing operators to
|
||||
// understand LRU sizing to disable.
|
||||
func NewCachedProvider(inner Provider, defaultModel string, size int) (*CachedProvider, error) {
|
||||
if size <= 0 {
|
||||
// Sentinel: nil cache means pass-through. NewCachedProvider
|
||||
// stays callable so the wiring layer can always wrap.
|
||||
return &CachedProvider{inner: inner, defaultModel: defaultModel}, nil
|
||||
}
|
||||
cache, err := lru.New[string, []float32](size)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed cache init: %w", err)
|
||||
}
|
||||
return &CachedProvider{
|
||||
inner: inner,
|
||||
defaultModel: defaultModel,
|
||||
cache: cache,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Embed returns vectors for texts, memoizing per (model, text). On
|
||||
// a batch with mixed hits/misses, only the misses round-trip to
|
||||
// inner; the result preserves caller ordering.
|
||||
func (c *CachedProvider) Embed(ctx context.Context, texts []string, model string) (Result, error) {
|
||||
// Pass-through when caching disabled.
|
||||
if c.cache == nil {
|
||||
return c.inner.Embed(ctx, texts, model)
|
||||
}
|
||||
|
||||
if len(texts) == 0 {
|
||||
return Result{}, ErrEmptyTexts
|
||||
}
|
||||
|
||||
effectiveModel := model
|
||||
if effectiveModel == "" {
|
||||
effectiveModel = c.defaultModel
|
||||
}
|
||||
|
||||
// Pass 1: cache lookup; collect misses preserving original index
|
||||
// so we can write the upstream result back into the right slots.
|
||||
out := make([][]float32, len(texts))
|
||||
missTexts := make([]string, 0, len(texts))
|
||||
missIdx := make([]int, 0, len(texts))
|
||||
for i, t := range texts {
|
||||
key := cacheKey(effectiveModel, t)
|
||||
if v, ok := c.cache.Get(key); ok {
|
||||
out[i] = v
|
||||
c.hits.Add(1)
|
||||
continue
|
||||
}
|
||||
missTexts = append(missTexts, t)
|
||||
missIdx = append(missIdx, i)
|
||||
c.misses.Add(1)
|
||||
}
|
||||
|
||||
// All hits — synthesize the result without an upstream call.
|
||||
// Use the effective model + the first cached vector's length
|
||||
// for the response. Every cached vector for the same model has
|
||||
// the same dimension by construction (Provider guarantees it).
|
||||
if len(missTexts) == 0 {
|
||||
dim := 0
|
||||
if len(out) > 0 && len(out[0]) > 0 {
|
||||
dim = len(out[0])
|
||||
}
|
||||
return Result{
|
||||
Model: effectiveModel,
|
||||
Dimension: dim,
|
||||
Vectors: out,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Pass 2: fetch the misses, populate cache, merge.
|
||||
res, err := c.inner.Embed(ctx, missTexts, model)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
if len(res.Vectors) != len(missTexts) {
|
||||
return Result{}, fmt.Errorf("embed cache: provider returned %d vectors for %d miss texts",
|
||||
len(res.Vectors), len(missTexts))
|
||||
}
|
||||
// Cache under effectiveModel (request-derived), NOT res.Model
|
||||
// (upstream-derived). Future requests with the same input shape
|
||||
// — same explicit model OR the same "" — must hit the same key.
|
||||
// Only ours is predictable from the request; upstream's resolution
|
||||
// can drift if the upstream's default changes.
|
||||
for j, t := range missTexts {
|
||||
out[missIdx[j]] = res.Vectors[j]
|
||||
c.cache.Add(cacheKey(effectiveModel, t), res.Vectors[j])
|
||||
}
|
||||
|
||||
return Result{
|
||||
Model: res.Model,
|
||||
Dimension: res.Dimension,
|
||||
Vectors: out,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Stats returns lifetime hit + miss counts for this cache. Atomic
|
||||
// reads — no locking. Useful for /health, /metrics, or operator
|
||||
// dashboards.
|
||||
func (c *CachedProvider) Stats() (hits, misses int64) {
|
||||
return c.hits.Load(), c.misses.Load()
|
||||
}
|
||||
|
||||
// HitRate returns the fraction of requests served from cache.
|
||||
// Returns 0.0 when no requests have been served (avoids NaN).
|
||||
func (c *CachedProvider) HitRate() float64 {
|
||||
h := c.hits.Load()
|
||||
m := c.misses.Load()
|
||||
total := h + m
|
||||
if total == 0 {
|
||||
return 0.0
|
||||
}
|
||||
return float64(h) / float64(total)
|
||||
}
|
||||
|
||||
// Len returns the current entry count in the cache. Returns 0 when
|
||||
// caching is disabled.
|
||||
func (c *CachedProvider) Len() int {
|
||||
if c.cache == nil {
|
||||
return 0
|
||||
}
|
||||
return c.cache.Len()
|
||||
}
|
||||
|
||||
// cacheKey is "<model>:<sha256(text)>". sha256 collapses long texts
|
||||
// to a fixed-size key; model prefix scopes cache to one model so
|
||||
// callers using multiple models don't get cross-contamination.
|
||||
func cacheKey(model, text string) string {
|
||||
h := sha256.Sum256([]byte(text))
|
||||
return model + ":" + hex.EncodeToString(h[:])
|
||||
}
|
||||
350
internal/embed/cached_test.go
Normal file
350
internal/embed/cached_test.go
Normal file
@ -0,0 +1,350 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// stubProvider records every call so tests can verify caching
|
||||
// behavior (call counts, request texts) without spinning up Ollama.
|
||||
type stubProvider struct {
|
||||
mu sync.Mutex
|
||||
model string
|
||||
dim int
|
||||
callCount int
|
||||
lastTexts []string
|
||||
err error
|
||||
// vectorFn returns a deterministic vector for a given text so
|
||||
// cache hits assert byte-equality, not just shape.
|
||||
vectorFn func(text string) []float32
|
||||
}
|
||||
|
||||
func (s *stubProvider) Embed(_ context.Context, texts []string, model string) (Result, error) {
|
||||
s.mu.Lock()
|
||||
s.callCount++
|
||||
s.lastTexts = append([]string(nil), texts...)
|
||||
s.mu.Unlock()
|
||||
if s.err != nil {
|
||||
return Result{}, s.err
|
||||
}
|
||||
out := make([][]float32, len(texts))
|
||||
for i, t := range texts {
|
||||
out[i] = s.vectorFn(t)
|
||||
}
|
||||
resolvedModel := model
|
||||
if resolvedModel == "" {
|
||||
resolvedModel = s.model
|
||||
}
|
||||
return Result{Model: resolvedModel, Dimension: s.dim, Vectors: out}, nil
|
||||
}
|
||||
|
||||
func deterministicVector(t string) []float32 {
|
||||
v := make([]float32, 4)
|
||||
for i := range v {
|
||||
v[i] = float32(int(t[0]) + i)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func newStub() *stubProvider {
|
||||
return &stubProvider{
|
||||
model: "test-model",
|
||||
dim: 4,
|
||||
vectorFn: deterministicVector,
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedProvider_PassThroughWhenSizeZero(t *testing.T) {
|
||||
stub := newStub()
|
||||
c, err := NewCachedProvider(stub, "test-model", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCachedProvider: %v", err)
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := c.Embed(context.Background(), []string{"hello"}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Embed: %v", err)
|
||||
}
|
||||
}
|
||||
if stub.callCount != 3 {
|
||||
t.Errorf("size=0 should pass through every call, got callCount=%d", stub.callCount)
|
||||
}
|
||||
if c.Len() != 0 {
|
||||
t.Errorf("Len()=%d, want 0 when caching disabled", c.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedProvider_HitOnRepeatText(t *testing.T) {
|
||||
stub := newStub()
|
||||
c, err := NewCachedProvider(stub, "test-model", 100)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCachedProvider: %v", err)
|
||||
}
|
||||
|
||||
// First call — miss, hits upstream.
|
||||
r1, err := c.Embed(context.Background(), []string{"hello"}, "test-model")
|
||||
if err != nil {
|
||||
t.Fatalf("first: %v", err)
|
||||
}
|
||||
if stub.callCount != 1 {
|
||||
t.Errorf("first call should hit upstream, callCount=%d", stub.callCount)
|
||||
}
|
||||
|
||||
// Second call — hit, no upstream call.
|
||||
r2, err := c.Embed(context.Background(), []string{"hello"}, "test-model")
|
||||
if err != nil {
|
||||
t.Fatalf("second: %v", err)
|
||||
}
|
||||
if stub.callCount != 1 {
|
||||
t.Errorf("second call should hit cache, callCount=%d (want unchanged 1)", stub.callCount)
|
||||
}
|
||||
|
||||
// Result should be byte-equal.
|
||||
if len(r1.Vectors[0]) != len(r2.Vectors[0]) {
|
||||
t.Errorf("vector dim differs: %d vs %d", len(r1.Vectors[0]), len(r2.Vectors[0]))
|
||||
}
|
||||
for i := range r1.Vectors[0] {
|
||||
if r1.Vectors[0][i] != r2.Vectors[0][i] {
|
||||
t.Errorf("vector[%d]: %v vs %v", i, r1.Vectors[0][i], r2.Vectors[0][i])
|
||||
}
|
||||
}
|
||||
|
||||
hits, misses := c.Stats()
|
||||
if hits != 1 || misses != 1 {
|
||||
t.Errorf("hits/misses = %d/%d, want 1/1", hits, misses)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedProvider_MixedBatchOnlyFetchesMisses(t *testing.T) {
|
||||
stub := newStub()
|
||||
c, _ := NewCachedProvider(stub, "test-model", 100)
|
||||
|
||||
// Prime cache with one text.
|
||||
c.Embed(context.Background(), []string{"alpha"}, "test-model")
|
||||
if stub.callCount != 1 {
|
||||
t.Fatalf("priming should produce 1 upstream call, got %d", stub.callCount)
|
||||
}
|
||||
|
||||
// Mixed batch: alpha cached, beta + gamma fresh.
|
||||
r, err := c.Embed(context.Background(),
|
||||
[]string{"alpha", "beta", "gamma"}, "test-model")
|
||||
if err != nil {
|
||||
t.Fatalf("Embed: %v", err)
|
||||
}
|
||||
|
||||
// Provider should have been called only with the misses.
|
||||
if len(stub.lastTexts) != 2 {
|
||||
t.Errorf("upstream called with %d texts, want 2 (only misses)", len(stub.lastTexts))
|
||||
}
|
||||
for _, lt := range stub.lastTexts {
|
||||
if lt == "alpha" {
|
||||
t.Errorf("upstream got 'alpha' but it was cached")
|
||||
}
|
||||
}
|
||||
|
||||
// Result preserves order: alpha first (cached), beta + gamma in
|
||||
// their original positions.
|
||||
if len(r.Vectors) != 3 {
|
||||
t.Fatalf("got %d vectors, want 3", len(r.Vectors))
|
||||
}
|
||||
for i, txt := range []string{"alpha", "beta", "gamma"} {
|
||||
want := deterministicVector(txt)
|
||||
for j := range want {
|
||||
if r.Vectors[i][j] != want[j] {
|
||||
t.Errorf("vec[%d (%s)][%d] = %v, want %v",
|
||||
i, txt, j, r.Vectors[i][j], want[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedProvider_ModelKeyIsolation(t *testing.T) {
|
||||
stub := newStub()
|
||||
c, _ := NewCachedProvider(stub, "default-model", 100)
|
||||
|
||||
// Same text, different models → distinct cache entries.
|
||||
c.Embed(context.Background(), []string{"shared"}, "model-A")
|
||||
c.Embed(context.Background(), []string{"shared"}, "model-B")
|
||||
if stub.callCount != 2 {
|
||||
t.Errorf("different models should NOT share cache, got callCount=%d", stub.callCount)
|
||||
}
|
||||
|
||||
// Repeat with model-A → hits cache.
|
||||
c.Embed(context.Background(), []string{"shared"}, "model-A")
|
||||
if stub.callCount != 2 {
|
||||
t.Errorf("repeat of model-A should hit cache, got callCount=%d", stub.callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedProvider_EmptyModelResolvesToDefault(t *testing.T) {
|
||||
stub := newStub()
|
||||
c, _ := NewCachedProvider(stub, "the-default", 100)
|
||||
|
||||
// First call with model="" — caches under "the-default".
|
||||
c.Embed(context.Background(), []string{"text"}, "")
|
||||
if stub.callCount != 1 {
|
||||
t.Fatalf("first call should hit upstream, got %d", stub.callCount)
|
||||
}
|
||||
|
||||
// Second call with model="the-default" — should hit cache because
|
||||
// the first call resolved "" → "the-default" for the key.
|
||||
c.Embed(context.Background(), []string{"text"}, "the-default")
|
||||
if stub.callCount != 1 {
|
||||
t.Errorf("explicit default-model should hit cache populated by '' request, got callCount=%d", stub.callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedProvider_LRUEviction(t *testing.T) {
|
||||
stub := newStub()
|
||||
c, _ := NewCachedProvider(stub, "test-model", 2) // cap=2
|
||||
|
||||
// Fill: a, b
|
||||
c.Embed(context.Background(), []string{"a"}, "")
|
||||
c.Embed(context.Background(), []string{"b"}, "")
|
||||
// c arrives — evicts a (LRU).
|
||||
c.Embed(context.Background(), []string{"c"}, "")
|
||||
if c.Len() > 2 {
|
||||
t.Errorf("Len()=%d, want ≤2 with cap=2", c.Len())
|
||||
}
|
||||
if stub.callCount != 3 {
|
||||
t.Errorf("callCount=%d, want 3 after a/b/c distinct", stub.callCount)
|
||||
}
|
||||
|
||||
// Re-request a → miss (was evicted).
|
||||
c.Embed(context.Background(), []string{"a"}, "")
|
||||
if stub.callCount != 4 {
|
||||
t.Errorf("evicted a should miss, callCount=%d want 4", stub.callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedProvider_PropagatesUpstreamError(t *testing.T) {
|
||||
stub := &stubProvider{
|
||||
model: "test-model",
|
||||
dim: 4,
|
||||
vectorFn: deterministicVector,
|
||||
err: errors.New("upstream broken"),
|
||||
}
|
||||
c, _ := NewCachedProvider(stub, "test-model", 10)
|
||||
|
||||
_, err := c.Embed(context.Background(), []string{"hello"}, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected upstream error to propagate")
|
||||
}
|
||||
// Cache should NOT have stored anything on error.
|
||||
if c.Len() != 0 {
|
||||
t.Errorf("error path should not populate cache, Len()=%d", c.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedProvider_AllHitsSynthesized(t *testing.T) {
|
||||
stub := newStub()
|
||||
c, _ := NewCachedProvider(stub, "test-model", 100)
|
||||
|
||||
// Prime
|
||||
c.Embed(context.Background(), []string{"x", "y"}, "")
|
||||
stub.mu.Lock()
|
||||
stub.callCount = 0 // reset for clarity
|
||||
stub.mu.Unlock()
|
||||
|
||||
// All-hits batch → no upstream call.
|
||||
r, err := c.Embed(context.Background(), []string{"x", "y"}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Embed: %v", err)
|
||||
}
|
||||
if stub.callCount != 0 {
|
||||
t.Errorf("all-hits batch should not hit upstream, callCount=%d", stub.callCount)
|
||||
}
|
||||
if r.Dimension != 4 {
|
||||
t.Errorf("all-hits Dimension=%d, want 4 (from cached vector len)", r.Dimension)
|
||||
}
|
||||
if r.Model != "test-model" {
|
||||
t.Errorf("all-hits Model=%q, want resolved 'test-model'", r.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedProvider_HitRate(t *testing.T) {
|
||||
stub := newStub()
|
||||
c, _ := NewCachedProvider(stub, "test-model", 100)
|
||||
|
||||
if c.HitRate() != 0.0 {
|
||||
t.Errorf("initial HitRate=%v, want 0.0", c.HitRate())
|
||||
}
|
||||
|
||||
c.Embed(context.Background(), []string{"a"}, "") // miss
|
||||
c.Embed(context.Background(), []string{"a"}, "") // hit
|
||||
c.Embed(context.Background(), []string{"a"}, "") // hit
|
||||
|
||||
got := c.HitRate()
|
||||
want := 2.0 / 3.0
|
||||
if got < want-0.01 || got > want+0.01 {
|
||||
t.Errorf("HitRate=%v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedProvider_EmptyTextsRejected(t *testing.T) {
|
||||
stub := newStub()
|
||||
c, _ := NewCachedProvider(stub, "test-model", 100)
|
||||
|
||||
_, err := c.Embed(context.Background(), []string{}, "")
|
||||
if !errors.Is(err, ErrEmptyTexts) {
|
||||
t.Errorf("empty texts should return ErrEmptyTexts, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedProvider_ConcurrentSafe(t *testing.T) {
|
||||
stub := newStub()
|
||||
c, _ := NewCachedProvider(stub, "test-model", 100)
|
||||
|
||||
// 50 goroutines × 100 calls each, mostly the same text.
|
||||
var wg sync.WaitGroup
|
||||
for g := 0; g < 50; g++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 100; i++ {
|
||||
_, err := c.Embed(context.Background(), []string{"shared"}, "")
|
||||
if err != nil {
|
||||
t.Errorf("concurrent embed err: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
hits, misses := c.Stats()
|
||||
if hits+misses != 5000 {
|
||||
t.Errorf("total calls = %d, want 5000", hits+misses)
|
||||
}
|
||||
// First call is a guaranteed miss; subsequent should mostly be
|
||||
// hits. With 50 concurrent goroutines hitting an empty cache,
|
||||
// up to ~50 first-arrivers may miss before the first one writes.
|
||||
if misses > 100 {
|
||||
t.Errorf("misses=%d unreasonably high; cache concurrency broken?", misses)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheKey_StableAndDistinct(t *testing.T) {
|
||||
a := cacheKey("m1", "text-a")
|
||||
a2 := cacheKey("m1", "text-a")
|
||||
b := cacheKey("m1", "text-b")
|
||||
c := cacheKey("m2", "text-a")
|
||||
|
||||
if a != a2 {
|
||||
t.Errorf("cache key not stable: %q != %q", a, a2)
|
||||
}
|
||||
if a == b {
|
||||
t.Errorf("different texts produced same key: %q", a)
|
||||
}
|
||||
if a == c {
|
||||
t.Errorf("different models produced same key: %q", a)
|
||||
}
|
||||
// Key shape: model + ":" + 64-hex-chars sha256.
|
||||
const sha256HexLen = 64
|
||||
want := len("m1") + 1 + sha256HexLen
|
||||
if len(a) != want {
|
||||
t.Errorf("key len=%d, want %d", len(a), want)
|
||||
}
|
||||
}
|
||||
@ -65,11 +65,14 @@ type GatewayConfig struct {
|
||||
// EmbeddConfig drives the embed service. ProviderURL points at the
|
||||
// embedding backend (Ollama in G2, possibly OpenAI/Voyage in G3+).
|
||||
// DefaultModel is what gets used when callers don't specify a
|
||||
// model in their request body.
|
||||
// model in their request body. CacheSize is the LRU cache cap on
|
||||
// (model, sha256(text)) → vector lookups; 0 disables caching.
|
||||
// Default 10000 entries ≈ 30 MiB at d=768.
|
||||
type EmbeddConfig struct {
|
||||
Bind string `toml:"bind"`
|
||||
ProviderURL string `toml:"provider_url"`
|
||||
DefaultModel string `toml:"default_model"`
|
||||
CacheSize int `toml:"cache_size"`
|
||||
}
|
||||
|
||||
// VectordConfig adds vectord-specific knobs. StoragedURL is
|
||||
@ -153,6 +156,7 @@ func DefaultConfig() Config {
|
||||
Bind: "127.0.0.1:3216",
|
||||
ProviderURL: "http://localhost:11434", // local Ollama
|
||||
DefaultModel: "nomic-embed-text",
|
||||
CacheSize: 10_000, // ~30 MiB at d=768; set to 0 to disable
|
||||
},
|
||||
Queryd: QuerydConfig{
|
||||
Bind: "127.0.0.1:3214",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user