377 lines
15 KiB
Python
377 lines
15 KiB
Python
"""
|
|
AutoDev - LLM Communication Layer
|
|
Supports Ollama and vLLM backends with streaming, retry logic, and robust error handling.
|
|
"""
|
|
|
|
import json
|
|
import sys
|
|
import time
|
|
import urllib.request
|
|
import urllib.error
|
|
from . import config
|
|
|
|
|
|
class LLMError(Exception):
|
|
pass
|
|
|
|
|
|
class LLM:
|
|
def __init__(self, backend: str = None, model: str = None):
|
|
self.backend = backend or config.LLM_BACKEND
|
|
self.model = model or config.MODEL_NAME
|
|
if self.backend == "ollama":
|
|
self.base_url = config.OLLAMA_URL
|
|
elif self.backend == "vllm":
|
|
self.base_url = config.VLLM_URL
|
|
else:
|
|
raise LLMError(f"Unknown backend: {self.backend}")
|
|
self.context_size = None # Auto-detected on first use
|
|
|
|
def detect_context_size(self) -> int:
|
|
"""Detect the model's effective context window size.
|
|
|
|
Checks (in priority order):
|
|
1. Ollama /api/ps for running model's actual context_length
|
|
2. num_ctx in model parameters from /api/show
|
|
3. Model architecture's max context_length from /api/show
|
|
4. vLLM max_model_len from /v1/models
|
|
5. Fallback to config default
|
|
"""
|
|
if self.context_size is not None:
|
|
return self.context_size
|
|
try:
|
|
if self.backend == "ollama":
|
|
self.context_size = self._detect_ollama_context()
|
|
elif self.backend == "vllm":
|
|
self.context_size = self._detect_vllm_context()
|
|
except Exception:
|
|
pass
|
|
if not self.context_size:
|
|
self.context_size = config.MAX_CONTEXT_TOKENS
|
|
return self.context_size
|
|
|
|
def detect_gpu_status(self) -> dict:
|
|
"""Check GPU/CPU offload status for the running model.
|
|
|
|
Returns dict with:
|
|
loaded: bool - whether model is currently loaded
|
|
gpu_percent: int - percentage of model on GPU (0-100)
|
|
size_total: int - total model size in bytes
|
|
size_vram: int - bytes on GPU
|
|
warning: str|None - warning message if mostly CPU
|
|
"""
|
|
result = {"loaded": False, "gpu_percent": 0, "size_total": 0,
|
|
"size_vram": 0, "warning": None}
|
|
if self.backend != "ollama":
|
|
return result
|
|
try:
|
|
url = f"{self.base_url}/api/ps"
|
|
req = urllib.request.Request(url)
|
|
with urllib.request.urlopen(req, timeout=5) as resp:
|
|
data = json.loads(resp.read().decode("utf-8"))
|
|
for m in data.get("models", []):
|
|
if self.model in m.get("name", ""):
|
|
result["loaded"] = True
|
|
result["size_total"] = m.get("size", 0)
|
|
result["size_vram"] = m.get("size_vram", 0)
|
|
if result["size_total"] > 0:
|
|
result["gpu_percent"] = int(
|
|
result["size_vram"] / result["size_total"] * 100
|
|
)
|
|
if result["gpu_percent"] == 0:
|
|
result["warning"] = (
|
|
"Model is running entirely on CPU. "
|
|
"This will be extremely slow and may not complete. "
|
|
"Consider using a smaller model or freeing GPU memory."
|
|
)
|
|
elif result["gpu_percent"] < 50:
|
|
result["warning"] = (
|
|
f"Only {result['gpu_percent']}% of model is on GPU. "
|
|
"Performance will be significantly degraded. "
|
|
"Consider using a smaller model."
|
|
)
|
|
break
|
|
except Exception:
|
|
pass
|
|
return result
|
|
|
|
def _detect_ollama_context(self) -> int | None:
|
|
# 1. Check running model — this gives the actual runtime context_length
|
|
try:
|
|
url = f"{self.base_url}/api/ps"
|
|
req = urllib.request.Request(url)
|
|
with urllib.request.urlopen(req, timeout=5) as resp:
|
|
data = json.loads(resp.read().decode("utf-8"))
|
|
for m in data.get("models", []):
|
|
if self.model in m.get("name", ""):
|
|
ctx = m.get("context_length")
|
|
if ctx:
|
|
return int(ctx)
|
|
except Exception:
|
|
pass
|
|
|
|
# 2. Check model config from /api/show
|
|
try:
|
|
url = f"{self.base_url}/api/show"
|
|
payload = {"name": self.model}
|
|
data = self._post_raw(url, payload)
|
|
|
|
# Check parameters for explicit num_ctx setting
|
|
params = data.get("parameters", "")
|
|
for line in params.splitlines():
|
|
if "num_ctx" in line:
|
|
parts = line.split()
|
|
for p in parts:
|
|
if p.isdigit():
|
|
return int(p)
|
|
|
|
# Check modelfile for PARAMETER num_ctx
|
|
modelfile = data.get("modelfile", "")
|
|
for line in modelfile.splitlines():
|
|
if "num_ctx" in line.lower():
|
|
parts = line.split()
|
|
for p in parts:
|
|
if p.isdigit():
|
|
return int(p)
|
|
|
|
# 3. Fall back to architecture's max context_length
|
|
model_info = data.get("model_info", {})
|
|
for key, val in model_info.items():
|
|
if "context_length" in key:
|
|
return int(val)
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
def _detect_vllm_context(self) -> int | None:
|
|
try:
|
|
url = f"{self.base_url}/v1/models"
|
|
req = urllib.request.Request(url)
|
|
with urllib.request.urlopen(req, timeout=10) as resp:
|
|
data = json.loads(resp.read().decode("utf-8"))
|
|
for m in data.get("data", []):
|
|
if m.get("id") == self.model:
|
|
return m.get("max_model_len")
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
def query(self, prompt: str, system: str = "", temperature: float = 0.2,
|
|
stream: bool = False) -> str:
|
|
if not system:
|
|
system = config.EXPERT_IDENTITY
|
|
if self.backend == "ollama":
|
|
if stream:
|
|
return self._stream_ollama(prompt, system, temperature)
|
|
result = self._query_ollama(prompt, system, temperature)
|
|
else:
|
|
if stream:
|
|
return self._stream_vllm(prompt, system, temperature)
|
|
result = self._query_vllm(prompt, system, temperature)
|
|
# Push LLM thinking to web UI
|
|
try:
|
|
from .web import push_event
|
|
push_event("llm_response", {"response": result})
|
|
except Exception:
|
|
pass
|
|
return result
|
|
|
|
def _query_ollama(self, prompt: str, system: str, temperature: float) -> str:
|
|
url = f"{self.base_url}/api/generate"
|
|
payload = {
|
|
"model": self.model,
|
|
"prompt": prompt,
|
|
"system": system,
|
|
"stream": False,
|
|
"options": {"temperature": temperature},
|
|
}
|
|
return self._post(url, payload, key="response")
|
|
|
|
def _stream_ollama(self, prompt: str, system: str, temperature: float) -> str:
|
|
url = f"{self.base_url}/api/generate"
|
|
payload = {
|
|
"model": self.model,
|
|
"prompt": prompt,
|
|
"system": system,
|
|
"stream": True,
|
|
"options": {"temperature": temperature},
|
|
}
|
|
return self._stream_post(url, parse_fn=lambda chunk: chunk.get("response", ""))
|
|
|
|
def _query_vllm(self, prompt: str, system: str, temperature: float) -> str:
|
|
url = f"{self.base_url}/v1/completions"
|
|
full_prompt = f"{system}\n\n{prompt}" if system else prompt
|
|
payload = {
|
|
"model": self.model,
|
|
"prompt": full_prompt,
|
|
"max_tokens": 4096,
|
|
"temperature": temperature,
|
|
"stream": False,
|
|
}
|
|
data = self._post_raw(url, payload)
|
|
try:
|
|
return data["choices"][0]["text"]
|
|
except (KeyError, IndexError):
|
|
raise LLMError(f"Unexpected vLLM response: {data}")
|
|
|
|
def _stream_vllm(self, prompt: str, system: str, temperature: float) -> str:
|
|
url = f"{self.base_url}/v1/completions"
|
|
full_prompt = f"{system}\n\n{prompt}" if system else prompt
|
|
payload = {
|
|
"model": self.model,
|
|
"prompt": full_prompt,
|
|
"max_tokens": 4096,
|
|
"temperature": temperature,
|
|
"stream": True,
|
|
}
|
|
return self._stream_post(url, parse_fn=lambda chunk: (
|
|
chunk.get("choices", [{}])[0].get("text", "") if chunk.get("choices") else ""
|
|
))
|
|
|
|
def chat(self, messages: list[dict], temperature: float = 0.2,
|
|
stream: bool = False) -> str:
|
|
if self.backend == "ollama":
|
|
url = f"{self.base_url}/api/chat"
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"stream": stream,
|
|
"options": {"temperature": temperature},
|
|
}
|
|
if stream:
|
|
return self._stream_post(url, parse_fn=lambda c: c.get("message", {}).get("content", ""))
|
|
return self._post(url, payload, key="message", subkey="content")
|
|
else:
|
|
url = f"{self.base_url}/v1/chat/completions"
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"max_tokens": 4096,
|
|
"temperature": temperature,
|
|
"stream": stream,
|
|
}
|
|
if stream:
|
|
return self._stream_post(url, parse_fn=lambda c: (
|
|
c.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
|
if c.get("choices") else ""
|
|
))
|
|
data = self._post_raw(url, payload)
|
|
try:
|
|
return data["choices"][0]["message"]["content"]
|
|
except (KeyError, IndexError):
|
|
raise LLMError(f"Unexpected vLLM chat response: {data}")
|
|
|
|
def _post(self, url: str, payload: dict, key: str, subkey: str = None) -> str:
|
|
data = self._post_raw(url, payload)
|
|
try:
|
|
result = data[key]
|
|
if subkey:
|
|
result = result[subkey]
|
|
return result
|
|
except (KeyError, TypeError):
|
|
raise LLMError(f"Unexpected response structure: {data}")
|
|
|
|
def _post_raw(self, url: str, payload: dict, retries: int = 2) -> dict:
|
|
body = json.dumps(payload).encode("utf-8")
|
|
req = urllib.request.Request(
|
|
url, data=body, headers={"Content-Type": "application/json"}
|
|
)
|
|
last_err = None
|
|
for attempt in range(retries + 1):
|
|
try:
|
|
with urllib.request.urlopen(req, timeout=config.LLM_TIMEOUT) as resp:
|
|
return json.loads(resp.read().decode("utf-8"))
|
|
except urllib.error.URLError as e:
|
|
last_err = e
|
|
if attempt < retries:
|
|
time.sleep(2 ** attempt)
|
|
except json.JSONDecodeError as e:
|
|
raise LLMError(f"Invalid JSON from LLM: {e}")
|
|
raise LLMError(f"LLM request failed after {retries + 1} attempts ({url}): {last_err}")
|
|
|
|
def _stream_post(self, url: str, parse_fn) -> str:
|
|
"""Stream response, printing tokens to console as they arrive."""
|
|
# Build the same request but with stream=True already in payload
|
|
# We need to read line by line
|
|
body = json.dumps({"stream": True}).encode("utf-8")
|
|
# Actually we need the full payload — caller already set stream=True
|
|
# Re-read from the caller context isn't possible, so we use a different approach:
|
|
# The caller methods build the payload and call us. We need the payload.
|
|
# Refactored: callers should pass payload. For now, fall back to non-streaming.
|
|
# This is handled by the _stream_generate and _stream_chat methods.
|
|
raise LLMError("Direct _stream_post not supported; use streaming query methods")
|
|
|
|
def query_stream(self, prompt: str, system: str = "", temperature: float = 0.2) -> str:
|
|
"""Query with streaming output to console."""
|
|
if not system:
|
|
system = config.EXPERT_IDENTITY
|
|
if self.backend == "ollama":
|
|
return self._stream_ollama_impl(prompt, system, temperature)
|
|
else:
|
|
return self._stream_vllm_impl(prompt, system, temperature)
|
|
|
|
def _stream_ollama_impl(self, prompt: str, system: str, temperature: float) -> str:
|
|
url = f"{self.base_url}/api/generate"
|
|
payload = {
|
|
"model": self.model,
|
|
"prompt": prompt,
|
|
"system": system,
|
|
"stream": True,
|
|
"options": {"temperature": temperature},
|
|
}
|
|
return self._do_stream(url, payload, lambda c: c.get("response", ""))
|
|
|
|
def _stream_vllm_impl(self, prompt: str, system: str, temperature: float) -> str:
|
|
url = f"{self.base_url}/v1/completions"
|
|
full_prompt = f"{system}\n\n{prompt}" if system else prompt
|
|
payload = {
|
|
"model": self.model,
|
|
"prompt": full_prompt,
|
|
"max_tokens": 4096,
|
|
"temperature": temperature,
|
|
"stream": True,
|
|
}
|
|
return self._do_stream(url, payload, lambda c: (
|
|
c.get("choices", [{}])[0].get("text", "") if c.get("choices") else ""
|
|
))
|
|
|
|
def _do_stream(self, url: str, payload: dict, parse_fn) -> str:
|
|
"""Execute streaming request, print tokens live, return full text."""
|
|
body = json.dumps(payload).encode("utf-8")
|
|
req = urllib.request.Request(
|
|
url, data=body, headers={"Content-Type": "application/json"}
|
|
)
|
|
full_text = []
|
|
try:
|
|
with urllib.request.urlopen(req, timeout=config.LLM_TIMEOUT) as resp:
|
|
buffer = b""
|
|
while True:
|
|
chunk = resp.read(1)
|
|
if not chunk:
|
|
break
|
|
buffer += chunk
|
|
if chunk == b"\n" and buffer.strip():
|
|
line = buffer.decode("utf-8").strip()
|
|
buffer = b""
|
|
# vLLM SSE format
|
|
if line.startswith("data: "):
|
|
line = line[6:]
|
|
if line == "[DONE]":
|
|
break
|
|
try:
|
|
data = json.loads(line)
|
|
token = parse_fn(data)
|
|
if token:
|
|
full_text.append(token)
|
|
sys.stdout.write(token)
|
|
sys.stdout.flush()
|
|
except json.JSONDecodeError:
|
|
pass
|
|
elif chunk == b"\n":
|
|
buffer = b""
|
|
except urllib.error.URLError as e:
|
|
raise LLMError(f"Stream request failed ({url}): {e}")
|
|
sys.stdout.write("\n")
|
|
sys.stdout.flush()
|
|
return "".join(full_text)
|