Diffcontext / diffcontext /mcp_server.py
trakshan-mishra
Update inline analysis to support structured lists for ChatGPT Actions compatibility
ed862c3
Raw
History Blame Contribute Delete
15.8 kB
"""
DiffContext MCP Server
Lets Claude Desktop / Cursor call DiffContext natively β€” no copy-paste needed.
How to use:
1. pip install mcp
2. Run: python -m diffcontext.mcp_server
"""
import ast
import json
import os
import sys
import tempfile
import shutil
from pathlib import Path
from typing import Any, Dict, List, Optional
# ── Try to import MCP; give a clear error if not installed ──────────────────
try:
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import Tool, TextContent
HAS_MCP = True
except ImportError:
HAS_MCP = False
# ── Try to import DiffContext ────────────────────────────────────────────────
try:
from .pipeline import index_repository, analyze_impact
from .pipeline import compile as dc_compile
HAS_DIFFCONTEXT = True
except ImportError:
try:
from diffcontext.pipeline import index_repository, analyze_impact
from diffcontext.pipeline import compile as dc_compile
HAS_DIFFCONTEXT = True
except ImportError:
HAS_DIFFCONTEXT = False
# ── Sessions helper for remote hosted servers ────────────────────────────────
def _get_sessions() -> Dict[str, str]:
try:
import importlib
backend_main = importlib.import_module("diffcontext-service.backend.main")
return backend_main.SESSIONS
except Exception:
try:
import importlib
backend_main = importlib.import_module("backend.main")
return backend_main.SESSIONS
except Exception:
return {}
# ── Shared analysis logic (same as web service) ──────────────────────────────
def _list_symbols_in_dir(repo_path: str) -> List[str]:
symbols = []
for py_file in Path(repo_path).rglob("*.py"):
rel = "./" + str(py_file.relative_to(repo_path))
try:
source = py_file.read_text(errors="ignore")
tree = ast.parse(source)
except SyntaxError:
continue
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
symbols.append(f"{rel}:{node.name}")
return sorted(symbols)
def _blast_radius(repo_path: str, symbol: str) -> dict:
if HAS_DIFFCONTEXT:
idx = index_repository(repo_path)
if symbol not in idx.symbols and symbol not in idx.graph:
import difflib
suggestions = difflib.get_close_matches(symbol, list(idx.symbols.keys()), n=3, cutoff=0.4)
return {"error": f"'{symbol}' not found.", "suggestions": suggestions}
impact = analyze_impact(idx, [symbol])
reverse: Dict[str, List[str]] = {}
for caller, callees in idx.graph.items():
for callee in callees:
reverse.setdefault(callee, []).append(caller)
by_file: Dict[str, List[str]] = {}
for sym in impact.blast_radius:
fp = sym.split(":")[0]
by_file.setdefault(fp, []).append(sym.split(":")[-1])
return {
"symbol": symbol,
"direct_callers": reverse.get(symbol, []),
"direct_callees": idx.graph.get(symbol, []),
"blast_radius_count": len(impact.blast_radius),
"blast_radius_by_file": by_file,
"all_affected": impact.blast_radius[:50],
}
# AST fallback
all_functions: Dict[str, str] = {}
call_graph: Dict[str, List[str]] = {}
for py_file in Path(repo_path).rglob("*.py"):
rel = "./" + str(py_file.relative_to(repo_path))
try:
source = py_file.read_text(errors="ignore")
tree = ast.parse(source)
except SyntaxError:
continue
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
fid = f"{rel}:{node.name}"
all_functions[fid] = ""
calls = []
for child in ast.walk(node):
if isinstance(child, ast.Call):
if isinstance(child.func, ast.Name):
calls.append(child.func.id)
elif isinstance(child.func, ast.Attribute):
calls.append(child.func.attr)
call_graph[fid] = calls
name_to_ids: Dict[str, List[str]] = {}
for fid in all_functions:
name_to_ids.setdefault(fid.split(":")[-1], []).append(fid)
resolved_graph: Dict[str, List[str]] = {}
for fid, raw_calls in call_graph.items():
resolved = []
for call in raw_calls:
for candidate in name_to_ids.get(call, []):
if candidate != fid:
resolved.append(candidate)
resolved_graph[fid] = resolved
reverse: Dict[str, List[str]] = {}
for caller, callees in resolved_graph.items():
for callee in callees:
reverse.setdefault(callee, []).append(caller)
if symbol not in all_functions:
import difflib
suggestions = difflib.get_close_matches(symbol, list(all_functions.keys()), n=3, cutoff=0.4)
return {"error": f"'{symbol}' not found.", "suggestions": suggestions}
visited = {symbol}
blast = []
frontier = [symbol]
while frontier:
next_f = []
for node in frontier:
for caller in reverse.get(node, []):
if caller not in visited:
visited.add(caller)
blast.append(caller)
next_f.append(caller)
frontier = next_f
by_file: Dict[str, List[str]] = {}
for sym in blast:
fp = sym.split(":")[0]
by_file.setdefault(fp, []).append(sym.split(":")[-1])
return {
"symbol": symbol,
"direct_callers": reverse.get(symbol, []),
"direct_callees": resolved_graph.get(symbol, []),
"blast_radius_count": len(blast),
"blast_radius_by_file": by_file,
"all_affected": blast[:50],
}
# ── MCP Tool definitions ──────────────────────────────────────────────────────
if HAS_MCP:
TOOLS = [
Tool(
name="list_symbols",
description=(
"List all Python functions/methods in a repository. "
"Returns symbol IDs in the format './file.py:function_name'. "
"Provide either repo_path (local absolute path) or repo_id (from /upload on remote hosted servers)."
),
inputSchema={
"type": "object",
"properties": {
"repo_path": {
"type": "string",
"description": "Absolute path to the Python project directory on the local machine.",
},
"repo_id": {
"type": "string",
"description": "The session ID of an uploaded repository (for remote MCP servers).",
}
},
"required": [],
},
),
Tool(
name="blast_radius",
description=(
"Find the blast radius of a Python function β€” which other functions "
"will break if you change it. Returns direct callers, direct callees, "
"and the full transitive set of affected functions grouped by file."
),
inputSchema={
"type": "object",
"properties": {
"repo_path": {
"type": "string",
"description": "Absolute path to the Python project directory.",
},
"repo_id": {
"type": "string",
"description": "The session ID of an uploaded repository (for remote MCP servers).",
},
"symbol": {
"type": "string",
"description": "Symbol ID in format './relative/path.py:function_name' or './path.py:ClassName.method'",
},
},
"required": ["symbol"],
},
),
Tool(
name="compile_context",
description=(
"Compile a token-efficient LLM context package for a code change. "
"Returns the relevant functions and their relationships, trimmed to a token budget."
),
inputSchema={
"type": "object",
"properties": {
"repo_path": {
"type": "string",
"description": "Absolute path to the Python project directory.",
},
"repo_id": {
"type": "string",
"description": "The session ID of an uploaded repository (for remote MCP servers).",
},
"symbol": {
"type": "string",
"description": "Symbol ID to analyze, e.g. './src/auth.py:validate_jwt'",
},
"max_tokens": {
"type": "integer",
"description": "Token budget for context (default 8000). Lower = tighter selection.",
"default": 8000,
},
},
"required": ["symbol"],
},
),
Tool(
name="analyze_inline",
description=(
"Analyze Python code you paste directly β€” no file upload or local repo needed. "
"Useful for quick analysis of code snippets. "
"Returns blast radius and caller/callee relationships."
),
inputSchema={
"type": "object",
"properties": {
"files": {
"type": "object",
"description": "Dict of filename -> Python source code. E.g. {'service.py': 'def foo(): ...'}",
},
"symbol": {
"type": "string",
"description": "Symbol to analyze, e.g. './service.py:foo'",
},
},
"required": ["files", "symbol"],
},
),
]
else:
TOOLS = []
# ── MCP Server ────────────────────────────────────────────────────────────────
if HAS_MCP:
server = Server("diffcontext")
@server.list_tools()
async def list_tools():
return TOOLS
@server.call_tool()
async def call_tool(name: str, arguments: Dict[str, Any]):
try:
repo_path = arguments.get("repo_path")
repo_id = arguments.get("repo_id")
if name != "analyze_inline" and not repo_path and not repo_id:
result = {"error": "Must provide either 'repo_path' or 'repo_id'."}
else:
if repo_id:
sessions = _get_sessions()
if repo_id not in sessions:
result = {"error": f"repo_id '{repo_id}' not found. Please upload the repository zip first."}
return [TextContent(type="text", text=json.dumps(result, indent=2))]
repo_path = sessions[repo_id]
if name == "list_symbols":
if not os.path.isdir(repo_path):
result = {"error": f"Directory not found: {repo_path}"}
else:
symbols = _list_symbols_in_dir(repo_path)
result = {"count": len(symbols), "symbols": symbols}
elif name == "blast_radius":
symbol = arguments["symbol"]
if not os.path.isdir(repo_path):
result = {"error": f"Directory not found: {repo_path}"}
else:
result = _blast_radius(repo_path, symbol)
elif name == "compile_context":
symbol = arguments["symbol"]
max_tokens = arguments.get("max_tokens", 8000)
if not os.path.isdir(repo_path):
result = {"error": f"Directory not found: {repo_path}"}
elif not HAS_DIFFCONTEXT:
br = _blast_radius(repo_path, symbol)
result = {
"context": json.dumps(br, indent=2),
"note": "Install DiffContext for full compiled context.",
}
else:
idx = index_repository(repo_path)
impact = analyze_impact(idx, [symbol])
ctx = dc_compile(idx, impact, max_tokens=max_tokens)
result = {
"context": ctx.text,
"symbol_count": ctx.symbol_count,
"token_estimate": ctx.token_estimate,
"reduction_pct": round(ctx.reduction_pct, 1),
}
elif name == "analyze_inline":
files = arguments["files"]
symbol = arguments["symbol"]
tmp_dir = tempfile.mkdtemp(prefix="diffctx_mcp_")
try:
files_dict = {}
if isinstance(files, list):
for f in files:
if isinstance(f, dict):
files_dict[f.get("filename")] = f.get("content")
else:
files_dict = files
for filename, code in files_dict.items():
safe_name = Path(filename).name
with open(os.path.join(tmp_dir, safe_name), "w") as f:
f.write(code)
result = _blast_radius(tmp_dir, symbol)
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
else:
result = {"error": f"Unknown tool: {name}"}
except Exception as e:
result = {"error": str(e)}
return [TextContent(type="text", text=json.dumps(result, indent=2))]
else:
server = None
async def run_mcp_server():
if not HAS_MCP:
print(
"ERROR: 'mcp' package not installed.\n"
"Fix: pip install mcp\n",
file=sys.stderr,
)
sys.exit(1)
async with stdio_server() as (read_stream, write_stream):
await server.run(read_stream, write_stream, server.create_initialization_options())
def main():
import asyncio
print("Starting DiffContext MCP Server...", file=sys.stderr)
print(f"DiffContext installed: {HAS_DIFFCONTEXT}", file=sys.stderr)
print(f"MCP installed: {HAS_MCP}", file=sys.stderr)
if not HAS_MCP:
print("\nTo install: pip install mcp", file=sys.stderr)
sys.exit(1)
asyncio.run(run_mcp_server())
if __name__ == "__main__":
main()