Spaces:
Sleeping
Sleeping
trakshan-mishra
Update inline analysis to support structured lists for ChatGPT Actions compatibility
ed862c3 | """ | |
| DiffContext MCP Server | |
| Lets Claude Desktop / Cursor call DiffContext natively β no copy-paste needed. | |
| How to use: | |
| 1. pip install mcp | |
| 2. Run: python -m diffcontext.mcp_server | |
| """ | |
| import ast | |
| import json | |
| import os | |
| import sys | |
| import tempfile | |
| import shutil | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| # ββ Try to import MCP; give a clear error if not installed ββββββββββββββββββ | |
| try: | |
| from mcp.server import Server | |
| from mcp.server.stdio import stdio_server | |
| from mcp.types import Tool, TextContent | |
| HAS_MCP = True | |
| except ImportError: | |
| HAS_MCP = False | |
| # ββ Try to import DiffContext ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| from .pipeline import index_repository, analyze_impact | |
| from .pipeline import compile as dc_compile | |
| HAS_DIFFCONTEXT = True | |
| except ImportError: | |
| try: | |
| from diffcontext.pipeline import index_repository, analyze_impact | |
| from diffcontext.pipeline import compile as dc_compile | |
| HAS_DIFFCONTEXT = True | |
| except ImportError: | |
| HAS_DIFFCONTEXT = False | |
| # ββ Sessions helper for remote hosted servers ββββββββββββββββββββββββββββββββ | |
| def _get_sessions() -> Dict[str, str]: | |
| try: | |
| import importlib | |
| backend_main = importlib.import_module("diffcontext-service.backend.main") | |
| return backend_main.SESSIONS | |
| except Exception: | |
| try: | |
| import importlib | |
| backend_main = importlib.import_module("backend.main") | |
| return backend_main.SESSIONS | |
| except Exception: | |
| return {} | |
| # ββ Shared analysis logic (same as web service) ββββββββββββββββββββββββββββββ | |
| def _list_symbols_in_dir(repo_path: str) -> List[str]: | |
| symbols = [] | |
| for py_file in Path(repo_path).rglob("*.py"): | |
| rel = "./" + str(py_file.relative_to(repo_path)) | |
| try: | |
| source = py_file.read_text(errors="ignore") | |
| tree = ast.parse(source) | |
| except SyntaxError: | |
| continue | |
| for node in ast.walk(tree): | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| symbols.append(f"{rel}:{node.name}") | |
| return sorted(symbols) | |
| def _blast_radius(repo_path: str, symbol: str) -> dict: | |
| if HAS_DIFFCONTEXT: | |
| idx = index_repository(repo_path) | |
| if symbol not in idx.symbols and symbol not in idx.graph: | |
| import difflib | |
| suggestions = difflib.get_close_matches(symbol, list(idx.symbols.keys()), n=3, cutoff=0.4) | |
| return {"error": f"'{symbol}' not found.", "suggestions": suggestions} | |
| impact = analyze_impact(idx, [symbol]) | |
| reverse: Dict[str, List[str]] = {} | |
| for caller, callees in idx.graph.items(): | |
| for callee in callees: | |
| reverse.setdefault(callee, []).append(caller) | |
| by_file: Dict[str, List[str]] = {} | |
| for sym in impact.blast_radius: | |
| fp = sym.split(":")[0] | |
| by_file.setdefault(fp, []).append(sym.split(":")[-1]) | |
| return { | |
| "symbol": symbol, | |
| "direct_callers": reverse.get(symbol, []), | |
| "direct_callees": idx.graph.get(symbol, []), | |
| "blast_radius_count": len(impact.blast_radius), | |
| "blast_radius_by_file": by_file, | |
| "all_affected": impact.blast_radius[:50], | |
| } | |
| # AST fallback | |
| all_functions: Dict[str, str] = {} | |
| call_graph: Dict[str, List[str]] = {} | |
| for py_file in Path(repo_path).rglob("*.py"): | |
| rel = "./" + str(py_file.relative_to(repo_path)) | |
| try: | |
| source = py_file.read_text(errors="ignore") | |
| tree = ast.parse(source) | |
| except SyntaxError: | |
| continue | |
| for node in ast.walk(tree): | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| fid = f"{rel}:{node.name}" | |
| all_functions[fid] = "" | |
| calls = [] | |
| for child in ast.walk(node): | |
| if isinstance(child, ast.Call): | |
| if isinstance(child.func, ast.Name): | |
| calls.append(child.func.id) | |
| elif isinstance(child.func, ast.Attribute): | |
| calls.append(child.func.attr) | |
| call_graph[fid] = calls | |
| name_to_ids: Dict[str, List[str]] = {} | |
| for fid in all_functions: | |
| name_to_ids.setdefault(fid.split(":")[-1], []).append(fid) | |
| resolved_graph: Dict[str, List[str]] = {} | |
| for fid, raw_calls in call_graph.items(): | |
| resolved = [] | |
| for call in raw_calls: | |
| for candidate in name_to_ids.get(call, []): | |
| if candidate != fid: | |
| resolved.append(candidate) | |
| resolved_graph[fid] = resolved | |
| reverse: Dict[str, List[str]] = {} | |
| for caller, callees in resolved_graph.items(): | |
| for callee in callees: | |
| reverse.setdefault(callee, []).append(caller) | |
| if symbol not in all_functions: | |
| import difflib | |
| suggestions = difflib.get_close_matches(symbol, list(all_functions.keys()), n=3, cutoff=0.4) | |
| return {"error": f"'{symbol}' not found.", "suggestions": suggestions} | |
| visited = {symbol} | |
| blast = [] | |
| frontier = [symbol] | |
| while frontier: | |
| next_f = [] | |
| for node in frontier: | |
| for caller in reverse.get(node, []): | |
| if caller not in visited: | |
| visited.add(caller) | |
| blast.append(caller) | |
| next_f.append(caller) | |
| frontier = next_f | |
| by_file: Dict[str, List[str]] = {} | |
| for sym in blast: | |
| fp = sym.split(":")[0] | |
| by_file.setdefault(fp, []).append(sym.split(":")[-1]) | |
| return { | |
| "symbol": symbol, | |
| "direct_callers": reverse.get(symbol, []), | |
| "direct_callees": resolved_graph.get(symbol, []), | |
| "blast_radius_count": len(blast), | |
| "blast_radius_by_file": by_file, | |
| "all_affected": blast[:50], | |
| } | |
| # ββ MCP Tool definitions ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if HAS_MCP: | |
| TOOLS = [ | |
| Tool( | |
| name="list_symbols", | |
| description=( | |
| "List all Python functions/methods in a repository. " | |
| "Returns symbol IDs in the format './file.py:function_name'. " | |
| "Provide either repo_path (local absolute path) or repo_id (from /upload on remote hosted servers)." | |
| ), | |
| inputSchema={ | |
| "type": "object", | |
| "properties": { | |
| "repo_path": { | |
| "type": "string", | |
| "description": "Absolute path to the Python project directory on the local machine.", | |
| }, | |
| "repo_id": { | |
| "type": "string", | |
| "description": "The session ID of an uploaded repository (for remote MCP servers).", | |
| } | |
| }, | |
| "required": [], | |
| }, | |
| ), | |
| Tool( | |
| name="blast_radius", | |
| description=( | |
| "Find the blast radius of a Python function β which other functions " | |
| "will break if you change it. Returns direct callers, direct callees, " | |
| "and the full transitive set of affected functions grouped by file." | |
| ), | |
| inputSchema={ | |
| "type": "object", | |
| "properties": { | |
| "repo_path": { | |
| "type": "string", | |
| "description": "Absolute path to the Python project directory.", | |
| }, | |
| "repo_id": { | |
| "type": "string", | |
| "description": "The session ID of an uploaded repository (for remote MCP servers).", | |
| }, | |
| "symbol": { | |
| "type": "string", | |
| "description": "Symbol ID in format './relative/path.py:function_name' or './path.py:ClassName.method'", | |
| }, | |
| }, | |
| "required": ["symbol"], | |
| }, | |
| ), | |
| Tool( | |
| name="compile_context", | |
| description=( | |
| "Compile a token-efficient LLM context package for a code change. " | |
| "Returns the relevant functions and their relationships, trimmed to a token budget." | |
| ), | |
| inputSchema={ | |
| "type": "object", | |
| "properties": { | |
| "repo_path": { | |
| "type": "string", | |
| "description": "Absolute path to the Python project directory.", | |
| }, | |
| "repo_id": { | |
| "type": "string", | |
| "description": "The session ID of an uploaded repository (for remote MCP servers).", | |
| }, | |
| "symbol": { | |
| "type": "string", | |
| "description": "Symbol ID to analyze, e.g. './src/auth.py:validate_jwt'", | |
| }, | |
| "max_tokens": { | |
| "type": "integer", | |
| "description": "Token budget for context (default 8000). Lower = tighter selection.", | |
| "default": 8000, | |
| }, | |
| }, | |
| "required": ["symbol"], | |
| }, | |
| ), | |
| Tool( | |
| name="analyze_inline", | |
| description=( | |
| "Analyze Python code you paste directly β no file upload or local repo needed. " | |
| "Useful for quick analysis of code snippets. " | |
| "Returns blast radius and caller/callee relationships." | |
| ), | |
| inputSchema={ | |
| "type": "object", | |
| "properties": { | |
| "files": { | |
| "type": "object", | |
| "description": "Dict of filename -> Python source code. E.g. {'service.py': 'def foo(): ...'}", | |
| }, | |
| "symbol": { | |
| "type": "string", | |
| "description": "Symbol to analyze, e.g. './service.py:foo'", | |
| }, | |
| }, | |
| "required": ["files", "symbol"], | |
| }, | |
| ), | |
| ] | |
| else: | |
| TOOLS = [] | |
| # ββ MCP Server ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if HAS_MCP: | |
| server = Server("diffcontext") | |
| async def list_tools(): | |
| return TOOLS | |
| async def call_tool(name: str, arguments: Dict[str, Any]): | |
| try: | |
| repo_path = arguments.get("repo_path") | |
| repo_id = arguments.get("repo_id") | |
| if name != "analyze_inline" and not repo_path and not repo_id: | |
| result = {"error": "Must provide either 'repo_path' or 'repo_id'."} | |
| else: | |
| if repo_id: | |
| sessions = _get_sessions() | |
| if repo_id not in sessions: | |
| result = {"error": f"repo_id '{repo_id}' not found. Please upload the repository zip first."} | |
| return [TextContent(type="text", text=json.dumps(result, indent=2))] | |
| repo_path = sessions[repo_id] | |
| if name == "list_symbols": | |
| if not os.path.isdir(repo_path): | |
| result = {"error": f"Directory not found: {repo_path}"} | |
| else: | |
| symbols = _list_symbols_in_dir(repo_path) | |
| result = {"count": len(symbols), "symbols": symbols} | |
| elif name == "blast_radius": | |
| symbol = arguments["symbol"] | |
| if not os.path.isdir(repo_path): | |
| result = {"error": f"Directory not found: {repo_path}"} | |
| else: | |
| result = _blast_radius(repo_path, symbol) | |
| elif name == "compile_context": | |
| symbol = arguments["symbol"] | |
| max_tokens = arguments.get("max_tokens", 8000) | |
| if not os.path.isdir(repo_path): | |
| result = {"error": f"Directory not found: {repo_path}"} | |
| elif not HAS_DIFFCONTEXT: | |
| br = _blast_radius(repo_path, symbol) | |
| result = { | |
| "context": json.dumps(br, indent=2), | |
| "note": "Install DiffContext for full compiled context.", | |
| } | |
| else: | |
| idx = index_repository(repo_path) | |
| impact = analyze_impact(idx, [symbol]) | |
| ctx = dc_compile(idx, impact, max_tokens=max_tokens) | |
| result = { | |
| "context": ctx.text, | |
| "symbol_count": ctx.symbol_count, | |
| "token_estimate": ctx.token_estimate, | |
| "reduction_pct": round(ctx.reduction_pct, 1), | |
| } | |
| elif name == "analyze_inline": | |
| files = arguments["files"] | |
| symbol = arguments["symbol"] | |
| tmp_dir = tempfile.mkdtemp(prefix="diffctx_mcp_") | |
| try: | |
| files_dict = {} | |
| if isinstance(files, list): | |
| for f in files: | |
| if isinstance(f, dict): | |
| files_dict[f.get("filename")] = f.get("content") | |
| else: | |
| files_dict = files | |
| for filename, code in files_dict.items(): | |
| safe_name = Path(filename).name | |
| with open(os.path.join(tmp_dir, safe_name), "w") as f: | |
| f.write(code) | |
| result = _blast_radius(tmp_dir, symbol) | |
| finally: | |
| shutil.rmtree(tmp_dir, ignore_errors=True) | |
| else: | |
| result = {"error": f"Unknown tool: {name}"} | |
| except Exception as e: | |
| result = {"error": str(e)} | |
| return [TextContent(type="text", text=json.dumps(result, indent=2))] | |
| else: | |
| server = None | |
| async def run_mcp_server(): | |
| if not HAS_MCP: | |
| print( | |
| "ERROR: 'mcp' package not installed.\n" | |
| "Fix: pip install mcp\n", | |
| file=sys.stderr, | |
| ) | |
| sys.exit(1) | |
| async with stdio_server() as (read_stream, write_stream): | |
| await server.run(read_stream, write_stream, server.create_initialization_options()) | |
| def main(): | |
| import asyncio | |
| print("Starting DiffContext MCP Server...", file=sys.stderr) | |
| print(f"DiffContext installed: {HAS_DIFFCONTEXT}", file=sys.stderr) | |
| print(f"MCP installed: {HAS_MCP}", file=sys.stderr) | |
| if not HAS_MCP: | |
| print("\nTo install: pip install mcp", file=sys.stderr) | |
| sys.exit(1) | |
| asyncio.run(run_mcp_server()) | |
| if __name__ == "__main__": | |
| main() | |