Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
import os | |
from rag_engine import RagEngine | |
from starlette.responses import JSONResponse | |
from starlette.status import HTTP_400_BAD_REQUEST | |
from fastapi.responses import StreamingResponse | |
import asyncio | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # or specify your allowed origins | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"],) | |
UPLOAD_FOLDER = "uploads" | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
rag = RagEngine() | |
async def upload_pdf(file: UploadFile = File(...)): | |
if not file.filename.endswith(".pdf"): | |
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Only pdf files are supported") | |
filename = os.path.basename(file.filename) # simple sanitization | |
filepath = os.path.join(UPLOAD_FOLDER, filename) | |
# Save uploaded file to disk | |
with open(filepath, "wb") as buffer: | |
content = await file.read() | |
buffer.write(content) | |
try: | |
rag.index_pdf(filepath) | |
except ValueError as ve: | |
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=str(ve)) | |
return JSONResponse(content={"message": f"file {filename} uploaded and indexed successfully"}) | |
async def stream_answer(request:Request ): | |
data = await request.json() | |
question = data.get("question", "") | |
print(question) | |
if not question.strip(): | |
raise HTTPException(status_code=400, detail="Empty question") | |
async def generate(): | |
# Assuming rag.ask_question_stream is a generator | |
for token in rag.stream_answer(question): | |
yield token | |
await asyncio.sleep(0) # yield control to event loop | |
return StreamingResponse((f"data: {token}\n\n" for token in rag.stream_answer(question)), | |
media_type="text/event-stream") |