Spaces:
Sleeping
Sleeping
File size: 16,105 Bytes
1d85c92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 |
import streamlit as st
import torch
import clip
from PIL import Image
import glob
import os
import numpy as np
import torch.nn.functional as F
from haystack import Pipeline
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
from haystack.components.preprocessors import DocumentSplitter
from haystack.components.writers import DocumentWriter
from haystack.components.converters import PyPDFToDocument
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.components.joiners import DocumentJoiner
from haystack.components.rankers import TransformersSimilarityRanker
from haystack.components.builders import PromptBuilder
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator
# Initialize Streamlit session state
if "messages" not in st.session_state:
st.session_state.messages = []
if "document_store" not in st.session_state:
st.session_state.document_store = InMemoryDocumentStore()
st.session_state.pipeline_initialized = False
# CLIP Model initialization
device = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_DIR = "./new_data"
@st.cache_resource
def load_clip_model():
return clip.load("ViT-L/14", device=device)
model, preprocess = load_clip_model()
@st.cache_data
def load_images():
images = []
if os.path.exists(IMAGE_DIR):
image_files = [f for f in os.listdir(IMAGE_DIR) if f.endswith(('png', 'jpg', 'jpeg'))]
for image_file in image_files:
image_path = os.path.join(IMAGE_DIR, image_file)
image = Image.open(image_path).convert("RGB")
images.append((image_file, image))
return images
@st.cache_data
def encode_images(images):
image_features = []
for image_file, image in images:
image_input = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
image_feature = model.encode_image(image_input)
image_feature = F.normalize(image_feature, dim=-1)
image_features.append((image_file, image_feature))
return image_features
def search_images_by_text(text_query, top_k=5):
text_inputs = clip.tokenize([text_query]).to(device)
with torch.no_grad():
text_features = model.encode_text(text_inputs)
text_features = F.normalize(text_features, dim=-1)
similarities = []
for image_file, image_feature in image_features:
similarity = torch.cosine_similarity(text_features, image_feature).item()
similarities.append((image_file, similarity))
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_k]
def search_images_by_image(query_image, top_k=5):
query_image = preprocess(query_image).unsqueeze(0).to(device)
with torch.no_grad():
query_image_feature = model.encode_image(query_image)
query_image_feature = F.normalize(query_image_feature, dim=-1)
similarities = []
for image_file, image_feature in image_features:
similarity = torch.cosine_similarity(query_image_feature, image_feature).item()
similarities.append((image_file, similarity))
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_k]
# Custom CSS
st.markdown("""
<style>
.title {
font-size: 40px;
color: #FF4B4B;
font-weight: bold;
text-align: center;
}
.subtitle {
font-size: 24px;
color: #FF914D;
font-weight: bold;
margin-top: 30px;
}
.result-container {
border: 1px solid #ddd;
padding: 10px;
border-radius: 10px;
text-align: center;
margin-bottom: 10px;
}
.score-badge {
color: white;
background-color: #007BFF;
padding: 5px;
border-radius: 5px;
font-weight: bold;
}
</style>
""", unsafe_allow_html=True)
# Main App
st.markdown('<h1 class="title">Multi-Model Search & QA System</h1>', unsafe_allow_html=True)
# Sidebar for app selection and setup
with st.sidebar:
st.header("Application Settings")
app_mode = st.radio("Select Application Mode:", ["Document Q&A", "Image Search"])
if app_mode == "Document Q&A":
st.header("Document Setup")
uploaded_file = st.file_uploader("Upload PDF Document", type=['pdf'])
if uploaded_file and not st.session_state.pipeline_initialized:
with open("temp.pdf", "wb") as f:
f.write(uploaded_file.getvalue())
# Initialize components
document_embedder = SentenceTransformersDocumentEmbedder(model="BAAI/bge-small-en-v1.5")
# Create indexing pipeline
indexing_pipeline = Pipeline()
indexing_pipeline.add_component("converter", PyPDFToDocument())
indexing_pipeline.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=2))
indexing_pipeline.add_component("embedder", document_embedder)
indexing_pipeline.add_component("writer", DocumentWriter(st.session_state.document_store))
indexing_pipeline.connect("converter", "splitter")
indexing_pipeline.connect("splitter", "embedder")
indexing_pipeline.connect("embedder", "writer")
text_embedder2 = SentenceTransformersTextEmbedder(model="BAAI/bge-small-en-v1.5")
embedding_retriever2 = InMemoryEmbeddingRetriever(st.session_state.document_store)
bm25_retriever2 = InMemoryBM25Retriever(st.session_state.document_store)
document_joiner2 = DocumentJoiner()
ranker2 = TransformersSimilarityRanker(model="BAAI/bge-reranker-base")
with st.spinner("Processing document..."):
try:
indexing_pipeline.run({"converter": {"sources": ["temp.pdf"]}})
st.success(f"Processed {st.session_state.document_store.count_documents()} document chunks")
st.session_state.pipeline_initialized = True
# Initialize retrieval components
text_embedder = SentenceTransformersTextEmbedder(model="BAAI/bge-small-en-v1.5")
embedding_retriever = InMemoryEmbeddingRetriever(st.session_state.document_store)
bm25_retriever = InMemoryBM25Retriever(st.session_state.document_store)
document_joiner = DocumentJoiner()
ranker = TransformersSimilarityRanker(model="BAAI/bge-reranker-base")
template = """
act as a senior customer care executive and help users sorting out their queries. Be polite and friendly. Answer the user's questions based on the below context only dont try to make up any answer make sure that create a good version of all the documents that u recived and make the answer complining to the question make user the you sound exactly same as the documents delow.:
CONTEXT:
{% for document in documents %}
{{ document.content }}
{% endfor %}
Make sure to provide all the details. If the answer is not in the provided context just say, 'answer is not available in the context'. Don't provide the wrong answer.
If the person asks any external recommendation just say 'sorry i can't help you with that'.
Question: {{question}}
explain in detail
"""
prompt_builder = PromptBuilder(template=template)
if "GOOGLE_API_KEY" not in os.environ:
os.environ["GOOGLE_API_KEY"] = 'AIzaSyDNIiOX5-Z1YFxZcaHFIEQr0DcXNvRelqI'
generator = GoogleAIGeminiGenerator(model="gemini-pro")
# Create retrieval pipeline
st.session_state.retrieval_pipeline = Pipeline()
st.session_state.retrieval_pipeline.add_component("text_embedder", text_embedder)
st.session_state.retrieval_pipeline.add_component("embedding_retriever", embedding_retriever)
st.session_state.retrieval_pipeline.add_component("bm25_retriever", bm25_retriever)
st.session_state.retrieval_pipeline.add_component("document_joiner", document_joiner)
st.session_state.retrieval_pipeline.add_component("ranker", ranker)
st.session_state.retrieval_pipeline.add_component("prompt_builder", prompt_builder)
st.session_state.retrieval_pipeline.add_component("llm", generator)
# Connect pipeline components
st.session_state.retrieval_pipeline.connect("text_embedder", "embedding_retriever")
st.session_state.retrieval_pipeline.connect("bm25_retriever", "document_joiner")
st.session_state.retrieval_pipeline.connect("embedding_retriever", "document_joiner")
st.session_state.retrieval_pipeline.connect("document_joiner", "ranker")
st.session_state.retrieval_pipeline.connect("ranker", "prompt_builder.documents")
st.session_state.retrieval_pipeline.connect("prompt_builder", "llm")
# Ranker pipeline
st.session_state.hybrid_retrieval2 = Pipeline()
st.session_state.hybrid_retrieval2.add_component("text_embedder", text_embedder2)
st.session_state.hybrid_retrieval2.add_component("embedding_retriever", embedding_retriever2)
st.session_state.hybrid_retrieval2.add_component("bm25_retriever", bm25_retriever2)
st.session_state.hybrid_retrieval2.add_component("document_joiner", document_joiner2)
st.session_state.hybrid_retrieval2.add_component("ranker", ranker2)
st.session_state.hybrid_retrieval2.connect("text_embedder", "embedding_retriever")
st.session_state.hybrid_retrieval2.connect("bm25_retriever", "document_joiner")
st.session_state.hybrid_retrieval2.connect("embedding_retriever", "document_joiner")
st.session_state.hybrid_retrieval2.connect("document_joiner", "ranker")
except Exception as e:
st.error(f"Error processing document: {str(e)}")
finally:
if os.path.exists("temp.pdf"):
os.remove("temp.pdf")
# Main content area
if app_mode == "Document Q&A":
st.markdown('<h2 class="subtitle">Document Q&A System</h2>', unsafe_allow_html=True)
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("Ask a question about your document"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
if st.session_state.pipeline_initialized:
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
try:
result = st.session_state.retrieval_pipeline.run(
{
"text_embedder": {"text": prompt},
"bm25_retriever": {"query": prompt},
"ranker": {"query": prompt},
"prompt_builder": {"question": prompt}
}
)
result2 = st.session_state.hybrid_retrieval2.run(
{
"text_embedder": {"text": prompt},
"bm25_retriever": {"query": prompt},
"ranker": {"query": prompt}
}
)
l = []
for i in result2['ranker']['documents']:
if i.meta['file_path'] in l:
pass
else:
l.append(i.meta['file_path'])
l.append(i.meta['page_number'])
response = result['llm']['replies'][0]
response = f"{response} \n\nsource: {l} "
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
except Exception as e:
error_message = f"Error generating response: {str(e)}"
st.error(error_message)
st.session_state.messages.append({"role": "assistant", "content": error_message})
else:
with st.chat_message("assistant"):
message = "Please upload a document first to start the conversation."
st.warning(message)
st.session_state.messages.append({"role": "assistant", "content": message})
else: # Image Search mode
st.markdown('<h2 class="subtitle">Image Search System</h2>', unsafe_allow_html=True)
# Load and encode images
images = load_images()
image_features = encode_images(images)
search_type = st.radio("Select Search Type:", ["Text-to-Image", "Image-to-Image"])
if search_type == "Text-to-Image":
query = st.text_input("Enter a text description to find similar images:")
if query:
results = search_images_by_text(query)
st.write(f"Top results for query: **{query}**")
cols = st.columns(3)
for idx, (image_file, score) in enumerate(results):
with cols[idx % 3]:
st.markdown(f'<div class="result-container">', unsafe_allow_html=True)
image_path = os.path.join(IMAGE_DIR, image_file)
image = Image.open(image_path)
st.image(image, caption=image_file)
st.markdown(f'<span class="score-badge">Score: {score:.4f}</span>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
else: # Image-to-Image search
uploaded_image = st.file_uploader("Upload an image to find similar images:", type=["png", "jpg", "jpeg"])
if uploaded_image is not None:
query_image = Image.open(uploaded_image).convert("RGB")
st.image(query_image, caption="Query Image", use_column_width=True)
# Search and display results
results = search_images_by_image(query_image)
st.write("Top results for the uploaded image:")
cols = st.columns(3)
for idx, (image_file, score) in enumerate(results):
with cols[idx % 3]:
st.markdown(f'<div class="result-container">', unsafe_allow_html=True)
image_path = os.path.join(IMAGE_DIR, image_file)
image = Image.open(image_path)
st.image(image, caption=image_file)
st.markdown(f'<span class="score-badge">Score: {score:.4f}</span>', unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
if __name__ == "__main__":
# Create the image directory if it doesn't exist
if not os.path.exists(IMAGE_DIR):
os.makedirs(IMAGE_DIR) |