Spaces:
Sleeping
Sleeping
| import json | |
| import logging | |
| import os | |
| import re | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| from pydantic import ValidationError | |
| try: | |
| from dotenv import load_dotenv | |
| except ImportError: | |
| def load_dotenv(): | |
| return False | |
| try: | |
| from openai import OpenAI | |
| except ImportError: | |
| OpenAI = None | |
| from envs.environment import WorkSpaceEnvironment | |
| from models.schemas import WorkSpaceAction, WorkspaceState | |
| from prompter.system_prompt import SystemPrompt | |
| load_dotenv() | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| SCRIPTED_QUESTIONS = { | |
| "Finance": ( | |
| "Hi Finance, what budget guardrails should the PRD lock in for the first release? " | |
| "Please call out the hard budget cap and any scope discipline we should preserve." | |
| ), | |
| "Security": ( | |
| "Hi Security, what authentication requirement is non-negotiable for this app? " | |
| "Please tell me the strongest user-verification control that must appear in the PRD." | |
| ), | |
| "UX": ( | |
| "Hi UX, what checkout experience must the PRD guarantee for launch? " | |
| "Please describe the required conversion flow in plain terms." | |
| ), | |
| } | |
| def normalize_agent_mode(mode: str | None) -> str: | |
| canonical = (mode or "").strip().lower() | |
| aliases = { | |
| "": "scripted", | |
| "scripted": "scripted", | |
| "medium": "medium", | |
| "mock": "scripted", | |
| "deterministic": "scripted", | |
| "llm": "llm", | |
| "local": "local", | |
| "trained": "local", | |
| "live": "llm", | |
| "online": "llm", | |
| "remote": "llm", | |
| "api": "llm", | |
| } | |
| if canonical not in aliases: | |
| raise ValueError(f"Unsupported agent mode: {mode}") | |
| return aliases[canonical] | |
| class AgentDecision: | |
| action: Optional[WorkSpaceAction] | |
| status: str = "ok" | |
| error: Optional[str] = None | |
| raw_response: Optional[str] = None | |
| class AgentWrapper: | |
| def __init__(self, mode: str | None = None): | |
| requested_mode = mode or os.getenv("BASELINE_AGENT_MODE") or "scripted" | |
| self.mode = normalize_agent_mode(requested_mode) | |
| self.model_name = os.getenv("AGENT_MODEL_NAME") or os.getenv("MODEL_NAME") or "llama-3.1-8b-instant" | |
| self.prompt_builder = SystemPrompt() | |
| self.client: object | None = None | |
| self.local_model = None | |
| self.local_tokenizer = None | |
| self._torch = None | |
| if self.mode == "llm": | |
| if OpenAI is None: | |
| raise RuntimeError("openai package is required for llm agent mode.") | |
| self.client = OpenAI( | |
| base_url=os.getenv("AGENT_API_BASE_URL") or os.getenv("API_BASE_URL_1"), | |
| api_key=os.getenv("AGENT_API_KEY") or os.getenv("GROQ_API_KEY"), | |
| timeout=45.0, | |
| max_retries=2, | |
| ) | |
| elif self.mode == "local": | |
| try: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| except ImportError as exc: | |
| raise RuntimeError("transformers and torch are required for local agent mode.") from exc | |
| model_path = os.getenv("LOCAL_AGENT_MODEL_PATH") | |
| if not model_path: | |
| raise RuntimeError("Set LOCAL_AGENT_MODEL_PATH for local agent mode.") | |
| self._torch = torch | |
| self.local_tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| if self.local_tokenizer.pad_token is None: | |
| self.local_tokenizer.pad_token = self.local_tokenizer.eos_token | |
| self.local_model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| ) | |
| self.reset_episode() | |
| def reset_episode(self): | |
| self.scripted_targets = ["Finance", "Security", "UX"] | |
| self.final_draft = self._build_final_draft() | |
| def get_action( | |
| self, | |
| observation_text: str, | |
| conversation_history: list[dict[str, str]], | |
| discovered_constraints: str, | |
| ) -> AgentDecision: | |
| if self.mode == "scripted": | |
| return self._scripted_action(observation_text) | |
| if self.mode == "local": | |
| return self._local_action(observation_text, conversation_history, discovered_constraints) | |
| return self._llm_action(observation_text, conversation_history, discovered_constraints) | |
| def _scripted_action(self, observation_text: str) -> AgentDecision: | |
| current_turn = self._extract_turn(observation_text) | |
| # Gather constraints 1-by-1 | |
| if current_turn < len(self.scripted_targets): | |
| target = self.scripted_targets[current_turn] | |
| return AgentDecision( | |
| action=WorkSpaceAction( | |
| action_type="message_expert", | |
| target=target, | |
| content=SCRIPTED_QUESTIONS[target], | |
| ) | |
| ) | |
| # Propose a draft on Turn 4 | |
| if current_turn == len(self.scripted_targets): | |
| return AgentDecision( | |
| action=WorkSpaceAction( | |
| action_type="propose_draft", | |
| target="All", | |
| content=self._build_draft_proposal(), | |
| ) | |
| ) | |
| if current_turn == len(self.scripted_targets) + 1: | |
| return AgentDecision( | |
| action=WorkSpaceAction( | |
| action_type="submit_final", | |
| target=None, | |
| content=self.final_draft, | |
| ) | |
| ) | |
| return AgentDecision(action=None, status="completed") | |
| def _llm_action( | |
| self, | |
| observation_text: str, | |
| conversation_history: list[dict[str, str]], | |
| discovered_constraints: str, | |
| ) -> AgentDecision: | |
| if self.client is None: | |
| return AgentDecision( | |
| action=None, | |
| status="infra_error", | |
| error="Agent client is not configured for llm mode.", | |
| ) | |
| system_prompt = self.prompt_builder.system_prompt( | |
| conversation_history=self._render_history(conversation_history), | |
| discovered=discovered_constraints, | |
| ) | |
| try: | |
| response = self.client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| *conversation_history, | |
| {"role": "user", "content": observation_text}, | |
| ], | |
| model=self.model_name, | |
| temperature=0.2, | |
| max_tokens=2048, | |
| response_format={"type": "json_object"} | |
| ) | |
| except Exception as exc: | |
| logger.error(f"Agent API Error: {exc}") | |
| return AgentDecision(action=None, status="infra_error", error=str(exc)) | |
| raw_text = (response.choices[0].message.content or "").strip() | |
| json_match = re.search(r"\{.*?\}", raw_text, re.DOTALL) | |
| if not json_match: | |
| return AgentDecision( | |
| action=None, | |
| status="parse_error", | |
| error="Model response did not contain a JSON object.", | |
| raw_response=raw_text, | |
| ) | |
| try: | |
| payload = json.loads(json_match.group(0)) | |
| except json.JSONDecodeError as exc: | |
| return AgentDecision( | |
| action=None, | |
| status="parse_error", | |
| error=f"Invalid JSON payload: {exc}", | |
| raw_response=raw_text, | |
| ) | |
| try: | |
| action = WorkSpaceAction(**payload) | |
| except ValidationError as exc: | |
| return AgentDecision( | |
| action=None, | |
| status="policy_error", | |
| error=f"Schema validation failed: {exc}", | |
| raw_response=raw_text, | |
| ) | |
| semantic_error = self._validate_action(action) | |
| if semantic_error: | |
| return AgentDecision( | |
| action=None, | |
| status="policy_error", | |
| error=semantic_error, | |
| raw_response=raw_text, | |
| ) | |
| return AgentDecision(action=action, raw_response=raw_text) | |
| def _local_action( | |
| self, | |
| observation_text: str, | |
| conversation_history: list[dict[str, str]], | |
| discovered_constraints: str, | |
| ) -> AgentDecision: | |
| if self.local_model is None or self.local_tokenizer is None: | |
| return AgentDecision( | |
| action=None, | |
| status="infra_error", | |
| error="Local model is not configured for local agent mode.", | |
| ) | |
| system_prompt = self.prompt_builder.system_prompt( | |
| conversation_history=self._render_history(conversation_history), | |
| discovered=discovered_constraints, | |
| ) | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| *conversation_history, | |
| {"role": "user", "content": observation_text}, | |
| ] | |
| try: | |
| if hasattr(self.local_tokenizer, "apply_chat_template"): | |
| prompt_text = self.local_tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| else: | |
| prompt_text = ( | |
| f"System: {system_prompt}\n" | |
| + "\n".join(f"{m['role']}: {m['content']}" for m in conversation_history) | |
| + f"\nuser: {observation_text}\nassistant:" | |
| ) | |
| inputs = self.local_tokenizer(prompt_text, return_tensors="pt") | |
| inputs = {k: v.to(self.local_model.device) for k, v in inputs.items()} | |
| prompt_len = inputs["input_ids"].shape[1] | |
| with self._torch.no_grad(): | |
| output_ids = self.local_model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| do_sample=False, | |
| temperature=0.0, | |
| pad_token_id=self.local_tokenizer.pad_token_id, | |
| ) | |
| completion_ids = output_ids[0][prompt_len:] | |
| raw_text = self.local_tokenizer.decode(completion_ids, skip_special_tokens=True).strip() | |
| except Exception as exc: | |
| logger.error(f"Local Agent Error: {exc}") | |
| return AgentDecision(action=None, status="infra_error", error=str(exc)) | |
| json_match = re.search(r"\{.*?\}", raw_text, re.DOTALL) | |
| if not json_match: | |
| return AgentDecision( | |
| action=None, | |
| status="parse_error", | |
| error="Model response did not contain a JSON object.", | |
| raw_response=raw_text, | |
| ) | |
| try: | |
| payload = json.loads(json_match.group(0)) | |
| except json.JSONDecodeError as exc: | |
| return AgentDecision( | |
| action=None, | |
| status="parse_error", | |
| error=f"Invalid JSON payload: {exc}", | |
| raw_response=raw_text, | |
| ) | |
| try: | |
| action = WorkSpaceAction(**payload) | |
| except ValidationError as exc: | |
| return AgentDecision( | |
| action=None, | |
| status="policy_error", | |
| error=f"Schema validation failed: {exc}", | |
| raw_response=raw_text, | |
| ) | |
| semantic_error = self._validate_action(action) | |
| if semantic_error: | |
| return AgentDecision( | |
| action=None, | |
| status="policy_error", | |
| error=semantic_error, | |
| raw_response=raw_text, | |
| ) | |
| return AgentDecision(action=action, raw_response=raw_text) | |
| def _build_draft_proposal(self) -> str: | |
| return ( | |
| "Draft PRD proposal for the mobile app MVP:\n" | |
| "- Keep the initial release budget capped at $50k and prioritize the highest-ROI scope.\n" | |
| "- Require biometric 2FA for sign-in and sensitive actions.\n" | |
| "- Deliver a true single-click checkout so the purchase flow stays low-friction." | |
| ) | |
| def _build_final_draft(self) -> str: | |
| return ( | |
| "Mobile App PRD Final Draft\n" | |
| "1. Budget and scope: The first release must stay at or below a $50k budget cap, with the MVP limited to the highest-ROI features.\n" | |
| "2. Security: The app must require biometric 2FA for login and other sensitive account actions.\n" | |
| "3. UX: Checkout must be implemented as a single-click checkout flow with minimal friction for the user.\n" | |
| "4. Delivery focus: Product, design, and engineering should keep the implementation lean so these launch requirements are met without scope creep." | |
| ) | |
| def _validate_action(self, action: WorkSpaceAction) -> Optional[str]: | |
| if not action.content.strip(): | |
| return "Action content cannot be empty." | |
| if action.action_type == "message_expert" and action.target is None: | |
| return "message_expert actions must include a target expert." | |
| if action.action_type == "message_expert" and action.target == "All": | |
| return "message_expert must target exactly one expert; do not use target='All'." | |
| if action.action_type == "propose_draft" and action.target != "All": | |
| return "propose_draft actions must use target='All' to collect multi-expert draft feedback." | |
| if action.action_type == "submit_final" and action.target is not None: | |
| return "submit_final actions must use target=null." | |
| return None | |
| def _render_history(self, conversation_history: list[dict[str, str]], max_items: int = 8) -> str: | |
| if not conversation_history: | |
| return "No prior conversation yet." | |
| rendered = [] | |
| for message in conversation_history[-max_items:]: | |
| content = message["content"].replace("\n", " ").strip() | |
| rendered.append(f"{message['role']}: {content}") | |
| return "\n".join(rendered) | |
| def _extract_turn(self, observation_text: str) -> int: | |
| match = re.search(r"Turn\s+(\d+)", observation_text) | |
| return int(match.group(1)) if match else 0 | |
| def get_discovered_constraints(self, state: WorkspaceState) -> str: | |
| lines = [] | |
| for name, expert in state.experts.items(): | |
| if expert.constraint_discovered_by_agent: | |
| lines.append(f"{name}: discovered from prior expert feedback.") | |
| else: | |
| lines.append(f"{name}: still unknown.") | |
| return "\n".join(lines) | |
| def summarize_results(results: list[dict], episodes_requested: int, agent_mode: str, env_mode: str) -> dict: | |
| status_counts: dict[str, int] = {} | |
| for result in results: | |
| status_counts[result["status"]] = status_counts.get(result["status"], 0) + 1 | |
| completed = [result for result in results if result["status"] == "completed"] | |
| avg_cumulative = None | |
| avg_final = None | |
| avg_turns = None | |
| all_constraints_discovered_rate = None | |
| finance_discovery_rate = None | |
| security_discovery_rate = None | |
| ux_discovery_rate = None | |
| if completed: | |
| avg_cumulative = round( | |
| sum(result["cumulative_reward"] for result in completed) / len(completed), | |
| 3, | |
| ) | |
| avg_final = round( | |
| sum(result["final_step_reward"] for result in completed) / len(completed), | |
| 3, | |
| ) | |
| avg_turns = round( | |
| sum(result["turns_completed"] for result in completed) / len(completed), | |
| 2, | |
| ) | |
| def has_discovery(result: dict, expert_name: str) -> bool: | |
| marker = f"{expert_name}: discovered from prior expert feedback." | |
| return marker in (result.get("discovered_constraints") or "") | |
| finance_hits = sum(1 for result in completed if has_discovery(result, "Finance")) | |
| security_hits = sum(1 for result in completed if has_discovery(result, "Security")) | |
| ux_hits = sum(1 for result in completed if has_discovery(result, "UX")) | |
| all_constraints_hits = sum( | |
| 1 | |
| for result in completed | |
| if has_discovery(result, "Finance") | |
| and has_discovery(result, "Security") | |
| and has_discovery(result, "UX") | |
| ) | |
| finance_discovery_rate = round(finance_hits / len(completed), 3) | |
| security_discovery_rate = round(security_hits / len(completed), 3) | |
| ux_discovery_rate = round(ux_hits / len(completed), 3) | |
| all_constraints_discovered_rate = round(all_constraints_hits / len(completed), 3) | |
| return { | |
| "episodes_requested": episodes_requested, | |
| "episodes_completed": len(completed), | |
| "completion_rate": round(len(completed) / episodes_requested, 3) if episodes_requested else 0.0, | |
| "average_cumulative_reward_completed": avg_cumulative, | |
| "average_final_step_reward_completed": avg_final, | |
| "average_turns_completed": avg_turns, | |
| "all_constraints_discovered_rate": all_constraints_discovered_rate, | |
| "finance_discovery_rate": finance_discovery_rate, | |
| "security_discovery_rate": security_discovery_rate, | |
| "ux_discovery_rate": ux_discovery_rate, | |
| "status_counts": status_counts, | |
| "agent_mode": agent_mode, | |
| "environment_mode": env_mode, | |
| } | |
| def record_baseline(episodes: Optional[int] = None): | |
| episodes = episodes or int(os.getenv("BASELINE_EPISODES", "10")) | |
| step_delay = float(os.getenv("BASELINE_STEP_DELAY", "0")) | |
| agent_mode = os.getenv("BASELINE_AGENT_MODE") or "scripted" | |
| env_mode = os.getenv("BASELINE_ENV_MODE") or "mock" | |
| env = WorkSpaceEnvironment(mode=env_mode) | |
| agent = AgentWrapper(mode=agent_mode) | |
| all_results = [] | |
| print( | |
| f"Starting Baseline Recording for {episodes} episodes " | |
| f"(agent_mode={agent.mode}, env_mode={env.mode})..." | |
| ) | |
| for i in range(episodes): | |
| obs = env.reset() | |
| agent.reset_episode() | |
| conversation_history: list[dict[str, str]] = [] | |
| cumulative_reward = 0.0 | |
| step_rewards: list[float] = [] | |
| episode_result: Optional[dict] = None | |
| print(f"\n--- Episode {i + 1} ---") | |
| while not obs.done: | |
| prompt = f"Turn {obs.current_turn}. Feedback: {obs.feedback}" | |
| discovered = agent.get_discovered_constraints(env.state()) | |
| decision = agent.get_action(prompt, conversation_history, discovered) | |
| if obs.current_turn >= 4: | |
| prompt += "\n\nCRITICAL SYSTEM OVERRIDE: You are out of time. You MUST output a JSON with action_type: 'submit_final' right now. Do not message anyone else." | |
| if decision.status != "ok" or decision.action is None: | |
| episode_result = { | |
| "episode": i + 1, | |
| "status": decision.status, | |
| "error_source": "agent", | |
| "error_detail": decision.error, | |
| "raw_response": decision.raw_response, | |
| "final_step_reward": step_rewards[-1] if step_rewards else None, | |
| "cumulative_reward": round(cumulative_reward, 3), | |
| "step_rewards": step_rewards, | |
| "turns_completed": obs.current_turn, | |
| "discovered_constraints": discovered, | |
| "chat_history": env.state().chat_history, | |
| } | |
| print(f" {decision.status.upper()}: Episode {i + 1} ended early") | |
| break | |
| action = decision.action | |
| print(f"Agent Action: {action.action_type} -> {action.target}") | |
| conversation_history.append({"role": "user", "content": prompt}) | |
| conversation_history.append({"role": "assistant", "content": action.model_dump_json()}) | |
| try: | |
| obs = env.step(action) | |
| except Exception as exc: | |
| logger.error(f"Environment step failed: {exc}") | |
| episode_result = { | |
| "episode": i + 1, | |
| "status": "infra_error", | |
| "error_source": "environment", | |
| "error_detail": str(exc), | |
| "raw_response": None, | |
| "final_step_reward": step_rewards[-1] if step_rewards else None, | |
| "cumulative_reward": round(cumulative_reward, 3), | |
| "step_rewards": step_rewards, | |
| "turns_completed": env.state().turn_count, | |
| "discovered_constraints": agent.get_discovered_constraints(env.state()), | |
| "chat_history": env.state().chat_history, | |
| } | |
| print(f" INFRA_ERROR: Environment failed during episode {i + 1}") | |
| break | |
| cumulative_reward = round(cumulative_reward + obs.reward, 3) | |
| step_rewards.append(obs.reward) | |
| if step_delay > 0: | |
| time.sleep(step_delay) | |
| if episode_result is None: | |
| episode_result = { | |
| "episode": i + 1, | |
| "status": "completed", | |
| "error_source": None, | |
| "error_detail": None, | |
| "raw_response": None, | |
| "final_step_reward": obs.reward, | |
| "cumulative_reward": cumulative_reward, | |
| "step_rewards": step_rewards, | |
| "turns_completed": obs.current_turn, | |
| "discovered_constraints": agent.get_discovered_constraints(env.state()), | |
| "chat_history": env.state().chat_history, | |
| } | |
| print( | |
| f"Episode {i + 1} completed in {obs.current_turn} turns. " | |
| f"Final step reward: {obs.reward:.3f} | Cumulative reward: {cumulative_reward:.3f}" | |
| ) | |
| else: | |
| print(f"Episode {i + 1} status: {episode_result['status']}") | |
| all_results.append(episode_result) | |
| summary = summarize_results(all_results, episodes, agent.mode, env.mode) | |
| output_payload = { | |
| "summary": summary, | |
| "episodes": all_results, | |
| } | |
| with open("baseline_results.json", "w", encoding="utf-8") as file: | |
| json.dump(output_payload, file, indent=4) | |
| print("\nBaseline summary:") | |
| print(json.dumps(summary, indent=4)) | |
| print("Saved to baseline_results.json.") | |
| if __name__ == "__main__": | |
| record_baseline() | |