293 lines
10 KiB
Python
293 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Loop Detector Proxy — Agent0 Plugin
|
|
|
|
An Ollama-compatible HTTP proxy that intercepts LLM completions, detects
|
|
repeated outputs, and applies configurable mitigations before returning the
|
|
response to the caller (agent-zero).
|
|
|
|
Architecture:
|
|
agent-zero → loop-detector:11434 → ollama:11434 (host)
|
|
|
|
POC limitations:
|
|
- Forces stream=false on all forwarded requests to simplify response
|
|
buffering. Streaming responses are re-emitted as single JSON objects,
|
|
which agent-zero handles correctly for task execution.
|
|
- State is in-memory; restarting the proxy clears all session history.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import time
|
|
from collections import defaultdict, deque
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import httpx
|
|
import yaml
|
|
from fastapi import FastAPI, Request, Response
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)-7s %(message)s",
|
|
)
|
|
log = logging.getLogger("loop_detector")
|
|
|
|
CONFIG_PATH = Path(__file__).parent / "config.yaml"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def load_config() -> dict:
|
|
with open(CONFIG_PATH) as f:
|
|
return yaml.safe_load(f)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Session state
|
|
# ---------------------------------------------------------------------------
|
|
# Key: (model, session_fingerprint) → deque of response hashes
|
|
_history: dict[tuple, deque] = defaultdict(lambda: deque(maxlen=20))
|
|
# Key: same tuple → current consecutive-repeat count
|
|
_consecutive: dict[tuple, int] = defaultdict(int)
|
|
_last_hash: dict[tuple, str] = {}
|
|
|
|
|
|
def _session_key(model: str, messages: list[dict]) -> tuple[str, str]:
|
|
"""
|
|
Derive a stable session identifier from the model name and the content
|
|
of the first system message (which is unique per agent profile).
|
|
Falls back to a hash of all non-assistant messages if there is no system
|
|
message.
|
|
"""
|
|
system_msgs = [m for m in messages if m.get("role") == "system"]
|
|
if system_msgs:
|
|
fingerprint = hashlib.sha256(
|
|
system_msgs[0].get("content", "")[:512].encode()
|
|
).hexdigest()[:16]
|
|
else:
|
|
non_assistant = [m for m in messages if m.get("role") != "assistant"]
|
|
blob = json.dumps(non_assistant, sort_keys=True)
|
|
fingerprint = hashlib.sha256(blob.encode()).hexdigest()[:16]
|
|
return (model, fingerprint)
|
|
|
|
|
|
def _hash_response(text: str) -> str:
|
|
return hashlib.sha256(text.strip().encode()).hexdigest()
|
|
|
|
|
|
def record_and_check(session: tuple, text: str, min_length: int) -> int:
|
|
"""
|
|
Record a completion and return the current consecutive-repeat count.
|
|
Returns 1 (no loop) if the text is below min_length.
|
|
"""
|
|
if len(text.strip()) < min_length:
|
|
return 1
|
|
|
|
h = _hash_response(text)
|
|
history = _history[session]
|
|
|
|
if _last_hash.get(session) == h:
|
|
_consecutive[session] += 1
|
|
else:
|
|
_consecutive[session] = 1
|
|
_last_hash[session] = h
|
|
|
|
history.append(h)
|
|
return _consecutive[session]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Mitigations
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def apply_mitigations(
|
|
request_body: dict,
|
|
count: int,
|
|
config: dict,
|
|
) -> tuple[dict, str | None]:
|
|
"""
|
|
Walk the configured mitigation list in order.
|
|
Returns (modified_request_body, override_response_text_or_None).
|
|
An override_response means "return this to the caller, skip the LLM".
|
|
"""
|
|
mitigations = config.get("mitigations", [])
|
|
override: str | None = None
|
|
|
|
for m in mitigations:
|
|
if not m.get("enabled", True):
|
|
continue
|
|
if count < m.get("trigger_count", 2):
|
|
continue
|
|
|
|
strategy = m["strategy"]
|
|
|
|
if strategy == "temperature_boost":
|
|
opts = request_body.setdefault("options", {})
|
|
current = opts.get("temperature", 0.7)
|
|
boosted = min(current + m.get("boost_amount", 0.35), m.get("max_temperature", 1.4))
|
|
opts["temperature"] = boosted
|
|
log.warning("mitigation=temperature_boost %.2f → %.2f (count=%d)", current, boosted, count)
|
|
|
|
elif strategy == "forbidden_action":
|
|
msg = m.get("injection_message", "STOP. Try something completely different.").format(count=count)
|
|
request_body.setdefault("messages", []).append({"role": "user", "content": msg})
|
|
log.warning("mitigation=forbidden_action injected (count=%d)", count)
|
|
|
|
elif strategy == "history_truncation":
|
|
messages = request_body.get("messages", [])
|
|
truncate = m.get("truncate_turns", 6)
|
|
system_msgs = [m_ for m_ in messages if m_.get("role") == "system"]
|
|
non_system = [m_ for m_ in messages if m_.get("role") != "system"]
|
|
# Keep at least the most recent exchange after truncation
|
|
trimmed = non_system[:-truncate] if len(non_system) > truncate else non_system[-2:]
|
|
request_body["messages"] = system_msgs + trimmed
|
|
log.warning("mitigation=history_truncation dropped %d turns (count=%d)", truncate, count)
|
|
|
|
elif strategy == "circuit_breaker":
|
|
override = m.get("response_message", "[LOOP DETECTOR] Halted due to repeated responses.")
|
|
log.error("mitigation=circuit_breaker TRIGGERED (count=%d)", count)
|
|
break # nothing further makes sense
|
|
|
|
return request_body, override
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Ollama forwarding
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def call_ollama(path: str, body: dict, upstream: str) -> tuple[str, dict]:
|
|
"""
|
|
Forward a request to upstream Ollama (stream=False) and return
|
|
(assistant_text, raw_response_dict).
|
|
"""
|
|
body = dict(body)
|
|
body["stream"] = False
|
|
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
r = await client.post(f"{upstream}{path}", json=body)
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
|
|
if path == "/api/chat":
|
|
text = data.get("message", {}).get("content", "")
|
|
else:
|
|
text = data.get("response", "")
|
|
|
|
return text, data
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# FastAPI app
|
|
# ---------------------------------------------------------------------------
|
|
|
|
app = FastAPI(title="Loop Detector Proxy")
|
|
|
|
|
|
@app.post("/api/chat")
|
|
async def chat(request: Request) -> Response:
|
|
config = load_config()
|
|
body = await request.json()
|
|
model = body.get("model", "unknown")
|
|
upstream = config["upstream_ollama"]
|
|
min_len = config["detection"]["min_length"]
|
|
|
|
text, raw = await call_ollama("/api/chat", body, upstream)
|
|
session = _session_key(model, body.get("messages", []))
|
|
count = record_and_check(session, text, min_len)
|
|
|
|
if count >= 2:
|
|
log.warning("loop detected model=%s session=%s count=%d", model, session[1], count)
|
|
body, override = apply_mitigations(body, count, config)
|
|
|
|
if override is not None:
|
|
raw["message"] = {"role": "assistant", "content": override}
|
|
raw["loop_detected"] = True
|
|
return Response(content=json.dumps(raw), media_type="application/json")
|
|
|
|
# Retry with mitigations applied
|
|
text, raw = await call_ollama("/api/chat", body, upstream)
|
|
record_and_check(session, text, min_len)
|
|
|
|
raw["message"] = {"role": "assistant", "content": text}
|
|
return Response(content=json.dumps(raw), media_type="application/json")
|
|
|
|
|
|
@app.post("/api/generate")
|
|
async def generate(request: Request) -> Response:
|
|
config = load_config()
|
|
body = await request.json()
|
|
model = body.get("model", "unknown")
|
|
upstream = config["upstream_ollama"]
|
|
min_len = config["detection"]["min_length"]
|
|
|
|
# /api/generate has no messages list; use prompt as fingerprint
|
|
messages = [{"role": "user", "content": body.get("prompt", "")}]
|
|
session = _session_key(model, messages)
|
|
|
|
text, raw = await call_ollama("/api/generate", body, upstream)
|
|
count = record_and_check(session, text, min_len)
|
|
|
|
if count >= 2:
|
|
log.warning("loop detected model=%s session=%s count=%d", model, session[1], count)
|
|
body, override = apply_mitigations(body, count, config)
|
|
|
|
if override is not None:
|
|
raw["response"] = override
|
|
raw["loop_detected"] = True
|
|
return Response(content=json.dumps(raw), media_type="application/json")
|
|
|
|
text, raw = await call_ollama("/api/generate", body, upstream)
|
|
record_and_check(session, text, min_len)
|
|
|
|
raw["response"] = text
|
|
return Response(content=json.dumps(raw), media_type="application/json")
|
|
|
|
|
|
@app.get("/health")
|
|
async def health() -> dict:
|
|
cfg = load_config()
|
|
return {
|
|
"status": "ok",
|
|
"upstream": cfg["upstream_ollama"],
|
|
"active_sessions": len(_history),
|
|
"timestamp": int(time.time()),
|
|
}
|
|
|
|
|
|
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "HEAD"])
|
|
async def passthrough(path: str, request: Request) -> Response:
|
|
"""Forward anything we don't handle (model listing, embeddings, etc.) straight through."""
|
|
config = load_config()
|
|
upstream = config["upstream_ollama"]
|
|
|
|
body = await request.body()
|
|
headers = {k: v for k, v in request.headers.items() if k.lower() != "host"}
|
|
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
r = await client.request(
|
|
request.method,
|
|
f"{upstream}/{path}",
|
|
content=body,
|
|
headers=headers,
|
|
)
|
|
|
|
return Response(
|
|
content=r.content,
|
|
status_code=r.status_code,
|
|
media_type=r.headers.get("content-type"),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Entry point
|
|
# ---------------------------------------------------------------------------
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
cfg = load_config()
|
|
uvicorn.run(app, host="0.0.0.0", port=cfg["proxy_port"], log_level="info")
|