hugging2021 commited on
Commit
a88526d
·
verified ·
1 Parent(s): bd2e020

Update rag_server.py

Browse files
Files changed (1) hide show
  1. rag_server.py +29 -20
rag_server.py CHANGED
@@ -1,3 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, Request
2
  from fastapi.responses import JSONResponse, FileResponse, HTMLResponse
3
  from fastapi.staticfiles import StaticFiles
@@ -6,8 +17,6 @@ from rag_system import build_rag_chain, ask_question
6
  from vector_store import get_embeddings, load_vector_store
7
  from llm_loader import load_llama_model
8
  import uuid
9
- import os
10
- import shutil
11
  from urllib.parse import urljoin, quote
12
 
13
  from fastapi.responses import StreamingResponse
@@ -16,17 +25,17 @@ import time
16
 
17
  app = FastAPI()
18
 
19
- # 정적 파일 서빙을 위한 설정
20
  os.makedirs("static/documents", exist_ok=True)
21
  app.mount("/static", StaticFiles(directory="static"), name="static")
22
 
23
- # 전역 객체 준비
24
  embeddings = get_embeddings(device="cpu")
25
  vectorstore = load_vector_store(embeddings, load_path="vector_db")
26
  llm = load_llama_model()
27
- qa_chain = build_rag_chain(llm, vectorstore, language="ko", k=7)
28
 
29
- # 서버 URL 설정 (실제 환경에 맞게 수정 필요)
30
  BASE_URL = "http://220.124.155.35:8500"
31
 
32
  class Question(BaseModel):
@@ -37,7 +46,7 @@ def get_document_url(source_path):
37
  return None
38
  filename = os.path.basename(source_path)
39
  dataset_root = os.path.join(os.getcwd(), "dataset")
40
- # dataset 전체 하위 폴더에서 파일명 일치하는 파일 찾기
41
  found_path = None
42
  for root, dirs, files in os.walk(dataset_root):
43
  if filename in files:
@@ -51,13 +60,13 @@ def get_document_url(source_path):
51
  return urljoin(BASE_URL, f"/static/documents/{encoded_filename}")
52
 
53
  def create_download_link(url, filename):
54
- return f'출처: [{filename}]({url})'
55
 
56
  @app.post("/ask")
57
  def ask(question: Question):
58
  result = ask_question(qa_chain, question.question)
59
 
60
- # 소스 문서 정보 처리
61
  sources = []
62
  for doc in result["source_documents"]:
63
  source_path = doc.metadata.get('source', 'N/A')
@@ -100,7 +109,7 @@ async def openai_compatible_chat(request: Request):
100
  result = ask_question(qa_chain, user_input)
101
  answer = result['result']
102
 
103
- # 소스 문서 정보 처리
104
  sources = []
105
  for doc in result["source_documents"]:
106
  source_path = doc.metadata.get('source', 'N/A')
@@ -116,13 +125,13 @@ async def openai_compatible_chat(request: Request):
116
  }
117
  sources.append(source_info)
118
 
119
- # 소스 정보를 줄씩만 출력
120
- sources_md = "\n참고 문서:\n"
121
  seen = set()
122
  for source in sources:
123
  key = (source['filename'], source['document_url'])
124
  if source['document_url'] and source['filename'] and key not in seen:
125
- sources_md += f"출처: [{source['filename']}]({source['document_url']})\n"
126
  seen.add(key)
127
 
128
  final_answer = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip()
@@ -143,9 +152,9 @@ async def openai_compatible_chat(request: Request):
143
  "model": "rag",
144
  })
145
 
146
- # 스트리밍 응답을 위한 generator
147
  def event_stream():
148
- # 답변 본문만 먼저 스트리밍
149
  answer_main = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip()
150
  for char in answer_main:
151
  chunk = {
@@ -161,15 +170,15 @@ async def openai_compatible_chat(request: Request):
161
  }
162
  yield f"data: {json.dumps(chunk)}\n\n"
163
  time.sleep(0.005)
164
- # 참고 문서(다운로드 링크) 마지막에 번에 붙여서 전송
165
- sources_md = "\n참고 문서:\n"
166
  seen = set()
167
  for source in sources:
168
  key = (source['filename'], source['document_url'])
169
  if source['document_url'] and source['filename'] and key not in seen:
170
- sources_md += f"출처: [{source['filename']}]({source['document_url']})\n"
171
  seen.add(key)
172
- if sources_md.strip() != "참고 문서:":
173
  chunk = {
174
  "id": f"chatcmpl-{uuid.uuid4()}",
175
  "object": "chat.completion.chunk",
@@ -194,4 +203,4 @@ async def openai_compatible_chat(request: Request):
194
  yield f"data: {json.dumps(done)}\n\n"
195
  return
196
 
197
- return StreamingResponse(event_stream(), media_type="text/event-stream")
 
1
+ import os
2
+ import re
3
+ import glob
4
+ import time
5
+ from collections import defaultdict
6
+
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_core.documents import Document
9
+ from langchain_community.embeddings import HuggingFaceEmbeddings
10
+ from langchain_community.vectorstores import FAISS
11
+
12
  from fastapi import FastAPI, Request
13
  from fastapi.responses import JSONResponse, FileResponse, HTMLResponse
14
  from fastapi.staticfiles import StaticFiles
 
17
  from vector_store import get_embeddings, load_vector_store
18
  from llm_loader import load_llama_model
19
  import uuid
 
 
20
  from urllib.parse import urljoin, quote
21
 
22
  from fastapi.responses import StreamingResponse
 
25
 
26
  app = FastAPI()
27
 
28
+ # Configuration for serving static files
29
  os.makedirs("static/documents", exist_ok=True)
30
  app.mount("/static", StaticFiles(directory="static"), name="static")
31
 
32
+ # Prepare global objects
33
  embeddings = get_embeddings(device="cpu")
34
  vectorstore = load_vector_store(embeddings, load_path="vector_db")
35
  llm = load_llama_model()
36
+ qa_chain = build_rag_chain(llm, vectorstore, language="en", k=7)
37
 
38
+ # Server URL configuration (adjust to match your actual environment)
39
  BASE_URL = "http://220.124.155.35:8500"
40
 
41
  class Question(BaseModel):
 
46
  return None
47
  filename = os.path.basename(source_path)
48
  dataset_root = os.path.join(os.getcwd(), "dataset")
49
+ # Find file matching filename in the entire dataset subdirectory
50
  found_path = None
51
  for root, dirs, files in os.walk(dataset_root):
52
  if filename in files:
 
60
  return urljoin(BASE_URL, f"/static/documents/{encoded_filename}")
61
 
62
  def create_download_link(url, filename):
63
+ return f'Source: [{filename}]({url})'
64
 
65
  @app.post("/ask")
66
  def ask(question: Question):
67
  result = ask_question(qa_chain, question.question)
68
 
69
+ # Process source document information
70
  sources = []
71
  for doc in result["source_documents"]:
72
  source_path = doc.metadata.get('source', 'N/A')
 
109
  result = ask_question(qa_chain, user_input)
110
  answer = result['result']
111
 
112
+ # Process source document information
113
  sources = []
114
  for doc in result["source_documents"]:
115
  source_path = doc.metadata.get('source', 'N/A')
 
125
  }
126
  sources.append(source_info)
127
 
128
+ # Output source information one line at a time
129
+ sources_md = "\nReferences Documents:\n"
130
  seen = set()
131
  for source in sources:
132
  key = (source['filename'], source['document_url'])
133
  if source['document_url'] and source['filename'] and key not in seen:
134
+ sources_md += f"Source: [{source['filename']}]({source['document_url']})\n"
135
  seen.add(key)
136
 
137
  final_answer = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip()
 
152
  "model": "rag",
153
  })
154
 
155
+ # Generator for streaming response
156
  def event_stream():
157
+ # Stream only the answer body first
158
  answer_main = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip()
159
  for char in answer_main:
160
  chunk = {
 
170
  }
171
  yield f"data: {json.dumps(chunk)}\n\n"
172
  time.sleep(0.005)
173
+ # Send reference documents (download links) all at once at the end
174
+ sources_md = "\nReferences Documents:\n"
175
  seen = set()
176
  for source in sources:
177
  key = (source['filename'], source['document_url'])
178
  if source['document_url'] and source['filename'] and key not in seen:
179
+ sources_md += f"Source: [{source['filename']}]({source['document_url']})\n"
180
  seen.add(key)
181
+ if sources_md.strip() != "References Documents:":
182
  chunk = {
183
  "id": f"chatcmpl-{uuid.uuid4()}",
184
  "object": "chat.completion.chunk",
 
203
  yield f"data: {json.dumps(done)}\n\n"
204
  return
205
 
206
+ return StreamingResponse(event_stream(), media_type="text/event-stream")