|
from contextlib import contextmanager |
|
from codetiming import Timer |
|
@contextmanager |
|
def _timer(name: str, timing_raw): |
|
with Timer(name=name, logger=None) as timer: |
|
yield |
|
timing_raw[name] = timer.last |
|
|
|
from buffer import SurveyManager |
|
from buffer import BufferManager_V2 as BufferManager |
|
from vllm import LLM, SamplingParams |
|
from transformers import AutoTokenizer |
|
import re |
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import asyncio |
|
import argparse |
|
from pydantic import BaseModel |
|
import json |
|
import aiohttp |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
active_connections = set() |
|
|
|
@app.websocket("/ws") |
|
async def websocket_endpoint(websocket: WebSocket): |
|
await websocket.accept() |
|
active_connections.add(websocket) |
|
try: |
|
while True: |
|
await websocket.receive_text() |
|
except WebSocketDisconnect: |
|
active_connections.remove(websocket) |
|
|
|
async def post_to_frontend(payload): |
|
print(f"Sending payload to frontend: {payload}") |
|
for ws in list(active_connections): |
|
try: |
|
await ws.send_text(payload) |
|
except Exception as e: |
|
print(f"Error sending to WebSocket: {e}") |
|
active_connections.remove(ws) |
|
|
|
|
|
def write_to_json(data, path): |
|
with open(path, 'w', encoding='utf8') as f: |
|
f.write(json.dumps(data, ensure_ascii=False, indent=4)) |
|
|
|
class OriginalvLLMRollout: |
|
def __init__(self, model_name_or_path): |
|
|
|
self.rollout_model = LLM( |
|
model=model_name_or_path, |
|
tokenizer=model_name_or_path, |
|
gpu_memory_utilization=0.95, |
|
trust_remote_code=True, |
|
) |
|
self.sampling_params = SamplingParams( |
|
temperature=0.7, |
|
top_p=0.8, |
|
repetition_penalty=1.05, |
|
top_k=20, |
|
max_tokens=2748, |
|
) |
|
|
|
def generate(self, input_texts): |
|
generated_texts = [] |
|
completions = self.rollout_model.generate(input_texts, self.sampling_params, use_tqdm=False) |
|
for output in completions: |
|
generated_text = output.outputs[0].text |
|
generated_texts.append(generated_text) |
|
return generated_texts |
|
|
|
def chat(self, input_messages): |
|
generated_texts = [] |
|
completions = self.rollout_model.chat(input_messages, self.sampling_params, use_tqdm=False) |
|
for output in completions: |
|
generated_text = output.outputs[0].text |
|
generated_texts.append(generated_text) |
|
return generated_texts |
|
|
|
async def rollout_with_env(querys, batch_size, max_turns, model_path, url, |
|
deploy_port=None): |
|
""" |
|
Args: |
|
querys: [string] |
|
""" |
|
|
|
|
|
|
|
n = len(querys) // batch_size |
|
batch_querys = [] |
|
for i in range(n+1): |
|
temp_data = querys[i*batch_size: (i+1)*batch_size] |
|
if len(temp_data) > 0: |
|
batch_querys.append(temp_data) |
|
print("QUERY NUMBER with BATCH: ", [len(x) for x in batch_querys]) |
|
|
|
|
|
|
|
|
|
vllm_manager = OriginalvLLMRollout(model_path) |
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
total_rollout_data = [] |
|
for querys in batch_querys: |
|
|
|
|
|
|
|
buffer_manager = BufferManager(querys) |
|
|
|
while True: |
|
|
|
if buffer_manager.step >= max_turns: |
|
break |
|
|
|
|
|
|
|
|
|
messagess_todo = buffer_manager.build_prompt_for_generator() |
|
|
|
|
|
|
|
if len(messagess_todo) == 0: |
|
break |
|
|
|
|
|
|
|
|
|
timing_raw = {} |
|
with _timer('vllm sampling', timing_raw): |
|
|
|
response_texts = await asyncio.to_thread(vllm_manager.chat, messagess_todo) |
|
|
|
|
|
|
|
|
|
|
|
extracted_results = [] |
|
for response_text in response_texts: |
|
result = BufferManager.parse_generator_response(response_text) |
|
extracted_results.append(result) |
|
|
|
|
|
|
|
|
|
payload = { |
|
"tool_calls": [x["tool_call"] for x in extracted_results] |
|
} |
|
if buffer_manager.step <=2: |
|
payload["topk"] = 20 |
|
with _timer('get env feedback', timing_raw): |
|
|
|
async with aiohttp.ClientSession() as session: |
|
async with session.post(url, json=payload) as resp: |
|
env_response_batched = await resp.json() |
|
|
|
|
|
|
|
|
|
with _timer('postprocessing', timing_raw): |
|
buffer_manager.update_trajectory(extracted_results, env_response_batched) |
|
buffer_manager.step += 1 |
|
|
|
print(timing_raw) |
|
|
|
if deploy_port is not None: |
|
now_text = json_to_markdown(buffer_manager.batch_rollout_data[-1]) |
|
now_search_keywords= buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["search_keywords"] |
|
now_update = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["answer_thought"] |
|
next_update = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["tool_call_thought"] |
|
now_query = buffer_manager.batch_rollout_data[-1]["query"] |
|
trajs = buffer_manager.batch_rollout_data[-1]["trajectory"] |
|
updated_success = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["update_success"] |
|
if updated_success: |
|
for traj in reversed(trajs): |
|
if len(traj["summarys"]) > 0: |
|
break |
|
summary_num = len(traj["summarys"]) |
|
if summary_num == 0: |
|
summary_text = "No summaries yet." |
|
else: |
|
summary_text = "\n".join(traj["summarys"]) |
|
frontend_payload = { |
|
"markdown": now_text, |
|
"searchKeywords": now_search_keywords, |
|
"nowUpdate": now_update, |
|
"nextUpdate": next_update, |
|
"query": now_query, |
|
"papers": summary_text |
|
} |
|
frontend_payload = json.dumps(frontend_payload, ensure_ascii=False) |
|
try: |
|
await post_to_frontend(frontend_payload) |
|
except Exception as e: |
|
print(f"Error posting to frontend: {e}") |
|
|
|
|
|
|
|
for item in buffer_manager.batch_rollout_data: |
|
item["survey_text"] = SurveyManager.convert_survey_dict_to_str(item["state"]["current_survey"]) |
|
|
|
total_rollout_data.extend(buffer_manager.batch_rollout_data) |
|
|
|
|
|
|
|
del buffer_manager |
|
|
|
return total_rollout_data |
|
|
|
|
|
def json_to_markdown(json_data): |
|
text = SurveyManager.convert_survey_dict_to_str(json_data["state"]["current_survey"]) |
|
all_summarys = {} |
|
for traj in json_data["trajectory"]: |
|
for item in traj["summarys"]: |
|
split_text = item.split("\n") |
|
bibkey = split_text[0].split(":")[1].strip() |
|
title_begin_index = item.find("Title:") + len("Title:") |
|
title_end_index = item.find("Abstract:") |
|
title = item[title_begin_index:title_end_index].strip() |
|
arxivid = bibkey.split("arxivid")[-1].strip() |
|
html = f"arxiv.org/abs/{arxivid}" |
|
all_summarys[bibkey] = f"[{title}](https://{html})" |
|
|
|
reg = r"\\cite\{(.+?)\}" |
|
placeholder_reg = re.compile(r"^#\d+$") |
|
reg_bibkeys = re.findall(reg, text) |
|
bibkeys = [] |
|
for bibkey in reg_bibkeys: |
|
single_bib = bibkey.split(",") |
|
for bib in single_bib: |
|
if not placeholder_reg.match(bib): |
|
bib = bib.strip() |
|
if bib and bib != "*" and bib not in bibkeys: |
|
bibkeys.append(bib) |
|
|
|
bibkeys_index = {bibkey: i+1 for i, bibkey in enumerate(bibkeys)} |
|
|
|
def replace_bibkey(bibkey): |
|
bibkey = bibkey.group(1) |
|
single_bib = bibkey.split(",") |
|
new_bibs = [] |
|
for bib in single_bib: |
|
if not placeholder_reg.match(bib): |
|
bib = bib.strip() |
|
if bib and bib != "*": |
|
if bib in bibkeys_index: |
|
new_bibs.append(f"{bibkeys_index[bib]}") |
|
else: |
|
print(f"Warning: {bib} not found in bibkeys") |
|
if len(new_bibs) > 0: |
|
return "[" + ",".join(new_bibs) + "]" |
|
else: |
|
return "" |
|
text = re.sub(reg, replace_bibkey, text) |
|
reference_text = "\n\n".join([f"[{i}] {all_summarys[bibkey]}" for bibkey, i in bibkeys_index.items()]) |
|
text += "\n## References\n" + reference_text |
|
return text |
|
|
|
async def test_surveyGen(model_path, out_path,querys, url, deploy_port=None): |
|
|
|
total_rollout_data = await rollout_with_env(querys, 1, 1000, model_path, url, deploy_port) |
|
all_md_texts = [] |
|
for json_data in total_rollout_data: |
|
md_text = json_to_markdown(json_data) |
|
all_md_texts.append(md_text) |
|
|
|
all_md_texts = "\n\n".join(all_md_texts) |
|
with open(out_path, 'w', encoding='utf8') as f: |
|
f.write(all_md_texts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QueryRequest(BaseModel): |
|
query: str |
|
|
|
@app.post("/generate_survey") |
|
async def generate_survey(request: QueryRequest): |
|
global args |
|
|
|
model_path = args.model_path |
|
out_path = args.output_file |
|
query = request.query |
|
querys = [query] |
|
url = args.retriver_url |
|
deploy_port = args.port if args.port is not None else None |
|
try: |
|
await test_surveyGen(model_path, out_path, querys, url, deploy_port) |
|
return {"status": "success", "message": "Survey generated successfully."} |
|
except Exception as e: |
|
print(f"Error generating survey: {e}") |
|
return {"status": "error", "message": str(e)} |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Run survey generation with vLLM.") |
|
parser.add_argument("--model_path", type=str, required=True, help="Path to the model.") |
|
parser.add_argument("--query", type=str, required=True, help="Query to generate survey.") |
|
parser.add_argument("--output_file", type=str, required=True, help="Path to the output Markdown file.") |
|
parser.add_argument("--retriver_url", type=str, default="http://localhost:8400", help="URL of the retriever service.") |
|
parser.add_argument("--port", type=str, default=None, help="Deploy port, default is None, which means not deploy.") |
|
args = parser.parse_args() |
|
|
|
if args.port is not None: |
|
import uvicorn |
|
uvicorn.run(app, host="localhost", port=int(args.port)) |
|
|
|
|
|
else: |
|
asyncio.run( |
|
test_surveyGen( |
|
model_path=args.model_path, |
|
out_path=args.output_file, |
|
querys=[args.query], |
|
url=args.retriver_url |
|
) |
|
) |
|
|