docQArag / app.py
ksimdeep's picture
Create app.py
7e40afd verified
import torch # Add missing import
import streamlit as st
import os
import tempfile
from langchain_community.document_loaders import (
TextLoader,
CSVLoader,
UnstructuredFileLoader
)
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
pipeline
)
# Configuration
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
EMBEDDING_MODEL = "thenlper/gte-large"
CHUNK_SIZE = 1024
CHUNK_OVERLAP = 128
MAX_NEW_TOKENS = 2048
# Initialize session state
if "messages" not in st.session_state:
st.session_state.messages = []
@st.cache_resource
def initialize_model():
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
# Load config first to modify RoPE params
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
# Fix RoPE scaling configuration
if hasattr(config, "rope_scaling"):
config.rope_scaling = {
"type": config.rope_scaling.get("rope_type", "linear"),
"factor": config.rope_scaling.get("factor", 8.0)
}
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
config=config,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True
)
return pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto",
max_new_tokens=MAX_NEW_TOKENS,
temperature=0.1
)
def process_uploaded_files(uploaded_files):
documents = []
with tempfile.TemporaryDirectory() as temp_dir:
for file in uploaded_files:
temp_path = os.path.join(temp_dir, file.name)
with open(temp_path, "wb") as f:
f.write(file.getbuffer())
try:
if file.name.endswith(".txt"):
loader = TextLoader(temp_path)
elif file.name.endswith(".csv"):
loader = CSVLoader(temp_path)
else:
loader = UnstructuredFileLoader(temp_path)
documents.extend(loader.load())
except Exception as e:
st.error(f"Error loading {file.name}: {str(e)}")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
length_function=len
)
return text_splitter.split_documents(documents)
def create_retriever(documents):
embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={'device': 'cuda'},
encode_kwargs={'normalize_embeddings': True}
)
bm25_retriever = BM25Retriever.from_documents(documents)
bm25_retriever.k = st.session_state.get("top_k", 5)
return EnsembleRetriever(
retrievers=[bm25_retriever],
weights=[0.5]
)
def generate_response(query, retriever, generator):
docs = retriever.get_relevant_documents(query)
context = "\n\n".join(
f"[Doc{i+1}] {doc.page_content}\nSource: {doc.metadata.get('source', 'unknown')}"
for i, doc in enumerate(docs)
)
prompt = f"""<s>[INST] You are a precision-focused research assistant tasked with answering queries based solely on the provided context.
**Context:**
{context}
**Query:**
{query}
**Response Instructions:**
- Write a detailed, coherent, and insightful article that fully addresses the query based on the provided context.
- Adhere to the following principles:
1. **Define the Core Subject**: Introduce and build the discussion logically around the main topic.
2. **Establish Connections**: Highlight relationships between ideas and concepts with reasoning and examples.
3. **Elaborate on Key Points**: Provide in-depth explanations and emphasize the significance of concepts.
4. **Maintain Objectivity**: Use only the context provided, avoiding speculation or external knowledge.
5. **Ensure Structure and Clarity**: Present information sequentially for a smooth narrative flow.
6. **Engage with Content**: Explore implicit meanings, resolve doubts, and address counterpoints logically.
7. **Provide Examples and Insights**: Use examples to clarify abstract ideas and offer actionable steps if applicable.
8. **Logical Depth**: Draw inferences, explain purposes, and refute opposing ideas when necessary.
- Cite sources explicitly as [Doc1], [Doc2], etc.
- If uncertain, state: "I cannot determine from the provided context."
Craft the response as a seamless, thorough, and authoritative explanation that naturally integrates all aspects of the query. [/INST]"""
response = generator(
prompt,
pad_token_id=generator.tokenizer.eos_token_id,
do_sample=True
)[0]['generated_text']
return response.split("[/INST]")[-1].strip(), docs
# def generate_response(query, retriever, generator):
# docs = retriever.get_relevant_documents(query)
# context = "\n\n".join(
# f"[Doc{i+1}] {doc.page_content}\nSource: {doc.metadata.get('source', 'unknown')}"
# for i, doc in enumerate(docs)
# )
# prompt = f"""<s>[INST] You are a precise research assistant. Use ONLY the provided context:
# {context}
# Question: {query}
# Answer with:
# 1. Direct facts from context
# 2. NO speculation
# 3. Cite sources like [Doc1]
# 4. If unsure, say "I cannot determine this from the provided data" [/INST]"""
# response = generator(
# prompt,
# pad_token_id=generator.tokenizer.eos_token_id,
# do_sample=True
# )[0]['generated_text']
# return response.split("[/INST]")[-1].strip(), docs
# Streamlit UI
st.title("πŸ“š Document-Based QA Assistant")
st.markdown("Upload your documents and ask questions!")
# Sidebar controls
with st.sidebar:
st.header("Configuration")
uploaded_files = st.file_uploader(
"Upload documents (TXT)",
type=["txt", "csv"],
accept_multiple_files=True
)
st.session_state.top_k = st.slider("Number of documents to retrieve", 3, 10, 5)
st.markdown("---")
st.markdown("Powered by Mistral-7B and LangChain")
# Main chat interface
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if "sources" in message:
with st.expander("View Sources"):
for i, doc in enumerate(message["sources"]):
st.markdown(f"**Doc{i+1}** ({doc.metadata.get('source', 'unknown')})")
st.info(doc.page_content)
# Process documents
if uploaded_files and "retriever" not in st.session_state:
with st.spinner("Processing documents..."):
documents = process_uploaded_files(uploaded_files)
st.session_state.retriever = create_retriever(documents)
st.session_state.generator = initialize_model()
if prompt := st.chat_input("Ask a question about your documents"):
# Add user message
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Generate response
if "retriever" not in st.session_state:
st.error("Please upload documents first!")
st.stop()
with st.spinner("Analyzing documents..."):
try:
response, sources = generate_response(
prompt,
st.session_state.retriever,
st.session_state.generator
)
# Add assistant response
st.session_state.messages.append({
"role": "assistant",
"content": response,
"sources": sources
})
# Display response
with st.chat_message("assistant"):
st.markdown(response)
with st.expander("View Document Sources"):
for i, doc in enumerate(sources):
st.markdown(f"**Doc{i+1}** ({doc.metadata.get('source', 'unknown')})")
st.info(doc.page_content)
except Exception as e:
st.error(f"Error generating response: {str(e)}")