|
import json |
|
import os |
|
import re |
|
import time |
|
import shutil |
|
from typing import Dict, List, Any |
|
from fastapi.responses import JSONResponse, FileResponse |
|
from gpt_researcher.document.document import DocumentLoader |
|
from backend.utils import write_md_to_pdf, write_md_to_word, write_text_to_md |
|
from pathlib import Path |
|
from datetime import datetime |
|
from fastapi import HTTPException |
|
import logging |
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
logger = logging.getLogger(__name__) |
|
|
|
class CustomLogsHandler: |
|
"""Custom handler to capture streaming logs from the research process""" |
|
def __init__(self, websocket, task: str): |
|
self.logs = [] |
|
self.websocket = websocket |
|
sanitized_filename = sanitize_filename(f"task_{int(time.time())}_{task}") |
|
self.log_file = os.path.join("/tmp/outputs", f"{sanitized_filename}.json") |
|
self.timestamp = datetime.now().isoformat() |
|
|
|
os.makedirs("/tmp/outputs", exist_ok=True) |
|
with open(self.log_file, 'w') as f: |
|
json.dump({ |
|
"timestamp": self.timestamp, |
|
"events": [], |
|
"content": { |
|
"query": "", |
|
"sources": [], |
|
"context": [], |
|
"report": "", |
|
"costs": 0.0 |
|
} |
|
}, f, indent=2) |
|
|
|
async def send_json(self, data: Dict[str, Any]) -> None: |
|
"""Store log data and send to websocket""" |
|
|
|
if self.websocket: |
|
await self.websocket.send_json(data) |
|
|
|
|
|
with open(self.log_file, 'r') as f: |
|
log_data = json.load(f) |
|
|
|
|
|
if data.get('type') == 'logs': |
|
log_data['events'].append({ |
|
"timestamp": datetime.now().isoformat(), |
|
"type": "event", |
|
"data": data |
|
}) |
|
else: |
|
|
|
log_data['content'].update(data) |
|
|
|
|
|
with open(self.log_file, 'w') as f: |
|
json.dump(log_data, f, indent=2) |
|
logger.debug(f"Log entry written to: {self.log_file}") |
|
|
|
|
|
class Researcher: |
|
def __init__(self, query: str, report_type: str = "research_report"): |
|
self.query = query |
|
self.report_type = report_type |
|
|
|
self.research_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{hash(query)}" |
|
|
|
self.logs_handler = CustomLogsHandler(self.research_id) |
|
self.researcher = GPTResearcher( |
|
query=query, |
|
report_type=report_type, |
|
websocket=self.logs_handler |
|
) |
|
|
|
async def research(self) -> dict: |
|
"""Conduct research and return paths to generated files""" |
|
await self.researcher.conduct_research() |
|
report = await self.researcher.write_report() |
|
|
|
|
|
sanitized_filename = sanitize_filename(f"task_{int(time.time())}_{self.query}") |
|
file_paths = await generate_report_files(report, sanitized_filename) |
|
|
|
|
|
json_relative_path = os.path.relpath(self.logs_handler.log_file) |
|
|
|
return { |
|
"output": { |
|
**file_paths, |
|
"json": json_relative_path |
|
} |
|
} |
|
|
|
def sanitize_filename(filename: str) -> str: |
|
|
|
prefix, timestamp, *task_parts = filename.split('_') |
|
task = '_'.join(task_parts) |
|
|
|
|
|
|
|
max_task_length = 255 - 8 - 5 - 10 - 6 - 10 |
|
|
|
|
|
truncated_task = task[:max_task_length] if len(task) > max_task_length else task |
|
|
|
|
|
sanitized = f"{prefix}_{timestamp}_{truncated_task}" |
|
return re.sub(r"[^\w\s-]", "", sanitized).strip() |
|
|
|
|
|
async def handle_start_command(websocket, data: str, manager): |
|
json_data = json.loads(data[6:]) |
|
task, report_type, source_urls, document_urls, tone, headers, report_source = extract_command_data( |
|
json_data) |
|
|
|
if not task or not report_type: |
|
print("Error: Missing task or report_type") |
|
return |
|
|
|
|
|
logs_handler = CustomLogsHandler(websocket, task) |
|
|
|
await logs_handler.send_json({ |
|
"query": task, |
|
"sources": [], |
|
"context": [], |
|
"report": "" |
|
}) |
|
|
|
sanitized_filename = sanitize_filename(f"task_{int(time.time())}_{task}") |
|
|
|
report = await manager.start_streaming( |
|
task, |
|
report_type, |
|
report_source, |
|
source_urls, |
|
document_urls, |
|
tone, |
|
websocket, |
|
headers |
|
) |
|
report = str(report) |
|
file_paths = await generate_report_files(report, sanitized_filename) |
|
|
|
file_paths["json"] = os.path.relpath(logs_handler.log_file) |
|
await send_file_paths(websocket, file_paths) |
|
|
|
|
|
async def handle_human_feedback(data: str): |
|
feedback_data = json.loads(data[14:]) |
|
print(f"Received human feedback: {feedback_data}") |
|
|
|
|
|
async def handle_chat(websocket, data: str, manager): |
|
json_data = json.loads(data[4:]) |
|
print(f"Received chat message: {json_data.get('message')}") |
|
await manager.chat(json_data.get("message"), websocket) |
|
|
|
async def generate_report_files(report: str, filename: str) -> Dict[str, str]: |
|
pdf_path = await write_md_to_pdf(report, filename) |
|
docx_path = await write_md_to_word(report, filename) |
|
md_path = await write_text_to_md(report, filename) |
|
return {"pdf": pdf_path, "docx": docx_path, "md": md_path} |
|
|
|
|
|
async def send_file_paths(websocket, file_paths: Dict[str, str]): |
|
await websocket.send_json({"type": "path", "output": file_paths}) |
|
|
|
|
|
def get_config_dict( |
|
langchain_api_key: str, openai_api_key: str, tavily_api_key: str, |
|
google_api_key: str, google_cx_key: str, bing_api_key: str, |
|
searchapi_api_key: str, serpapi_api_key: str, serper_api_key: str, searx_url: str |
|
) -> Dict[str, str]: |
|
return { |
|
"LANGCHAIN_API_KEY": langchain_api_key or os.getenv("LANGCHAIN_API_KEY", ""), |
|
"OPENAI_API_KEY": openai_api_key or os.getenv("OPENAI_API_KEY", ""), |
|
"TAVILY_API_KEY": tavily_api_key or os.getenv("TAVILY_API_KEY", ""), |
|
"GOOGLE_API_KEY": google_api_key or os.getenv("GOOGLE_API_KEY", ""), |
|
"GOOGLE_CX_KEY": google_cx_key or os.getenv("GOOGLE_CX_KEY", ""), |
|
"BING_API_KEY": bing_api_key or os.getenv("BING_API_KEY", ""), |
|
"SEARCHAPI_API_KEY": searchapi_api_key or os.getenv("SEARCHAPI_API_KEY", ""), |
|
"SERPAPI_API_KEY": serpapi_api_key or os.getenv("SERPAPI_API_KEY", ""), |
|
"SERPER_API_KEY": serper_api_key or os.getenv("SERPER_API_KEY", ""), |
|
"SEARX_URL": searx_url or os.getenv("SEARX_URL", ""), |
|
"LANGCHAIN_TRACING_V2": os.getenv("LANGCHAIN_TRACING_V2", "true"), |
|
"DOC_PATH": os.getenv("DOC_PATH", "/tmp/my-docs"), |
|
"RETRIEVER": os.getenv("RETRIEVER", ""), |
|
"EMBEDDING_MODEL": os.getenv("OPENAI_EMBEDDING_MODEL", "") |
|
} |
|
|
|
|
|
def update_environment_variables(config: Dict[str, str]): |
|
for key, value in config.items(): |
|
os.environ[key] = value |
|
|
|
|
|
async def handle_file_upload(file, DOC_PATH: str) -> Dict[str, str]: |
|
file_path = os.path.join(DOC_PATH, os.path.basename(file.filename)) |
|
with open(file_path, "wb") as buffer: |
|
shutil.copyfileobj(file.file, buffer) |
|
print(f"File uploaded to {file_path}") |
|
|
|
document_loader = DocumentLoader(DOC_PATH) |
|
await document_loader.load() |
|
|
|
return {"filename": file.filename, "path": file_path} |
|
|
|
|
|
async def handle_file_deletion(filename: str, DOC_PATH: str) -> JSONResponse: |
|
file_path = os.path.join(DOC_PATH, os.path.basename(filename)) |
|
if os.path.exists(file_path): |
|
os.remove(file_path) |
|
print(f"File deleted: {file_path}") |
|
return JSONResponse(content={"message": "File deleted successfully"}) |
|
else: |
|
print(f"File not found: {file_path}") |
|
return JSONResponse(status_code=404, content={"message": "File not found"}) |
|
|
|
|
|
async def execute_multi_agents(manager) -> Any: |
|
websocket = manager.active_connections[0] if manager.active_connections else None |
|
if websocket: |
|
report = await run_research_task("Is AI in a hype cycle?", websocket, stream_output) |
|
return {"report": report} |
|
else: |
|
return JSONResponse(status_code=400, content={"message": "No active WebSocket connection"}) |
|
|
|
|
|
async def handle_websocket_communication(websocket, manager): |
|
while True: |
|
data = await websocket.receive_text() |
|
if data.startswith("start"): |
|
await handle_start_command(websocket, data, manager) |
|
elif data.startswith("human_feedback"): |
|
await handle_human_feedback(data) |
|
elif data.startswith("chat"): |
|
await handle_chat(websocket, data, manager) |
|
else: |
|
print("Error: Unknown command or not enough parameters provided.") |
|
|
|
|
|
def extract_command_data(json_data: Dict) -> tuple: |
|
return ( |
|
json_data.get("task"), |
|
json_data.get("report_type"), |
|
json_data.get("source_urls"), |
|
json_data.get("document_urls"), |
|
json_data.get("tone"), |
|
json_data.get("headers", {}), |
|
json_data.get("report_source") |
|
) |
|
|