open-webui-rag-system / rag_server.py
hugging2021's picture
Update rag_server.py
a88526d verified
raw
history blame
7.13 kB
import os
import re
import glob
import time
from collections import defaultdict
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from rag_system import build_rag_chain, ask_question
from vector_store import get_embeddings, load_vector_store
from llm_loader import load_llama_model
import uuid
from urllib.parse import urljoin, quote
from fastapi.responses import StreamingResponse
import json
import time
app = FastAPI()
# Configuration for serving static files
os.makedirs("static/documents", exist_ok=True)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Prepare global objects
embeddings = get_embeddings(device="cpu")
vectorstore = load_vector_store(embeddings, load_path="vector_db")
llm = load_llama_model()
qa_chain = build_rag_chain(llm, vectorstore, language="en", k=7)
# Server URL configuration (adjust to match your actual environment)
BASE_URL = "http://220.124.155.35:8500"
class Question(BaseModel):
question: str
def get_document_url(source_path):
if not source_path or source_path == 'N/A':
return None
filename = os.path.basename(source_path)
dataset_root = os.path.join(os.getcwd(), "dataset")
# Find file matching filename in the entire dataset subdirectory
found_path = None
for root, dirs, files in os.walk(dataset_root):
if filename in files:
found_path = os.path.join(root, filename)
break
if not found_path or not os.path.exists(found_path):
return None
static_path = f"static/documents/{filename}"
shutil.copy2(found_path, static_path)
encoded_filename = quote(filename)
return urljoin(BASE_URL, f"/static/documents/{encoded_filename}")
def create_download_link(url, filename):
return f'Source: [{filename}]({url})'
@app.post("/ask")
def ask(question: Question):
result = ask_question(qa_chain, question.question)
# Process source document information
sources = []
for doc in result["source_documents"]:
source_path = doc.metadata.get('source', 'N/A')
document_url = get_document_url(source_path) if source_path != 'N/A' else None
source_info = {
"source": source_path,
"content": doc.page_content,
"page": doc.metadata.get('page', 'N/A'),
"document_url": document_url,
"filename": os.path.basename(source_path) if source_path != 'N/A' else None
}
sources.append(source_info)
return {
"answer": result['result'].split("A:")[-1].strip() if "A:" in result['result'] else result['result'].strip(),
"sources": sources
}
@app.get("/v1/models")
def list_models():
return JSONResponse({
"object": "list",
"data": [
{
"id": "rag",
"object": "model",
"owned_by": "local",
}
]
})
@app.post("/v1/chat/completions")
async def openai_compatible_chat(request: Request):
payload = await request.json()
messages = payload.get("messages", [])
user_input = messages[-1]["content"] if messages else ""
stream = payload.get("stream", False)
result = ask_question(qa_chain, user_input)
answer = result['result']
# Process source document information
sources = []
for doc in result["source_documents"]:
source_path = doc.metadata.get('source', 'N/A')
document_url = get_document_url(source_path) if source_path != 'N/A' else None
filename = os.path.basename(source_path) if source_path != 'N/A' else None
source_info = {
"source": source_path,
"content": doc.page_content,
"page": doc.metadata.get('page', 'N/A'),
"document_url": document_url,
"filename": filename
}
sources.append(source_info)
# Output source information one line at a time
sources_md = "\nReferences Documents:\n"
seen = set()
for source in sources:
key = (source['filename'], source['document_url'])
if source['document_url'] and source['filename'] and key not in seen:
sources_md += f"Source: [{source['filename']}]({source['document_url']})\n"
seen.add(key)
final_answer = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip()
final_answer += sources_md
if not stream:
return JSONResponse({
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": final_answer
},
"finish_reason": "stop"
}],
"model": "rag",
})
# Generator for streaming response
def event_stream():
# Stream only the answer body first
answer_main = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip()
for char in answer_main:
chunk = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"choices": [{
"index": 0,
"delta": {
"content": char
},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
time.sleep(0.005)
# Send reference documents (download links) all at once at the end
sources_md = "\nReferences Documents:\n"
seen = set()
for source in sources:
key = (source['filename'], source['document_url'])
if source['document_url'] and source['filename'] and key not in seen:
sources_md += f"Source: [{source['filename']}]({source['document_url']})\n"
seen.add(key)
if sources_md.strip() != "References Documents:":
chunk = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"choices": [{
"index": 0,
"delta": {
"content": sources_md
},
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
done = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(done)}\n\n"
return
return StreamingResponse(event_stream(), media_type="text/event-stream")