|
import torch |
|
import streamlit as st |
|
import os |
|
import tempfile |
|
from langchain_community.document_loaders import ( |
|
TextLoader, |
|
CSVLoader, |
|
UnstructuredFileLoader |
|
) |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.retrievers import BM25Retriever |
|
from langchain.retrievers import EnsembleRetriever |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
BitsAndBytesConfig, |
|
pipeline |
|
) |
|
|
|
|
|
|
|
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3" |
|
EMBEDDING_MODEL = "thenlper/gte-large" |
|
CHUNK_SIZE = 1024 |
|
CHUNK_OVERLAP = 128 |
|
MAX_NEW_TOKENS = 2048 |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
@st.cache_resource |
|
def initialize_model(): |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True |
|
) |
|
|
|
|
|
from transformers import AutoConfig |
|
config = AutoConfig.from_pretrained( |
|
MODEL_NAME, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
if hasattr(config, "rope_scaling"): |
|
config.rope_scaling = { |
|
"type": config.rope_scaling.get("rope_type", "linear"), |
|
"factor": config.rope_scaling.get("factor", 8.0) |
|
} |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
config=config, |
|
quantization_config=quantization_config, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
return pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
device_map="auto", |
|
max_new_tokens=MAX_NEW_TOKENS, |
|
temperature=0.1 |
|
) |
|
|
|
def process_uploaded_files(uploaded_files): |
|
documents = [] |
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
for file in uploaded_files: |
|
temp_path = os.path.join(temp_dir, file.name) |
|
with open(temp_path, "wb") as f: |
|
f.write(file.getbuffer()) |
|
|
|
try: |
|
if file.name.endswith(".txt"): |
|
loader = TextLoader(temp_path) |
|
elif file.name.endswith(".csv"): |
|
loader = CSVLoader(temp_path) |
|
else: |
|
loader = UnstructuredFileLoader(temp_path) |
|
documents.extend(loader.load()) |
|
except Exception as e: |
|
st.error(f"Error loading {file.name}: {str(e)}") |
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=CHUNK_SIZE, |
|
chunk_overlap=CHUNK_OVERLAP, |
|
length_function=len |
|
) |
|
return text_splitter.split_documents(documents) |
|
|
|
def create_retriever(documents): |
|
embeddings = HuggingFaceEmbeddings( |
|
model_name=EMBEDDING_MODEL, |
|
model_kwargs={'device': 'cuda'}, |
|
encode_kwargs={'normalize_embeddings': True} |
|
) |
|
|
|
bm25_retriever = BM25Retriever.from_documents(documents) |
|
bm25_retriever.k = st.session_state.get("top_k", 5) |
|
|
|
return EnsembleRetriever( |
|
retrievers=[bm25_retriever], |
|
weights=[0.5] |
|
) |
|
|
|
|
|
def generate_response(query, retriever, generator): |
|
docs = retriever.get_relevant_documents(query) |
|
context = "\n\n".join( |
|
f"[Doc{i+1}] {doc.page_content}\nSource: {doc.metadata.get('source', 'unknown')}" |
|
for i, doc in enumerate(docs) |
|
) |
|
|
|
prompt = f"""<s>[INST] You are a precision-focused research assistant tasked with answering queries based solely on the provided context. |
|
|
|
**Context:** |
|
{context} |
|
|
|
**Query:** |
|
{query} |
|
|
|
**Response Instructions:** |
|
- Write a detailed, coherent, and insightful article that fully addresses the query based on the provided context. |
|
- Adhere to the following principles: |
|
1. **Define the Core Subject**: Introduce and build the discussion logically around the main topic. |
|
2. **Establish Connections**: Highlight relationships between ideas and concepts with reasoning and examples. |
|
3. **Elaborate on Key Points**: Provide in-depth explanations and emphasize the significance of concepts. |
|
4. **Maintain Objectivity**: Use only the context provided, avoiding speculation or external knowledge. |
|
5. **Ensure Structure and Clarity**: Present information sequentially for a smooth narrative flow. |
|
6. **Engage with Content**: Explore implicit meanings, resolve doubts, and address counterpoints logically. |
|
7. **Provide Examples and Insights**: Use examples to clarify abstract ideas and offer actionable steps if applicable. |
|
8. **Logical Depth**: Draw inferences, explain purposes, and refute opposing ideas when necessary. |
|
- Cite sources explicitly as [Doc1], [Doc2], etc. |
|
- If uncertain, state: "I cannot determine from the provided context." |
|
|
|
Craft the response as a seamless, thorough, and authoritative explanation that naturally integrates all aspects of the query. [/INST]""" |
|
|
|
response = generator( |
|
prompt, |
|
pad_token_id=generator.tokenizer.eos_token_id, |
|
do_sample=True |
|
)[0]['generated_text'] |
|
|
|
return response.split("[/INST]")[-1].strip(), docs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.title("π Document-Based QA Assistant") |
|
st.markdown("Upload your documents and ask questions!") |
|
|
|
|
|
with st.sidebar: |
|
st.header("Configuration") |
|
uploaded_files = st.file_uploader( |
|
"Upload documents (TXT)", |
|
type=["txt", "csv"], |
|
accept_multiple_files=True |
|
) |
|
st.session_state.top_k = st.slider("Number of documents to retrieve", 3, 10, 5) |
|
st.markdown("---") |
|
st.markdown("Powered by Mistral-7B and LangChain") |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
if "sources" in message: |
|
with st.expander("View Sources"): |
|
for i, doc in enumerate(message["sources"]): |
|
st.markdown(f"**Doc{i+1}** ({doc.metadata.get('source', 'unknown')})") |
|
st.info(doc.page_content) |
|
|
|
|
|
if uploaded_files and "retriever" not in st.session_state: |
|
with st.spinner("Processing documents..."): |
|
documents = process_uploaded_files(uploaded_files) |
|
st.session_state.retriever = create_retriever(documents) |
|
st.session_state.generator = initialize_model() |
|
|
|
if prompt := st.chat_input("Ask a question about your documents"): |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
|
|
if "retriever" not in st.session_state: |
|
st.error("Please upload documents first!") |
|
st.stop() |
|
|
|
with st.spinner("Analyzing documents..."): |
|
try: |
|
response, sources = generate_response( |
|
prompt, |
|
st.session_state.retriever, |
|
st.session_state.generator |
|
) |
|
|
|
|
|
st.session_state.messages.append({ |
|
"role": "assistant", |
|
"content": response, |
|
"sources": sources |
|
}) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
st.markdown(response) |
|
with st.expander("View Document Sources"): |
|
for i, doc in enumerate(sources): |
|
st.markdown(f"**Doc{i+1}** ({doc.metadata.get('source', 'unknown')})") |
|
st.info(doc.page_content) |
|
|
|
except Exception as e: |
|
st.error(f"Error generating response: {str(e)}") |