|
import json |
|
import os |
|
from typing import Dict, List |
|
|
|
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect, File, UploadFile, Header |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.templating import Jinja2Templates |
|
from pydantic import BaseModel |
|
|
|
from backend.server.websocket_manager import WebSocketManager |
|
from backend.server.server_utils import ( |
|
get_config_dict, |
|
update_environment_variables, handle_file_upload, handle_file_deletion, |
|
execute_multi_agents, handle_websocket_communication |
|
) |
|
|
|
|
|
from gpt_researcher.utils.logging_config import setup_research_logging |
|
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
logger.propagate = True |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
handlers=[ |
|
logging.StreamHandler() |
|
] |
|
) |
|
|
|
|
|
|
|
|
|
class ResearchRequest(BaseModel): |
|
task: str |
|
report_type: str |
|
agent: str |
|
|
|
|
|
class ConfigRequest(BaseModel): |
|
ANTHROPIC_API_KEY: str |
|
TAVILY_API_KEY: str |
|
LANGCHAIN_TRACING_V2: str |
|
LANGCHAIN_API_KEY: str |
|
OPENAI_API_KEY: str |
|
DOC_PATH: str |
|
RETRIEVER: str |
|
GOOGLE_API_KEY: str = '' |
|
GOOGLE_CX_KEY: str = '' |
|
BING_API_KEY: str = '' |
|
SEARCHAPI_API_KEY: str = '' |
|
SERPAPI_API_KEY: str = '' |
|
SERPER_API_KEY: str = '' |
|
SEARX_URL: str = '' |
|
XAI_API_KEY: str |
|
DEEPSEEK_API_KEY: str |
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.mount("/site", StaticFiles(directory="./frontend"), name="site") |
|
app.mount("/static", StaticFiles(directory="./frontend/static"), name="static") |
|
templates = Jinja2Templates(directory="./frontend") |
|
|
|
|
|
manager = WebSocketManager() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["http://localhost:3000"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
DOC_PATH = os.getenv("DOC_PATH", "/tmp/my-docs") |
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
def startup_event(): |
|
os.makedirs("outputs", exist_ok=True) |
|
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") |
|
os.makedirs(DOC_PATH, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
async def read_root(request: Request): |
|
return templates.TemplateResponse("index.html", {"request": request, "report": None}) |
|
|
|
|
|
@app.get("/files/") |
|
async def list_files(): |
|
files = os.listdir(DOC_PATH) |
|
print(f"Files in {DOC_PATH}: {files}") |
|
return {"files": files} |
|
|
|
|
|
@app.post("/api/multi_agents") |
|
async def run_multi_agents(): |
|
return await execute_multi_agents(manager) |
|
|
|
|
|
@app.post("/upload/") |
|
async def upload_file(file: UploadFile = File(...)): |
|
return await handle_file_upload(file, DOC_PATH) |
|
|
|
|
|
@app.delete("/files/{filename}") |
|
async def delete_file(filename: str): |
|
return await handle_file_deletion(filename, DOC_PATH) |
|
|
|
|
|
@app.websocket("/ws") |
|
async def websocket_endpoint(websocket: WebSocket): |
|
await manager.connect(websocket) |
|
try: |
|
await handle_websocket_communication(websocket, manager) |
|
except WebSocketDisconnect: |
|
await manager.disconnect(websocket) |
|
|