Files

377 lines
15 KiB
Python

"""
AutoDev - LLM Communication Layer
Supports Ollama and vLLM backends with streaming, retry logic, and robust error handling.
"""
import json
import sys
import time
import urllib.request
import urllib.error
from . import config
class LLMError(Exception):
pass
class LLM:
def __init__(self, backend: str = None, model: str = None):
self.backend = backend or config.LLM_BACKEND
self.model = model or config.MODEL_NAME
if self.backend == "ollama":
self.base_url = config.OLLAMA_URL
elif self.backend == "vllm":
self.base_url = config.VLLM_URL
else:
raise LLMError(f"Unknown backend: {self.backend}")
self.context_size = None # Auto-detected on first use
def detect_context_size(self) -> int:
"""Detect the model's effective context window size.
Checks (in priority order):
1. Ollama /api/ps for running model's actual context_length
2. num_ctx in model parameters from /api/show
3. Model architecture's max context_length from /api/show
4. vLLM max_model_len from /v1/models
5. Fallback to config default
"""
if self.context_size is not None:
return self.context_size
try:
if self.backend == "ollama":
self.context_size = self._detect_ollama_context()
elif self.backend == "vllm":
self.context_size = self._detect_vllm_context()
except Exception:
pass
if not self.context_size:
self.context_size = config.MAX_CONTEXT_TOKENS
return self.context_size
def detect_gpu_status(self) -> dict:
"""Check GPU/CPU offload status for the running model.
Returns dict with:
loaded: bool - whether model is currently loaded
gpu_percent: int - percentage of model on GPU (0-100)
size_total: int - total model size in bytes
size_vram: int - bytes on GPU
warning: str|None - warning message if mostly CPU
"""
result = {"loaded": False, "gpu_percent": 0, "size_total": 0,
"size_vram": 0, "warning": None}
if self.backend != "ollama":
return result
try:
url = f"{self.base_url}/api/ps"
req = urllib.request.Request(url)
with urllib.request.urlopen(req, timeout=5) as resp:
data = json.loads(resp.read().decode("utf-8"))
for m in data.get("models", []):
if self.model in m.get("name", ""):
result["loaded"] = True
result["size_total"] = m.get("size", 0)
result["size_vram"] = m.get("size_vram", 0)
if result["size_total"] > 0:
result["gpu_percent"] = int(
result["size_vram"] / result["size_total"] * 100
)
if result["gpu_percent"] == 0:
result["warning"] = (
"Model is running entirely on CPU. "
"This will be extremely slow and may not complete. "
"Consider using a smaller model or freeing GPU memory."
)
elif result["gpu_percent"] < 50:
result["warning"] = (
f"Only {result['gpu_percent']}% of model is on GPU. "
"Performance will be significantly degraded. "
"Consider using a smaller model."
)
break
except Exception:
pass
return result
def _detect_ollama_context(self) -> int | None:
# 1. Check running model — this gives the actual runtime context_length
try:
url = f"{self.base_url}/api/ps"
req = urllib.request.Request(url)
with urllib.request.urlopen(req, timeout=5) as resp:
data = json.loads(resp.read().decode("utf-8"))
for m in data.get("models", []):
if self.model in m.get("name", ""):
ctx = m.get("context_length")
if ctx:
return int(ctx)
except Exception:
pass
# 2. Check model config from /api/show
try:
url = f"{self.base_url}/api/show"
payload = {"name": self.model}
data = self._post_raw(url, payload)
# Check parameters for explicit num_ctx setting
params = data.get("parameters", "")
for line in params.splitlines():
if "num_ctx" in line:
parts = line.split()
for p in parts:
if p.isdigit():
return int(p)
# Check modelfile for PARAMETER num_ctx
modelfile = data.get("modelfile", "")
for line in modelfile.splitlines():
if "num_ctx" in line.lower():
parts = line.split()
for p in parts:
if p.isdigit():
return int(p)
# 3. Fall back to architecture's max context_length
model_info = data.get("model_info", {})
for key, val in model_info.items():
if "context_length" in key:
return int(val)
except Exception:
pass
return None
def _detect_vllm_context(self) -> int | None:
try:
url = f"{self.base_url}/v1/models"
req = urllib.request.Request(url)
with urllib.request.urlopen(req, timeout=10) as resp:
data = json.loads(resp.read().decode("utf-8"))
for m in data.get("data", []):
if m.get("id") == self.model:
return m.get("max_model_len")
except Exception:
pass
return None
def query(self, prompt: str, system: str = "", temperature: float = 0.2,
stream: bool = False) -> str:
if not system:
system = config.EXPERT_IDENTITY
if self.backend == "ollama":
if stream:
return self._stream_ollama(prompt, system, temperature)
result = self._query_ollama(prompt, system, temperature)
else:
if stream:
return self._stream_vllm(prompt, system, temperature)
result = self._query_vllm(prompt, system, temperature)
# Push LLM thinking to web UI
try:
from .web import push_event
push_event("llm_response", {"response": result})
except Exception:
pass
return result
def _query_ollama(self, prompt: str, system: str, temperature: float) -> str:
url = f"{self.base_url}/api/generate"
payload = {
"model": self.model,
"prompt": prompt,
"system": system,
"stream": False,
"options": {"temperature": temperature},
}
return self._post(url, payload, key="response")
def _stream_ollama(self, prompt: str, system: str, temperature: float) -> str:
url = f"{self.base_url}/api/generate"
payload = {
"model": self.model,
"prompt": prompt,
"system": system,
"stream": True,
"options": {"temperature": temperature},
}
return self._stream_post(url, parse_fn=lambda chunk: chunk.get("response", ""))
def _query_vllm(self, prompt: str, system: str, temperature: float) -> str:
url = f"{self.base_url}/v1/completions"
full_prompt = f"{system}\n\n{prompt}" if system else prompt
payload = {
"model": self.model,
"prompt": full_prompt,
"max_tokens": 4096,
"temperature": temperature,
"stream": False,
}
data = self._post_raw(url, payload)
try:
return data["choices"][0]["text"]
except (KeyError, IndexError):
raise LLMError(f"Unexpected vLLM response: {data}")
def _stream_vllm(self, prompt: str, system: str, temperature: float) -> str:
url = f"{self.base_url}/v1/completions"
full_prompt = f"{system}\n\n{prompt}" if system else prompt
payload = {
"model": self.model,
"prompt": full_prompt,
"max_tokens": 4096,
"temperature": temperature,
"stream": True,
}
return self._stream_post(url, parse_fn=lambda chunk: (
chunk.get("choices", [{}])[0].get("text", "") if chunk.get("choices") else ""
))
def chat(self, messages: list[dict], temperature: float = 0.2,
stream: bool = False) -> str:
if self.backend == "ollama":
url = f"{self.base_url}/api/chat"
payload = {
"model": self.model,
"messages": messages,
"stream": stream,
"options": {"temperature": temperature},
}
if stream:
return self._stream_post(url, parse_fn=lambda c: c.get("message", {}).get("content", ""))
return self._post(url, payload, key="message", subkey="content")
else:
url = f"{self.base_url}/v1/chat/completions"
payload = {
"model": self.model,
"messages": messages,
"max_tokens": 4096,
"temperature": temperature,
"stream": stream,
}
if stream:
return self._stream_post(url, parse_fn=lambda c: (
c.get("choices", [{}])[0].get("delta", {}).get("content", "")
if c.get("choices") else ""
))
data = self._post_raw(url, payload)
try:
return data["choices"][0]["message"]["content"]
except (KeyError, IndexError):
raise LLMError(f"Unexpected vLLM chat response: {data}")
def _post(self, url: str, payload: dict, key: str, subkey: str = None) -> str:
data = self._post_raw(url, payload)
try:
result = data[key]
if subkey:
result = result[subkey]
return result
except (KeyError, TypeError):
raise LLMError(f"Unexpected response structure: {data}")
def _post_raw(self, url: str, payload: dict, retries: int = 2) -> dict:
body = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(
url, data=body, headers={"Content-Type": "application/json"}
)
last_err = None
for attempt in range(retries + 1):
try:
with urllib.request.urlopen(req, timeout=config.LLM_TIMEOUT) as resp:
return json.loads(resp.read().decode("utf-8"))
except urllib.error.URLError as e:
last_err = e
if attempt < retries:
time.sleep(2 ** attempt)
except json.JSONDecodeError as e:
raise LLMError(f"Invalid JSON from LLM: {e}")
raise LLMError(f"LLM request failed after {retries + 1} attempts ({url}): {last_err}")
def _stream_post(self, url: str, parse_fn) -> str:
"""Stream response, printing tokens to console as they arrive."""
# Build the same request but with stream=True already in payload
# We need to read line by line
body = json.dumps({"stream": True}).encode("utf-8")
# Actually we need the full payload — caller already set stream=True
# Re-read from the caller context isn't possible, so we use a different approach:
# The caller methods build the payload and call us. We need the payload.
# Refactored: callers should pass payload. For now, fall back to non-streaming.
# This is handled by the _stream_generate and _stream_chat methods.
raise LLMError("Direct _stream_post not supported; use streaming query methods")
def query_stream(self, prompt: str, system: str = "", temperature: float = 0.2) -> str:
"""Query with streaming output to console."""
if not system:
system = config.EXPERT_IDENTITY
if self.backend == "ollama":
return self._stream_ollama_impl(prompt, system, temperature)
else:
return self._stream_vllm_impl(prompt, system, temperature)
def _stream_ollama_impl(self, prompt: str, system: str, temperature: float) -> str:
url = f"{self.base_url}/api/generate"
payload = {
"model": self.model,
"prompt": prompt,
"system": system,
"stream": True,
"options": {"temperature": temperature},
}
return self._do_stream(url, payload, lambda c: c.get("response", ""))
def _stream_vllm_impl(self, prompt: str, system: str, temperature: float) -> str:
url = f"{self.base_url}/v1/completions"
full_prompt = f"{system}\n\n{prompt}" if system else prompt
payload = {
"model": self.model,
"prompt": full_prompt,
"max_tokens": 4096,
"temperature": temperature,
"stream": True,
}
return self._do_stream(url, payload, lambda c: (
c.get("choices", [{}])[0].get("text", "") if c.get("choices") else ""
))
def _do_stream(self, url: str, payload: dict, parse_fn) -> str:
"""Execute streaming request, print tokens live, return full text."""
body = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(
url, data=body, headers={"Content-Type": "application/json"}
)
full_text = []
try:
with urllib.request.urlopen(req, timeout=config.LLM_TIMEOUT) as resp:
buffer = b""
while True:
chunk = resp.read(1)
if not chunk:
break
buffer += chunk
if chunk == b"\n" and buffer.strip():
line = buffer.decode("utf-8").strip()
buffer = b""
# vLLM SSE format
if line.startswith("data: "):
line = line[6:]
if line == "[DONE]":
break
try:
data = json.loads(line)
token = parse_fn(data)
if token:
full_text.append(token)
sys.stdout.write(token)
sys.stdout.flush()
except json.JSONDecodeError:
pass
elif chunk == b"\n":
buffer = b""
except urllib.error.URLError as e:
raise LLMError(f"Stream request failed ({url}): {e}")
sys.stdout.write("\n")
sys.stdout.flush()
return "".join(full_text)