sql-query-reviewer / inference.py
hellinferno's picture
improve: 20 tasks, richer keywords, enhanced reward/grader, bigram matching, compelling README
b83c8ad
"""
Inference Script β€” SQL Query Reviewer
======================================
MANDATORY environment variables:
API_BASE_URL The API endpoint for the LLM.
MODEL_NAME The model identifier to use for inference.
HF_TOKEN Your Hugging Face / API key.
LOCAL_IMAGE_NAME The name of the local Docker image for the environment.
STDOUT FORMAT (must match exactly):
[START] task=<task_name> env=<benchmark> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
"""
from __future__ import annotations
import asyncio
import json
import os
import sys
import traceback
from typing import Any, List, Optional
from openai import OpenAI
# ---------------------------------------------------------------------------
# Environment variables β€” read ALL names the validator might set
# ---------------------------------------------------------------------------
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") or os.getenv("IMAGE_NAME")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
BENCHMARK = "sql-query-reviewer"
MAX_STEPS = 10
SUCCESS_SCORE_THRESHOLD = 0.1
SYSTEM_PROMPT = """You are reviewing a SQL query for correctness, performance, and security.
Return exactly one JSON object with these keys:
- action_type: identify_issue, suggest_fix, approve, or request_more_context
- issue_category: syntax, performance, security, logic, or style (when relevant)
- issue_description: concise issue statement (when relevant)
- suggested_fix: corrected SQL or corrected fragment (when relevant)
- confidence: float between 0.0 and 1.0
Guidelines:
- Prefer identify_issue until you believe all important issues are covered.
- Use approve only when the query looks acceptable or all issues have been identified.
- Return ONLY valid JSON, no prose, no markdown fences.
"""
# ---------------------------------------------------------------------------
# Structured stdout logging β€” matches hackathon spec EXACTLY
# ---------------------------------------------------------------------------
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(
step: int, action: str, reward: float, done: bool, error: Optional[str]
) -> None:
done_str = str(done).lower()
error_str = error if error else "null"
print(
f"[STEP] step={step} action={action} reward={reward:.2f} "
f"done={done_str} error={error_str}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} "
f"score={score:.2f} rewards={rewards_str}",
flush=True,
)
# ---------------------------------------------------------------------------
# LLM interaction β€” fully wrapped in try/except
# ---------------------------------------------------------------------------
def extract_json(content: str) -> dict[str, Any]:
stripped = content.strip()
if stripped.startswith("```"):
lines = [l for l in stripped.splitlines() if not l.startswith("```")]
stripped = "\n".join(lines).strip()
start = stripped.find("{")
end = stripped.rfind("}")
if start == -1 or end == -1 or end <= start:
raise ValueError(f"No JSON object found in: {content[:200]!r}")
return json.loads(stripped[start : end + 1])
def choose_action(llm_client: Any, model_name: str, observation: Any) -> dict[str, Any]:
"""Ask LLM for an action. Returns a raw dict. NEVER raises."""
try:
obs_data = {
"query": getattr(observation, "query", ""),
"schema_info": getattr(observation, "schema_info", {}),
"context": getattr(observation, "context", ""),
"issues_found_so_far": [],
"remaining_actions": getattr(observation, "remaining_actions", 0),
"difficulty": getattr(observation, "difficulty", "unknown"),
"feedback": getattr(observation, "feedback", ""),
}
# Safely serialize issues
for i in getattr(observation, "issues_found_so_far", []):
try:
obs_data["issues_found_so_far"].append(
i.model_dump() if hasattr(i, "model_dump") else str(i)
)
except Exception:
pass
response = llm_client.chat.completions.create(
model=model_name,
temperature=0,
max_tokens=300,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": json.dumps(obs_data, indent=2)},
],
)
content = response.choices[0].message.content or ""
parsed = extract_json(content)
if "action_type" not in parsed:
parsed["action_type"] = "approve"
return parsed
except Exception as exc:
print(f"[DEBUG] choose_action error: {exc}", flush=True)
return {"action_type": "approve", "confidence": 0.1}
# ---------------------------------------------------------------------------
# Episode runner (async) β€” for from_docker_image connection
# ---------------------------------------------------------------------------
async def run_episode_async(
env: Any, llm_client: Any, model_name: str, task_id: str
) -> None:
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_id, env=BENCHMARK, model=model_name)
try:
from sql_query_reviewer.models import SQLReviewAction
result = await env.reset()
for step_num in range(1, MAX_STEPS + 1):
if result.done:
break
action_dict = choose_action(llm_client, model_name, result.observation)
try:
action = SQLReviewAction.model_validate(action_dict)
except Exception:
action = SQLReviewAction(action_type="approve", confidence=0.1)
action_str = action.action_type
if action.issue_description:
action_str = f"{action.action_type}({action.issue_category})"
try:
result = await env.step(action)
except Exception as step_err:
print(f"[DEBUG] env.step error: {step_err}", flush=True)
log_step(step=step_num, action=action_str, reward=0.0, done=True, error=str(step_err))
rewards.append(0.0)
steps_taken = step_num
break
reward = result.reward if result.reward is not None else 0.0
rewards.append(reward)
steps_taken = step_num
error_msg = None
if hasattr(result, "info") and result.info:
error_msg = result.info.get("error")
log_step(step=step_num, action=action_str, reward=reward, done=result.done, error=error_msg)
if result.done:
break
# Final score
try:
state = await env.state()
if hasattr(state, "final_score") and state.final_score is not None:
score = state.final_score
else:
score = sum(rewards) / max(len(rewards), 1)
except Exception:
score = sum(rewards) / max(len(rewards), 1) if rewards else 0.0
score = max(0.01, min(0.99, score))
success = score >= SUCCESS_SCORE_THRESHOLD
except Exception as exc:
print(f"[DEBUG] Episode error: {exc}", flush=True)
traceback.print_exc(file=sys.stdout)
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
# ---------------------------------------------------------------------------
# Episode runner (sync) β€” for direct HTTP connection
# ---------------------------------------------------------------------------
def run_episode(env: Any, llm_client: Any, model_name: str, task_id: str) -> None:
"""Public episode runner expected by tests.
Uses the synchronous env interface: reset(task_id=...), step(...), state().
"""
return run_episode_sync(env, llm_client, model_name, task_id)
def run_episode_sync(
env: Any, llm_client: Any, model_name: str, task_id: str
) -> None:
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_id, env=BENCHMARK, model=model_name)
try:
from sql_query_reviewer.models import SQLReviewAction
result = env.reset(task_id=task_id)
for step_num in range(1, MAX_STEPS + 1):
if result.done:
break
action_dict = choose_action(llm_client, model_name, result.observation)
try:
action = SQLReviewAction.model_validate(action_dict)
except Exception:
action = SQLReviewAction(action_type="approve", confidence=0.1)
action_str = action.action_type
if action.issue_description:
action_str = f"{action.action_type}({action.issue_category})"
try:
result = env.step(action)
except Exception as step_err:
print(f"[DEBUG] env.step error: {step_err}", flush=True)
log_step(step=step_num, action=action_str, reward=0.0, done=True, error=str(step_err))
rewards.append(0.0)
steps_taken = step_num
break
reward = result.reward if result.reward is not None else 0.0
rewards.append(reward)
steps_taken = step_num
error_msg = None
if hasattr(result, "info") and result.info:
error_msg = result.info.get("error")
log_step(step=step_num, action=action_str, reward=reward, done=result.done, error=error_msg)
if result.done:
break
try:
state = env.state()
if hasattr(state, "final_score") and state.final_score is not None:
score = state.final_score
else:
score = sum(rewards) / max(len(rewards), 1)
except Exception:
score = sum(rewards) / max(len(rewards), 1) if rewards else 0.0
score = max(0.01, min(0.99, score))
success = score >= SUCCESS_SCORE_THRESHOLD
except Exception as exc:
print(f"[DEBUG] Episode error: {exc}", flush=True)
traceback.print_exc(file=sys.stdout)
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
# ---------------------------------------------------------------------------
# Main β€” tries Docker image first (validator), then HTTP (local dev)
# ---------------------------------------------------------------------------
async def async_main() -> int:
# Build LLM client (even without key, don't crash β€” emit logs and exit)
if not API_KEY:
print("[DEBUG] WARNING: No API key found (HF_TOKEN / API_KEY / OPENAI_API_KEY)", flush=True)
_fallback_ids = [
"easy_001", "easy_002", "easy_003", "easy_004", "easy_005", "easy_006", "easy_007",
"medium_001", "medium_002", "medium_003", "medium_004", "medium_005", "medium_006", "medium_007",
"hard_001", "hard_002", "hard_003", "hard_004", "hard_005", "hard_006",
]
for tid in _fallback_ids:
log_start(task=tid, env=BENCHMARK, model=MODEL_NAME)
log_end(success=False, steps=0, score=0.01, rewards=[])
return 1
llm_client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
_default_ids = ",".join([
"easy_001", "easy_002", "easy_003", "easy_004", "easy_005", "easy_006", "easy_007",
"medium_001", "medium_002", "medium_003", "medium_004", "medium_005", "medium_006", "medium_007",
"hard_001", "hard_002", "hard_003", "hard_004", "hard_005", "hard_006",
])
task_ids = tuple(
tid.strip()
for tid in os.getenv("TASK_IDS", _default_ids).split(",")
if tid.strip()
)
env = None
try:
# ------------------------------------------------------------------
# Method 1: from_docker_image (what the hackathon validator uses)
# ------------------------------------------------------------------
if IMAGE_NAME:
print(f"[DEBUG] Connecting via Docker image: {IMAGE_NAME}", flush=True)
try:
from sql_query_reviewer.client import SQLReviewEnv
env = await SQLReviewEnv.from_docker_image(IMAGE_NAME)
print("[DEBUG] Docker connection OK", flush=True)
for task_id in task_ids:
await run_episode_async(env, llm_client, MODEL_NAME, task_id)
return 0
except AttributeError:
# from_docker_image not implemented in our client β€” try openenv-core
print("[DEBUG] from_docker_image not in custom client, trying openenv generic", flush=True)
try:
from openenv.core.env_client import GenericEnvClient
env = await GenericEnvClient.from_docker_image(IMAGE_NAME)
print("[DEBUG] GenericEnvClient Docker connection OK", flush=True)
for task_id in task_ids:
await run_episode_async(env, llm_client, MODEL_NAME, task_id)
return 0
except Exception as exc2:
print(f"[DEBUG] GenericEnvClient also failed: {exc2}", flush=True)
except Exception as exc:
print(f"[DEBUG] Docker connection failed: {exc}", flush=True)
# ------------------------------------------------------------------
# Method 2: Async HTTP (fallback for local/URL-based testing)
# ------------------------------------------------------------------
print(f"[DEBUG] Connecting via URL: {ENV_BASE_URL}", flush=True)
try:
from sql_query_reviewer.client import SQLReviewEnv
env = SQLReviewEnv(base_url=ENV_BASE_URL)
await env.__aenter__()
print("[DEBUG] Async URL connection OK", flush=True)
for task_id in task_ids:
await run_episode_async(env, llm_client, MODEL_NAME, task_id)
return 0
except Exception as exc:
print(f"[DEBUG] Async URL failed: {exc}", flush=True)
# ------------------------------------------------------------------
# Method 3: Sync HTTP (last resort)
# ------------------------------------------------------------------
try:
from sql_query_reviewer.client import SyncSQLReviewEnv
sync_env = SyncSQLReviewEnv(base_url=ENV_BASE_URL)
sync_env.__enter__()
print("[DEBUG] Sync HTTP connection OK", flush=True)
for task_id in task_ids:
run_episode_sync(sync_env, llm_client, MODEL_NAME, task_id)
sync_env.close()
return 0
except Exception as exc:
print(f"[DEBUG] Sync HTTP also failed: {exc}", flush=True)
# All methods failed β€” still emit valid log lines
print("[DEBUG] All connection methods exhausted", flush=True)
for task_id in task_ids:
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
log_end(success=False, steps=0, score=0.01, rewards=[])
return 1
except Exception as exc:
print(f"[DEBUG] Fatal: {exc}", flush=True)
traceback.print_exc(file=sys.stdout)
return 1
finally:
if env is not None:
try:
await env.close()
except Exception as close_err:
print(f"[DEBUG] env.close() error: {close_err}", flush=True)
def main() -> int:
return asyncio.run(async_main())
if __name__ == "__main__":
raise SystemExit(main())