Initial commit: agent-inference service
Moved from gnommoweb/agent-inference. Generic LLM inference bridge supporting litellm (anthropic/openai/ollama/lm_studio), Agent Zero MCP, and Hermes JSON-RPC WebSocket agent types. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,662 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Agent Inference Service
|
||||
|
||||
A generic LLM inference service that can serve any configured agent.
|
||||
Receives the full agent profile (identity, knowledge, guardrails) and model
|
||||
configuration in each request.
|
||||
|
||||
Three connection modes are supported, selected by agent_type (primary) or
|
||||
the endpoint's type field:
|
||||
|
||||
agent_type = "standard" / type = anthropic|openai|ollama|lm_studio
|
||||
Assembles a full system prompt from the agent profile and calls the model
|
||||
via litellm. The LLM must respond with JSON {"message", "pose"}.
|
||||
|
||||
agent_type = "agent0" / type = "agent0"
|
||||
Connects to a live Agent Zero instance via its MCP server (streamable-http).
|
||||
Calls the built-in send_message tool, which has full tool use, memory, and
|
||||
persistent context. Pose is chosen heuristically from the plain-text reply.
|
||||
MCP URL: {endpoint}/mcp/t-{mcp_key}/http
|
||||
|
||||
agent_type = "hermes"
|
||||
Connects to a Hermes Agent dashboard via JSON-RPC 2.0 WebSocket.
|
||||
Fetches the ephemeral session token from the dashboard HTML automatically.
|
||||
WebSocket: {endpoint}/api/ws?token={token}
|
||||
Session: session.create → prompt.submit → stream message.delta events.
|
||||
|
||||
Nothing is summarised or truncated by this service. Context-window management
|
||||
for LLM calls is the responsibility of the caller (gnommoweb).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import logging
|
||||
from fastapi import FastAPI, HTTPException, Header
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
import httpx
|
||||
import litellm
|
||||
|
||||
# --- Config ---
|
||||
API_KEY = os.getenv("AGENT_INFERENCE_KEY", "agent-inference-dev-key")
|
||||
FALLBACK_MODEL = os.getenv("AGENT_INFERENCE_MODEL", "anthropic/claude-sonnet-4-20250514")
|
||||
ANTHROPIC_KEY = os.getenv("API_KEY_ANTHROPIC", "")
|
||||
AGENT_MAX_MESSAGES = int(os.getenv("AGENT_INFERENCE_MAX_MESSAGES", "10"))
|
||||
|
||||
if ANTHROPIC_KEY:
|
||||
os.environ["ANTHROPIC_API_KEY"] = ANTHROPIC_KEY
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
log = logging.getLogger("agent-inference")
|
||||
|
||||
# ── In-memory Agent0 session map ──────────────────────────────────────────────
|
||||
# Maps gnommoweb conversation_id → Agent Zero chat_id (context id).
|
||||
# Allows conversation continuity across turns for Agent0 MCP connections.
|
||||
# Cleared on service restart — acceptable for current single-user usage.
|
||||
_a0_sessions: dict[int, str] = {}
|
||||
|
||||
# ── In-memory Hermes session map ───────────────────────────────────────────────
|
||||
# Maps gnommoweb conversation_id → Hermes session_id.
|
||||
# Also caches the dashboard token per endpoint (re-fetched on 401).
|
||||
_hermes_sessions: dict[int, str] = {}
|
||||
_hermes_token_cache: dict[str, str] = {} # endpoint → session token
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────────────
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
name: Optional[str] = None
|
||||
model_id: str = "" # litellm model string; label-only for Agent0 MCP
|
||||
endpoint: Optional[str] = None # base URL; api_base for litellm, or Agent0 host
|
||||
api_key: Optional[str] = None # LLM API key; not used for Agent0 MCP
|
||||
type: Optional[str] = None # anthropic | openai | ollama | lm_studio | hermes | agent0
|
||||
mcp_key: Optional[str] = None # Agent0 MCP token
|
||||
|
||||
class ModelsPayload(BaseModel):
|
||||
heaviness: int = 1 # 1=light, 2=medium, 3=heavy
|
||||
light: Optional[ModelConfig] = None
|
||||
medium: Optional[ModelConfig] = None
|
||||
heavy: Optional[ModelConfig] = None
|
||||
|
||||
class AgentProfile(BaseModel):
|
||||
"""Full agent profile received from gnommoweb. All fields sent verbatim."""
|
||||
name: str
|
||||
role: str = ""
|
||||
from_name: str = ""
|
||||
poses: list[str] = Field(default_factory=list)
|
||||
identity_document: str = ""
|
||||
job_description: str = ""
|
||||
guardrails: str = ""
|
||||
best_practices: str = ""
|
||||
knowledge_base: str = ""
|
||||
mock_responses: list[str] = Field(default_factory=list)
|
||||
agent_type: str = "standard" # standard | agent0 | hermes
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
user_id: int
|
||||
agent_id: int = 1
|
||||
message: str
|
||||
conversation_id: Optional[int] = None
|
||||
history: list = Field(default_factory=list)
|
||||
agent: Optional[AgentProfile] = None
|
||||
models: Optional[ModelsPayload] = None
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
letter_id: int = 0
|
||||
timestamp: int
|
||||
message: str
|
||||
pose: str
|
||||
conversation_id: Optional[int] = None
|
||||
actions: Optional[list] = None
|
||||
|
||||
# ── System prompt (LLM path only) ─────────────────────────────────────────────
|
||||
|
||||
def build_system_prompt(agent: AgentProfile, message_count: int = 0) -> str:
|
||||
"""
|
||||
Assemble the complete system prompt for one LLM call.
|
||||
|
||||
Sections (all verbatim from the agent's DB fields):
|
||||
1. Identity anchor — always present; name + role + from_name
|
||||
2. Identity document — full persona, backstory, values (if set)
|
||||
3. Job description — responsibilities and objectives (if set)
|
||||
4. Guardrails — hard constraints (if set)
|
||||
5. Best practices — response style and soft rules (if set)
|
||||
6. Knowledge base — factual grounding (if set)
|
||||
7. Poses — complete list of valid pose identifiers
|
||||
8. Response format — JSON output structure
|
||||
9. Wrap-up hint — injected near the conversation limit
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# ── 1. Identity anchor (always present) ──────────────────────────────────
|
||||
anchor_lines = [f"You are {agent.name}"]
|
||||
if agent.role:
|
||||
anchor_lines[0] += f", {agent.role}"
|
||||
if agent.from_name and agent.from_name not in (agent.name, f"{agent.name}, {agent.role}"):
|
||||
anchor_lines.append(f"You are known as: {agent.from_name}")
|
||||
anchor_lines.append(
|
||||
"You MUST maintain this character and persona throughout the entire conversation. "
|
||||
"Do NOT break character. Do NOT acknowledge being an AI or a language model."
|
||||
)
|
||||
parts.append("\n".join(anchor_lines))
|
||||
|
||||
# ── 2. Identity document (verbatim) ──────────────────────────────────────
|
||||
if agent.identity_document:
|
||||
parts.append(f"## IDENTITY\n\n{agent.identity_document}")
|
||||
|
||||
# ── 3. Job description (verbatim) ────────────────────────────────────────
|
||||
if agent.job_description:
|
||||
parts.append(f"## YOUR ROLE\n\n{agent.job_description}")
|
||||
|
||||
# ── 4. Guardrails (verbatim) ─────────────────────────────────────────────
|
||||
if agent.guardrails:
|
||||
parts.append(f"## GUARDRAILS — FOLLOW EXACTLY\n\n{agent.guardrails}")
|
||||
|
||||
# ── 5. Best practices (verbatim) ─────────────────────────────────────────
|
||||
if agent.best_practices:
|
||||
parts.append(f"## BEST PRACTICES\n\n{agent.best_practices}")
|
||||
|
||||
# ── 6. Knowledge base (verbatim) ─────────────────────────────────────────
|
||||
if agent.knowledge_base:
|
||||
parts.append(f"## KNOWLEDGE BASE\n\n{agent.knowledge_base}")
|
||||
|
||||
# ── 7. Poses ──────────────────────────────────────────────────────────────
|
||||
valid_poses = agent.poses if agent.poses else ["neutral"]
|
||||
pose_lines = [
|
||||
"## POSES",
|
||||
"",
|
||||
"Every response MUST include a pose field set to one of these exact identifiers:",
|
||||
"",
|
||||
]
|
||||
for p in valid_poses:
|
||||
pose_lines.append(f" {p}")
|
||||
pose_lines += [
|
||||
"",
|
||||
"Choose the pose that best reflects your character's genuine emotional state in this moment.",
|
||||
"Pose names are descriptive — use them accordingly.",
|
||||
"Any identifier not in the list above will be rejected.",
|
||||
]
|
||||
parts.append("\n".join(pose_lines))
|
||||
|
||||
# ── 8. Response format ────────────────────────────────────────────────────
|
||||
parts.append(
|
||||
"## RESPONSE FORMAT\n\n"
|
||||
"Respond ONLY with a valid JSON object. No prose before or after it.\n\n"
|
||||
"Required fields:\n"
|
||||
' "message" — your in-character response text\n'
|
||||
' "pose" — one identifier from the POSES list above\n\n'
|
||||
"Optional field:\n"
|
||||
' "actions" — array of action buttons to show the user, e.g.:\n'
|
||||
' [{"emoji": "🫳🏽", "label": "Do"}, {"emoji": "👁️", "label": "Look"}]\n'
|
||||
" Omit entirely to keep the current action buttons unchanged.\n"
|
||||
" Maximum 4 actions. The Leave action is always implicit.\n\n"
|
||||
"Example:\n"
|
||||
'{\n'
|
||||
' "message": "Your response here.",\n'
|
||||
' "pose": "neutral"\n'
|
||||
'}'
|
||||
)
|
||||
|
||||
# ── 9. Wrap-up hint (injected near the conversation limit) ────────────────
|
||||
wrap_up_threshold = max(AGENT_MAX_MESSAGES - 3, 4)
|
||||
if message_count >= AGENT_MAX_MESSAGES:
|
||||
parts.append(
|
||||
"## END THIS CONVERSATION NOW\n\n"
|
||||
"This conversation has reached its limit. You MUST wrap up in this response. "
|
||||
"Thank the user, guide them toward whatever next step suits your role, "
|
||||
"and set your pose to 'closed'. This is your final message."
|
||||
)
|
||||
elif message_count >= wrap_up_threshold:
|
||||
parts.append(
|
||||
"NOTE: This conversation is approaching its limit. "
|
||||
"Begin guiding the user toward a natural conclusion."
|
||||
)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
# ── Model selection ───────────────────────────────────────────────────────────
|
||||
|
||||
def select_model(payload: Optional[ModelsPayload]) -> Optional[ModelConfig]:
|
||||
"""
|
||||
Pick the ModelConfig based on requested heaviness.
|
||||
Fallback chain: requested heaviness → lower levels → None (use env fallback).
|
||||
"""
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
heaviness = payload.heaviness
|
||||
candidates = []
|
||||
if heaviness >= 3 and payload.heavy:
|
||||
candidates.append(payload.heavy)
|
||||
if heaviness >= 2 and payload.medium:
|
||||
candidates.append(payload.medium)
|
||||
if payload.light:
|
||||
candidates.append(payload.light)
|
||||
|
||||
chosen = candidates[0] if candidates else None
|
||||
if not chosen:
|
||||
log.warning(f"No model configured for heaviness={heaviness}, falling back to env model")
|
||||
return chosen
|
||||
|
||||
# ── Pose heuristic (Agent0 MCP path) ─────────────────────────────────────────
|
||||
|
||||
def pick_pose(text: str, valid_poses: list[str]) -> str:
|
||||
"""
|
||||
Choose a pose based on simple keyword heuristics.
|
||||
Used for Agent0 MCP responses, which are plain text not JSON.
|
||||
"""
|
||||
t = text.lower()
|
||||
candidates: list[tuple[list[str], list[str]]] = [
|
||||
(["error", "fail", "sorry", "cannot", "unable", "permission denied"], ["sorry", "annoyed", "bored"]),
|
||||
(["done", "complete", "finished", "success", "pulled", "updated", "deployed", "pushed"], ["engaged", "interested", "naughty"]),
|
||||
(["?", "what", "which", "how", "why", "could you"], ["inquisitive"]),
|
||||
(["warning", "careful", "note", "attention"], ["annoyed", "bored"]),
|
||||
]
|
||||
for keywords, poses in candidates:
|
||||
if any(k in t for k in keywords):
|
||||
for p in poses:
|
||||
if p in valid_poses:
|
||||
return p
|
||||
for p in ("neutral", "neutral2", "engaged"):
|
||||
if p in valid_poses:
|
||||
return p
|
||||
return valid_poses[0] if valid_poses else "neutral"
|
||||
|
||||
# ── Agent0 MCP connection ─────────────────────────────────────────────────────
|
||||
|
||||
async def handle_agent0_mcp(
|
||||
req: ChatRequest,
|
||||
agent: AgentProfile,
|
||||
model: ModelConfig,
|
||||
valid_poses: list[str],
|
||||
) -> ChatResponse:
|
||||
"""
|
||||
Send a message to a live Agent Zero instance via its MCP streamable-http
|
||||
server and return the response as a ChatResponse.
|
||||
|
||||
MCP URL format: {endpoint}/mcp/t-{mcp_key}/http
|
||||
Tool called: send_message (built into every Agent Zero instance)
|
||||
|
||||
Conversation continuity is maintained via Agent Zero's chat_id, which maps
|
||||
to a gnommoweb conversation_id in _a0_sessions (in-memory).
|
||||
"""
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp import ClientSession
|
||||
|
||||
endpoint = (model.endpoint or "").rstrip("/")
|
||||
mcp_key = model.mcp_key or ""
|
||||
mcp_url = f"{endpoint}/mcp/t-{mcp_key}/http"
|
||||
|
||||
# Retrieve existing Agent0 chat_id for this conversation (if any)
|
||||
a0_chat_id = _a0_sessions.get(req.conversation_id) if req.conversation_id else None
|
||||
|
||||
log.info(
|
||||
f"[{agent.name}] → Agent0 MCP url={mcp_url} "
|
||||
f"conv={req.conversation_id} a0_chat={a0_chat_id or 'new'} "
|
||||
f"msg='{req.message[:60]}'"
|
||||
)
|
||||
|
||||
try:
|
||||
async with streamablehttp_client(mcp_url) as (read, write, _):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
|
||||
tool_args: dict = {
|
||||
"message": req.message,
|
||||
"persistent_chat": True,
|
||||
}
|
||||
if a0_chat_id:
|
||||
tool_args["chat_id"] = a0_chat_id
|
||||
|
||||
result = await session.call_tool("send_message", tool_args)
|
||||
|
||||
# The tool returns a JSON-serialised ToolResponse / ToolError
|
||||
raw = result.content[0].text if result.content else "{}"
|
||||
log.info(f"[{agent.name}] ← Agent0 MCP raw: {raw[:300]}")
|
||||
|
||||
parsed = json.loads(raw)
|
||||
|
||||
if parsed.get("status") == "error":
|
||||
raise RuntimeError(parsed.get("error", "Unknown Agent0 error"))
|
||||
|
||||
response_text = parsed.get("response", "")
|
||||
new_chat_id = parsed.get("chat_id", "")
|
||||
|
||||
# Persist the Agent0 chat_id so the next turn continues the same context
|
||||
if new_chat_id and req.conversation_id is not None:
|
||||
_a0_sessions[req.conversation_id] = new_chat_id
|
||||
|
||||
pose = pick_pose(response_text, valid_poses)
|
||||
|
||||
return ChatResponse(
|
||||
letter_id=0,
|
||||
timestamp=int(time.time()),
|
||||
message=response_text,
|
||||
pose=pose,
|
||||
conversation_id=req.conversation_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"[{agent.name}] Agent0 MCP error: {e}")
|
||||
fallback_pose = next(
|
||||
(p for p in ("sorry", "annoyed", "neutral") if p in valid_poses),
|
||||
valid_poses[0],
|
||||
)
|
||||
return ChatResponse(
|
||||
letter_id=0,
|
||||
timestamp=int(time.time()),
|
||||
message="*connection interference* The response channel is temporarily disrupted. Try again in a moment.",
|
||||
pose=fallback_pose,
|
||||
conversation_id=req.conversation_id,
|
||||
)
|
||||
|
||||
# ── Hermes Agent connection ───────────────────────────────────────────────────
|
||||
|
||||
async def _fetch_hermes_token(endpoint: str) -> str:
|
||||
"""
|
||||
Fetch the ephemeral session token injected into the Hermes dashboard HTML.
|
||||
|
||||
The token is generated fresh on every server start (secrets.token_urlsafe(32))
|
||||
and embedded as: window.__HERMES_SESSION_TOKEN__="<token>"
|
||||
|
||||
Caches the result per endpoint; callers should invalidate on 401.
|
||||
"""
|
||||
cached = _hermes_token_cache.get(endpoint)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(f"{endpoint}/")
|
||||
resp.raise_for_status()
|
||||
html = resp.text
|
||||
|
||||
m = re.search(r'__HERMES_SESSION_TOKEN__\s*=\s*"([^"]+)"', html)
|
||||
if not m:
|
||||
raise RuntimeError(
|
||||
f"Could not extract Hermes session token from {endpoint}/ "
|
||||
"(is the dashboard running?)"
|
||||
)
|
||||
|
||||
token = m.group(1)
|
||||
_hermes_token_cache[endpoint] = token
|
||||
log.info(f"[hermes] fetched new session token from {endpoint}")
|
||||
return token
|
||||
|
||||
|
||||
async def handle_hermes(
|
||||
req: "ChatRequest",
|
||||
agent: AgentProfile,
|
||||
model: ModelConfig,
|
||||
valid_poses: list[str],
|
||||
) -> "ChatResponse":
|
||||
"""
|
||||
Send a message to a Hermes Agent dashboard via JSON-RPC 2.0 WebSocket.
|
||||
|
||||
Protocol:
|
||||
1. Fetch ephemeral session token from dashboard HTML (cached, re-fetched on 401).
|
||||
2. Open WebSocket at {endpoint}/api/ws?token={token}
|
||||
3. JSON-RPC session.create {} → { session_id }
|
||||
4. JSON-RPC prompt.submit { session_id, text } → (ack)
|
||||
5. Accumulate message.delta payloads until message.complete
|
||||
6. Return assembled text with pose heuristic.
|
||||
|
||||
Conversation continuity: gnommoweb conversation_id → Hermes session_id
|
||||
stored in _hermes_sessions (in-memory, cleared on restart).
|
||||
"""
|
||||
try:
|
||||
from websockets.asyncio.client import connect as ws_connect
|
||||
except ImportError:
|
||||
from websockets.client import connect as ws_connect # type: ignore
|
||||
|
||||
endpoint = (model.endpoint or "").rstrip("/")
|
||||
hermes_session_id = _hermes_sessions.get(req.conversation_id) if req.conversation_id else None
|
||||
|
||||
token = await _fetch_hermes_token(endpoint)
|
||||
|
||||
ws_scheme = "wss" if endpoint.startswith("https") else "ws"
|
||||
ws_url = f"{ws_scheme}://{endpoint.split('://', 1)[-1]}/api/ws?token={token}"
|
||||
|
||||
log.info(
|
||||
f"[{agent.name}] → Hermes url={endpoint} "
|
||||
f"conv={req.conversation_id} hermes_sess={hermes_session_id or 'new'} "
|
||||
f"msg='{req.message[:60]}'"
|
||||
)
|
||||
|
||||
async def _do_chat(tok: str) -> ChatResponse:
|
||||
nonlocal hermes_session_id
|
||||
ws_url_inner = f"{ws_scheme}://{endpoint.split('://', 1)[-1]}/api/ws?token={tok}"
|
||||
|
||||
async with ws_connect(ws_url_inner) as ws:
|
||||
req_id = 0
|
||||
|
||||
async def rpc(method: str, params: dict) -> dict:
|
||||
nonlocal req_id
|
||||
req_id += 1
|
||||
rid = f"h{req_id}"
|
||||
await ws.send(json.dumps({"jsonrpc": "2.0", "id": rid, "method": method, "params": params}))
|
||||
# Wait for the matching response (skip events that arrive first)
|
||||
while True:
|
||||
raw = await ws.recv()
|
||||
msg = json.loads(raw)
|
||||
if msg.get("id") == rid:
|
||||
if "error" in msg:
|
||||
raise RuntimeError(msg["error"].get("message", "RPC error"))
|
||||
return msg.get("result") or {}
|
||||
# Event frame — queue it for later (ignored for now, we don't miss
|
||||
# message.delta here because prompt.submit is called after create)
|
||||
|
||||
# ── 1. Create or reuse session ────────────────────────────────
|
||||
if not hermes_session_id:
|
||||
created = await rpc("session.create", {})
|
||||
hermes_session_id = created.get("session_id") or created.get("id", "")
|
||||
if req.conversation_id is not None and hermes_session_id:
|
||||
_hermes_sessions[req.conversation_id] = hermes_session_id
|
||||
log.info(f"[{agent.name}] created Hermes session {hermes_session_id}")
|
||||
|
||||
# ── 2. Submit prompt ──────────────────────────────────────────
|
||||
await rpc("prompt.submit", {"session_id": hermes_session_id, "text": req.message})
|
||||
|
||||
# ── 3. Stream events until message.complete ───────────────────
|
||||
full_text = ""
|
||||
while True:
|
||||
raw = await asyncio.wait_for(ws.recv(), timeout=120.0)
|
||||
msg = json.loads(raw)
|
||||
if msg.get("method") != "event":
|
||||
continue
|
||||
params = msg.get("params") or {}
|
||||
etype = params.get("type", "")
|
||||
payload = params.get("payload") or {}
|
||||
if etype == "message.delta":
|
||||
full_text += payload.get("text", "")
|
||||
elif etype == "message.complete":
|
||||
break
|
||||
elif etype == "error":
|
||||
raise RuntimeError(payload.get("message", "Hermes error"))
|
||||
|
||||
pose = pick_pose(full_text, valid_poses)
|
||||
return ChatResponse(
|
||||
letter_id=0,
|
||||
timestamp=int(time.time()),
|
||||
message=full_text or "*silence*",
|
||||
pose=pose,
|
||||
conversation_id=req.conversation_id,
|
||||
)
|
||||
|
||||
try:
|
||||
return await _do_chat(token)
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
# 401 / rejected — token likely regenerated; clear cache and retry once
|
||||
if "401" in err_str or "403" in err_str or "Unauthorized" in err_str.lower():
|
||||
log.warning(f"[{agent.name}] Hermes token rejected, re-fetching…")
|
||||
_hermes_token_cache.pop(endpoint, None)
|
||||
try:
|
||||
token = await _fetch_hermes_token(endpoint)
|
||||
return await _do_chat(token)
|
||||
except Exception as retry_err:
|
||||
log.error(f"[{agent.name}] Hermes retry failed: {retry_err}")
|
||||
err_str = str(retry_err)
|
||||
|
||||
log.error(f"[{agent.name}] Hermes error: {err_str}")
|
||||
fallback_pose = next(
|
||||
(p for p in ("sorry", "annoyed", "neutral") if p in valid_poses),
|
||||
valid_poses[0],
|
||||
)
|
||||
return ChatResponse(
|
||||
letter_id=0,
|
||||
timestamp=int(time.time()),
|
||||
message="*connection interference* The response channel is temporarily disrupted. Try again in a moment.",
|
||||
pose=fallback_pose,
|
||||
conversation_id=req.conversation_id,
|
||||
)
|
||||
|
||||
|
||||
# ── Fallback agent ────────────────────────────────────────────────────────────
|
||||
|
||||
def make_fallback_agent() -> AgentProfile:
|
||||
return AgentProfile(
|
||||
name="Agent",
|
||||
role="University Agent",
|
||||
poses=["neutral", "bored", "annoyed", "engaged", "closed"],
|
||||
identity_document="You are a helpful university agent.",
|
||||
)
|
||||
|
||||
# ── App ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
app = FastAPI(title="Agent Inference Service", version="1.2.0")
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "service": "agent-inference", "fallback_model": FALLBACK_MODEL}
|
||||
|
||||
|
||||
@app.post("/v1/agent/chat", response_model=ChatResponse)
|
||||
async def agent_chat(req: ChatRequest, authorization: str = Header(default="")):
|
||||
if authorization != f"Bearer {API_KEY}":
|
||||
raise HTTPException(status_code=401, detail="Unauthorised.")
|
||||
|
||||
agent = req.agent or make_fallback_agent()
|
||||
valid_poses = agent.poses if agent.poses else ["neutral"]
|
||||
model = select_model(req.models)
|
||||
|
||||
# ── Route: Hermes Agent ───────────────────────────────────────────────────
|
||||
if agent.agent_type == "hermes":
|
||||
if not model or not model.endpoint:
|
||||
log.error(f"[{agent.name}] Hermes agent has no model with endpoint configured")
|
||||
return ChatResponse(
|
||||
letter_id=0,
|
||||
timestamp=int(time.time()),
|
||||
message="*configuration error* No Hermes endpoint configured for this agent.",
|
||||
pose=valid_poses[0],
|
||||
conversation_id=req.conversation_id,
|
||||
)
|
||||
return await handle_hermes(req, agent, model, valid_poses)
|
||||
|
||||
# ── Route: Agent0 MCP ─────────────────────────────────────────────────────
|
||||
if agent.agent_type == "agent0" or (model and model.type == "agent0"):
|
||||
return await handle_agent0_mcp(req, agent, model, valid_poses)
|
||||
|
||||
# ── Route: standard LLM via litellm ──────────────────────────────────────
|
||||
model_id = model.model_id if model else FALLBACK_MODEL
|
||||
api_base = model.endpoint if model else None
|
||||
api_key = model.api_key if model else None
|
||||
|
||||
if not model_id:
|
||||
model_id = FALLBACK_MODEL
|
||||
|
||||
total_messages = len(req.history) + 1
|
||||
system_prompt = build_system_prompt(agent, message_count=total_messages)
|
||||
|
||||
# Structure: [system] + history + [current user message]
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
for msg in req.history:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
if role not in ("user", "assistant"):
|
||||
role = "user"
|
||||
messages.append({"role": role, "content": content})
|
||||
messages.append({"role": "user", "content": req.message})
|
||||
|
||||
log.info(
|
||||
f"[{agent.name}] → {model_id} "
|
||||
f"heaviness={req.models.heaviness if req.models else 1} "
|
||||
f"history={len(req.history)} msgs "
|
||||
f"prompt={len(system_prompt)} chars "
|
||||
f"user={req.user_id} "
|
||||
f"msg='{req.message[:60]}'"
|
||||
)
|
||||
|
||||
completion_kwargs = dict(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
temperature=0.8,
|
||||
max_tokens=400,
|
||||
)
|
||||
if api_base:
|
||||
completion_kwargs["api_base"] = api_base
|
||||
if api_key:
|
||||
completion_kwargs["api_key"] = api_key
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(**completion_kwargs)
|
||||
raw = response.choices[0].message.content.strip()
|
||||
log.info(f"[{agent.name}] ← {raw[:200]}")
|
||||
|
||||
# Parse JSON response
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
start = raw.find("{")
|
||||
end = raw.rfind("}") + 1
|
||||
if start >= 0 and end > start:
|
||||
parsed = json.loads(raw[start:end])
|
||||
else:
|
||||
parsed = {"message": raw, "pose": "neutral"}
|
||||
|
||||
# Validate pose
|
||||
pose = parsed.get("pose", "neutral")
|
||||
if pose not in valid_poses:
|
||||
log.warning(f"[{agent.name}] Invalid pose '{pose}', defaulting to '{valid_poses[0]}'")
|
||||
pose = valid_poses[0]
|
||||
|
||||
# Validate optional actions
|
||||
raw_actions = parsed.get("actions")
|
||||
actions = None
|
||||
if isinstance(raw_actions, list) and raw_actions:
|
||||
actions = [
|
||||
{"emoji": a["emoji"], "label": a["label"]}
|
||||
for a in raw_actions
|
||||
if isinstance(a, dict) and "emoji" in a and "label" in a
|
||||
] or None
|
||||
|
||||
return ChatResponse(
|
||||
letter_id=0,
|
||||
timestamp=int(time.time()),
|
||||
message=parsed.get("message", raw),
|
||||
pose=pose,
|
||||
conversation_id=req.conversation_id,
|
||||
actions=actions,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"[{agent.name}] LLM error: {e}")
|
||||
fallback_pose = next(
|
||||
(p for p in ("annoyed", "bored", "sorry", "neutral") if p in valid_poses),
|
||||
valid_poses[0],
|
||||
)
|
||||
return ChatResponse(
|
||||
letter_id=0,
|
||||
timestamp=int(time.time()),
|
||||
message="*connection interference* The response channel is temporarily disrupted. Try again in a moment.",
|
||||
pose=fallback_pose,
|
||||
conversation_id=req.conversation_id,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
port = int(os.getenv("AGENT_INFERENCE_PORT", "8089"))
|
||||
log.info(f"Starting agent inference service on port {port}")
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
Reference in New Issue
Block a user