"""Admin endpoints: model lifecycle + VRAM introspection. Phase 17 / Phase C — the VRAM-aware profile swap story. Ollama loads models lazily and unloads after a TTL (default 5min). For predictable swaps between model profiles, we need explicit control: - unload: send keep_alive=0 to force immediate unload - preload: send keep_alive=5m with empty prompt to proactively load - ps: query Ollama's /api/ps to see what's currently loaded All three are thin wrappers over Ollama's own API. No state held here. """ import asyncio import shutil from fastapi import APIRouter, HTTPException from pydantic import BaseModel from .ollama import client router = APIRouter() class ModelRequest(BaseModel): model: str @router.post("/unload") async def unload(req: ModelRequest): """Force Ollama to unload a model immediately (keep_alive=0 trick).""" async with client() as c: resp = await c.post( "/api/generate", json={"model": req.model, "prompt": "", "keep_alive": 0, "stream": False}, ) # Ollama returns 200 even on unload; anything else is abnormal. if resp.status_code not in (200,): raise HTTPException(502, f"Ollama unload error: {resp.text}") return {"unloaded": req.model} @router.post("/preload") async def preload(req: ModelRequest): """Force Ollama to load a model into VRAM and keep it there.""" async with client() as c: resp = await c.post( "/api/generate", json={ "model": req.model, # "" can confuse some models; a single token is safer and still ~instant. "prompt": " ", "keep_alive": "5m", "stream": False, "options": {"num_predict": 1}, }, ) if resp.status_code != 200: raise HTTPException(502, f"Ollama preload error: {resp.text}") data = resp.json() return { "preloaded": req.model, "load_duration_ns": data.get("load_duration"), "total_duration_ns": data.get("total_duration"), } @router.get("/ps") async def list_loaded(): """What models does Ollama currently have in VRAM?""" async with client() as c: resp = await c.get("/api/ps") if resp.status_code != 200: raise HTTPException(502, f"Ollama ps error: {resp.text}") data = resp.json() # Flatten Ollama's fields into something small + useful. models = [ { "name": m.get("name"), "size_bytes": m.get("size"), "size_vram_bytes": m.get("size_vram"), "expires_at": m.get("expires_at"), } for m in data.get("models", []) ] return {"models": models} @router.get("/vram") async def vram_summary(): """Combined: nvidia-smi VRAM + Ollama loaded models. Shells out to nvidia-smi; if it's not on PATH, returns just the Ollama view. Intentionally async-via-to_thread so the blocking subprocess doesn't stall the event loop. """ gpu = None if shutil.which("nvidia-smi"): gpu = await asyncio.to_thread(_nvidia_smi_snapshot) async with client() as c: resp = await c.get("/api/ps") loaded = resp.json().get("models", []) if resp.status_code == 200 else [] return { "gpu": gpu, "ollama_loaded": [ { "name": m.get("name"), "size_vram_mib": (m.get("size_vram", 0) or 0) // (1024 * 1024), "expires_at": m.get("expires_at"), } for m in loaded ], } def _nvidia_smi_snapshot(): """One-shot nvidia-smi poll. Returns None on failure.""" import subprocess try: out = subprocess.check_output( [ "nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu,name", "--format=csv,noheader,nounits", ], timeout=2, ).decode().strip() used_mib, total_mib, util_pct, name = [s.strip() for s in out.split(",")] return { "name": name, "used_mib": int(used_mib), "total_mib": int(total_mib), "utilization_pct": int(util_pct), } except Exception: return None