#!/usr/bin/env python3 """ Loop Detector Proxy — Agent0 Plugin An Ollama-compatible HTTP proxy that intercepts LLM completions, detects repeated outputs, and applies configurable mitigations before returning the response to the caller (agent-zero). Architecture: agent-zero → loop-detector:11434 → ollama:11434 (host) POC limitations: - Forces stream=false on all forwarded requests to simplify response buffering. Streaming responses are re-emitted as single JSON objects, which agent-zero handles correctly for task execution. - State is in-memory; restarting the proxy clears all session history. """ import hashlib import json import logging import time from collections import defaultdict, deque from pathlib import Path from typing import Any import httpx import yaml from fastapi import FastAPI, Request, Response logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-7s %(message)s", ) log = logging.getLogger("loop_detector") CONFIG_PATH = Path(__file__).parent / "config.yaml" # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- def load_config() -> dict: with open(CONFIG_PATH) as f: return yaml.safe_load(f) # --------------------------------------------------------------------------- # Session state # --------------------------------------------------------------------------- # Key: (model, session_fingerprint) → deque of response hashes _history: dict[tuple, deque] = defaultdict(lambda: deque(maxlen=20)) # Key: same tuple → current consecutive-repeat count _consecutive: dict[tuple, int] = defaultdict(int) _last_hash: dict[tuple, str] = {} def _session_key(model: str, messages: list[dict]) -> tuple[str, str]: """ Derive a stable session identifier from the model name and the content of the first system message (which is unique per agent profile). Falls back to a hash of all non-assistant messages if there is no system message. """ system_msgs = [m for m in messages if m.get("role") == "system"] if system_msgs: fingerprint = hashlib.sha256( system_msgs[0].get("content", "")[:512].encode() ).hexdigest()[:16] else: non_assistant = [m for m in messages if m.get("role") != "assistant"] blob = json.dumps(non_assistant, sort_keys=True) fingerprint = hashlib.sha256(blob.encode()).hexdigest()[:16] return (model, fingerprint) def _hash_response(text: str) -> str: return hashlib.sha256(text.strip().encode()).hexdigest() def record_and_check(session: tuple, text: str, min_length: int) -> int: """ Record a completion and return the current consecutive-repeat count. Returns 1 (no loop) if the text is below min_length. """ if len(text.strip()) < min_length: return 1 h = _hash_response(text) history = _history[session] if _last_hash.get(session) == h: _consecutive[session] += 1 else: _consecutive[session] = 1 _last_hash[session] = h history.append(h) return _consecutive[session] # --------------------------------------------------------------------------- # Mitigations # --------------------------------------------------------------------------- def apply_mitigations( request_body: dict, count: int, config: dict, ) -> tuple[dict, str | None]: """ Walk the configured mitigation list in order. Returns (modified_request_body, override_response_text_or_None). An override_response means "return this to the caller, skip the LLM". """ mitigations = config.get("mitigations", []) override: str | None = None for m in mitigations: if not m.get("enabled", True): continue if count < m.get("trigger_count", 2): continue strategy = m["strategy"] if strategy == "temperature_boost": opts = request_body.setdefault("options", {}) current = opts.get("temperature", 0.7) boosted = min(current + m.get("boost_amount", 0.35), m.get("max_temperature", 1.4)) opts["temperature"] = boosted log.warning("mitigation=temperature_boost %.2f → %.2f (count=%d)", current, boosted, count) elif strategy == "forbidden_action": msg = m.get("injection_message", "STOP. Try something completely different.").format(count=count) request_body.setdefault("messages", []).append({"role": "user", "content": msg}) log.warning("mitigation=forbidden_action injected (count=%d)", count) elif strategy == "history_truncation": messages = request_body.get("messages", []) truncate = m.get("truncate_turns", 6) system_msgs = [m_ for m_ in messages if m_.get("role") == "system"] non_system = [m_ for m_ in messages if m_.get("role") != "system"] # Keep at least the most recent exchange after truncation trimmed = non_system[:-truncate] if len(non_system) > truncate else non_system[-2:] request_body["messages"] = system_msgs + trimmed log.warning("mitigation=history_truncation dropped %d turns (count=%d)", truncate, count) elif strategy == "circuit_breaker": override = m.get("response_message", "[LOOP DETECTOR] Halted due to repeated responses.") log.error("mitigation=circuit_breaker TRIGGERED (count=%d)", count) break # nothing further makes sense return request_body, override # --------------------------------------------------------------------------- # Ollama forwarding # --------------------------------------------------------------------------- async def call_ollama(path: str, body: dict, upstream: str) -> tuple[str, dict]: """ Forward a request to upstream Ollama (stream=False) and return (assistant_text, raw_response_dict). """ body = dict(body) body["stream"] = False async with httpx.AsyncClient(timeout=300.0) as client: r = await client.post(f"{upstream}{path}", json=body) r.raise_for_status() data = r.json() if path == "/api/chat": text = data.get("message", {}).get("content", "") else: text = data.get("response", "") return text, data # --------------------------------------------------------------------------- # FastAPI app # --------------------------------------------------------------------------- app = FastAPI(title="Loop Detector Proxy") @app.post("/api/chat") async def chat(request: Request) -> Response: config = load_config() body = await request.json() model = body.get("model", "unknown") upstream = config["upstream_ollama"] min_len = config["detection"]["min_length"] text, raw = await call_ollama("/api/chat", body, upstream) session = _session_key(model, body.get("messages", [])) count = record_and_check(session, text, min_len) if count >= 2: log.warning("loop detected model=%s session=%s count=%d", model, session[1], count) body, override = apply_mitigations(body, count, config) if override is not None: raw["message"] = {"role": "assistant", "content": override} raw["loop_detected"] = True return Response(content=json.dumps(raw), media_type="application/json") # Retry with mitigations applied text, raw = await call_ollama("/api/chat", body, upstream) record_and_check(session, text, min_len) raw["message"] = {"role": "assistant", "content": text} return Response(content=json.dumps(raw), media_type="application/json") @app.post("/api/generate") async def generate(request: Request) -> Response: config = load_config() body = await request.json() model = body.get("model", "unknown") upstream = config["upstream_ollama"] min_len = config["detection"]["min_length"] # /api/generate has no messages list; use prompt as fingerprint messages = [{"role": "user", "content": body.get("prompt", "")}] session = _session_key(model, messages) text, raw = await call_ollama("/api/generate", body, upstream) count = record_and_check(session, text, min_len) if count >= 2: log.warning("loop detected model=%s session=%s count=%d", model, session[1], count) body, override = apply_mitigations(body, count, config) if override is not None: raw["response"] = override raw["loop_detected"] = True return Response(content=json.dumps(raw), media_type="application/json") text, raw = await call_ollama("/api/generate", body, upstream) record_and_check(session, text, min_len) raw["response"] = text return Response(content=json.dumps(raw), media_type="application/json") @app.get("/health") async def health() -> dict: cfg = load_config() return { "status": "ok", "upstream": cfg["upstream_ollama"], "active_sessions": len(_history), "timestamp": int(time.time()), } @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "HEAD"]) async def passthrough(path: str, request: Request) -> Response: """Forward anything we don't handle (model listing, embeddings, etc.) straight through.""" config = load_config() upstream = config["upstream_ollama"] body = await request.body() headers = {k: v for k, v in request.headers.items() if k.lower() != "host"} async with httpx.AsyncClient(timeout=120.0) as client: r = await client.request( request.method, f"{upstream}/{path}", content=body, headers=headers, ) return Response( content=r.content, status_code=r.status_code, media_type=r.headers.get("content-type"), ) # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- if __name__ == "__main__": import uvicorn cfg = load_config() uvicorn.run(app, host="0.0.0.0", port=cfg["proxy_port"], log_level="info")