// ollama.go — Provider backed by an Ollama HTTP server. Compatible // with Ollama 0.21+ via the per-text /api/embeddings endpoint. // Newer Ollama (0.4+) exposes /api/embed for batched calls, but // the per-text loop is forward-compatible with both. package embed import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "time" ) // OllamaProvider hits an Ollama server at the configured base URL. type OllamaProvider struct { baseURL string defaultModel string hc *http.Client } // NewOllama builds a provider against baseURL (e.g. // "http://localhost:11434"). defaultModel is what gets used when // callers pass an empty model name. func NewOllama(baseURL, defaultModel string) *OllamaProvider { return &OllamaProvider{ baseURL: strings.TrimRight(baseURL, "/"), defaultModel: defaultModel, hc: &http.Client{ // Embeddings are CPU-bound on the server side; 60s gives // plenty of headroom for a single-text call. Caller can // add an outer ctx deadline for batch-level cap. Timeout: 60 * time.Second, }, } } // ollamaRequest is Ollama's /api/embeddings body shape. type ollamaRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` } // ollamaResponse mirrors the success body. Embedding is float64 // from Ollama; we convert to float32 at the boundary. type ollamaResponse struct { Embedding []float64 `json:"embedding"` } // Embed loops over texts, issuing one HTTP call per text. Errors // short-circuit — if call N fails, we return the error and the // caller sees no partial Result. func (p *OllamaProvider) Embed(ctx context.Context, texts []string, model string) (Result, error) { if len(texts) == 0 { return Result{}, ErrEmptyTexts } if model == "" { model = p.defaultModel } if model == "" { return Result{}, fmt.Errorf("embed: no model (empty request, no default)") } out := Result{Model: model, Vectors: make([][]float32, 0, len(texts))} for i, text := range texts { vec, err := p.embedOne(ctx, model, text) if err != nil { return Result{}, fmt.Errorf("embed text[%d]: %w", i, err) } // Per-text vectors must agree on dimension. Lock on first. if out.Dimension == 0 { out.Dimension = len(vec) } else if len(vec) != out.Dimension { return Result{}, fmt.Errorf("%w: text[%d] returned %d, prior were %d", ErrModelMismatch, i, len(vec), out.Dimension) } out.Vectors = append(out.Vectors, vec) } return out, nil } func (p *OllamaProvider) embedOne(ctx context.Context, model, text string) ([]float32, error) { body, err := json.Marshal(ollamaRequest{Model: model, Prompt: text}) if err != nil { return nil, fmt.Errorf("marshal: %w", err) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/api/embeddings", bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("req: %w", err) } req.Header.Set("Content-Type", "application/json") req.ContentLength = int64(len(body)) resp, err := p.hc.Do(req) if err != nil { return nil, fmt.Errorf("do: %w", err) } defer drainAndClose(resp.Body) if resp.StatusCode != http.StatusOK { preview, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) return nil, fmt.Errorf("upstream status %d: %s", resp.StatusCode, string(preview)) } var or ollamaResponse if err := json.NewDecoder(resp.Body).Decode(&or); err != nil { return nil, fmt.Errorf("decode: %w", err) } if len(or.Embedding) == 0 { return nil, fmt.Errorf("upstream returned empty embedding") } // Float64 → Float32. Loss of precision is acceptable for HNSW // search; float32 matches the rest of the system. out := make([]float32, len(or.Embedding)) for i, v := range or.Embedding { out[i] = float32(v) } return out, nil } func drainAndClose(body io.ReadCloser) { _, _ = io.Copy(io.Discard, io.LimitReader(body, 64<<10)) _ = body.Close() }