package main import ( "context" "net/http" "net/http/httptest" "strings" "testing" "github.com/modelcontextprotocol/go-sdk/mcp" ) // Tests the MCP tool surface end-to-end without a subprocess: spin // up a fake gateway via httptest, build the MCP server pointed at // it, connect a client via in-memory transports, call each tool. // fakeGateway returns an httptest.Server that responds to the routes // mcpd proxies. Each route's handler is configurable via the routes map. func fakeGateway(t *testing.T, routes map[string]http.HandlerFunc) *httptest.Server { t.Helper() mux := http.NewServeMux() for path, h := range routes { mux.HandleFunc(path, h) } mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { t.Errorf("fakeGateway: unexpected route %s %s", r.Method, r.URL.Path) http.NotFound(w, r) }) srv := httptest.NewServer(mux) t.Cleanup(srv.Close) return srv } // connect builds the mcpd server pointed at gatewayURL, connects an // in-memory client, and returns a ready-to-use ClientSession. func connect(t *testing.T, gatewayURL string) *mcp.ClientSession { t.Helper() ctx := context.Background() srv := buildServer(gatewayURL) client := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "v0.0.1"}, nil) t1, t2 := mcp.NewInMemoryTransports() if _, err := srv.Connect(ctx, t1, nil); err != nil { t.Fatalf("server connect: %v", err) } session, err := client.Connect(ctx, t2, nil) if err != nil { t.Fatalf("client connect: %v", err) } t.Cleanup(func() { _ = session.Close() }) return session } func callTool(t *testing.T, session *mcp.ClientSession, name string, args any) *mcp.CallToolResult { t.Helper() res, err := session.CallTool(context.Background(), &mcp.CallToolParams{ Name: name, Arguments: args, }) if err != nil { t.Fatalf("CallTool %s: %v", name, err) } return res } func resultText(t *testing.T, res *mcp.CallToolResult) string { t.Helper() if len(res.Content) == 0 { t.Fatal("result has no content") } tc, ok := res.Content[0].(*mcp.TextContent) if !ok { t.Fatalf("first content is %T, want *mcp.TextContent", res.Content[0]) } return tc.Text } func TestListTools(t *testing.T) { gw := fakeGateway(t, nil) session := connect(t, gw.URL) res, err := session.ListTools(context.Background(), nil) if err != nil { t.Fatalf("ListTools: %v", err) } want := map[string]bool{ "list_datasets": false, "get_manifest": false, "query_sql": false, "embed_text": false, "search_vectors": false, } for _, tool := range res.Tools { if _, ok := want[tool.Name]; ok { want[tool.Name] = true } } for name, found := range want { if !found { t.Errorf("expected tool %q exposed by mcpd, not in ListTools", name) } } } func TestListDatasets_HappyPath(t *testing.T) { gw := fakeGateway(t, map[string]http.HandlerFunc{ "/v1/catalog/list": func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"manifests":[{"name":"workers","row_count":500000}],"count":1}`)) }, }) session := connect(t, gw.URL) res := callTool(t, session, "list_datasets", listDatasetsArgs{}) if res.IsError { t.Fatalf("expected success, got IsError with: %s", resultText(t, res)) } body := resultText(t, res) if !strings.Contains(body, "workers") || !strings.Contains(body, "500000") { t.Errorf("response missing expected fields: %s", body) } } func TestGetManifest_HappyPath(t *testing.T) { gw := fakeGateway(t, map[string]http.HandlerFunc{ "/v1/catalog/manifest/workers": func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte(`{"name":"workers","row_count":500000,"schema_fingerprint":"sha256:abc"}`)) }, }) session := connect(t, gw.URL) res := callTool(t, session, "get_manifest", getManifestArgs{Name: "workers"}) if res.IsError { t.Fatalf("expected success, got: %s", resultText(t, res)) } body := resultText(t, res) if !strings.Contains(body, "workers") { t.Errorf("missing manifest fields in response: %s", body) } } func TestGetManifest_EmptyName_IsError(t *testing.T) { gw := fakeGateway(t, nil) // no handlers — tool should error before hitting gateway session := connect(t, gw.URL) res := callTool(t, session, "get_manifest", getManifestArgs{Name: ""}) if !res.IsError { t.Fatal("expected IsError on empty name") } } func TestQuerySQL_HappyPath(t *testing.T) { gw := fakeGateway(t, map[string]http.HandlerFunc{ "/v1/sql": func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { t.Errorf("/v1/sql got %s, want POST", r.Method) } _, _ = w.Write([]byte(`{"columns":[{"name":"n","type":"BIGINT"}],"rows":[[5]],"row_count":1}`)) }, }) session := connect(t, gw.URL) res := callTool(t, session, "query_sql", querySQLArgs{SQL: "SELECT count(*) FROM workers"}) if res.IsError { t.Fatalf("expected success, got: %s", resultText(t, res)) } if !strings.Contains(resultText(t, res), `"row_count":1`) { t.Errorf("response missing row_count field: %s", resultText(t, res)) } } func TestQuerySQL_EmptySQL_IsError(t *testing.T) { gw := fakeGateway(t, nil) session := connect(t, gw.URL) res := callTool(t, session, "query_sql", querySQLArgs{SQL: " "}) if !res.IsError { t.Fatal("expected IsError on whitespace SQL") } } func TestEmbedText_HappyPath(t *testing.T) { gw := fakeGateway(t, map[string]http.HandlerFunc{ "/v1/embed": func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte(`{"model":"nomic-embed-text","dimension":768,"vectors":[[0.1,0.2]]}`)) }, }) session := connect(t, gw.URL) res := callTool(t, session, "embed_text", embedTextArgs{Texts: []string{"hello"}}) if res.IsError { t.Fatalf("expected success, got: %s", resultText(t, res)) } if !strings.Contains(resultText(t, res), `"dimension":768`) { t.Errorf("missing dimension in response: %s", resultText(t, res)) } } func TestEmbedText_EmptyTexts_IsError(t *testing.T) { gw := fakeGateway(t, nil) session := connect(t, gw.URL) res := callTool(t, session, "embed_text", embedTextArgs{Texts: nil}) if !res.IsError { t.Fatal("expected IsError on empty texts") } } func TestSearchVectors_HappyPath(t *testing.T) { gw := fakeGateway(t, map[string]http.HandlerFunc{ "/v1/vectors/index/test_idx/search": func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte(`{"results":[{"id":"v1","distance":0.001}]}`)) }, }) session := connect(t, gw.URL) res := callTool(t, session, "search_vectors", searchVectorsArgs{ IndexName: "test_idx", Vector: []float32{1, 0, 0, 0}, K: 5, }) if res.IsError { t.Fatalf("expected success, got: %s", resultText(t, res)) } if !strings.Contains(resultText(t, res), `"id":"v1"`) { t.Errorf("missing top-1 in response: %s", resultText(t, res)) } } func TestSearchVectors_MissingIndex_IsError(t *testing.T) { gw := fakeGateway(t, nil) session := connect(t, gw.URL) res := callTool(t, session, "search_vectors", searchVectorsArgs{ Vector: []float32{1, 0, 0, 0}, }) if !res.IsError { t.Fatal("expected IsError on missing index_name") } } func TestSearchVectors_MissingVector_IsError(t *testing.T) { gw := fakeGateway(t, nil) session := connect(t, gw.URL) res := callTool(t, session, "search_vectors", searchVectorsArgs{ IndexName: "test_idx", }) if !res.IsError { t.Fatal("expected IsError on missing vector") } } func TestGatewayError_PropagatesAsIsError(t *testing.T) { gw := fakeGateway(t, map[string]http.HandlerFunc{ "/v1/sql": func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusBadRequest) _, _ = w.Write([]byte("syntax error: no such table")) }, }) session := connect(t, gw.URL) res := callTool(t, session, "query_sql", querySQLArgs{SQL: "SELECT * FROM nope"}) if !res.IsError { t.Fatal("expected IsError on gateway 4xx") } body := resultText(t, res) if !strings.Contains(body, "400") { t.Errorf("error result should mention status 400, got: %s", body) } } func TestGatewayUnreachable_PropagatesAsIsError(t *testing.T) { // Point at a non-listening port — connect will fail. session := connect(t, "http://127.0.0.1:1") // reserved port res := callTool(t, session, "list_datasets", listDatasetsArgs{}) if !res.IsError { t.Fatal("expected IsError on unreachable gateway") } }