import re import json import copy # BASE_SURVEY_STRUCTURE = """ # # Title: A survey of ... # # Introduction: None. # # Section 1: None. # ## Subsection 1 (if needed): None. # ## Subsection 2 (if needed): None. # ### Subsubsection 1 (if needed): None. # ### Subsubsection 2 (if needed): None. # ### ... # # Section 2: None. # # ... # # Conclusion: None. # """ class SurveyManager: BASE_SURVEY_STRUCTURE = { "title": "", "abstract": "", "introduction": { "content": "" }, "sections": [], "conclusion": "" } def __init__(self): pass @staticmethod def parse_update_pos(update_pos): """ (1) "title", "abstract", "introduction", or "conclusion" (2) "section-i/subsection-j/..." """ if update_pos in ["title", "abstract", "introduction", "conclusion","plan"]: return update_pos else: keys = update_pos.split("/") if len(keys) == 1: # Section-? i = int(keys[0].lower().split("section-")[-1]) return f"section-{i}" elif len(keys) == 2: # Section-?/Subsection-? i = int(keys[0].lower().split("section-")[-1]) j = int(keys[1].lower().split("subsection-")[-1]) return f"section-{i}/subsection-{j}" elif len(keys) == 3: # Section-?/Subsection-?/Subsubsection-? i = int(keys[0].lower().split("section-")[-1]) j = int(keys[1].lower().split("subsection-")[-1]) k = int(keys[2].lower().split("subsubsection-")[-1]) return f"section-{i}/subsection-{j}/subsubsection-{k}" else: raise ValueError("unsupported update_pos keys") @staticmethod def _to_one_line(string): if isinstance(string, dict): if "content" in string and string["content"]: return SurveyManager._to_one_line(string["content"]) # return SurveyManager._to_one_line(string["content"]) else: return "[PLAN] " + string.get("plan", "").replace("\n", " ").strip() if not string: return "" else: return string#.replace("\n", " ") @staticmethod def convert_survey_dict_to_str(current_survey): string = "" if current_survey == {}: return "There is no survey." # title try: content = SurveyManager._to_one_line(current_survey["title"]) string += f"# {content}\n" except: string += f"# Title: None\n" # abstract try: content = SurveyManager._to_one_line(current_survey["abstract"]) string += f"## Abstract\n{content}\n" except: string += f"## Abstract\nNone\n" # introduction try: content = SurveyManager._to_one_line(current_survey["introduction"]) string += f"## Introduction\n{content}\n" except: string += f"## Introduction\nNone\n" # sections if "sections" in current_survey: for i, section in enumerate(current_survey["sections"]): title_key = "name" if "name" in section else "title" name, content = section[title_key], SurveyManager._to_one_line(section) # string += f"# Section-{i+1} [{name}]: {content}\n" string += f"## {name}\n{content}\n" if "subsections" in section: for j, subsection in enumerate(section["subsections"]): name, content = subsection[title_key], SurveyManager._to_one_line(subsection) # string += f" ## Subsection-{j+1} [{name}]: {content}\n" string += f"### {name}\n{content}\n" if "subsubsections" in subsection: for k, subsubsection in enumerate(subsection["subsubsections"]): name, content = subsubsection[title_key], SurveyManager._to_one_line(subsubsection) # string += f" ### Subsubsection-{k+1} [{name}]: {content}\n" string += f"#### {name}\n{content}\n" # conclusion try: content = SurveyManager._to_one_line(current_survey["conclusion"]) string += f"## Conclusion\n{content}\n" except: string += f"## Conclusion:\nNone\n" return string @staticmethod def _abbr_one_line(string, abbr=True): if isinstance(string, dict): if "content" in string and string["content"]: return SurveyManager._abbr_one_line(string["content"], abbr=abbr) elif "plan" in string: return "[PLAN] " + string["plan"].replace("\n", " ").strip() else: return "" else: if not string: return "" else: if abbr and len(string) > 50: return "[OK] " + string.replace("\n", " ").strip()[:50] + "..." else: return "[OK] " + string.replace("\n", " ").strip() @staticmethod def convert_survey_dict_to_abbr_str(current_survey): string = "" if current_survey == {}: return "There is no survey." # title try: content = SurveyManager._abbr_one_line(current_survey["title"], abbr=False) string += f"# Title: {content}\n" except: string += f"# Title: None\n" # abstract try: content = SurveyManager._abbr_one_line(current_survey["abstract"], abbr=False) string += f"# Abstract: {content}\n" except: string += f"# Abstract: None\n" # introduction try: content = SurveyManager._abbr_one_line(current_survey["introduction"]) string += f"# Introduction: {content}\n" except: string += f"# Introduction: None\n" # sections if "sections" in current_survey: for i, section in enumerate(current_survey["sections"]): title_key = "name" if "name" in section else "title" name, content = section[title_key], SurveyManager._abbr_one_line(section) string += f"# Section-{i+1} [{name}]: {content}\n" if "subsections" in section: for j, subsection in enumerate(section["subsections"]): name, content = subsection[title_key], SurveyManager._abbr_one_line(subsection) string += f" ## Subsection-{j+1} [{name}]: {content}\n" if "subsubsections" in subsection: for k, subsubsection in enumerate(subsection["subsubsections"]): name, content = subsubsection[title_key], SurveyManager._abbr_one_line(subsubsection) string += f" ### Subsubsection-{k+1} [{name}]: {content}\n" # conclusion try: content = SurveyManager._abbr_one_line(current_survey["conclusion"]) string += f"# Conclusion: {content}\n" except: string += f"# Conclusion: None\n" return string @staticmethod def update_one_section(sections, i, content): # i -= 1 if i >= 0 and i <= (len(sections)-1): sections[i]["content"] = content return True else: # print("update fail!") return False @staticmethod def update_current_survey(current_survey, answer) -> bool: """ update_pos: "section-i/subsection-j/subsubsection-k" """ # if answer == {}: # return True try: update_pos, content = answer["update"], answer["content"] if update_pos == "plan": # current_survey = content if current_survey == {}: for k,v in content.items(): current_survey[k] = copy.deepcopy(v) else: return False elif update_pos in ["conclusion", "abstract"]: if update_pos not in current_survey: # print("update fail!") return False current_survey[update_pos] = content elif update_pos == "introduction": if update_pos not in current_survey: # print("update fail!") return False current_survey[update_pos] = {"content": content} else: keys = update_pos.split("/") if len(keys) == 1: # Section-? i = int(keys[0].lower().split("section-")[-1])-1 return SurveyManager.update_one_section(current_survey["sections"], i, content) elif len(keys) == 2: # Section-?/Subsection-? i = int(keys[0].lower().split("section-")[-1])-1 j = int(keys[1].lower().split("subsection-")[-1])-1 try: return SurveyManager.update_one_section(current_survey["sections"][i]["subsections"], j, content) except: # print("update fail!") return False elif len(keys) == 3: # Section-?/Subsection-?/Subsubsection-? i = int(keys[0].lower().split("section-")[-1])-1 j = int(keys[1].lower().split("subsection-")[-1])-1 k = int(keys[2].lower().split("subsubsection-")[-1])-1 try: return SurveyManager.update_one_section(current_survey["sections"][i]["subsections"][j]["subsubsections"], k, content) # 禁用第四级 except: # print("update fail!") return False else: # print("update fail!") # print("unsupported update_pos keys") return False # raise ValueError("unsupported update_pos keys") except: # print("update fail!") return False # print("answer is not a valid json object.") # print(answer) # raise ValueError("answer is not a valid json object.") return True from prompts import * class PromptManger: system_prompt = SYSTEM_PROMPT_0415_BUFFER user_prompt_v0 = USER_PROMPT_v0_0424_BUFFER user_prompt = USER_PROMPT_0415_BUFFER class BufferManager: """ Used to manage prompts/responses generated during the Rollout phase, providing data support for subsequent training. batch_rollout_data = [ { query (or env_id): # Uniquely identifies a query or environment, [input parameter]. *running_id: # Uniquely identifies a single rollout. For cases where a query or environment is repeated multiple times, the query can be the same, but running_id will not repeat. state: { # Indicates whether the process is finished. "score": 0.0, "done": True / False "current_survey": dict # Structured data. } trajectory: [ # Organizes all data into a multi-turn interaction format. { step: int, 0~?, # The first step, usually includes some init_info or plan. original_response: str, The raw output from the model, which may have various formatting issues. answer_thought: str, # Encapsulated using the ... block. answer: { "original_str": str "update": str, "name": str, "content": str, "inclusions": list, # Extracted independently? } tool_call_thought: str, # Encapsulated using the ... block. tool_call: { "original_str": str, # Encapsulated using the ... block, used for tool invocation. In the survey setting, it is either "done" to end the task or "search". "tool_name": str # done or search. "keywords": list[str], Extracted search keywords from tool_call, otherwise none. } *papers: list[str], # Top-n papers retrieved via the search engine. Required if using the Agent-Summary-1 for collaborative optimization; otherwise, not needed. cites: list[str], # References cited by the model, which may include multiple citations. summarys: list[str], # Summaries of papers generated using Agent-Summary-1. Must include BIBKEY. *prompt_for_generator: str, # The prompt input to the generator at the current step. Required if using Agent-Summary-2 for generation and collaborative optimization; otherwise, not needed. }, ... ] }, ... ] """ def __init__(self, prompts, repeat_n: int=1): # self.config = config self.step = 0 self.batch_rollout_data = [] self.running_ids = [] # active envs batch_size = prompts.batch['input_ids'].size(0) uids = prompts.non_tensor_batch['uid'] querys = prompts.non_tensor_batch['raw_prompt'].copy() ground_truths = prompts.non_tensor_batch['ground_truth'] # print(querys) new_querys = [] for i_batch in range(batch_size): raw_prompt_i_batch = querys[i_batch][-1]["content"] new_querys.append(raw_prompt_i_batch) querys = new_querys assert len(querys) == len(uids) for query, uid, ground_truth in zip(querys, uids, ground_truths): now_survey = {} for _ in range(repeat_n): self.batch_rollout_data.append({ "query": query, "uid": uid, "state": { # "score": 0.0, # only for debug # "format_score": None, # will update at last step "done": False, "current_survey": {} }, "trajectory": [], "history_messages": [], }) @staticmethod def _build_system_prompt(): prompt = PromptManger.system_prompt return prompt @staticmethod def _build_user_prompt_v0(query, current_survey): # query prompt = PromptManger.user_prompt_v0.replace("", query) # add template prompt = prompt.replace("", SurveyManager.convert_survey_dict_to_abbr_str(current_survey)) return prompt @staticmethod def _build_user_prompt(query, current_survey, trajs): last_traj = trajs[-1] # query prompt = PromptManger.user_prompt.replace("", query) # add current survey prompt = prompt.replace("", SurveyManager.convert_survey_dict_to_abbr_str(current_survey)) # current plan if last_traj["tool_call_thought"] == "": prompt = prompt.replace("", "Your last thought is not available, please give new plan") else: prompt = prompt.replace("", last_traj["tool_call_thought"]) prompt = prompt.replace("", json.dumps(last_traj["tool_call"])) # summarys for traj in reversed(trajs): if len(traj["summarys"]) > 0: break summary_num = len(traj["summarys"]) if summary_num == 0: prompt = prompt.replace("", "There is no result.") else: prompt = prompt.replace("", f"There are {summary_num} results:\n\n" + "\n\n".join(traj["summarys"])) return prompt @staticmethod def _build_user_prompt_force_correct(query, current_survey, trajs): if current_survey == {}: # gen plan now_section = "plan" # trajs[-1]["tool_call_thought"] = "Next I will provide the plan. " else: now_section = "" if isinstance(current_survey["abstract"],dict) and "content" not in current_survey["abstract"]: now_section = "abstract" elif "content" not in current_survey["introduction"]: now_section = "introduction" elif "sections" in current_survey: for section in current_survey["sections"]: if "content" not in section: now_section = "section-{}".format(current_survey["sections"].index(section) + 1) break elif "subsections" in section: for subsection in section["subsections"]: if "content" not in subsection: now_section = "section-{}/subsection-{}".format( current_survey["sections"].index(section) + 1, section["subsections"].index(subsection) + 1 ) break elif "subsubsections" in subsection: for subsubsection in subsection["subsubsections"]: if "content" not in subsubsection: now_section = "section-{}/subsection-{}/subsubsection-{}".format( current_survey["sections"].index(section) + 1, section["subsections"].index(subsection) + 1, subsection["subsubsections"].index(subsubsection) + 1 ) break if now_section: break if now_section: break elif isinstance(current_survey["conclusion"],dict) and "content" not in current_survey["conclusion"]: now_section = "conclusion" else: trajs[-1]["tool_call_thought"] = "Next I will finalize the survey." if now_section != "": trajs[-1]["tool_call_thought"] = f"Next I will provide {now_section}" for traj in reversed(trajs): if len(traj["summarys"]) > 0: break summary_num = len(traj["summarys"]) if now_section == "plan" and summary_num == 0: trajs[-1]["tool_call_thought"] = "I need to get enough information." return BufferManager._build_user_prompt(query, current_survey, trajs) @staticmethod def _check_finalize(query, current_survey, trajs): if current_survey == {}: # gen plan return False # trajs[-1]["tool_call_thought"] = "Next I will provide the plan. " else: now_section = "" if isinstance(current_survey["abstract"],dict) and "content" not in current_survey["abstract"]: now_section = "abstract" elif "content" not in current_survey["introduction"]: now_section = "introduction" elif "sections" in current_survey: for section in current_survey["sections"]: if "content" not in section: now_section = "section-{}".format(current_survey["sections"].index(section) + 1) break elif "subsections" in section: for subsection in section["subsections"]: if "content" not in subsection: now_section = "section-{}/subsection-{}".format( current_survey["sections"].index(section) + 1, section["subsections"].index(subsection) + 1 ) break elif "subsubsections" in subsection: for subsubsection in subsection["subsubsections"]: if "content" not in subsubsection: now_section = "section-{}/subsection-{}/subsubsection-{}".format( current_survey["sections"].index(section) + 1, section["subsections"].index(subsection) + 1, subsection["subsubsections"].index(subsubsection) + 1 ) break if now_section: break if now_section: break elif isinstance(current_survey["conclusion"],dict) and "content" not in current_survey["conclusion"]: now_section = "conclusion" # else: # trajs[-1]["tool_call_thought"] = "Next I will finalize the survey." if now_section != "": return False return True # rule-based method: query, plan, paragraphs -> prompt -> thought, paragraph, action def build_prompt_for_generator(self): total_messages = [] self.running_ids = [] for running_id, data in enumerate(self.batch_rollout_data): if data["state"]["done"]: pass else: if len(data["trajectory"]) == 0: # first prompt user_prompt = BufferManager._build_user_prompt_v0(data["query"], data["state"]["current_survey"]) else: if data["trajectory"][-1]["update_success"]: user_prompt = BufferManager._build_user_prompt(data["query"], data["state"]["current_survey"], data["trajectory"]) else: # user_prompt = data["history_messages"][-1][1]["content"] user_prompt = BufferManager._build_user_prompt_force_correct(data["query"], data["state"]["current_survey"], data["trajectory"]) messages = [ { "role": "system", "content": BufferManager._build_system_prompt(), }, { "role": "user", "content": user_prompt, } ] data["history_messages"].append(messages) total_messages.append(messages) self.running_ids.append(running_id) # update running ids return total_messages def update_all_scores(self, scores): assert len(scores) == len(self.batch_rollout_data) for score, log in zip(scores, self.batch_rollout_data): log["state"]["score"] = score def update_all_format_scores(self, scores): assert len(scores) == len(self.batch_rollout_data) for score, log in zip(scores, self.batch_rollout_data): log["state"]["format_score"] = score def update_trajectory(self, model_responses, env_feedbacks): """ model_response: original_response, thought, paragraph, tool_call, format_reward env_feedback: done, search_keywards, abstracts, outcome_reward """ assert len(self.running_ids) == len(model_responses) assert len(self.running_ids) == len(env_feedbacks) for running_id, response, feedback in zip(self.running_ids, model_responses, env_feedbacks): # update state self.batch_rollout_data[running_id]["state"]["done"] = feedback["done"] # if True, finalize the task update_success = False if response["true"]: if self.batch_rollout_data[running_id]["state"]["current_survey"] != {}: if len(response["answer"]) != 0: # no empty dict or start update_success = SurveyManager.update_current_survey( self.batch_rollout_data[running_id]["state"]["current_survey"], response["answer"]) else: # Search Then Write if len(response["answer"]) != 0 and "There is no result" not in self.batch_rollout_data[running_id]["history_messages"][-1][1]["content"]: update_success = SurveyManager.update_current_survey( self.batch_rollout_data[running_id]["state"]["current_survey"], response["answer"]) elif "There is no result" in self.batch_rollout_data[running_id]["history_messages"][-1][1]["content"] and len(response["answer"]) == 0: update_success = True self.batch_rollout_data[running_id]["trajectory"].append({ "step": self.step, "original_response": response["original_response"], "answer_thought": response["answer_thought"], "answer": response["answer"], "tool_call_thought": response["tool_call_thought"], "tool_call": response["tool_call"], "search_keywords": feedback["search_keywords"], "summarys": feedback["summarys"], "update_success": update_success and response["true"], }) self.batch_rollout_data[running_id]["history_messages"][-1].append({ "role": "assistant", "content": response["original_response"], }) if self.batch_rollout_data[running_id]["state"]["done"]: real_done = BufferManager._check_finalize(self.batch_rollout_data[running_id]["query"], self.batch_rollout_data[running_id]["state"]["current_survey"], self.batch_rollout_data[running_id]["trajectory"]) if not real_done: self.batch_rollout_data[running_id]["state"]["done"] = False @staticmethod def match_reference(text:str): reg = r"\\\w*cite(?!style)\w*\{(.+?)\}" placeholder_reg = re.compile(r"^#\d+$") reg_bibkeys = re.findall(reg, text) bibkeys = set() for bibkey in reg_bibkeys: single_bib = bibkey.split(",") for bib in single_bib: if not placeholder_reg.match(bib): bib = bib.strip() if bib and bib != "*": bibkeys.add(bib) reg = r"\\nocite{(.+?)\}" reg_bibkeys = re.findall(reg, text) for bibkey in reg_bibkeys: single_bib = bibkey.split(",") for bib in single_bib: if not placeholder_reg.match(bib): bib = bib.strip() if bib and bib != "*": bibkeys.remove(bib) ref_key_list = list(bibkeys) return ref_key_list @staticmethod def parse_generator_response(response): """ 1. 解析失败: step + 1, 重新生成, 给出提示 2. 解析成功: 2.1 tool_call == search(keywords) 发送post请求 2.2 tool_call == done 结束任务 **standard format** Current Update: [Your Thoughts]: str {"update": str, "content": str}: dict Next Plan: [Your Thoughts]: str {"tool": "search", "arguments": {}}: dict """ extracted_result = { "original_response": response } try: current_update = response.split("Current Update:")[-1].split("Next Plan:")[0] except: current_update = response # pattern think_pattern = r"(.*?)" answer_pattern = r"(.*?)" tool_pattern = r"(.*?)" # extract information from current_update think_match = re.search(think_pattern, current_update, re.DOTALL) # 多行提取 if think_match: think = think_match.group(1) think = think.strip() else: think = "" extracted_result["answer_thought"] = think answer_match = re.search(answer_pattern, current_update, re.DOTALL) # 多行提取 has_answer = False if answer_match: answer = answer_match.group(1) answer = answer.strip() try: answer = json.loads(answer) if not answer == {}: assert isinstance(answer["update"], str) answer["update"] = SurveyManager.parse_update_pos(answer["update"]) if answer["update"] == "plan": assert isinstance(answer["content"], dict) plan = answer["content"] assert isinstance(plan, dict) plan.pop("instruction",None) keys = ["abstract", "introduction", "conclusion","sections","title"] for key in keys: assert key in plan for key in plan: assert key in keys if key == "sections": assert isinstance(plan[key], list) for section in plan[key]: assert isinstance(section, dict) assert "plan" in section assert "title" in section assert isinstance(section["plan"], str) assert isinstance(section["title"], str) assert section["title"] != "Methodology" # 不能是Methodology,WIP if "subsections" in section: assert isinstance(section["subsections"], list) for subsection in section["subsections"]: assert isinstance(subsection, dict) assert "plan" in subsection assert "title" in subsection assert isinstance(subsection["plan"], str) assert isinstance(subsection["title"], str) if "subsubsections" in section: assert isinstance(subsection["subsubsections"], list) for subsubsection in subsection["subsubsections"]: assert isinstance(subsubsection, dict) assert "plan" in subsubsection assert "title" in subsubsection assert isinstance(subsubsection["plan"], str) assert isinstance(subsubsection["title"], str) elif key == "title": assert isinstance(plan[key], str) else: assert isinstance(plan[key], dict) assert "plan" in plan[key] if key not in ["abstract", "conclusion", "introduction"]: assert "title" in plan[key] else: assert isinstance(answer["content"], str) has_answer = True except: answer = {} else: answer = {} extracted_result["answer"] = answer # extract information from next_plan try: next_plan = response.split("Next Plan:")[1] except: try: next_plan = response.split("")[1] except: next_plan = response think_match = re.search(think_pattern, next_plan, re.DOTALL) # 多行提取 if think_match: think = think_match.group(1) think = think.strip() else: think = "" extracted_result["tool_call_thought"] = think tool_match = re.search(tool_pattern, next_plan, re.DOTALL) # 多行提取 has_tool_call = False if tool_match: tool_text = tool_match.group(1) tool_text = tool_text.strip() try: tool_call = json.loads(tool_text) assert tool_call["name"] in ["search_engine", "finalize"] if tool_call["name"] == "search_engine": assert isinstance(tool_call["arguments"]["query"], list) has_tool_call = True except: tool_call = {} else: tool_call = {} extracted_result["tool_call"] = tool_call extracted_result["true"] = has_answer and has_tool_call reg = r"[\u4e00-\u9fa5]" has_chinese = re.search(reg, response) is not None extracted_result["true"] = extracted_result["true"] and not has_chinese return extracted_result class BufferManager_V2(BufferManager): def __init__(self, querys, repeat_n=1): # self.config = config self.step = 0 self.batch_rollout_data = [] self.running_ids = [] # active envs for uid, query in enumerate(querys): print("CURRENT QUERY: ", query) for _ in range(repeat_n): self.batch_rollout_data.append({ "query": query, "uid": f"query_{uid}", "state": { # "score": 0.0, # only for debug # "format_score": None, # will update at last step "done": False, "current_survey": {} }, "trajectory": [], "history_messages": [] })