Adding Festinger with wordnet
This commit is contained in:
@@ -0,0 +1,292 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Loop Detector Proxy — Agent0 Plugin
|
||||
|
||||
An Ollama-compatible HTTP proxy that intercepts LLM completions, detects
|
||||
repeated outputs, and applies configurable mitigations before returning the
|
||||
response to the caller (agent-zero).
|
||||
|
||||
Architecture:
|
||||
agent-zero → loop-detector:11434 → ollama:11434 (host)
|
||||
|
||||
POC limitations:
|
||||
- Forces stream=false on all forwarded requests to simplify response
|
||||
buffering. Streaming responses are re-emitted as single JSON objects,
|
||||
which agent-zero handles correctly for task execution.
|
||||
- State is in-memory; restarting the proxy clears all session history.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import FastAPI, Request, Response
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)-7s %(message)s",
|
||||
)
|
||||
log = logging.getLogger("loop_detector")
|
||||
|
||||
CONFIG_PATH = Path(__file__).parent / "config.yaml"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_config() -> dict:
|
||||
with open(CONFIG_PATH) as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session state
|
||||
# ---------------------------------------------------------------------------
|
||||
# Key: (model, session_fingerprint) → deque of response hashes
|
||||
_history: dict[tuple, deque] = defaultdict(lambda: deque(maxlen=20))
|
||||
# Key: same tuple → current consecutive-repeat count
|
||||
_consecutive: dict[tuple, int] = defaultdict(int)
|
||||
_last_hash: dict[tuple, str] = {}
|
||||
|
||||
|
||||
def _session_key(model: str, messages: list[dict]) -> tuple[str, str]:
|
||||
"""
|
||||
Derive a stable session identifier from the model name and the content
|
||||
of the first system message (which is unique per agent profile).
|
||||
Falls back to a hash of all non-assistant messages if there is no system
|
||||
message.
|
||||
"""
|
||||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||||
if system_msgs:
|
||||
fingerprint = hashlib.sha256(
|
||||
system_msgs[0].get("content", "")[:512].encode()
|
||||
).hexdigest()[:16]
|
||||
else:
|
||||
non_assistant = [m for m in messages if m.get("role") != "assistant"]
|
||||
blob = json.dumps(non_assistant, sort_keys=True)
|
||||
fingerprint = hashlib.sha256(blob.encode()).hexdigest()[:16]
|
||||
return (model, fingerprint)
|
||||
|
||||
|
||||
def _hash_response(text: str) -> str:
|
||||
return hashlib.sha256(text.strip().encode()).hexdigest()
|
||||
|
||||
|
||||
def record_and_check(session: tuple, text: str, min_length: int) -> int:
|
||||
"""
|
||||
Record a completion and return the current consecutive-repeat count.
|
||||
Returns 1 (no loop) if the text is below min_length.
|
||||
"""
|
||||
if len(text.strip()) < min_length:
|
||||
return 1
|
||||
|
||||
h = _hash_response(text)
|
||||
history = _history[session]
|
||||
|
||||
if _last_hash.get(session) == h:
|
||||
_consecutive[session] += 1
|
||||
else:
|
||||
_consecutive[session] = 1
|
||||
_last_hash[session] = h
|
||||
|
||||
history.append(h)
|
||||
return _consecutive[session]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mitigations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def apply_mitigations(
|
||||
request_body: dict,
|
||||
count: int,
|
||||
config: dict,
|
||||
) -> tuple[dict, str | None]:
|
||||
"""
|
||||
Walk the configured mitigation list in order.
|
||||
Returns (modified_request_body, override_response_text_or_None).
|
||||
An override_response means "return this to the caller, skip the LLM".
|
||||
"""
|
||||
mitigations = config.get("mitigations", [])
|
||||
override: str | None = None
|
||||
|
||||
for m in mitigations:
|
||||
if not m.get("enabled", True):
|
||||
continue
|
||||
if count < m.get("trigger_count", 2):
|
||||
continue
|
||||
|
||||
strategy = m["strategy"]
|
||||
|
||||
if strategy == "temperature_boost":
|
||||
opts = request_body.setdefault("options", {})
|
||||
current = opts.get("temperature", 0.7)
|
||||
boosted = min(current + m.get("boost_amount", 0.35), m.get("max_temperature", 1.4))
|
||||
opts["temperature"] = boosted
|
||||
log.warning("mitigation=temperature_boost %.2f → %.2f (count=%d)", current, boosted, count)
|
||||
|
||||
elif strategy == "forbidden_action":
|
||||
msg = m.get("injection_message", "STOP. Try something completely different.").format(count=count)
|
||||
request_body.setdefault("messages", []).append({"role": "user", "content": msg})
|
||||
log.warning("mitigation=forbidden_action injected (count=%d)", count)
|
||||
|
||||
elif strategy == "history_truncation":
|
||||
messages = request_body.get("messages", [])
|
||||
truncate = m.get("truncate_turns", 6)
|
||||
system_msgs = [m_ for m_ in messages if m_.get("role") == "system"]
|
||||
non_system = [m_ for m_ in messages if m_.get("role") != "system"]
|
||||
# Keep at least the most recent exchange after truncation
|
||||
trimmed = non_system[:-truncate] if len(non_system) > truncate else non_system[-2:]
|
||||
request_body["messages"] = system_msgs + trimmed
|
||||
log.warning("mitigation=history_truncation dropped %d turns (count=%d)", truncate, count)
|
||||
|
||||
elif strategy == "circuit_breaker":
|
||||
override = m.get("response_message", "[LOOP DETECTOR] Halted due to repeated responses.")
|
||||
log.error("mitigation=circuit_breaker TRIGGERED (count=%d)", count)
|
||||
break # nothing further makes sense
|
||||
|
||||
return request_body, override
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ollama forwarding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def call_ollama(path: str, body: dict, upstream: str) -> tuple[str, dict]:
|
||||
"""
|
||||
Forward a request to upstream Ollama (stream=False) and return
|
||||
(assistant_text, raw_response_dict).
|
||||
"""
|
||||
body = dict(body)
|
||||
body["stream"] = False
|
||||
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
r = await client.post(f"{upstream}{path}", json=body)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
if path == "/api/chat":
|
||||
text = data.get("message", {}).get("content", "")
|
||||
else:
|
||||
text = data.get("response", "")
|
||||
|
||||
return text, data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI app
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
app = FastAPI(title="Loop Detector Proxy")
|
||||
|
||||
|
||||
@app.post("/api/chat")
|
||||
async def chat(request: Request) -> Response:
|
||||
config = load_config()
|
||||
body = await request.json()
|
||||
model = body.get("model", "unknown")
|
||||
upstream = config["upstream_ollama"]
|
||||
min_len = config["detection"]["min_length"]
|
||||
|
||||
text, raw = await call_ollama("/api/chat", body, upstream)
|
||||
session = _session_key(model, body.get("messages", []))
|
||||
count = record_and_check(session, text, min_len)
|
||||
|
||||
if count >= 2:
|
||||
log.warning("loop detected model=%s session=%s count=%d", model, session[1], count)
|
||||
body, override = apply_mitigations(body, count, config)
|
||||
|
||||
if override is not None:
|
||||
raw["message"] = {"role": "assistant", "content": override}
|
||||
raw["loop_detected"] = True
|
||||
return Response(content=json.dumps(raw), media_type="application/json")
|
||||
|
||||
# Retry with mitigations applied
|
||||
text, raw = await call_ollama("/api/chat", body, upstream)
|
||||
record_and_check(session, text, min_len)
|
||||
|
||||
raw["message"] = {"role": "assistant", "content": text}
|
||||
return Response(content=json.dumps(raw), media_type="application/json")
|
||||
|
||||
|
||||
@app.post("/api/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
config = load_config()
|
||||
body = await request.json()
|
||||
model = body.get("model", "unknown")
|
||||
upstream = config["upstream_ollama"]
|
||||
min_len = config["detection"]["min_length"]
|
||||
|
||||
# /api/generate has no messages list; use prompt as fingerprint
|
||||
messages = [{"role": "user", "content": body.get("prompt", "")}]
|
||||
session = _session_key(model, messages)
|
||||
|
||||
text, raw = await call_ollama("/api/generate", body, upstream)
|
||||
count = record_and_check(session, text, min_len)
|
||||
|
||||
if count >= 2:
|
||||
log.warning("loop detected model=%s session=%s count=%d", model, session[1], count)
|
||||
body, override = apply_mitigations(body, count, config)
|
||||
|
||||
if override is not None:
|
||||
raw["response"] = override
|
||||
raw["loop_detected"] = True
|
||||
return Response(content=json.dumps(raw), media_type="application/json")
|
||||
|
||||
text, raw = await call_ollama("/api/generate", body, upstream)
|
||||
record_and_check(session, text, min_len)
|
||||
|
||||
raw["response"] = text
|
||||
return Response(content=json.dumps(raw), media_type="application/json")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict:
|
||||
cfg = load_config()
|
||||
return {
|
||||
"status": "ok",
|
||||
"upstream": cfg["upstream_ollama"],
|
||||
"active_sessions": len(_history),
|
||||
"timestamp": int(time.time()),
|
||||
}
|
||||
|
||||
|
||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "HEAD"])
|
||||
async def passthrough(path: str, request: Request) -> Response:
|
||||
"""Forward anything we don't handle (model listing, embeddings, etc.) straight through."""
|
||||
config = load_config()
|
||||
upstream = config["upstream_ollama"]
|
||||
|
||||
body = await request.body()
|
||||
headers = {k: v for k, v in request.headers.items() if k.lower() != "host"}
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
r = await client.request(
|
||||
request.method,
|
||||
f"{upstream}/{path}",
|
||||
content=body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=r.content,
|
||||
status_code=r.status_code,
|
||||
media_type=r.headers.get("content-type"),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
cfg = load_config()
|
||||
uvicorn.run(app, host="0.0.0.0", port=cfg["proxy_port"], log_level="info")
|
||||
Reference in New Issue
Block a user