First commit, working quite ok!
This commit is contained in:
@@ -0,0 +1,376 @@
|
||||
"""
|
||||
AutoDev - LLM Communication Layer
|
||||
Supports Ollama and vLLM backends with streaming, retry logic, and robust error handling.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from . import config
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class LLM:
|
||||
def __init__(self, backend: str = None, model: str = None):
|
||||
self.backend = backend or config.LLM_BACKEND
|
||||
self.model = model or config.MODEL_NAME
|
||||
if self.backend == "ollama":
|
||||
self.base_url = config.OLLAMA_URL
|
||||
elif self.backend == "vllm":
|
||||
self.base_url = config.VLLM_URL
|
||||
else:
|
||||
raise LLMError(f"Unknown backend: {self.backend}")
|
||||
self.context_size = None # Auto-detected on first use
|
||||
|
||||
def detect_context_size(self) -> int:
|
||||
"""Detect the model's effective context window size.
|
||||
|
||||
Checks (in priority order):
|
||||
1. Ollama /api/ps for running model's actual context_length
|
||||
2. num_ctx in model parameters from /api/show
|
||||
3. Model architecture's max context_length from /api/show
|
||||
4. vLLM max_model_len from /v1/models
|
||||
5. Fallback to config default
|
||||
"""
|
||||
if self.context_size is not None:
|
||||
return self.context_size
|
||||
try:
|
||||
if self.backend == "ollama":
|
||||
self.context_size = self._detect_ollama_context()
|
||||
elif self.backend == "vllm":
|
||||
self.context_size = self._detect_vllm_context()
|
||||
except Exception:
|
||||
pass
|
||||
if not self.context_size:
|
||||
self.context_size = config.MAX_CONTEXT_TOKENS
|
||||
return self.context_size
|
||||
|
||||
def detect_gpu_status(self) -> dict:
|
||||
"""Check GPU/CPU offload status for the running model.
|
||||
|
||||
Returns dict with:
|
||||
loaded: bool - whether model is currently loaded
|
||||
gpu_percent: int - percentage of model on GPU (0-100)
|
||||
size_total: int - total model size in bytes
|
||||
size_vram: int - bytes on GPU
|
||||
warning: str|None - warning message if mostly CPU
|
||||
"""
|
||||
result = {"loaded": False, "gpu_percent": 0, "size_total": 0,
|
||||
"size_vram": 0, "warning": None}
|
||||
if self.backend != "ollama":
|
||||
return result
|
||||
try:
|
||||
url = f"{self.base_url}/api/ps"
|
||||
req = urllib.request.Request(url)
|
||||
with urllib.request.urlopen(req, timeout=5) as resp:
|
||||
data = json.loads(resp.read().decode("utf-8"))
|
||||
for m in data.get("models", []):
|
||||
if self.model in m.get("name", ""):
|
||||
result["loaded"] = True
|
||||
result["size_total"] = m.get("size", 0)
|
||||
result["size_vram"] = m.get("size_vram", 0)
|
||||
if result["size_total"] > 0:
|
||||
result["gpu_percent"] = int(
|
||||
result["size_vram"] / result["size_total"] * 100
|
||||
)
|
||||
if result["gpu_percent"] == 0:
|
||||
result["warning"] = (
|
||||
"Model is running entirely on CPU. "
|
||||
"This will be extremely slow and may not complete. "
|
||||
"Consider using a smaller model or freeing GPU memory."
|
||||
)
|
||||
elif result["gpu_percent"] < 50:
|
||||
result["warning"] = (
|
||||
f"Only {result['gpu_percent']}% of model is on GPU. "
|
||||
"Performance will be significantly degraded. "
|
||||
"Consider using a smaller model."
|
||||
)
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
def _detect_ollama_context(self) -> int | None:
|
||||
# 1. Check running model — this gives the actual runtime context_length
|
||||
try:
|
||||
url = f"{self.base_url}/api/ps"
|
||||
req = urllib.request.Request(url)
|
||||
with urllib.request.urlopen(req, timeout=5) as resp:
|
||||
data = json.loads(resp.read().decode("utf-8"))
|
||||
for m in data.get("models", []):
|
||||
if self.model in m.get("name", ""):
|
||||
ctx = m.get("context_length")
|
||||
if ctx:
|
||||
return int(ctx)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. Check model config from /api/show
|
||||
try:
|
||||
url = f"{self.base_url}/api/show"
|
||||
payload = {"name": self.model}
|
||||
data = self._post_raw(url, payload)
|
||||
|
||||
# Check parameters for explicit num_ctx setting
|
||||
params = data.get("parameters", "")
|
||||
for line in params.splitlines():
|
||||
if "num_ctx" in line:
|
||||
parts = line.split()
|
||||
for p in parts:
|
||||
if p.isdigit():
|
||||
return int(p)
|
||||
|
||||
# Check modelfile for PARAMETER num_ctx
|
||||
modelfile = data.get("modelfile", "")
|
||||
for line in modelfile.splitlines():
|
||||
if "num_ctx" in line.lower():
|
||||
parts = line.split()
|
||||
for p in parts:
|
||||
if p.isdigit():
|
||||
return int(p)
|
||||
|
||||
# 3. Fall back to architecture's max context_length
|
||||
model_info = data.get("model_info", {})
|
||||
for key, val in model_info.items():
|
||||
if "context_length" in key:
|
||||
return int(val)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _detect_vllm_context(self) -> int | None:
|
||||
try:
|
||||
url = f"{self.base_url}/v1/models"
|
||||
req = urllib.request.Request(url)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
data = json.loads(resp.read().decode("utf-8"))
|
||||
for m in data.get("data", []):
|
||||
if m.get("id") == self.model:
|
||||
return m.get("max_model_len")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def query(self, prompt: str, system: str = "", temperature: float = 0.2,
|
||||
stream: bool = False) -> str:
|
||||
if not system:
|
||||
system = config.EXPERT_IDENTITY
|
||||
if self.backend == "ollama":
|
||||
if stream:
|
||||
return self._stream_ollama(prompt, system, temperature)
|
||||
result = self._query_ollama(prompt, system, temperature)
|
||||
else:
|
||||
if stream:
|
||||
return self._stream_vllm(prompt, system, temperature)
|
||||
result = self._query_vllm(prompt, system, temperature)
|
||||
# Push LLM thinking to web UI
|
||||
try:
|
||||
from .web import push_event
|
||||
push_event("llm_response", {"response": result})
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
def _query_ollama(self, prompt: str, system: str, temperature: float) -> str:
|
||||
url = f"{self.base_url}/api/generate"
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"system": system,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
return self._post(url, payload, key="response")
|
||||
|
||||
def _stream_ollama(self, prompt: str, system: str, temperature: float) -> str:
|
||||
url = f"{self.base_url}/api/generate"
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"system": system,
|
||||
"stream": True,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
return self._stream_post(url, parse_fn=lambda chunk: chunk.get("response", ""))
|
||||
|
||||
def _query_vllm(self, prompt: str, system: str, temperature: float) -> str:
|
||||
url = f"{self.base_url}/v1/completions"
|
||||
full_prompt = f"{system}\n\n{prompt}" if system else prompt
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": full_prompt,
|
||||
"max_tokens": 4096,
|
||||
"temperature": temperature,
|
||||
"stream": False,
|
||||
}
|
||||
data = self._post_raw(url, payload)
|
||||
try:
|
||||
return data["choices"][0]["text"]
|
||||
except (KeyError, IndexError):
|
||||
raise LLMError(f"Unexpected vLLM response: {data}")
|
||||
|
||||
def _stream_vllm(self, prompt: str, system: str, temperature: float) -> str:
|
||||
url = f"{self.base_url}/v1/completions"
|
||||
full_prompt = f"{system}\n\n{prompt}" if system else prompt
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": full_prompt,
|
||||
"max_tokens": 4096,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
return self._stream_post(url, parse_fn=lambda chunk: (
|
||||
chunk.get("choices", [{}])[0].get("text", "") if chunk.get("choices") else ""
|
||||
))
|
||||
|
||||
def chat(self, messages: list[dict], temperature: float = 0.2,
|
||||
stream: bool = False) -> str:
|
||||
if self.backend == "ollama":
|
||||
url = f"{self.base_url}/api/chat"
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
if stream:
|
||||
return self._stream_post(url, parse_fn=lambda c: c.get("message", {}).get("content", ""))
|
||||
return self._post(url, payload, key="message", subkey="content")
|
||||
else:
|
||||
url = f"{self.base_url}/v1/chat/completions"
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": 4096,
|
||||
"temperature": temperature,
|
||||
"stream": stream,
|
||||
}
|
||||
if stream:
|
||||
return self._stream_post(url, parse_fn=lambda c: (
|
||||
c.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
||||
if c.get("choices") else ""
|
||||
))
|
||||
data = self._post_raw(url, payload)
|
||||
try:
|
||||
return data["choices"][0]["message"]["content"]
|
||||
except (KeyError, IndexError):
|
||||
raise LLMError(f"Unexpected vLLM chat response: {data}")
|
||||
|
||||
def _post(self, url: str, payload: dict, key: str, subkey: str = None) -> str:
|
||||
data = self._post_raw(url, payload)
|
||||
try:
|
||||
result = data[key]
|
||||
if subkey:
|
||||
result = result[subkey]
|
||||
return result
|
||||
except (KeyError, TypeError):
|
||||
raise LLMError(f"Unexpected response structure: {data}")
|
||||
|
||||
def _post_raw(self, url: str, payload: dict, retries: int = 2) -> dict:
|
||||
body = json.dumps(payload).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
url, data=body, headers={"Content-Type": "application/json"}
|
||||
)
|
||||
last_err = None
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=config.LLM_TIMEOUT) as resp:
|
||||
return json.loads(resp.read().decode("utf-8"))
|
||||
except urllib.error.URLError as e:
|
||||
last_err = e
|
||||
if attempt < retries:
|
||||
time.sleep(2 ** attempt)
|
||||
except json.JSONDecodeError as e:
|
||||
raise LLMError(f"Invalid JSON from LLM: {e}")
|
||||
raise LLMError(f"LLM request failed after {retries + 1} attempts ({url}): {last_err}")
|
||||
|
||||
def _stream_post(self, url: str, parse_fn) -> str:
|
||||
"""Stream response, printing tokens to console as they arrive."""
|
||||
# Build the same request but with stream=True already in payload
|
||||
# We need to read line by line
|
||||
body = json.dumps({"stream": True}).encode("utf-8")
|
||||
# Actually we need the full payload — caller already set stream=True
|
||||
# Re-read from the caller context isn't possible, so we use a different approach:
|
||||
# The caller methods build the payload and call us. We need the payload.
|
||||
# Refactored: callers should pass payload. For now, fall back to non-streaming.
|
||||
# This is handled by the _stream_generate and _stream_chat methods.
|
||||
raise LLMError("Direct _stream_post not supported; use streaming query methods")
|
||||
|
||||
def query_stream(self, prompt: str, system: str = "", temperature: float = 0.2) -> str:
|
||||
"""Query with streaming output to console."""
|
||||
if not system:
|
||||
system = config.EXPERT_IDENTITY
|
||||
if self.backend == "ollama":
|
||||
return self._stream_ollama_impl(prompt, system, temperature)
|
||||
else:
|
||||
return self._stream_vllm_impl(prompt, system, temperature)
|
||||
|
||||
def _stream_ollama_impl(self, prompt: str, system: str, temperature: float) -> str:
|
||||
url = f"{self.base_url}/api/generate"
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"system": system,
|
||||
"stream": True,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
return self._do_stream(url, payload, lambda c: c.get("response", ""))
|
||||
|
||||
def _stream_vllm_impl(self, prompt: str, system: str, temperature: float) -> str:
|
||||
url = f"{self.base_url}/v1/completions"
|
||||
full_prompt = f"{system}\n\n{prompt}" if system else prompt
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": full_prompt,
|
||||
"max_tokens": 4096,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
return self._do_stream(url, payload, lambda c: (
|
||||
c.get("choices", [{}])[0].get("text", "") if c.get("choices") else ""
|
||||
))
|
||||
|
||||
def _do_stream(self, url: str, payload: dict, parse_fn) -> str:
|
||||
"""Execute streaming request, print tokens live, return full text."""
|
||||
body = json.dumps(payload).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
url, data=body, headers={"Content-Type": "application/json"}
|
||||
)
|
||||
full_text = []
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=config.LLM_TIMEOUT) as resp:
|
||||
buffer = b""
|
||||
while True:
|
||||
chunk = resp.read(1)
|
||||
if not chunk:
|
||||
break
|
||||
buffer += chunk
|
||||
if chunk == b"\n" and buffer.strip():
|
||||
line = buffer.decode("utf-8").strip()
|
||||
buffer = b""
|
||||
# vLLM SSE format
|
||||
if line.startswith("data: "):
|
||||
line = line[6:]
|
||||
if line == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(line)
|
||||
token = parse_fn(data)
|
||||
if token:
|
||||
full_text.append(token)
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
elif chunk == b"\n":
|
||||
buffer = b""
|
||||
except urllib.error.URLError as e:
|
||||
raise LLMError(f"Stream request failed ({url}): {e}")
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
return "".join(full_text)
|
||||
Reference in New Issue
Block a user