import streamlit as st import torch import clip from PIL import Image import glob import os import numpy as np import torch.nn.functional as F from haystack import Pipeline from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.preprocessors import DocumentSplitter from haystack.components.writers import DocumentWriter from haystack.components.converters import PyPDFToDocument from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever from haystack.components.joiners import DocumentJoiner from haystack.components.rankers import TransformersSimilarityRanker from haystack.components.builders import PromptBuilder from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator # Initialize Streamlit session state if "messages" not in st.session_state: st.session_state.messages = [] if "document_store" not in st.session_state: st.session_state.document_store = InMemoryDocumentStore() st.session_state.pipeline_initialized = False # CLIP Model initialization device = "cuda" if torch.cuda.is_available() else "cpu" IMAGE_DIR = "./new_data" @st.cache_resource def load_clip_model(): return clip.load("ViT-L/14", device=device) model, preprocess = load_clip_model() @st.cache_data def load_images(): images = [] if os.path.exists(IMAGE_DIR): image_files = [f for f in os.listdir(IMAGE_DIR) if f.endswith(('png', 'jpg', 'jpeg'))] for image_file in image_files: image_path = os.path.join(IMAGE_DIR, image_file) image = Image.open(image_path).convert("RGB") images.append((image_file, image)) return images @st.cache_data def encode_images(images): image_features = [] for image_file, image in images: image_input = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): image_feature = model.encode_image(image_input) image_feature = F.normalize(image_feature, dim=-1) image_features.append((image_file, image_feature)) return image_features def search_images_by_text(text_query, top_k=5): text_inputs = clip.tokenize([text_query]).to(device) with torch.no_grad(): text_features = model.encode_text(text_inputs) text_features = F.normalize(text_features, dim=-1) similarities = [] for image_file, image_feature in image_features: similarity = torch.cosine_similarity(text_features, image_feature).item() similarities.append((image_file, similarity)) similarities.sort(key=lambda x: x[1], reverse=True) return similarities[:top_k] def search_images_by_image(query_image, top_k=5): query_image = preprocess(query_image).unsqueeze(0).to(device) with torch.no_grad(): query_image_feature = model.encode_image(query_image) query_image_feature = F.normalize(query_image_feature, dim=-1) similarities = [] for image_file, image_feature in image_features: similarity = torch.cosine_similarity(query_image_feature, image_feature).item() similarities.append((image_file, similarity)) similarities.sort(key=lambda x: x[1], reverse=True) return similarities[:top_k] # Custom CSS st.markdown(""" """, unsafe_allow_html=True) # Main App st.markdown('

Multi-Model Search & QA System

', unsafe_allow_html=True) # Sidebar for app selection and setup with st.sidebar: st.header("Application Settings") app_mode = st.radio("Select Application Mode:", ["Document Q&A", "Image Search"]) if app_mode == "Document Q&A": st.header("Document Setup") uploaded_file = st.file_uploader("Upload PDF Document", type=['pdf']) if uploaded_file and not st.session_state.pipeline_initialized: with open("temp.pdf", "wb") as f: f.write(uploaded_file.getvalue()) # Initialize components document_embedder = SentenceTransformersDocumentEmbedder(model="BAAI/bge-small-en-v1.5") # Create indexing pipeline indexing_pipeline = Pipeline() indexing_pipeline.add_component("converter", PyPDFToDocument()) indexing_pipeline.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=2)) indexing_pipeline.add_component("embedder", document_embedder) indexing_pipeline.add_component("writer", DocumentWriter(st.session_state.document_store)) indexing_pipeline.connect("converter", "splitter") indexing_pipeline.connect("splitter", "embedder") indexing_pipeline.connect("embedder", "writer") text_embedder2 = SentenceTransformersTextEmbedder(model="BAAI/bge-small-en-v1.5") embedding_retriever2 = InMemoryEmbeddingRetriever(st.session_state.document_store) bm25_retriever2 = InMemoryBM25Retriever(st.session_state.document_store) document_joiner2 = DocumentJoiner() ranker2 = TransformersSimilarityRanker(model="BAAI/bge-reranker-base") with st.spinner("Processing document..."): try: indexing_pipeline.run({"converter": {"sources": ["temp.pdf"]}}) st.success(f"Processed {st.session_state.document_store.count_documents()} document chunks") st.session_state.pipeline_initialized = True # Initialize retrieval components text_embedder = SentenceTransformersTextEmbedder(model="BAAI/bge-small-en-v1.5") embedding_retriever = InMemoryEmbeddingRetriever(st.session_state.document_store) bm25_retriever = InMemoryBM25Retriever(st.session_state.document_store) document_joiner = DocumentJoiner() ranker = TransformersSimilarityRanker(model="BAAI/bge-reranker-base") template = """ act as a senior customer care executive and help users sorting out their queries. Be polite and friendly. Answer the user's questions based on the below context only dont try to make up any answer make sure that create a good version of all the documents that u recived and make the answer complining to the question make user the you sound exactly same as the documents delow.: CONTEXT: {% for document in documents %} {{ document.content }} {% endfor %} Make sure to provide all the details. If the answer is not in the provided context just say, 'answer is not available in the context'. Don't provide the wrong answer. If the person asks any external recommendation just say 'sorry i can't help you with that'. Question: {{question}} explain in detail """ prompt_builder = PromptBuilder(template=template) if "GOOGLE_API_KEY" not in os.environ: os.environ["GOOGLE_API_KEY"] = 'AIzaSyDNIiOX5-Z1YFxZcaHFIEQr0DcXNvRelqI' generator = GoogleAIGeminiGenerator(model="gemini-pro") # Create retrieval pipeline st.session_state.retrieval_pipeline = Pipeline() st.session_state.retrieval_pipeline.add_component("text_embedder", text_embedder) st.session_state.retrieval_pipeline.add_component("embedding_retriever", embedding_retriever) st.session_state.retrieval_pipeline.add_component("bm25_retriever", bm25_retriever) st.session_state.retrieval_pipeline.add_component("document_joiner", document_joiner) st.session_state.retrieval_pipeline.add_component("ranker", ranker) st.session_state.retrieval_pipeline.add_component("prompt_builder", prompt_builder) st.session_state.retrieval_pipeline.add_component("llm", generator) # Connect pipeline components st.session_state.retrieval_pipeline.connect("text_embedder", "embedding_retriever") st.session_state.retrieval_pipeline.connect("bm25_retriever", "document_joiner") st.session_state.retrieval_pipeline.connect("embedding_retriever", "document_joiner") st.session_state.retrieval_pipeline.connect("document_joiner", "ranker") st.session_state.retrieval_pipeline.connect("ranker", "prompt_builder.documents") st.session_state.retrieval_pipeline.connect("prompt_builder", "llm") # Ranker pipeline st.session_state.hybrid_retrieval2 = Pipeline() st.session_state.hybrid_retrieval2.add_component("text_embedder", text_embedder2) st.session_state.hybrid_retrieval2.add_component("embedding_retriever", embedding_retriever2) st.session_state.hybrid_retrieval2.add_component("bm25_retriever", bm25_retriever2) st.session_state.hybrid_retrieval2.add_component("document_joiner", document_joiner2) st.session_state.hybrid_retrieval2.add_component("ranker", ranker2) st.session_state.hybrid_retrieval2.connect("text_embedder", "embedding_retriever") st.session_state.hybrid_retrieval2.connect("bm25_retriever", "document_joiner") st.session_state.hybrid_retrieval2.connect("embedding_retriever", "document_joiner") st.session_state.hybrid_retrieval2.connect("document_joiner", "ranker") except Exception as e: st.error(f"Error processing document: {str(e)}") finally: if os.path.exists("temp.pdf"): os.remove("temp.pdf") # Main content area if app_mode == "Document Q&A": st.markdown('

Document Q&A System

', unsafe_allow_html=True) # Display chat messages for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Chat input if prompt := st.chat_input("Ask a question about your document"): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) if st.session_state.pipeline_initialized: with st.chat_message("assistant"): with st.spinner("Thinking..."): try: result = st.session_state.retrieval_pipeline.run( { "text_embedder": {"text": prompt}, "bm25_retriever": {"query": prompt}, "ranker": {"query": prompt}, "prompt_builder": {"question": prompt} } ) result2 = st.session_state.hybrid_retrieval2.run( { "text_embedder": {"text": prompt}, "bm25_retriever": {"query": prompt}, "ranker": {"query": prompt} } ) l = [] for i in result2['ranker']['documents']: if i.meta['file_path'] in l: pass else: l.append(i.meta['file_path']) l.append(i.meta['page_number']) response = result['llm']['replies'][0] response = f"{response} \n\nsource: {l} " st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response}) except Exception as e: error_message = f"Error generating response: {str(e)}" st.error(error_message) st.session_state.messages.append({"role": "assistant", "content": error_message}) else: with st.chat_message("assistant"): message = "Please upload a document first to start the conversation." st.warning(message) st.session_state.messages.append({"role": "assistant", "content": message}) else: # Image Search mode st.markdown('

Image Search System

', unsafe_allow_html=True) # Load and encode images images = load_images() image_features = encode_images(images) search_type = st.radio("Select Search Type:", ["Text-to-Image", "Image-to-Image"]) if search_type == "Text-to-Image": query = st.text_input("Enter a text description to find similar images:") if query: results = search_images_by_text(query) st.write(f"Top results for query: **{query}**") cols = st.columns(3) for idx, (image_file, score) in enumerate(results): with cols[idx % 3]: st.markdown(f'
', unsafe_allow_html=True) image_path = os.path.join(IMAGE_DIR, image_file) image = Image.open(image_path) st.image(image, caption=image_file) st.markdown(f'Score: {score:.4f}', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) else: # Image-to-Image search uploaded_image = st.file_uploader("Upload an image to find similar images:", type=["png", "jpg", "jpeg"]) if uploaded_image is not None: query_image = Image.open(uploaded_image).convert("RGB") st.image(query_image, caption="Query Image", use_column_width=True) # Search and display results results = search_images_by_image(query_image) st.write("Top results for the uploaded image:") cols = st.columns(3) for idx, (image_file, score) in enumerate(results): with cols[idx % 3]: st.markdown(f'
', unsafe_allow_html=True) image_path = os.path.join(IMAGE_DIR, image_file) image = Image.open(image_path) st.image(image, caption=image_file) st.markdown(f'Score: {score:.4f}', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) if __name__ == "__main__": # Create the image directory if it doesn't exist if not os.path.exists(IMAGE_DIR): os.makedirs(IMAGE_DIR)