Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, Request | |
| from fastapi.responses import JSONResponse, FileResponse, HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from rag_system import build_rag_chain, ask_question | |
| from vector_store import get_embeddings, load_vector_store | |
| from llm_loader import load_llama_model | |
| import uuid | |
| import os | |
| import shutil | |
| from urllib.parse import urljoin, quote | |
| from fastapi.responses import StreamingResponse | |
| import json | |
| import time | |
| app = FastAPI() | |
| # ์ ์ ํ์ผ ์๋น์ ์ํ ์ค์ | |
| os.makedirs("static/documents", exist_ok=True) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # ์ ์ญ ๊ฐ์ฒด ์ค๋น | |
| embeddings = get_embeddings(device="cpu") | |
| vectorstore = load_vector_store(embeddings, load_path="vector_db") | |
| llm = load_llama_model() | |
| qa_chain = build_rag_chain(llm, vectorstore, language="ko", k=7) | |
| # ์๋ฒ URL ์ค์ (์ค์ ํ๊ฒฝ์ ๋ง๊ฒ ์์ ํ์) | |
| BASE_URL = "http://220.124.155.35:8500" | |
| class Question(BaseModel): | |
| question: str | |
| def get_document_url(source_path): | |
| if not source_path or source_path == 'N/A': | |
| return None | |
| filename = os.path.basename(source_path) | |
| dataset_root = os.path.join(os.getcwd(), "dataset") | |
| # dataset ์ ์ฒด ํ์ ํด๋์์ ํ์ผ๋ช ์ผ์นํ๋ ํ์ผ ์ฐพ๊ธฐ | |
| found_path = None | |
| for root, dirs, files in os.walk(dataset_root): | |
| if filename in files: | |
| found_path = os.path.join(root, filename) | |
| break | |
| if not found_path or not os.path.exists(found_path): | |
| return None | |
| static_path = f"static/documents/{filename}" | |
| shutil.copy2(found_path, static_path) | |
| encoded_filename = quote(filename) | |
| return urljoin(BASE_URL, f"/static/documents/{encoded_filename}") | |
| def create_download_link(url, filename): | |
| return f'์ถ์ฒ: [{filename}]({url})' | |
| def ask(question: Question): | |
| result = ask_question(qa_chain, question.question) | |
| # ์์ค ๋ฌธ์ ์ ๋ณด ์ฒ๋ฆฌ | |
| sources = [] | |
| for doc in result["source_documents"]: | |
| source_path = doc.metadata.get('source', 'N/A') | |
| document_url = get_document_url(source_path) if source_path != 'N/A' else None | |
| source_info = { | |
| "source": source_path, | |
| "content": doc.page_content, | |
| "page": doc.metadata.get('page', 'N/A'), | |
| "document_url": document_url, | |
| "filename": os.path.basename(source_path) if source_path != 'N/A' else None | |
| } | |
| sources.append(source_info) | |
| return { | |
| "answer": result['result'].split("A:")[-1].strip() if "A:" in result['result'] else result['result'].strip(), | |
| "sources": sources | |
| } | |
| def list_models(): | |
| return JSONResponse({ | |
| "object": "list", | |
| "data": [ | |
| { | |
| "id": "rag", | |
| "object": "model", | |
| "owned_by": "local", | |
| } | |
| ] | |
| }) | |
| async def openai_compatible_chat(request: Request): | |
| payload = await request.json() | |
| messages = payload.get("messages", []) | |
| user_input = messages[-1]["content"] if messages else "" | |
| stream = payload.get("stream", False) | |
| result = ask_question(qa_chain, user_input) | |
| answer = result['result'] | |
| # ์์ค ๋ฌธ์ ์ ๋ณด ์ฒ๋ฆฌ | |
| sources = [] | |
| for doc in result["source_documents"]: | |
| source_path = doc.metadata.get('source', 'N/A') | |
| document_url = get_document_url(source_path) if source_path != 'N/A' else None | |
| filename = os.path.basename(source_path) if source_path != 'N/A' else None | |
| source_info = { | |
| "source": source_path, | |
| "content": doc.page_content, | |
| "page": doc.metadata.get('page', 'N/A'), | |
| "document_url": document_url, | |
| "filename": filename | |
| } | |
| sources.append(source_info) | |
| # ์์ค ์ ๋ณด๋ฅผ ํ ์ค์ฉ๋ง ์ถ๋ ฅ | |
| sources_md = "\n์ฐธ๊ณ ๋ฌธ์:\n" | |
| seen = set() | |
| for source in sources: | |
| key = (source['filename'], source['document_url']) | |
| if source['document_url'] and source['filename'] and key not in seen: | |
| sources_md += f"์ถ์ฒ: [{source['filename']}]({source['document_url']})\n" | |
| seen.add(key) | |
| final_answer = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip() | |
| final_answer += sources_md | |
| if not stream: | |
| return JSONResponse({ | |
| "id": f"chatcmpl-{uuid.uuid4()}", | |
| "object": "chat.completion", | |
| "choices": [{ | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": final_answer | |
| }, | |
| "finish_reason": "stop" | |
| }], | |
| "model": "rag", | |
| }) | |
| # ์คํธ๋ฆฌ๋ฐ ์๋ต์ ์ํ generator | |
| def event_stream(): | |
| # ๋ต๋ณ ๋ณธ๋ฌธ๋ง ๋จผ์ ์คํธ๋ฆฌ๋ฐ | |
| answer_main = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip() | |
| for char in answer_main: | |
| chunk = { | |
| "id": f"chatcmpl-{uuid.uuid4()}", | |
| "object": "chat.completion.chunk", | |
| "choices": [{ | |
| "index": 0, | |
| "delta": { | |
| "content": char | |
| }, | |
| "finish_reason": None | |
| }] | |
| } | |
| yield f"data: {json.dumps(chunk)}\n\n" | |
| time.sleep(0.005) | |
| # ์ฐธ๊ณ ๋ฌธ์(๋ค์ด๋ก๋ ๋งํฌ)๋ ๋ง์ง๋ง์ ํ ๋ฒ์ ๋ถ์ฌ์ ์ ์ก | |
| sources_md = "\n์ฐธ๊ณ ๋ฌธ์:\n" | |
| seen = set() | |
| for source in sources: | |
| key = (source['filename'], source['document_url']) | |
| if source['document_url'] and source['filename'] and key not in seen: | |
| sources_md += f"์ถ์ฒ: [{source['filename']}]({source['document_url']})\n" | |
| seen.add(key) | |
| if sources_md.strip() != "์ฐธ๊ณ ๋ฌธ์:": | |
| chunk = { | |
| "id": f"chatcmpl-{uuid.uuid4()}", | |
| "object": "chat.completion.chunk", | |
| "choices": [{ | |
| "index": 0, | |
| "delta": { | |
| "content": sources_md | |
| }, | |
| "finish_reason": None | |
| }] | |
| } | |
| yield f"data: {json.dumps(chunk)}\n\n" | |
| done = { | |
| "id": f"chatcmpl-{uuid.uuid4()}", | |
| "object": "chat.completion.chunk", | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {}, | |
| "finish_reason": "stop" | |
| }] | |
| } | |
| yield f"data: {json.dumps(done)}\n\n" | |
| return | |
| return StreamingResponse(event_stream(), media_type="text/event-stream") | |