PDF-RAG / app.py
aiqtech's picture
Update app.py
87dacef verified
raw
history blame
6.82 kB
import os
from typing import List
from chainlit.types import AskFileResponse
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
from aimakerspace.openai_utils.prompts import (
UserRolePrompt,
SystemRolePrompt,
AssistantRolePrompt,
)
from aimakerspace.openai_utils.embedding import EmbeddingModel
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.chatmodel import ChatOpenAI
import chainlit as cl
from chainlit import user_session
from chainlit.element import Text
system_template = """\
Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
system_role_prompt = SystemRolePrompt(system_template)
user_prompt_template = """\
Context:
{context}
Question:
{question}
"""
user_role_prompt = UserRolePrompt(user_prompt_template)
@cl.on_chat_start
async def init_sidebar():
# μ‚¬μ΄λ“œλ°” 헀더 κΎΈλ―ΈκΈ°
await cl.Sidebar(
cl.Text(content="πŸ“ **파일 μ—…λ‘œλ“œ μ„Ήμ…˜**", style="heading3"),
cl.FilePicker(
accept=[".pdf", ".txt"],
max_size_mb=2,
on_upload=handle_upload,
label="πŸ“€ PDF/TXT μ—…λ‘œλ“œ",
description="μ΅œλŒ€ 2MB 파일만 μ—…λ‘œλ“œ κ°€λŠ₯ν•©λ‹ˆλ‹€"
),
cl.Separator(),
cl.Text(content="πŸ” **λ¬Έμ„œ 뢄석 μƒνƒœ**", style="heading4"),
cl.ProgressRing(id="progress", visible=False),
cl.Text(id="status", content="λŒ€κΈ° 쀑...", style="caption"),
title="πŸ“š λ¬Έμ„œ 질의 μ‹œμŠ€ν…œ",
persistent=True # πŸ‘ˆ μ‚¬μ΄λ“œλ°” κ³ μ • μ„€μ •
).send()
async def handle_upload(file: AskFileResponse):
# μ§„ν–‰ μƒνƒœ μ—…λ°μ΄νŠΈ
status = user_session.get("status")
progress = user_session.get("progress")
await status.update(content=f"πŸ” {file.name} 뢄석 쀑...")
await progress.update(visible=True)
try:
# 파일 처리 둜직
texts = process_file(file)
# 벑터 DB ꡬ좕
vector_db = VectorDatabase()
vector_db = await vector_db.abuild_from_list(texts)
# μ„Έμ…˜μ— μ €μž₯
user_session.set("vector_db", vector_db)
# μƒνƒœ μ—…λ°μ΄νŠΈ
await status.update(content=f"βœ… {len(texts)}개 청크 처리 μ™„λ£Œ!")
await progress.update(visible=False)
# 파일 정보 μš”μ•½ ν‘œμ‹œ
await cl.Accordion(
title="πŸ“„ μ—…λ‘œλ“œ λ¬Έμ„œ 정보",
content=[
cl.Text(f"파일λͺ…: {file.name}"),
cl.Text(f"크기: {file.size/1024:.1f}KB"),
cl.Text(f"뢄석 μ‹œκ°„: {datetime.now().strftime('%H:%M:%S')}")
],
expanded=False
).send()
except Exception as e:
await cl.Error(
title="파일 처리 였λ₯˜",
content=f"{str(e)}"
).send()
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
self.llm = llm
self.vector_db_retriever = vector_db_retriever
async def arun_pipeline(self, user_query: str):
context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
context_prompt = ""
for context in context_list:
context_prompt += context[0] + "\n"
formatted_system_prompt = system_role_prompt.create_message()
formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
async def generate_response():
async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
yield chunk
return {"response": generate_response(), "context": context_list}
text_splitter = CharacterTextSplitter()
def process_file(file: AskFileResponse):
import tempfile
import shutil
print(f"Processing file: {file.name}")
# Create a temporary file with the correct extension
suffix = f".{file.name.split('.')[-1]}"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
# Copy the uploaded file content to the temporary file
shutil.copyfile(file.path, temp_file.name)
print(f"Created temporary file at: {temp_file.name}")
# Create appropriate loader
if file.name.lower().endswith('.pdf'):
loader = PDFLoader(temp_file.name)
else:
loader = TextFileLoader(temp_file.name)
try:
# Load and process the documents
documents = loader.load_documents()
texts = text_splitter.split_texts(documents)
return texts
finally:
# Clean up the temporary file
try:
os.unlink(temp_file.name)
except Exception as e:
print(f"Error cleaning up temporary file: {e}")
@cl.on_chat_start
async def on_chat_start():
files = None
# Wait for the user to upload a file
while files == None:
files = await cl.AskFileMessage(
content="Please upload a Text or PDF file to begin!",
accept=["text/plain", "application/pdf"],
max_size_mb=2,
timeout=180,
).send()
file = files[0]
msg = cl.Message(
content=f"Processing `{file.name}`..."
)
await msg.send()
# load the file
texts = process_file(file)
print(f"Processing {len(texts)} text chunks")
# Create a dict vector store
vector_db = VectorDatabase()
vector_db = await vector_db.abuild_from_list(texts)
chat_openai = ChatOpenAI()
# Create a chain
retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
vector_db_retriever=vector_db,
llm=chat_openai
)
# Let the user know that the system is ready
msg.content = f"Processing `{file.name}` done. You can now ask questions!"
await msg.update()
cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
@cl.on_message
async def main(message):
chain = cl.user_session.get("chain")
# 응닡 μŠ€νƒ€μΌ κ°œμ„ 
msg = cl.Message(
content="",
actions=[
cl.Action(name="source", value="πŸ“‘ μ†ŒμŠ€ 보기"),
cl.Action(name="feedback", value="πŸ’¬ ν”Όλ“œλ°± 남기기")
]
)
async for token in result["response"]:
await msg.stream_token(token, is_final=False)
# μ΅œμ’… λ©”μ‹œμ§€ ν¬λ§·νŒ…
final_content = f"""
🧠 **AI 뢄석 κ²°κ³Ό**
{msg.content}
πŸ“Œ μ°Έμ‘° λ¬Έμž₯:
{chr(10).join([f'- {ctx[0][:50]}...' for ctx in result['context']])}
"""
await msg.update(content=final_content)