import os from fastapi import APIRouter, HTTPException from pydantic import BaseModel from .ollama import client router = APIRouter() RERANK_MODEL = os.environ.get("RERANK_MODEL", "qwen2.5") class RerankRequest(BaseModel): query: str documents: list[str] model: str | None = None top_k: int | None = None class ScoredDocument(BaseModel): index: int text: str score: float class RerankResponse(BaseModel): results: list[ScoredDocument] model: str @router.post("", response_model=RerankResponse) async def rerank(req: RerankRequest): """Cross-encoder reranking via Ollama generate. Scores each document against the query by asking the model to rate relevance 0-10, then sorts by score descending. """ model = req.model or RERANK_MODEL scored = [] async with client() as c: for i, doc in enumerate(req.documents): prompt = ( f"Rate the relevance of the following document to the query on a scale of 0 to 10. " f"Respond with ONLY a number.\n\n" f"Query: {req.query}\n\n" f"Document: {doc}\n\n" f"Score:" ) resp = await c.post( "/api/generate", json={"model": model, "prompt": prompt, "stream": False, "options": {"temperature": 0.0, "num_predict": 8}}, ) if resp.status_code != 200: raise HTTPException(502, f"Ollama error: {resp.text}") text = resp.json().get("response", "").strip() try: score = float(text.split()[0]) score = max(0.0, min(10.0, score)) except (ValueError, IndexError): score = 0.0 scored.append(ScoredDocument(index=i, text=doc, score=score)) scored.sort(key=lambda x: x.score, reverse=True) if req.top_k: scored = scored[: req.top_k] return RerankResponse(results=scored, model=model)