-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllama_mcp_server.py
More file actions
206 lines (188 loc) · 7.54 KB
/
llama_mcp_server.py
File metadata and controls
206 lines (188 loc) · 7.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#!/usr/bin/env python3
"""
cursor-llama-mcp-bridge: Minimal MCP server that proxies a local llama.cpp `llama-server`
over its OpenAI-compatible HTTP API. Works great with Cursor via stdio transport.
Tools exposed:
- llama_health() -> dict
- llama_models() -> list[str]
- llama_chat(messages: list, max_tokens?: int, temperature?: float, top_p?: float, model?: str, stream?: bool) -> dict
- llama_completion(prompt: str, max_tokens?: int, temperature?: float, top_p?: float, model?: str) -> dict
ENV:
LLAMA_BASE_URL default: http://127.0.0.1:8081
LLAMA_API_KEY optional
LLAMA_TIMEOUT_S optional, default 60
LLAMA_DEFAULT_MODEL optional, e.g. gemma-3-1b-it-Q8_0.gguf
LLAMA_LOG_JSON if "1", emit newline-delimited JSON logs for easy support bundles
"""
from __future__ import annotations
import json
import os
import random
import sys
import time
from http.client import HTTPConnection, HTTPSConnection
from typing import Any, Dict, List, Optional, Tuple
# --- MCP glue (tiny shim that mirrors the common "mcp" python package API) ---
try:
from mcp import Server # type: ignore
except Exception: # pragma: no cover
# Extremely small fallback so the file can still be imported/tested.
class _DummyToolReg:
def __init__(self): self.tools = {}
def tool(self, name: Optional[str]=None):
def deco(fn):
self.tools[name or fn.__name__] = fn
return fn
return deco
def run(self, transport: str="stdio"):
raise SystemExit("Install 'mcp' package to run the server.")
Server = _DummyToolReg # type: ignore
mcp = Server() # real Server in production; dummy in tests
# --- Config ---
BASE_URL = os.environ.get("LLAMA_BASE_URL", "http://127.0.0.1:8081").rstrip("/")
API_KEY = os.environ.get("LLAMA_API_KEY", "")
TIMEOUT = float(os.environ.get("LLAMA_TIMEOUT_S", "60"))
DEF_MODEL= os.environ.get("LLAMA_DEFAULT_MODEL", "")
LOG_JSON = os.environ.get("LLAMA_LOG_JSON", "") == "1"
def _log(event: str, **kw: Any) -> None:
if LOG_JSON:
rec = {"ts": time.time(), "event": event}
rec.update(kw)
sys.stderr.write(json.dumps(rec, ensure_ascii=False) + "\n")
sys.stderr.flush()
# --- HTTP helper with retries/backoff ---
def _parse_base(url: str) -> Tuple[str, str, int]:
if not (url.startswith("http://") or url.startswith("https://")):
raise ValueError("LLAMA_BASE_URL must start with http:// or https://")
scheme, rest = url.split("://", 1)
if "/" in rest:
hostport, _ = rest.split("/", 1)
else:
hostport = rest
if ":" in hostport:
host, port_s = hostport.split(":", 1)
port = int(port_s)
else:
host, port = hostport, 80 if scheme == "http" else 443
return scheme, host, port
SCHEME, HOST, PORT = _parse_base(BASE_URL)
CONN_CLS = HTTPSConnection if SCHEME == "https" else HTTPConnection
def _headers() -> Dict[str, str]:
h = {"Content-Type": "application/json", "Accept": "application/json"}
if API_KEY:
h["Authorization"] = f"Bearer {API_KEY}"
return h
def _request_json(method: str, path: str, payload: Optional[Dict[str, Any]]=None,
timeout: float=TIMEOUT, retries: int=2, backoff_s: float=0.3) -> Dict[str, Any]:
last_err = None
for attempt in range(retries + 1):
try:
conn = CONN_CLS(HOST, PORT, timeout=timeout)
body = None if payload is None else json.dumps(payload).encode("utf-8")
conn.request(method, path, body=body, headers=_headers())
resp = conn.getresponse()
data = resp.read()
status = resp.status
try:
parsed = json.loads(data.decode("utf-8") if data else "{}")
except Exception:
snippet = data[:256] if isinstance(data, (bytes, bytearray)) else (str(data)[:256])
raise RuntimeError(f"Non-JSON response ({status}): {snippet!r}")
if 200 <= status < 300:
return parsed
msg = parsed if isinstance(parsed, dict) else {"raw": parsed}
raise RuntimeError(f"HTTP {status} error: {json.dumps(msg)[:400]}")
except Exception as e:
last_err = e
if attempt < retries:
sleep = backoff_s * (1.0 + random.random())
time.sleep(sleep)
continue
raise
# Should not reach
raise last_err # type: ignore
# --- Tools ---
@mcp.tool()
def llama_health() -> Dict[str, Any]:
"""
Return connectivity + config. If /v1/models fails, include error message.
"""
info = {"base_url": BASE_URL, "timeout_s": TIMEOUT, "default_model": DEF_MODEL or None}
try:
m = _request_json("GET", "/v1/models", None)
ids = [x.get("id") for x in (m.get("data") or []) if isinstance(x, dict)]
info.update({"status": "ok", "models": ids})
except Exception as e:
info.update({"status": "degraded", "error": str(e)})
_log("llama_health", **info)
return info
@mcp.tool()
def llama_models() -> List[str]:
"""
Return the list of model IDs advertised by llama-server.
"""
m = _request_json("GET", "/v1/models", None)
ids = [x.get("id") for x in (m.get("data") or []) if isinstance(x, dict)]
_log("llama_models", count=len(ids))
return ids
@mcp.tool()
def llama_chat(messages: List[Dict[str, str]],
max_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 0.95,
model: Optional[str] = None,
stream: bool = False,
**extra: Any) -> Dict[str, Any]:
"""
Call /v1/chat/completions on llama-server.
"""
eff_model = model or DEF_MODEL or None
payload: Dict[str, Any] = {
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"stream": stream,
}
if eff_model:
payload["model"] = eff_model
payload.update({k: v for k, v in extra.items() if v is not None})
t0 = time.time()
out = _request_json("POST", "/v1/chat/completions", payload)
_log("llama_chat", latency_ms=int((time.time()-t0)*1000), used_model=eff_model or "server-default")
return out
@mcp.tool()
def llama_completion(prompt: str,
max_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 0.95,
model: Optional[str] = None,
**extra: Any) -> Dict[str, Any]:
"""
Call /v1/completions on llama-server.
"""
eff_model = model or DEF_MODEL or None
payload: Dict[str, Any] = {
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
if eff_model:
payload["model"] = eff_model
payload.update({k: v for k, v in extra.items() if v is not None})
t0 = time.time()
out = _request_json("POST", "/v1/completions", payload)
_log("llama_completion", latency_ms=int((time.time()-t0)*1000), used_model=eff_model or "server-default")
return out
def _has_real_stdio() -> bool:
return hasattr(sys.stdout, "buffer") and hasattr(sys.stdin, "buffer")
if __name__ == "__main__":
if not _has_real_stdio():
sys.stderr.write("ERROR: Real stdio buffers not available. Run in a normal terminal (Cursor will spawn it).\n")
sys.exit(2)
# The real 'mcp' package exposes .run(transport="stdio")
try:
mcp.run(transport="stdio") # type: ignore
except AttributeError:
raise SystemExit("Install 'mcp' package and run this under Cursor.")