RAG-CHAT / app.py
DHEIVER's picture
Update app.py
914f0c8 verified
raw
history blame
15.5 kB
import gradio as gr
import os
import torch
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFaceEndpoint
api_token = os.getenv("HF_TOKEN")
# Available LLM models
list_llm = [
"meta-llama/Meta-Llama-3-8B-Instruct",
"mistralai/Mistral-7B-Instruct-v0.2",
"deepseek-ai/deepseek-llm-7b-chat"
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
def load_doc(list_file_path):
"""
Load and split PDF documents into chunks
"""
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024,
chunk_overlap=64
)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
def create_db(splits):
"""
Create vector database from document splits
"""
embeddings = HuggingFaceEmbeddings()
vectordb = FAISS.from_documents(splits, embeddings)
return vectordb
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
"""
Initialize the language model chain
"""
llm = HuggingFaceEndpoint(
repo_id=llm_model,
huggingfacehub_api_token=api_token,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
task="text-generation"
)
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever = vector_db.as_retriever()
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
return qa_chain
def initialize_database(list_file_obj, progress=gr.Progress()):
"""
Initialize the document database
"""
list_file_path = [x.name for x in list_file_obj if x is not None]
doc_splits = load_doc(list_file_path)
vector_db = create_db(doc_splits)
return vector_db, "Database created successfully!"
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
"""
Initialize the Language Model
"""
llm_name = list_llm[llm_option]
print("Selected LLM model:", llm_name)
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
return qa_chain, "Analysis Assistant initialized and ready!"
def format_chat_history(message, chat_history):
"""
Format chat history for the model
"""
formatted_chat_history = []
for user_message, bot_message in chat_history:
formatted_chat_history.append(f"User: {user_message}")
formatted_chat_history.append(f"Assistant: {bot_message}")
return formatted_chat_history
def conversation(qa_chain, message, history):
"""
Handle conversation and document analysis
"""
formatted_chat_history = format_chat_history(message, history)
response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"]
if response_answer.find("Helpful Answer:") != -1:
response_answer = response_answer.split("Helpful Answer:")[-1]
response_sources = response["source_documents"]
response_source1 = response_sources[0].page_content.strip()
response_source2 = response_sources[1].page_content.strip()
response_source3 = response_sources[2].page_content.strip()
response_source1_page = response_sources[0].metadata["page"] + 1
response_source2_page = response_sources[1].metadata["page"] + 1
response_source3_page = response_sources[2].metadata["page"] + 1
new_history = history + [(message, response_answer)]
return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
def demo():
"""
Main demo application
"""
# Enhanced theme with professional colors
theme = gr.themes.Default(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Roboto"), "system-ui", "sans-serif"]
)
css = """
.container { max-width: 1200px; margin: auto; }
.metadata { font-size: 0.9em; color: #666; }
.highlight { background-color: #f0f7ff; padding: 1em; border-radius: 8px; }
.warning { color: #e53e3e; }
.success { color: #38a169; }
"""
with gr.Blocks(theme=theme, css=css) as demo:
vector_db = gr.State()
qa_chain = gr.State()
# Enhanced header
gr.HTML(
"""
<div style='text-align: center; padding: 20px;'>
<h1 style='color: #1a365d; margin-bottom: 10px;'>MetroAssist AI - Expert in Metrology Report Analysis</h1>
<p style='color: #4a5568; font-size: 1.2em;'>Your intelligent assistant for advanced analysis of metrological documents</p>
</div>
"""
)
# Marketing and feature description
gr.Markdown(
"""
### 🔍 Specialized Metrology Analysis
MetroAssist AI is a specialized assistant designed to revolutionize metrology report analysis.
Powered by cutting-edge AI technology, it offers:
* **Precise Analysis**: Detailed interpretation of measurements, calibrations, and compliance
* **Intelligent Contextualization**: Deep understanding of metrological standards and norms
* **Advanced Technical Support**: Assistance in complex instrument and measurement analyses
* **Rapid Processing**: Efficient analysis of multiple technical documents
⚠️ **Security Note**: Your documents are processed with complete security. We do not permanently store confidential data.
"""
)
with gr.Row():
with gr.Column(scale=86):
gr.Markdown(
"""
### 📥 Step 1: Document Loading and Preparation
Upload your metrology reports for expert analysis.
"""
)
with gr.Row():
document = gr.Files(
height=300,
file_count="multiple",
file_types=["pdf"],
interactive=True,
label="Upload Metrology Reports (PDF)",
info="Accepts multiple PDF files"
)
with gr.Row():
db_btn = gr.Button(
"Process Documents",
variant="primary",
size="lg"
)
with gr.Row():
db_progress = gr.Textbox(
value="Waiting for documents...",
show_label=False,
container=False
)
gr.Markdown(
"""
### 🤖 Analysis Engine Configuration
Select and configure the AI model to best meet your needs.
"""
)
with gr.Row():
llm_btn = gr.Radio(
list_llm_simple,
label="Available AI Models",
value=list_llm_simple[0],
type="index",
info="Choose the most suitable model for your analysis"
)
with gr.Row():
with gr.Accordion("Advanced Analysis Parameters", open=False):
with gr.Row():
slider_temperature = gr.Slider(
minimum=0.01,
maximum=1.0,
value=0.5,
step=0.1,
label="Analysis Precision",
info="Controls the balance between precision and creativity in analysis",
interactive=True
)
with gr.Row():
slider_maxtokens = gr.Slider(
minimum=128,
maximum=9192,
value=4096,
step=128,
label="Response Length",
info="Defines the level of detail in analyses",
interactive=True
)
with gr.Row():
slider_topk = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Analysis Diversity",
info="Controls the variety of perspectives in analysis",
interactive=True
)
with gr.Row():
qachain_btn = gr.Button(
"Initialize Analysis Assistant",
variant="primary",
size="lg"
)
with gr.Row():
llm_progress = gr.Textbox(
value="Waiting for initialization...",
show_label=False
)
with gr.Column(scale=200):
gr.Markdown(
"""
### 💬 Step 2: Expert Consultation and Analysis
Ask questions about your metrology reports. The assistant will provide detailed technical analyses.
**Suggested questions:**
- Analyze the calibration results of this instrument
- Verify compliance with technical standards
- Identify critical points in measurements
- Compare results with specified limits
- Evaluate measurement uncertainty
- Assess calibration intervals
"""
)
chatbot = gr.Chatbot(
height=505,
show_label=True,
container=True,
label="Metrology Analysis"
)
with gr.Accordion("Source Document References", open=False):
with gr.Row():
doc_source1 = gr.Textbox(
label="Technical Reference 1",
lines=2,
container=True,
scale=20
)
source1_page = gr.Number(label="Page", scale=1)
with gr.Row():
doc_source2 = gr.Textbox(
label="Technical Reference 2",
lines=2,
container=True,
scale=20
)
source2_page = gr.Number(label="Page", scale=1)
with gr.Row():
doc_source3 = gr.Textbox(
label="Technical Reference 3",
lines=2,
container=True,
scale=20
)
source3_page = gr.Number(label="Page", scale=1)
with gr.Row():
msg = gr.Textbox(
placeholder="Enter your question about the metrology report...",
container=True,
label="Your Query"
)
with gr.Row():
submit_btn = gr.Button(
"Submit Query",
variant="primary"
)
clear_btn = gr.ClearButton(
[msg, chatbot],
value="Clear Conversation",
variant="secondary"
)
# Footer
gr.Markdown(
"""
---
### ℹ️ About MetroAssist AI
Developed for metrology professionals, engineers, and technicians who need precise
and reliable analysis of technical documents. Our tool uses advanced AI technology
to provide valuable insights and support decision-making in metrology.
**Specialized Features:**
- Detailed analysis of calibration certificates
- Interpretation of complex metrological data
- Verification of compliance with technical standards
- Decision support in metrological processes
- Uncertainty analysis and measurement traceability
- Quality control and measurement system analysis
*Version 1.0 - Updated 2024*
"""
)
# Event handlers
db_btn.click(
initialize_database,
inputs=[document],
outputs=[vector_db, db_progress]
)
qachain_btn.click(
initialize_LLM,
inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
outputs=[qa_chain, llm_progress]
).then(
lambda: [None, "", 0, "", 0, "", 0],
inputs=None,
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False
)
msg.submit(
conversation,
inputs=[qa_chain, msg, chatbot],
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False
)
submit_btn.click(
conversation,
inputs=[qa_chain, msg, chatbot],
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False
)
clear_btn.click(
lambda: [None, "", 0, "", 0, "", 0],
inputs=None,
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False
)
demo.queue().launch(debug=True)
if __name__ == "__main__":
demo()