import torch # Add missing import import streamlit as st import os import tempfile from langchain_community.document_loaders import ( TextLoader, CSVLoader, UnstructuredFileLoader ) from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.retrievers import BM25Retriever from langchain.retrievers import EnsembleRetriever from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline ) # Configuration MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3" EMBEDDING_MODEL = "thenlper/gte-large" CHUNK_SIZE = 1024 CHUNK_OVERLAP = 128 MAX_NEW_TOKENS = 2048 # Initialize session state if "messages" not in st.session_state: st.session_state.messages = [] @st.cache_resource def initialize_model(): quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) # Load config first to modify RoPE params from transformers import AutoConfig config = AutoConfig.from_pretrained( MODEL_NAME, trust_remote_code=True ) # Fix RoPE scaling configuration if hasattr(config, "rope_scaling"): config.rope_scaling = { "type": config.rope_scaling.get("rope_type", "linear"), "factor": config.rope_scaling.get("factor", 8.0) } tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, config=config, quantization_config=quantization_config, device_map="auto", trust_remote_code=True ) return pipeline( "text-generation", model=model, tokenizer=tokenizer, device_map="auto", max_new_tokens=MAX_NEW_TOKENS, temperature=0.1 ) def process_uploaded_files(uploaded_files): documents = [] with tempfile.TemporaryDirectory() as temp_dir: for file in uploaded_files: temp_path = os.path.join(temp_dir, file.name) with open(temp_path, "wb") as f: f.write(file.getbuffer()) try: if file.name.endswith(".txt"): loader = TextLoader(temp_path) elif file.name.endswith(".csv"): loader = CSVLoader(temp_path) else: loader = UnstructuredFileLoader(temp_path) documents.extend(loader.load()) except Exception as e: st.error(f"Error loading {file.name}: {str(e)}") text_splitter = RecursiveCharacterTextSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len ) return text_splitter.split_documents(documents) def create_retriever(documents): embeddings = HuggingFaceEmbeddings( model_name=EMBEDDING_MODEL, model_kwargs={'device': 'cuda'}, encode_kwargs={'normalize_embeddings': True} ) bm25_retriever = BM25Retriever.from_documents(documents) bm25_retriever.k = st.session_state.get("top_k", 5) return EnsembleRetriever( retrievers=[bm25_retriever], weights=[0.5] ) def generate_response(query, retriever, generator): docs = retriever.get_relevant_documents(query) context = "\n\n".join( f"[Doc{i+1}] {doc.page_content}\nSource: {doc.metadata.get('source', 'unknown')}" for i, doc in enumerate(docs) ) prompt = f"""[INST] You are a precision-focused research assistant tasked with answering queries based solely on the provided context. **Context:** {context} **Query:** {query} **Response Instructions:** - Write a detailed, coherent, and insightful article that fully addresses the query based on the provided context. - Adhere to the following principles: 1. **Define the Core Subject**: Introduce and build the discussion logically around the main topic. 2. **Establish Connections**: Highlight relationships between ideas and concepts with reasoning and examples. 3. **Elaborate on Key Points**: Provide in-depth explanations and emphasize the significance of concepts. 4. **Maintain Objectivity**: Use only the context provided, avoiding speculation or external knowledge. 5. **Ensure Structure and Clarity**: Present information sequentially for a smooth narrative flow. 6. **Engage with Content**: Explore implicit meanings, resolve doubts, and address counterpoints logically. 7. **Provide Examples and Insights**: Use examples to clarify abstract ideas and offer actionable steps if applicable. 8. **Logical Depth**: Draw inferences, explain purposes, and refute opposing ideas when necessary. - Cite sources explicitly as [Doc1], [Doc2], etc. - If uncertain, state: "I cannot determine from the provided context." Craft the response as a seamless, thorough, and authoritative explanation that naturally integrates all aspects of the query. [/INST]""" response = generator( prompt, pad_token_id=generator.tokenizer.eos_token_id, do_sample=True )[0]['generated_text'] return response.split("[/INST]")[-1].strip(), docs # def generate_response(query, retriever, generator): # docs = retriever.get_relevant_documents(query) # context = "\n\n".join( # f"[Doc{i+1}] {doc.page_content}\nSource: {doc.metadata.get('source', 'unknown')}" # for i, doc in enumerate(docs) # ) # prompt = f"""[INST] You are a precise research assistant. Use ONLY the provided context: # {context} # Question: {query} # Answer with: # 1. Direct facts from context # 2. NO speculation # 3. Cite sources like [Doc1] # 4. If unsure, say "I cannot determine this from the provided data" [/INST]""" # response = generator( # prompt, # pad_token_id=generator.tokenizer.eos_token_id, # do_sample=True # )[0]['generated_text'] # return response.split("[/INST]")[-1].strip(), docs # Streamlit UI st.title("📚 Document-Based QA Assistant") st.markdown("Upload your documents and ask questions!") # Sidebar controls with st.sidebar: st.header("Configuration") uploaded_files = st.file_uploader( "Upload documents (TXT)", type=["txt", "csv"], accept_multiple_files=True ) st.session_state.top_k = st.slider("Number of documents to retrieve", 3, 10, 5) st.markdown("---") st.markdown("Powered by Mistral-7B and LangChain") # Main chat interface for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) if "sources" in message: with st.expander("View Sources"): for i, doc in enumerate(message["sources"]): st.markdown(f"**Doc{i+1}** ({doc.metadata.get('source', 'unknown')})") st.info(doc.page_content) # Process documents if uploaded_files and "retriever" not in st.session_state: with st.spinner("Processing documents..."): documents = process_uploaded_files(uploaded_files) st.session_state.retriever = create_retriever(documents) st.session_state.generator = initialize_model() if prompt := st.chat_input("Ask a question about your documents"): # Add user message st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) # Generate response if "retriever" not in st.session_state: st.error("Please upload documents first!") st.stop() with st.spinner("Analyzing documents..."): try: response, sources = generate_response( prompt, st.session_state.retriever, st.session_state.generator ) # Add assistant response st.session_state.messages.append({ "role": "assistant", "content": response, "sources": sources }) # Display response with st.chat_message("assistant"): st.markdown(response) with st.expander("View Document Sources"): for i, doc in enumerate(sources): st.markdown(f"**Doc{i+1}** ({doc.metadata.get('source', 'unknown')})") st.info(doc.page_content) except Exception as e: st.error(f"Error generating response: {str(e)}")