|
|
import re |
|
|
import os |
|
|
import json |
|
|
import requests |
|
|
import time |
|
|
from typing import List, Optional, Dict |
|
|
from .prompts import DEEPRESEARCH_SYS_PROMPT, SUMMARY_SYS_PROMPT |
|
|
from functools import wraps |
|
|
from together import Together |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
def retry(max: int = 10, sleep: int = 1, fallback=None): |
|
|
""" |
|
|
Retry `max` times and, if still failing, return `fallback` |
|
|
instead of raising. This keeps outer loops alive. |
|
|
""" |
|
|
def decorator(func): |
|
|
@wraps(func) |
|
|
def wrapper(*args, **kwargs): |
|
|
for i in range(max): |
|
|
try: |
|
|
return func(*args, **kwargs) |
|
|
except Exception as e: |
|
|
print(f"[retry] attempt {i+1}/{max} failed: {e}") |
|
|
if i == max - 1: |
|
|
print(f"[retry] giving up β returning {fallback!r}") |
|
|
return fallback |
|
|
if sleep: |
|
|
time.sleep(sleep) |
|
|
return wrapper |
|
|
return decorator |
|
|
|
|
|
class ReCall(): |
|
|
date_str = \ |
|
|
f""" |
|
|
|
|
|
**Note**: Today's Date is {datetime.now().strftime("%Y-%m-%d")}, and time is {datetime.now().strftime("%H:%M:%S")}. This may be useful for answering questions about current events.""" |
|
|
|
|
|
anti_chinese_str = \ |
|
|
""" |
|
|
|
|
|
**Note**: Do not respond in chinese, do not think in chinese, only think and respond/answer in English, unless explicitly instructed by the user to respond in some other language.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
proper_formatting_str = """""" |
|
|
|
|
|
sys_prompt_non_search = """You are a helpful assistant. You will answer the user's question based on your knowledge and reasoning ability. You do not have access to the internet or any external tools. Do not use search. Answer all questions yourself.""" + date_str + anti_chinese_str |
|
|
|
|
|
sys_prompt_websailor_start = """ |
|
|
You are a Web Information Seeking Master. Your task is to thoroughly seek the internet for information and provide accurate answers to questions. No matter how complex the query, you will not give up until you find the corresponding information. |
|
|
In this environment you have access to a set of tools you can use to assist with the user query. |
|
|
You may perform multiple rounds of function calls. In each round, you can call one or more functions. |
|
|
|
|
|
As you proceed, adhere to the following principles: |
|
|
|
|
|
1. **Persistent Actions for Answers**: You will engage in many interactions, delving deeply into the topic to explore all possible aspects until a satisfactory answer is found. |
|
|
|
|
|
2. **Repeated Verification**: Before presenting a Final Answer, you will **cross-check** and **validate the information** you've gathered to confirm its accuracy and reliability. |
|
|
|
|
|
3. **Attention to Detail**: You will carefully analyze each information source to ensure that all data is current, relevant, and from credible origins.""" |
|
|
|
|
|
sys_prompt_websailor = """ |
|
|
You are a Web Information Seeking Master. Your task is to thoroughly seek the internet for information and provide accurate answers to questions. No matter how complex the query, you will not give up until you find the corresponding information. |
|
|
In this environment you have access to a set of tools you can use to assist with the user query. |
|
|
You may perform multiple rounds of function calls. In each round, you can call one or more functions. |
|
|
|
|
|
As you proceed, adhere to the following principles: |
|
|
|
|
|
1. **Persistent Actions for Answers**: You will engage in many interactions, delving deeply into the topic to explore all possible aspects until a satisfactory answer is found. |
|
|
|
|
|
2. **Repeated Verification**: Before presenting a Final Answer, you will **cross-check** and **validate the information** you've gathered to confirm its accuracy and reliability. |
|
|
|
|
|
3. **Attention to Detail**: You will carefully analyze each information source to ensure that all data is current, relevant, and from credible origins. |
|
|
|
|
|
|
|
|
|
|
|
Here are available functions in JSONSchema format: \n```json\n{func_schemas}\n``` |
|
|
|
|
|
In your response, you need to first think about the reasoning process in the mind and then conduct function calling to get the information or perform the actions if needed. \ |
|
|
The reasoning process and function calling are enclosed within <think> </think> and <tool_call> </tool_call> tags. \ |
|
|
The results of the function calls will be given back to you after execution, \ |
|
|
and you can continue to call functions until you get the final answer for the user's question. |
|
|
|
|
|
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags: |
|
|
<tool_call> |
|
|
{{"name": <function-name>, "arguments": <args-json-object>}} |
|
|
</tool_call> |
|
|
""" + date_str + anti_chinese_str + proper_formatting_str |
|
|
|
|
|
sys_prompt_websailor_deepseek = """ |
|
|
You are a Web Information Seeking Master. Your task is to thoroughly seek the internet for information and provide accurate answers to questions. No matter how complex the query, you will not give up until you find the corresponding information. |
|
|
In this environment you have access to a set of tools you can use to assist with the user query. |
|
|
You may perform multiple rounds of function calls. In each round, you can call one or more functions. |
|
|
|
|
|
As you proceed, adhere to the following principles: |
|
|
|
|
|
1. **Persistent Actions for Answers**: You will engage in many interactions, delving deeply into the topic to explore all possible aspects until a satisfactory answer is found. |
|
|
|
|
|
2. **Repeated Verification**: Before presenting a Final Answer, you will **cross-check** and **validate the information** you've gathered to confirm its accuracy and reliability. |
|
|
|
|
|
3. **Attention to Detail**: You will carefully analyze each information source to ensure that all data is current, relevant, and from credible origins. |
|
|
|
|
|
|
|
|
|
|
|
Here are available functions in JSONSchema format: \n```json\n{func_schemas}\n``` |
|
|
|
|
|
In your response, you need to first think about the reasoning process in the mind and then conduct function calling to get the information or perform the actions if needed. \ |
|
|
The reasoning process and function calling are enclosed within <think> </think> and <tool_calls_begin> <tool_calls_end> tags. \ |
|
|
The results of the function calls will be given back to you after execution, \ |
|
|
and you can continue to call functions until you get the final answer for the user's question. \ |
|
|
Finally, if you have got the answer, enclose it within \\boxed{{}} with latex format and do not continue to call functions, \ |
|
|
i.e., <think> Based on the response from the function call, I get the weather information. </think> The weather in Beijing on 2025-04-01 is \\[ \\boxed{{20C}} \\]. |
|
|
""" + date_str + anti_chinese_str + proper_formatting_str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
system_prompt = """In this environment you have access to a set of tools you can use to assist with the user query. \ |
|
|
You may perform multiple rounds of function calls. \ |
|
|
In each round, you can call one or more functions. \ |
|
|
|
|
|
Here are available functions in JSONSchema format: \n```json\n{func_schemas}\n``` |
|
|
|
|
|
In your response, you need to first think about the reasoning process in the mind and then conduct function calling to get the information or perform the actions if needed. \ |
|
|
The reasoning process and function calling are enclosed within <think> </think> and <tool_call> </tool_call> tags. \ |
|
|
The results of the function calls will be given back to you after execution, \ |
|
|
and you can continue to call functions until you get the final answer for the user's question. You are encouraged to utilize as many function calls as possible. \ |
|
|
Finally, if you have got the answer, wrap it in <answer> </answer> **and do not call any more functions**, \ |
|
|
e.g. <think> Based on the tool results β¦ </think> <answer>20 Β°C</answer>. |
|
|
|
|
|
For each function call, return a JSON object with function name and arguments within <tool_call></tool_call> XML tags: |
|
|
<tool_call> |
|
|
{{"name": <function-name-1>, "arguments": <args-json-object>}} |
|
|
</tool_call>""" + date_str + anti_chinese_str + proper_formatting_str |
|
|
|
|
|
system_prompt_budget = """ |
|
|
You are an autonomous reasoning agent with access to external tools. |
|
|
|
|
|
The conversation will retain only the *most-recent* <tool_response> block; older ones disappear. |
|
|
As soon as you receive tool results, extract the *essential facts tables links etc* that might be needed for later and restate them inside your <think> section. |
|
|
β**Never copy large bodies of text** or raw JSON from tool output into your visible reply; summarise instead. |
|
|
|
|
|
β **Workflow** |
|
|
1. In every round, start with <think> β¦ </think> to lay out your short reasoning. |
|
|
2. If you need external information or an action, emit one or more <tool_call> β¦ </tool_call> blocks (JSON spec below). |
|
|
3. When the environment returns <tool_response>, continue reasoning; you may call more tools. |
|
|
4. Once you can answer the user, wrap the final result in <answer> β¦ </answer> and STOP calling tools. |
|
|
|
|
|
β **Tool call format** (do **not** restate the schema or any explanations): |
|
|
<tool_call> |
|
|
{{"name": <function-name-1>, "arguments": <args-json-object>}} |
|
|
</tool_call> |
|
|
|
|
|
Here are available functions in JSONSchema format: \n```json\n{func_schemas}\n``` |
|
|
""" + date_str + anti_chinese_str + proper_formatting_str |
|
|
|
|
|
|
|
|
|
|
|
system_prompt_forcing_tool_call = """ |
|
|
In this environment you have access to a set of tools you can use to assist with the user query. |
|
|
You may perform multiple rounds of function calls upto ten. In each round, you can call upto three functions. |
|
|
|
|
|
ββββββββββββββββββββββββ AVAILABLE TOOLS ββββββββββββββββββββββββ |
|
|
```json |
|
|
[ |
|
|
{ |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "pubmed_search", |
|
|
"description": "Search PubMed for Medical related queries.", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"query": { "type": "string", "description": "Query to search for." }, |
|
|
"top_n": { "type": "integer", "description": "Number of hits", "default": 3 } |
|
|
}, |
|
|
"required": ["query"] |
|
|
} |
|
|
} |
|
|
} |
|
|
] |
|
|
``` |
|
|
|
|
|
ββββββββββββββββββββββββββββββ RULES ββββββββββββββββββββββββββββββ |
|
|
1. You MUST issue one pubmed_search tool call for each answer choice. Each query must relate the clinical context to that option. |
|
|
2. You MAY NOT skip any option or decide based only on internal reasoning. Evidence must be retrieved for all choices. |
|
|
3. You MAY issue follow-up tool calls if your reasoning leads you to need more evidence. |
|
|
4. You MUST wrap all reasoning in <think> </think> tags and all tool usage in <tool_call> </tool_call> tags. Number of <tool_call> and </tool_call> tokens in the entire trace MUST always match. |
|
|
5. Do NOT casually emit the <tool_call> </tool_call> during reasoning unless explicitly calling a tool in the proper format. |
|
|
5. Your final answer must be enclosed a single letter corresponding to the correct option enclosed in the <answer> </answer> tags. Do not output anything else inside these tags. |
|
|
6. DO NOT use any other confusing tags like <thiking> or </thinking>. |
|
|
7. Each <think> </think> block MUST be followed by a <tool_call> </tool_call> or <answer> </answer> or else the program will break without an answer. |
|
|
|
|
|
βββββββββββββββββββββ DUMMY EXAMPLE INTERLEAVED SKELETON βββββββββββββββββββββ |
|
|
<think> |
|
|
We are presented with a 54-year-old woman with invasive ductal carcinoma of the breast and osteolytic lesions in the thoracic spine. This strongly suggests metastatic spread. Our task is to determine the most likely anatomical route of metastasis to the spine. |
|
|
|
|
|
Letβs examine the given options: |
|
|
A. Hemiazygos vein |
|
|
B. Posterior intercostal veins |
|
|
C. Batsonβs vertebral venous plexus |
|
|
D. Internal mammary lymphatics |
|
|
|
|
|
We'll evaluate each option in turn using available literature and known anatomical pathways. |
|
|
**Option A: Hemiazygos vein** |
|
|
We begin by evaluating whether the hemiazygos vein could be involved in metastatic spread from breast cancer to the spine. |
|
|
</think> |
|
|
<tool_call> |
|
|
{"name": "pubmed_search", "arguments": {"query": "breast cancer metastasis hemiazygos vein", "top_n": 2}} |
|
|
</tool_call> |
|
|
<tool_response> |
|
|
... |
|
|
</tool_response> |
|
|
<think> |
|
|
There is limited or no strong evidence suggesting the hemiazygos vein is a common or primary route for vertebral metastasis from breast cancer. |
|
|
Lets explore **Option B: Posterior intercostal veins** and **Option C: Batsonβs vertebral venous plexus** and **Option D:Internal mammary lymphatics** |
|
|
</think> |
|
|
<tool_call> |
|
|
{"name": "pubmed_search", "arguments": {"query": "posterior intercostal veins breast cancer spinal metastasis", "top_n": 3}} |
|
|
</tool_call> |
|
|
<tool_call> |
|
|
{"name": "pubmed_search", "arguments": {"query": "Batson vertebral venous plexus breast cancer metastasis", "top_n": 3}} |
|
|
</tool_call> |
|
|
<tool_call> |
|
|
{"name": "pubmed_search", "arguments": {"query": "Internal mammary lymphatics breast cancer metastasis", "top_n": 3}} |
|
|
</tool_call> |
|
|
<tool_response> |
|
|
... |
|
|
</tool_response> |
|
|
<think> |
|
|
While the posterior intercostal veins may be involved in venous drainage, there is insufficient evidence to support them as a primary route for metastasis to the vertebral column. |
|
|
where as Batsonβs vertebral venous plexus β a valveless venous network that connects the thoracic and abdominal veins directly to the spine. I to find more specific information about option C. |
|
|
</think> |
|
|
<tool_call> |
|
|
{"name": "pubmed_search", "arguments": {"query": ""Batson vertebral venous plexus breast cancer metastasis in people over 50", "top_n": 1}} |
|
|
</tool_call> |
|
|
<think> |
|
|
After evaluating all four options, the most plausible route for breast cancer metastasis to the thoracic spine is clearly via Batsonβs vertebral venous plexus: |
|
|
</think> |
|
|
<answer>C</answer> |
|
|
""" + date_str + anti_chinese_str + proper_formatting_str |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, executor_url): |
|
|
self.executor_url = executor_url |
|
|
|
|
|
def init_prompt(self, func_schemas, question, old_prompt: Optional[str] = None, search_on: bool = True) -> str: |
|
|
if old_prompt is None or len(old_prompt.strip()) == 0: |
|
|
if search_on: |
|
|
system_prompt = f"<|im_start|>system\n{self.sys_prompt_websailor.format(func_schemas=func_schemas)}<|im_end|>" |
|
|
else: |
|
|
system_prompt = f"<|im_start|>system\n{self.sys_prompt_non_search}<|im_end|>" |
|
|
user_prompt = f"<|im_start|>user\n{question}<|im_end|>" |
|
|
assistant_prefix = f"<|im_start|>assistant\n<think>" |
|
|
return system_prompt + "\n" + user_prompt + "\n" + assistant_prefix |
|
|
else: |
|
|
user_prompt = f"<|im_start|>user\n{question}<|im_end|>" |
|
|
assistant_prefix = f"<|im_start|>assistant\n<think>" |
|
|
return old_prompt + "\n" + user_prompt + "\n" + assistant_prefix |
|
|
|
|
|
def replace_sys_prompt(self, old_prompt: str, func_schemas: str, search_on: bool = True) -> str: |
|
|
if search_on: |
|
|
new_sys_prompt = f"<|im_start|>system\n{self.sys_prompt_websailor.format(func_schemas=func_schemas)}<|im_end|>" |
|
|
old_sys_prompt = f"<|im_start|>system\n{self.sys_prompt_non_search}<|im_end|>" |
|
|
else: |
|
|
new_sys_prompt = f"<|im_start|>system\n{self.sys_prompt_non_search}<|im_end|>" |
|
|
old_sys_prompt = f"<|im_start|>system\n{self.sys_prompt_websailor.format(func_schemas=func_schemas)}<|im_end|>" |
|
|
|
|
|
return old_prompt.replace(old_sys_prompt, new_sys_prompt) |
|
|
|
|
|
def _strip_old_tool_responses(self, prompt: str) -> str: |
|
|
TOOL_RESPONSE_RE = re.compile(r"<tool_response>.*?</tool_response>\s*", re.DOTALL) |
|
|
"""Remove every existing <tool_response> β¦ </tool_response> block.""" |
|
|
return TOOL_RESPONSE_RE.sub("", prompt) |
|
|
|
|
|
def cat_assistant_response(self, curr_prompt, assistant_response): |
|
|
return curr_prompt + assistant_response + "<|im_end|>" |
|
|
|
|
|
def cat_tool_results(self, curr_prompt, tool_calls, results): |
|
|
tool_response_str = "" |
|
|
for tool_call, result in zip(tool_calls, results): |
|
|
tool_response_str += f"<tool_response>{tool_call}\n{result}\n</tool_response>\n" |
|
|
tool_response_str = f"<|im_start|>user\n{tool_response_str}<|im_end|>" |
|
|
assistant_prefix = f"<|im_start|>assistant\n<think>" |
|
|
return curr_prompt + "\n" + tool_response_str + "\n" + assistant_prefix |
|
|
|
|
|
def format_tool_call(self, tool_call_str: str): |
|
|
"""Convert JSON function call description to Python executable code string.""" |
|
|
try: |
|
|
call_json = json.loads(tool_call_str) |
|
|
func_name = call_json['name'] |
|
|
arguments = call_json.get('arguments', {}) |
|
|
|
|
|
args_str = ', '.join(f"{k}={repr(v)}" for k, v in arguments.items()) |
|
|
return f"{func_name}({args_str})" |
|
|
except Exception as e: |
|
|
return f"Parse tool call failed: {e}" |
|
|
|
|
|
def execute_tool_calls(self, env: str, tool_calls: List[str]) -> List[str]: |
|
|
def exe_tool_call(env, call): |
|
|
url = self.executor_url + '/execute' |
|
|
|
|
|
call_str = self.format_tool_call(call) |
|
|
|
|
|
if call_str.startswith("error: parse tool call failed"): |
|
|
return call_str |
|
|
|
|
|
try: |
|
|
data = { |
|
|
'env': env, |
|
|
'call': call_str |
|
|
} |
|
|
response = requests.post(url, json=data, timeout=60) |
|
|
if response.status_code != 200: |
|
|
return f"error: {response.status_code}" |
|
|
response = response.json() |
|
|
ret_str = '' |
|
|
if response['result']: |
|
|
ret_str += f'result: \n{response["result"]}\n' |
|
|
if response['output']: |
|
|
ret_str += f'output: \n{response["output"]}\n' |
|
|
if response['error']: |
|
|
ret_str += f'error: \n{response["error"]}\n' |
|
|
return ret_str.strip() |
|
|
except requests.exceptions.Timeout: |
|
|
return "error: execution timed out" |
|
|
except Exception as e: |
|
|
return str(e) |
|
|
|
|
|
results = [] |
|
|
for tool_call in tool_calls: |
|
|
result = exe_tool_call(env, tool_call) |
|
|
results.append(result) |
|
|
return results |
|
|
|
|
|
def validate_tool_calls(self, output_str): |
|
|
start_tags = re.findall(r'<tool_call>', output_str) |
|
|
end_tags = re.findall(r'</tool_call>', output_str) |
|
|
|
|
|
if len(start_tags) != len(end_tags): |
|
|
return False |
|
|
|
|
|
start_positions = [m.start() for m in re.finditer(r'<tool_call>', output_str)] |
|
|
end_positions = [m.start() for m in re.finditer(r'</tool_call>', output_str)] |
|
|
|
|
|
for start, end in zip(start_positions, end_positions): |
|
|
if start >= end: |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def extract_tool_calls(self, output_str): |
|
|
if not self.validate_tool_calls(output_str): |
|
|
return [] |
|
|
|
|
|
try: |
|
|
pattern = r'<tool_call>((?:(?!</tool_call>).)*)</tool_call>' |
|
|
matches = re.finditer(pattern, output_str, re.DOTALL) |
|
|
|
|
|
return [match.group(1).strip() for match in matches] |
|
|
except Exception as e: |
|
|
return [] |
|
|
|
|
|
def extract_tool_calls_deepseek(self, output_str): |
|
|
if not self.validate_tool_calls(output_str): |
|
|
return [] |
|
|
|
|
|
try: |
|
|
pattern = r'<tool_calls_begin>((?:(?!</tool_calls_end>).)*)<tool_calls_end>' |
|
|
matches = re.finditer(pattern, output_str, re.DOTALL) |
|
|
|
|
|
return [match.group(1).strip() for match in matches] |
|
|
except Exception as e: |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
@retry(max=5, sleep=1, fallback={"score": 0}) |
|
|
def run_ii_searcher( |
|
|
self, |
|
|
env: str, |
|
|
func_schemas: str, |
|
|
question: str, |
|
|
tokenizer, |
|
|
model_url="http://0.0.0.0:1214", |
|
|
temperature: float = 0.0, |
|
|
max_new_tokens: int = 40960, |
|
|
): |
|
|
curr_prompt = self.init_prompt(func_schemas, question) |
|
|
all_tool_calls= [] |
|
|
|
|
|
for _ in range(16): |
|
|
prompt_tokens = tokenizer(curr_prompt, return_tensors=None, add_special_tokens=False)["input_ids"] |
|
|
max_tokens_left = max_new_tokens - len(prompt_tokens) - 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response = requests.post( |
|
|
f'{model_url}/generate', |
|
|
json={ |
|
|
"text": curr_prompt, |
|
|
"sampling_params": { |
|
|
"temperature": temperature, |
|
|
"max_new_tokens": max_tokens_left, |
|
|
"repetition_penalty": 1.05 |
|
|
}, |
|
|
|
|
|
} |
|
|
).json() |
|
|
if "error" in response.keys(): |
|
|
print("resp",response) |
|
|
curr_prompt = self.cat_assistant_response(curr_prompt, response['text']) |
|
|
|
|
|
tool_calls: List[str] = self.extract_tool_calls(response['text']) |
|
|
all_tool_calls += tool_calls |
|
|
|
|
|
if len(tool_calls) == 0: |
|
|
break |
|
|
|
|
|
else: |
|
|
results: List[str] = self.execute_tool_calls(env, tool_calls) |
|
|
curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results) |
|
|
|
|
|
return curr_prompt, all_tool_calls |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Dict, Any, Tuple |
|
|
import requests |
|
|
|
|
|
def build_summary_prompt(self, question: str, transcript: str, tool_calls: Any) -> str: |
|
|
"""Assemble a compact but detailed prompt for summarization.""" |
|
|
tool_str = "" |
|
|
if tool_calls is not None: |
|
|
try: |
|
|
tool_str = str(tool_calls) |
|
|
except Exception: |
|
|
tool_str = "<unprintable tool_calls>" |
|
|
return ( |
|
|
"You are given a DeepSearch investigation trace.\n\n" |
|
|
f"Question:\n{question}\n\n" |
|
|
"Trace (model transcript):\n" |
|
|
f"{transcript}\n\n" |
|
|
"Tool Calls (as-recorded):\n" |
|
|
f"{tool_str}\n\n" |
|
|
"β End of trace β" |
|
|
) |
|
|
|
|
|
def reformat_trace(self, s: str) -> str: |
|
|
if not s: |
|
|
return s |
|
|
|
|
|
t = s |
|
|
|
|
|
|
|
|
def _speaker(m: re.Match) -> str: |
|
|
role = (m.group(1) or "").strip().upper() |
|
|
return f"\n{role}:\n" |
|
|
t = re.sub(r"<\|im_start\|\>(\w+)", _speaker, t, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
t = re.sub(r"<\|im_end\|\>", "\n", t, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
t = re.sub(r"<think\s*>", "", t, flags=re.IGNORECASE) |
|
|
t = re.sub(r"</think\s*>", "\n", t, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
t = re.sub(r"<tool_respon[sc]e\s*>", "SEARCH RESULT\n", t, flags=re.IGNORECASE) |
|
|
t = re.sub(r"</tool_respon[sc]e\s*>", "\n", t, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
t = re.sub(r"</?tool_call\s*>", "", t, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
t = re.sub(r"<\|[^>]+?\|>", "", t) |
|
|
|
|
|
|
|
|
|
|
|
t = re.sub(r"</?[^>\n]+?>", "", t) |
|
|
|
|
|
|
|
|
t = re.sub(r"\n{3,}", "\n\n", t).strip() |
|
|
|
|
|
return t |
|
|
|
|
|
def _openai_client(self): |
|
|
try: |
|
|
from openai import OpenAI |
|
|
except Exception as e: |
|
|
raise RuntimeError("openai package not installed. `pip install openai`") from e |
|
|
return OpenAI() |
|
|
|
|
|
def init_summary_prompt(self, system_prompt: str, prompt: str) -> str: |
|
|
system_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>" |
|
|
user_prompt = f"<|im_start|>user\n{prompt}<|im_end|>" |
|
|
assistant_prefix = f"<|im_start|>assistant\n<think>" |
|
|
return system_prompt + "\n" + user_prompt + "\n" + assistant_prefix |
|
|
|
|
|
def _call_hf_endpoint(self, base_url: str, system_prompt: str, prompt: str, temperature: float, max_tokens: int, deepresearch_on: bool) -> str: |
|
|
curr_prompt = self.init_summary_prompt(system_prompt, prompt) |
|
|
|
|
|
hf_token= os.environ['HF_TOKEN'] |
|
|
|
|
|
headers = { |
|
|
"Accept" : "application/json", |
|
|
"Authorization": f"Bearer {hf_token}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
response_summary = requests.post( |
|
|
url=f"{base_url}", |
|
|
headers=headers, |
|
|
json={ |
|
|
"inputs": curr_prompt, |
|
|
"parameters": { |
|
|
"temperature": temperature, |
|
|
"max_new_tokens": max_tokens, |
|
|
"top_p": 0.95, |
|
|
"repetition_penalty": 1.05, |
|
|
}, |
|
|
}, |
|
|
timeout=300, |
|
|
).json() |
|
|
|
|
|
if isinstance(response_summary, list): |
|
|
response_summary = response_summary[0] |
|
|
|
|
|
if isinstance(response_summary, dict) and "error" in response_summary: |
|
|
|
|
|
err_msg = f"[model_error] {response_summary.get('error')}" |
|
|
print("Got error response from summarising model:", err_msg, end="\n\n") |
|
|
|
|
|
assistant_text = response_summary.get("generated_text", "") |
|
|
|
|
|
if curr_prompt == assistant_text[:len(curr_prompt)]: |
|
|
assistant_text = assistant_text[len(curr_prompt):] |
|
|
|
|
|
|
|
|
|
|
|
report = re.split(r"</think\s*>", assistant_text, flags=re.IGNORECASE)[-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not deepresearch_on: |
|
|
report = report.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_tag = "<answer>" |
|
|
end_tag = "</answer>" |
|
|
pos_start = report.find(start_tag) |
|
|
pos_end = report[pos_start:].find(end_tag) + pos_start |
|
|
answer = report |
|
|
if pos_start != -1 and pos_end != -1: |
|
|
answer = report[pos_start + len(start_tag):pos_end].strip() |
|
|
|
|
|
str_find = "Final Answer:" |
|
|
if str_find in answer: |
|
|
answer = answer[::-1] |
|
|
pos = answer.find(str_find[::-1]) |
|
|
pos += len(str_find) |
|
|
answer = answer[pos:][::-1] |
|
|
answer = answer.rstrip('# \n-').strip(' \n-') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return answer |
|
|
|
|
|
report = report.strip() |
|
|
report = report[::-1] |
|
|
str_find = "Sources used" |
|
|
pos = report.find(str_find[::-1]) |
|
|
pos += len(str_find) |
|
|
report = report[pos:][::-1] |
|
|
report = report.rstrip('# \n-').strip(' \n-') |
|
|
|
|
|
if not report.startswith("##") and report.startswith("#"): |
|
|
report = "#" + report |
|
|
elif not report.startswith("##") and not report.startswith("#"): |
|
|
report = "## " + report |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
urls = {} |
|
|
count = 1 |
|
|
|
|
|
while "[http" in report: |
|
|
start_idx = report.find("[http") |
|
|
end_idx = report.find("]", start_idx) |
|
|
if end_idx != -1: |
|
|
url_string = report[start_idx + 1:end_idx] |
|
|
url_list = [] |
|
|
while len(url_string) > 0: |
|
|
pos1 = url_string.find(";") |
|
|
pos2 = url_string.find(",") |
|
|
pos3 = url_string.find(" ") |
|
|
|
|
|
if pos1 == -1: |
|
|
pos1 = len(url_string) + 1 |
|
|
if pos2 == -1: |
|
|
pos2 = len(url_string) + 1 |
|
|
if pos3 == -1: |
|
|
pos3 = len(url_string) + 1 |
|
|
|
|
|
pos = min(pos1, pos2, pos3) |
|
|
|
|
|
if pos == len(url_string) + 1: |
|
|
url = url_string |
|
|
else: |
|
|
url = url_string[:pos] |
|
|
|
|
|
url_list.append(url) |
|
|
|
|
|
if pos < len(url_string): |
|
|
url_string = url_string[pos + 1:].lstrip(" ,;") |
|
|
else: |
|
|
break |
|
|
|
|
|
report_new = report[:start_idx] + '(**' |
|
|
for url in url_list: |
|
|
if url not in urls: |
|
|
urls[url] = count |
|
|
count += 1 |
|
|
report_new += f'[{urls[url]}], ' |
|
|
report_new = report_new[:-2] |
|
|
report_new += '**)' + report[end_idx+1:] |
|
|
report = report_new |
|
|
else: |
|
|
break |
|
|
|
|
|
if len(urls) > 0: |
|
|
report += "\n\n## Sources used:\n" |
|
|
sorted_urls = sorted(urls.items(), key=lambda x: x[1]) |
|
|
for url, idx in sorted_urls: |
|
|
report += f"- **{idx}**: {url}\n" |
|
|
report += '\n' |
|
|
|
|
|
for url, idx in sorted_urls: |
|
|
report += f"[{idx}]: {url}\n" |
|
|
|
|
|
|
|
|
|
|
|
return report |
|
|
|
|
|
def _route_and_summarize( |
|
|
self, |
|
|
summary_llm: str, |
|
|
system_prompt: str, |
|
|
prompt: str, |
|
|
*, |
|
|
temperature: float, |
|
|
max_tokens: int, |
|
|
deepresearch_on: bool, |
|
|
) -> str: |
|
|
""" |
|
|
If `summary_llm` starts with 'http', treat as vLLM base_url; else treat as an OpenAI model id. |
|
|
For vLLM, prepend [SYSTEM]/[USER] tags; for OpenAI, pass messages with system+user. |
|
|
""" |
|
|
if not summary_llm.strip().lower().startswith("gpt-"): |
|
|
|
|
|
|
|
|
return self._call_hf_endpoint(summary_llm, system_prompt, prompt, temperature=temperature, max_tokens=max_tokens, deepresearch_on=deepresearch_on) |
|
|
|
|
|
else: |
|
|
client = self._openai_client() |
|
|
rsp = client.chat.completions.create( |
|
|
model=summary_llm, |
|
|
temperature=temperature, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": prompt}, |
|
|
], |
|
|
max_tokens=max_tokens, |
|
|
) |
|
|
|
|
|
return rsp.choices[0].message.content or "" |
|
|
|
|
|
@retry(max=5, sleep=1, fallback={"score": 0}) |
|
|
def run( |
|
|
self, |
|
|
env: str, |
|
|
func_schemas: str, |
|
|
question: str, |
|
|
tokenizer, |
|
|
model_url: str = "http://0.0.0.0:1214", |
|
|
temperature: float = 0.0, |
|
|
max_new_tokens: int = 40960, |
|
|
top_p: float = 0.6, |
|
|
old_prompt: Optional[str] = None, |
|
|
deepresearch_on: bool = True, |
|
|
summary_llm: str = "gpt-4.1-mini" |
|
|
): |
|
|
|
|
|
""" |
|
|
Returns: |
|
|
curr_prompt: the final prompt buffer (with assistant/tool traces you maintain internally) |
|
|
all_tool_calls: flat list of all tool call strings extracted across steps |
|
|
chat: a lightweight chat transcript list[{"role": "...", "content": "..."}] |
|
|
β’ 'user' items = the original question + aggregated tool responses |
|
|
β’ 'assistant' items = model responses (and a compact line-list of tool calls) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
search_on = True |
|
|
|
|
|
if old_prompt is not None: |
|
|
old_prompt = self.replace_sys_prompt(old_prompt, func_schemas, search_on) |
|
|
|
|
|
|
|
|
curr_prompt = self.init_prompt(func_schemas, question, old_prompt, search_on) |
|
|
all_tool_calls: List[str] = [] |
|
|
chat: List[Dict[str, str]] = [] |
|
|
|
|
|
|
|
|
chat.append({"role": "user", "content": question}) |
|
|
|
|
|
for i in range(64): |
|
|
|
|
|
prompt_tokens = tokenizer(curr_prompt, return_tensors=None, add_special_tokens=False)["input_ids"] |
|
|
max_tokens_left = max(1, max_new_tokens - len(prompt_tokens) - 100) |
|
|
|
|
|
|
|
|
|
|
|
hf_token= os.environ['HF_TOKEN'] |
|
|
|
|
|
headers = { |
|
|
"Accept" : "application/json", |
|
|
"Authorization": f"Bearer {hf_token}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
response = requests.post( |
|
|
url=f"{model_url}", |
|
|
headers=headers, |
|
|
json={ |
|
|
"inputs": curr_prompt, |
|
|
"parameters": { |
|
|
"temperature": temperature, |
|
|
"max_new_tokens": max_tokens_left, |
|
|
"top_p": top_p, |
|
|
"repetition_penalty": 1.05, |
|
|
}, |
|
|
}, |
|
|
timeout=300, |
|
|
).json() |
|
|
|
|
|
if isinstance(response, list): |
|
|
response = response[0] |
|
|
|
|
|
if isinstance(response, dict) and "error" in response: |
|
|
|
|
|
err_msg = f"[model_error] {response.get('error')}" |
|
|
print("Got error response from model:", err_msg, end="\n\n") |
|
|
chat.append({"role": "assistant", "content": err_msg}) |
|
|
break |
|
|
|
|
|
assistant_text = response.get("generated_text", "") |
|
|
|
|
|
if curr_prompt == assistant_text[:len(curr_prompt)]: |
|
|
|
|
|
|
|
|
assistant_text = assistant_text[len(curr_prompt):] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chat.append({"role": "assistant", "content": assistant_text}) |
|
|
|
|
|
|
|
|
curr_prompt = self.cat_assistant_response(curr_prompt, assistant_text) |
|
|
|
|
|
|
|
|
if search_on: |
|
|
tool_calls: List[str] = self.extract_tool_calls(assistant_text) |
|
|
|
|
|
else: |
|
|
tool_calls: List[str] = [] |
|
|
|
|
|
|
|
|
|
|
|
if tool_calls: |
|
|
yield "assistant_resp", (assistant_text, tool_calls) |
|
|
all_tool_calls.extend(tool_calls) |
|
|
|
|
|
|
|
|
chat.append({"role": "assistant", "content": "\n".join(tool_calls)}) |
|
|
|
|
|
|
|
|
results: List[str] = self.execute_tool_calls(env, tool_calls) |
|
|
|
|
|
yield "tool_results", (results, ) |
|
|
|
|
|
|
|
|
curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results) |
|
|
|
|
|
|
|
|
tool_res_blocks = [] |
|
|
for idx, (call, res) in enumerate(zip(tool_calls, results), 1): |
|
|
tool_res_blocks.append(f"[Tool {idx}] Result:\n{res}") |
|
|
chat.append({"role": "user", "content": "\n\n".join(tool_res_blocks)}) |
|
|
|
|
|
else: |
|
|
if search_on: |
|
|
prompt = self.build_summary_prompt(question, self.reformat_trace(curr_prompt) or "", all_tool_calls) |
|
|
system_prompt = DEEPRESEARCH_SYS_PROMPT if deepresearch_on else SUMMARY_SYS_PROMPT |
|
|
|
|
|
summary_text = self._route_and_summarize( |
|
|
summary_llm=summary_llm if deepresearch_on else model_url, |
|
|
system_prompt=system_prompt, |
|
|
prompt=prompt, |
|
|
temperature=0.6, |
|
|
max_tokens=16000, |
|
|
deepresearch_on=deepresearch_on |
|
|
) |
|
|
|
|
|
summary_text_splits = summary_text.split("</think>") |
|
|
summary_text_initial = summary_text_splits[0] |
|
|
summary_text_initial = summary_text_initial.replace("<think>", "").strip() |
|
|
summary_text_final = summary_text_splits[-1] |
|
|
|
|
|
if len(summary_text_initial) > 0 and "</think>" in summary_text: |
|
|
yield "assistant_resp", (summary_text_initial, []) |
|
|
yield "tool_results", ([], ) |
|
|
yield "assistant_resp", (summary_text_final, tool_calls) |
|
|
|
|
|
else: |
|
|
yield "assistant_resp", (assistant_text, tool_calls) |
|
|
print(f"Search is off, so no tool calls expected and no tool calls called.\nAssistant Response:\n{assistant_text}\n\n") |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return "end", (curr_prompt, ) |
|
|
|
|
|
@retry(max=5, sleep=1, fallback={"score": 0}) |
|
|
def run_deepseek( |
|
|
self, |
|
|
env: str, |
|
|
func_schemas: str, |
|
|
question: str, |
|
|
model_name: str, |
|
|
temperature: float = 0.0, |
|
|
top_p: float = 0.95, |
|
|
max_tokens: int = 32768, |
|
|
): |
|
|
|
|
|
""" |
|
|
Chat-based ReCall loop for DeepSeek-R1 on Together. |
|
|
""" |
|
|
sys_content = self.sys_prompt_websailor_deepseek.format(func_schemas=func_schemas) |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": sys_content}, |
|
|
{"role": "user", "content": question}, |
|
|
] |
|
|
|
|
|
|
|
|
client = Together(api_key="") |
|
|
all_tool_calls = [] |
|
|
for turn in range(32): |
|
|
resp = client.chat.completions.create( |
|
|
model=model_name, |
|
|
|
|
|
messages=messages, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
max_tokens=39000, |
|
|
stop=["<ο½endβofβsentenceο½>", "<|im_end|>"] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
assistant_text = resp.choices[0].message.content |
|
|
|
|
|
messages.append({"role": "assistant", "content": assistant_text}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tool_calls = self.extract_tool_calls_deepseek(assistant_text) |
|
|
print(tool_calls) |
|
|
all_tool_calls += tool_calls |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "<answer>" in assistant_text: |
|
|
break |
|
|
|
|
|
if len(tool_calls) != 0: |
|
|
results = self.execute_tool_calls(env, tool_calls) |
|
|
tool_resp_block = "".join( |
|
|
f"<tool_response>{c}\n{r}\n</tool_response>\n" |
|
|
for c, r in zip(tool_calls, results) |
|
|
) |
|
|
messages.append({"role": "user", "content": tool_resp_block}) |
|
|
|
|
|
else: |
|
|
print("no answer or tool call") |
|
|
break |
|
|
|
|
|
trajectory = "\n".join( |
|
|
f"<{m['role']}>\n{m['content']}" for m in messages |
|
|
if m["role"] != "system" |
|
|
) |
|
|
return trajectory, all_tool_calls |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@retry(max=5, sleep=1, fallback=None) |
|
|
def run_with_prompt_injection( |
|
|
self, |
|
|
env: str, |
|
|
func_schemas: str, |
|
|
question: str, |
|
|
model_url: str = "http://0.0.0.0:1214", |
|
|
temperature: float = 0.0, |
|
|
max_new_tokens: int = 512, |
|
|
top_n: int = 5, |
|
|
): |
|
|
""" |
|
|
0) call pubmed_search(question, top_n) once via the sandbox |
|
|
1) inject those snippets into the very first user message |
|
|
2) continue with the normal multi-turn ReCall loop against *model_url* |
|
|
""" |
|
|
|
|
|
|
|
|
retrieve_call = json.dumps({ |
|
|
"name": "pubmed_search", |
|
|
"arguments": {"query": question, "top_n": top_n} |
|
|
}) |
|
|
retrieval_raw = self.execute_tool_calls(env, [retrieve_call])[0] |
|
|
try: |
|
|
snippets_block = retrieval_raw.split("result:", 1)[-1].strip() |
|
|
except Exception: |
|
|
snippets_block = "" |
|
|
|
|
|
|
|
|
user_msg = ( |
|
|
f"Question: {question}\n\n" |
|
|
"Here are some relevant PubMed snippets:\n" |
|
|
f"{snippets_block}" |
|
|
) if snippets_block else f"Question: {question}" |
|
|
|
|
|
sys_prompt = self.init_prompt(func_schemas, question) |
|
|
system_prompt = f"<|im_start|>system\n{sys_prompt}<|im_end|>" |
|
|
user_prompt = f"<|im_start|>user\n{user_msg}<|im_end|>" |
|
|
assistant_pref= f"<|im_start|>assistant\n<think>" |
|
|
curr_prompt = system_prompt + "\n" + user_prompt + "\n" + assistant_pref |
|
|
|
|
|
|
|
|
for _ in range(10): |
|
|
resp = requests.post( |
|
|
f"{model_url}/generate", |
|
|
json={ |
|
|
"text": curr_prompt, |
|
|
"sampling_params": { |
|
|
"temperature": temperature, |
|
|
"max_new_tokens": max_new_tokens, |
|
|
} |
|
|
}, |
|
|
timeout=120, |
|
|
).json() |
|
|
if "error" in resp.keys(): |
|
|
print("resp", resp) |
|
|
assistant_txt = resp["text"] |
|
|
curr_prompt = self.cat_assistant_response(curr_prompt, assistant_txt) |
|
|
|
|
|
tool_calls = self.extract_tool_calls(assistant_txt) |
|
|
if len(tool_calls) != 0: |
|
|
|
|
|
|
|
|
results = self.execute_tool_calls(env, tool_calls) |
|
|
curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results) |
|
|
|
|
|
else: |
|
|
continue |
|
|
return curr_prompt |
|
|
|
|
|
|
|
|
|
|
|
@retry(max=5, sleep=1, fallback={"score": 0}) |
|
|
def run_budget( |
|
|
self, |
|
|
env: str, |
|
|
func_schemas: str, |
|
|
question: str, |
|
|
model_url: str = "http://0.0.0.0:1214", |
|
|
temperature: float = 0.0, |
|
|
max_new_tokens: int = 2048, |
|
|
) -> str: |
|
|
""" |
|
|
Execute an agentic dialogue with external tools while *pruning* previous |
|
|
<tool_response> blocks to prevent context-length explosion. |
|
|
""" |
|
|
curr_prompt = self.init_prompt(func_schemas, question) |
|
|
|
|
|
for _ in range(16): |
|
|
|
|
|
rsp = requests.post( |
|
|
f"{model_url}/generate", |
|
|
json={ |
|
|
"text": curr_prompt, |
|
|
"sampling_params": { |
|
|
"temperature": temperature, |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"stop": ["<|im_end|>", "</think>", "</think>\n" "</think>\n\n"], |
|
|
}, |
|
|
|
|
|
}, |
|
|
timeout=120, |
|
|
).json() |
|
|
generated = rsp["text"] |
|
|
matched = rsp["meta_info"]["finish_reason"].get("matched") |
|
|
|
|
|
|
|
|
if matched and not generated.endswith(matched): |
|
|
generated += matched |
|
|
|
|
|
|
|
|
if "error" in rsp: |
|
|
raise RuntimeError(rsp["error"]) |
|
|
|
|
|
assistant_text: str = rsp["text"] |
|
|
curr_prompt = self.cat_assistant_response(curr_prompt, assistant_text) |
|
|
|
|
|
|
|
|
if "<answer>" in assistant_text: |
|
|
break |
|
|
|
|
|
|
|
|
tool_calls: List[str] = self.extract_tool_calls(assistant_text) |
|
|
if not tool_calls: |
|
|
continue |
|
|
|
|
|
results: List[str] = self.execute_tool_calls(env, tool_calls) |
|
|
|
|
|
|
|
|
|
|
|
curr_prompt =self. _strip_old_tool_responses(curr_prompt) |
|
|
|
|
|
|
|
|
curr_prompt = self.cat_tool_results(curr_prompt, tool_calls, results) |
|
|
|
|
|
return curr_prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _strip_old_tool_responses_msgs(self, messages: list[dict]) -> list[dict]: |
|
|
""" |
|
|
Return a copy of `messages` with every *user* message that starts with |
|
|
<tool_response> removed. Keeps assistant turns untouched. |
|
|
""" |
|
|
return [ |
|
|
m for m in messages |
|
|
if not (m["role"] == "user" and m["content"].lstrip().startswith("<tool_response>")) |
|
|
] |
|
|
|
|
|
@retry(max=5, sleep=1, fallback={"score": 0}) |
|
|
def run_deepseek_budget( |
|
|
self, |
|
|
env: str, |
|
|
func_schemas: str, |
|
|
question: str, |
|
|
api_key: str, |
|
|
model_name: str, |
|
|
temperature: float = 0.0, |
|
|
top_p: float = 0.95, |
|
|
max_tokens: int = 32768, |
|
|
max_turns: int = 10, |
|
|
): |
|
|
""" |
|
|
Chat-based ReCall loop for DeepSeek-R1 **with context-budget pruning**. |
|
|
Keeps only the *latest* <tool_response> block to avoid prompt bloat. |
|
|
""" |
|
|
sys_content = self.system_prompt_budget.format(func_schemas=func_schemas) |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": sys_content}, |
|
|
{"role": "user", "content": question}, |
|
|
] |
|
|
|
|
|
client = Together(api_key=api_key) |
|
|
|
|
|
for turn in range(max_turns): |
|
|
|
|
|
resp = client.chat.completions.create( |
|
|
model=model_name, |
|
|
messages=messages, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
max_tokens=max_tokens, |
|
|
stop=["</tool_call>", "<ο½endβofβsentenceο½>"], |
|
|
) |
|
|
assistant_text = resp.choices[0].message.content |
|
|
messages.append({"role": "assistant", "content": assistant_text}) |
|
|
|
|
|
print(f"**assistant** \n {assistant_text}") |
|
|
|
|
|
|
|
|
if "<answer>" in assistant_text: |
|
|
break |
|
|
|
|
|
|
|
|
tool_calls = self.extract_tool_calls(assistant_text) |
|
|
print(f"**tool_calls** \n {tool_calls}") |
|
|
if not tool_calls: |
|
|
continue |
|
|
|
|
|
|
|
|
results = self.execute_tool_calls(env, tool_calls) |
|
|
print(f"**tool_response** \n {results}") |
|
|
|
|
|
|
|
|
messages = self._strip_old_tool_responses_msgs(messages) |
|
|
|
|
|
tool_resp_block = "".join( |
|
|
f"<tool_response>{c}\n{r}\n</tool_response>\n" |
|
|
for c, r in zip(tool_calls, results) |
|
|
) |
|
|
messages.append({"role": "user", "content": tool_resp_block}) |
|
|
|
|
|
|
|
|
trajectory = "\n".join( |
|
|
f"<{m['role']}>\n{m['content']}" for m in messages if m["role"] != "system" |
|
|
) |
|
|
return trajectory |
|
|
|
|
|
|
|
|
@retry(max=5, sleep=1, fallback=None) |
|
|
def run_deepseek_with_prompt_injection( |
|
|
self, |
|
|
env: str, |
|
|
func_schemas: str, |
|
|
question: str, |
|
|
api_key: str, |
|
|
model_name: str, |
|
|
temperature: float = 0.0, |
|
|
top_p: float = 0.95, |
|
|
max_tokens: int = 32768, |
|
|
): |
|
|
""" |
|
|
1) Call pubmed_search(question, top_n=5) as a tool to get snippets. |
|
|
2) Inject them into the first user message. |
|
|
3) Proceed with the usual DeepSeek-R1 toolβbased rollout. |
|
|
""" |
|
|
|
|
|
|
|
|
retrieve_call = json.dumps({ |
|
|
"name": "pubmed_search", |
|
|
"arguments": { |
|
|
"query": question, |
|
|
"top_n": 5 |
|
|
} |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raw_retrieval_results = self.execute_tool_calls(env, [retrieve_call])[0] |
|
|
|
|
|
try: |
|
|
snippets = raw_retrieval_results[9:] |
|
|
|
|
|
except: |
|
|
snippets = "" |
|
|
|
|
|
|
|
|
|
|
|
if snippets: |
|
|
|
|
|
user_content = ( |
|
|
f"Question: {question}\n\n" |
|
|
"Here are some relevant PubMed snippets:\n" |
|
|
f"{snippets}" |
|
|
) |
|
|
else: |
|
|
user_content = f"Question: {question}" |
|
|
|
|
|
|
|
|
sys_content = self.system_prompt_forcing_tool_call |
|
|
messages = [ |
|
|
{"role": "system", "content": sys_content}, |
|
|
{"role": "user", "content": user_content}, |
|
|
] |
|
|
client = Together(api_key=api_key) |
|
|
|
|
|
|
|
|
for turn in range(10): |
|
|
resp = client.chat.completions.create( |
|
|
model = model_name, |
|
|
messages = messages, |
|
|
temperature = temperature, |
|
|
top_p = top_p, |
|
|
max_tokens = max_tokens, |
|
|
stop = ["</tool_call>", "<ο½endβofβsentenceο½>"] |
|
|
) |
|
|
|
|
|
assistant_text = resp.choices[0].message.content |
|
|
messages.append({"role": "assistant", "content": assistant_text}) |
|
|
|
|
|
tool_calls = self.extract_tool_calls(assistant_text) |
|
|
if not tool_calls: |
|
|
break |
|
|
|
|
|
|
|
|
results = self.execute_tool_calls(env, tool_calls) |
|
|
|
|
|
tool_resp_block = "".join( |
|
|
f"<tool_response>{call}\n{out}\n</tool_response>\n" |
|
|
for call, out in zip(tool_calls, results) |
|
|
) |
|
|
messages.append({"role": "user", "content": tool_resp_block}) |
|
|
|
|
|
|
|
|
trajectory = "\n".join( |
|
|
f"<{m['role']}>\n{m['content']}" |
|
|
for m in messages |
|
|
if m["role"] != "system" |
|
|
) |
|
|
return trajectory |
|
|
|
|
|
|