Spaces:
Running
Running
ajaxwin
refactor: Update API base URL and model name, enhance message handling with history tracking
e5b8b13 | """ | |
| inference.py | |
| ------------ | |
| Inference script β Smart Contract Audit RL Environment. | |
| Implements agents for all three tasks using the Groq client. | |
| Emits mandatory structured stdout in the OpenEnv format. | |
| MANDATORY ENV VARS: | |
| HF_TOKEN Hugging Face Token (required) | |
| MODEL_NAME Model identifier (default: openai/gpt-oss-20b) | |
| MANDATORY STDOUT FORMAT (per episode): | |
| [START] task=<id> env=smart-contract-audit model=<model> | |
| [STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<str|null> | |
| [END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...> | |
| Usage: | |
| python inference.py | |
| Output: | |
| Structured stdout per episode, plus baseline_scores.json summary. | |
| """ | |
| import asyncio | |
| import json | |
| import os | |
| import sys | |
| from collections import deque | |
| from typing import Any, Dict, List, Optional, Callable, Awaitable, Union | |
| from openai import AsyncOpenAI | |
| from dotenv import load_dotenv | |
| from server import Task1Environment, Task2Environment, Task3Environment | |
| from env.schemas import Action, ActionType | |
| from utils import T1_SYSTEM, T2_SYSTEM, T3_SYSTEM | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Configuration | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| load_dotenv() | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1/") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "CohereLabs/tiny-aya-fire:cohere") | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") | |
| if not HF_TOKEN: | |
| raise RuntimeError("HF_TOKEN environment variable not set") | |
| if not MODEL_NAME: | |
| raise RuntimeError("MODEL_NAME not set") | |
| client = AsyncOpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL) | |
| # Benchmark / environment identifier (constant for this env) | |
| ENV_BENCHMARK = "smart-contract-audit" | |
| # Episodes per task | |
| NUM_EPISODES = 5 | |
| SEED_BASE = 42 | |
| # Max steps per task (same for all tasks) | |
| MAX_STEPS = 35 | |
| # A grader_score >= this is considered a "success" for the [END] line | |
| SUCCESS_SCORE_THRESHOLD = 0.5 | |
| # Throttle concurrent LLM calls | |
| SEMAPHORE = asyncio.Semaphore(3) | |
| # Timeout for each LLM request | |
| LLM_TIMEOUT = 20 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Unified LLM call function | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_llm_response( | |
| messages: List[Dict[str, str]], | |
| max_tokens: int = 200, | |
| temperature: float = 0.0, | |
| ) -> str: | |
| """ | |
| Call the LLM with the given messages and parameters. | |
| Returns the response content as a string. | |
| Raises an exception on failure (to be caught by the caller). | |
| """ | |
| try: | |
| async with SEMAPHORE: | |
| completion = await asyncio.wait_for( | |
| client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, # type: ignore | |
| ), | |
| timeout=LLM_TIMEOUT, | |
| ) | |
| return completion.choices[0].message.content.strip() # type: ignore | |
| except asyncio.TimeoutError: | |
| raise RuntimeError("LLM request timed out") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Mandatory stdout helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def log_start(task: str, env: str, model: str) -> None: | |
| """Emit the [START] line β one per episode.""" | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step( step: int, action: str, reward: float, done: bool, error: Optional[str] = None, | |
| ) -> None: | |
| """Emit a [STEP] line β one per env.step() call.""" | |
| error_val = error if error else "null" | |
| print( | |
| f"[STEP] step={step} action={action} " | |
| f"reward={reward:.2f} done={str(done).lower()} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end( success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| """Emit the [END] line β one per episode, always emitted.""" | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} " | |
| f"score={score:.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| def _clamp(reward: float) -> float: | |
| return max(0.001, min(0.999, reward)) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Generic episode runner | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def run_episode( | |
| env: Union[Task1Environment, Task2Environment, Task3Environment], | |
| seed: int, | |
| ep_num: int, | |
| *, | |
| task_id: str, | |
| system_prompt: str, | |
| user_msg_formatter: Callable[[Dict[str, Any]], str], | |
| max_tokens: int = 200, | |
| default_action: ActionType = ActionType.LIST_FUNCTIONS, | |
| extra_fields: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, | |
| ) -> Dict[str, Any]: | |
| r = env.reset(seed=seed) | |
| obs = r.observation.model_dump() | |
| log_start(task=task_id, env=ENV_BENCHMARK, model=MODEL_NAME) | |
| # Keep only the last 2 user-assistant pairs (4 messages). | |
| history: deque = deque(maxlen=4) | |
| step_rewards: List[float] = [] | |
| grader_score = 0.0 | |
| steps_taken = 0 | |
| error_msg: Optional[str] = None | |
| try: | |
| for step in range(1, MAX_STEPS + 1): | |
| user_msg_content = user_msg_formatter(obs) | |
| user_message = {"role": "user", "content": user_msg_content} | |
| history.append(user_message) | |
| # Always prepend the system prompt so it survives deque eviction | |
| messages_for_llm = [ | |
| {"role": "system", "content": system_prompt}, | |
| *list(history), | |
| ] | |
| try: | |
| raw = await get_llm_response(messages_for_llm, max_tokens=max_tokens, temperature=0.0) | |
| error_msg = None | |
| except Exception as e: | |
| raw = "" | |
| error_msg = str(e)[:80] | |
| print(f"[DEBUG] {task_id} LLM error ep={ep_num} step={step}: {e}", file=sys.stderr) | |
| # Append the assistant reply so the next step sees the full turn | |
| history.append({"role": "assistant", "content": raw}) | |
| try: | |
| parsed = json.loads(raw) | |
| at = ActionType(parsed["action"]) | |
| params = parsed.get("params", {}) | |
| except Exception as e: | |
| at, params = default_action, {} | |
| print("Error in parsing LLM response: " + str(e)) | |
| result = env.step(Action(action_type=at, params=params)) | |
| obs = result.observation.model_dump() | |
| r_val = result.reward.value | |
| done = result.done | |
| step_rewards.append(r_val) | |
| steps_taken = step | |
| log_step(step=step, action=at.value, reward=r_val, done=done, error=error_msg) | |
| if done: | |
| grader_score = r_val | |
| break | |
| await asyncio.sleep(0.3) | |
| finally: | |
| success = grader_score >= SUCCESS_SCORE_THRESHOLD | |
| log_end(success=success, steps=steps_taken, score=grader_score, rewards=step_rewards) | |
| result_dict = { | |
| "episode": ep_num, | |
| "seed": seed, | |
| "grader_score": _clamp(grader_score), | |
| "contract": obs.get("contract_name", ""), | |
| } | |
| if extra_fields: | |
| result_dict.update(extra_fields(obs)) | |
| return result_dict | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Task-specific user message formatters and extra field extractors | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def t1_user_msg(obs: Dict[str, Any]) -> str: | |
| return ( | |
| f"Last action : {obs['last_action'] or 'None'}\n" | |
| f"Last result : {obs['last_action_result'] or 'Episode just started.'}" | |
| ) | |
| def t2_user_msg(obs: Dict[str, Any]) -> str: | |
| extra = obs.get("extra", {}) | |
| return ( | |
| f"Target Function : {extra.get('target_function', '?')} " | |
| f"Last action : {obs['last_action'] or 'None'}\n" | |
| f"Last result :\n{obs['last_action_result'] or 'Episode just started.'}" | |
| ) | |
| def t2_extra_fields(obs: Dict[str, Any]) -> Dict[str, Any]: | |
| return {"function": obs.get("extra", {}).get("target_function", "?")} | |
| def t3_user_msg(obs: Dict[str, Any]) -> str: | |
| extra = obs.get("extra", {}) | |
| return ( | |
| f"Verify Property : {extra.get('property_english', '(none)')}\n" | |
| f"Last action : {obs['last_action'] or 'None'}\n" | |
| f"Last result :\n{obs['last_action_result'] or 'Episode just started.'}" | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Generic task runner | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def run_task( | |
| task_id: str, | |
| task_name: str, | |
| env_class: type, | |
| system_prompt: str, | |
| user_msg_formatter: Callable[[Dict[str, Any]], str], | |
| max_tokens: int = 200, | |
| default_action: ActionType = ActionType.LIST_FUNCTIONS, | |
| extra_fields: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, | |
| num_episodes: int = NUM_EPISODES, | |
| ) -> Dict[str, Any]: | |
| """Run multiple episodes for a given task and return aggregated results.""" | |
| print("\n" + "=" * 60, flush=True) | |
| print(f"TASK: {task_name}", flush=True) | |
| print("=" * 60, flush=True) | |
| env = env_class() | |
| tasks = [ | |
| run_episode( | |
| env, | |
| seed=SEED_BASE + i, | |
| ep_num=i + 1, | |
| task_id=task_id, | |
| system_prompt=system_prompt, | |
| user_msg_formatter=user_msg_formatter, | |
| max_tokens=max_tokens, | |
| default_action=default_action, | |
| extra_fields=extra_fields, | |
| ) | |
| for i in range(num_episodes) | |
| ] | |
| episodes = await asyncio.gather(*tasks) | |
| avg_score = sum(e["grader_score"] for e in episodes) / num_episodes | |
| print(f"\n Avg grader score : {_clamp(avg_score):.4f}", flush=True) | |
| return { | |
| "task_id": task_id, | |
| "name": task_name, | |
| "status": "active", | |
| "num_episodes": num_episodes, | |
| "episodes": episodes, | |
| "avg_grader_score": _clamp(avg_score), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Task-specific runners (thin wrappers for clarity) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def run_task1(n: int = NUM_EPISODES) -> Dict[str, Any]: | |
| return await run_task( | |
| task_id="task1_vuln_detection", | |
| task_name="Targeted Vulnerability Detection", | |
| env_class=Task1Environment, | |
| system_prompt=T1_SYSTEM, | |
| user_msg_formatter=t1_user_msg, | |
| max_tokens=200, | |
| default_action=ActionType.LIST_FUNCTIONS, | |
| num_episodes=n, | |
| ) | |
| async def run_task2(n: int = NUM_EPISODES) -> Dict[str, Any]: | |
| return await run_task( | |
| task_id="task2_property_discovery", | |
| task_name="Property Discovery", | |
| env_class=Task2Environment, | |
| system_prompt=T2_SYSTEM, | |
| user_msg_formatter=t2_user_msg, | |
| max_tokens=400, | |
| default_action=ActionType.GET_FUNCTION_CODE, | |
| extra_fields=t2_extra_fields, | |
| num_episodes=n, | |
| ) | |
| async def run_task3(n: int = NUM_EPISODES) -> Dict[str, Any]: | |
| return await run_task( | |
| task_id="task3_rule_checker", | |
| task_name="Rule Checker", | |
| env_class=Task3Environment, | |
| system_prompt=T3_SYSTEM, | |
| user_msg_formatter=t3_user_msg, | |
| max_tokens=200, | |
| default_action=ActionType.LIST_FUNCTIONS, | |
| num_episodes=n, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def main() -> None: | |
| """Async entry point.""" | |
| print("Smart Contract Audit RL Environment β Baseline Inference", flush=True) | |
| t1 = await run_task1(NUM_EPISODES) | |
| t2 = await run_task2(NUM_EPISODES) | |
| t3 = await run_task3(NUM_EPISODES) | |
| results: Dict[str, Any] = {"tasks": [t1, t2, t3]} | |
| overall = sum(t["avg_grader_score"] for t in results["tasks"]) / 3 | |
| results["overall_avg_score"] = overall | |
| print("\n" + "=" * 60, flush=True) | |
| print("BASELINE SUMMARY", flush=True) | |
| print("=" * 60, flush=True) | |
| for t in results["tasks"]: | |
| print(f" β {t['name']:40s}: {_clamp(t['avg_grader_score']):.3f}", flush=True) | |
| print(f"\n Overall avg grader score: {overall:.4f}", flush=True) | |
| with open("baseline_scores.json", "w") as f: | |
| json.dump(results, f, indent=2) | |
| print("\n Scores written to baseline_scores.json", flush=True) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |