gaur3009 commited on
Commit
013bd0c
Β·
verified Β·
1 Parent(s): 057b4d2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +150 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,152 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
 
2
  import streamlit as st
3
+ import PyPDF2
4
+ import torch
5
+ import weaviate
6
+ from transformers import AutoTokenizer, AutoModel
7
+ from weaviate.classes.init import Auth
8
+ import cohere
9
 
10
+ # Load credentials from environment variables or hardcoded (replace with env vars in prod)
11
+ WEAVIATE_URL = "vgwhgmrlqrqqgnlb1avjaa.c0.us-west3.gcp.weaviate.cloud"
12
+ WEAVIATE_API_KEY = "7VoeYTjkOS4aHINuhllGpH4JPgE2QquFmSMn"
13
+ COHERE_API_KEY = "LEvCVeZkqZMW1aLYjxDqlstCzWi4Cvlt9PiysqT8"
14
+
15
+ # Connect to Weaviate
16
+ client = weaviate.connect_to_weaviate_cloud(
17
+ cluster_url=WEAVIATE_URL,
18
+ auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
19
+ headers={"X-Cohere-Api-Key": COHERE_API_KEY}
20
+ )
21
+
22
+ cohere_client = cohere.Client(COHERE_API_KEY)
23
+
24
+ # Load sentence-transformer model
25
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
26
+ model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
27
+
28
+ def load_pdf(file):
29
+ """Extract text from PDF file."""
30
+ reader = PyPDF2.PdfReader(file)
31
+ text = ''.join([page.extract_text() for page in reader.pages if page.extract_text()])
32
+ return text
33
+
34
+ def get_embeddings(text):
35
+ """Generate mean pooled embedding for the input text."""
36
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
37
+ with torch.no_grad():
38
+ embeddings = model(**inputs).last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
39
+ return embeddings
40
+
41
+ def upload_document_chunks(chunks):
42
+ """Insert document chunks into Weaviate collection with embeddings."""
43
+ doc_collection = client.collections.get("Document")
44
+ for chunk in chunks:
45
+ embedding = get_embeddings(chunk)
46
+ doc_collection.data.insert(
47
+ properties={"content": chunk},
48
+ vector=embedding.tolist()
49
+ )
50
+
51
+ def query_answer(query):
52
+ """Search for top relevant document chunks based on query embedding."""
53
+ query_embedding = get_embeddings(query)
54
+ results = client.collections.get("Document").query.near_vector(
55
+ near_vector=query_embedding.tolist(),
56
+ limit=3
57
+ )
58
+ return results.objects
59
+
60
+ def generate_response(context, query):
61
+ """Generate answer using Cohere model based on context and query."""
62
+ response = cohere_client.generate(
63
+ model='command',
64
+ prompt=f"Context: {context}\n\nQuestion: {query}\nAnswer:",
65
+ max_tokens=100
66
+ )
67
+ return response.generations[0].text.strip()
68
+
69
+ def qa_pipeline(pdf_file, query):
70
+ """Main pipeline for QA: parse PDF, embed chunks, query Weaviate, and generate answer."""
71
+ document_text = load_pdf(pdf_file)
72
+ document_chunks = [document_text[i:i+500] for i in range(0, len(document_text), 500)]
73
+
74
+ upload_document_chunks(document_chunks)
75
+ top_docs = query_answer(query)
76
+
77
+ context = ' '.join([doc.properties['content'] for doc in top_docs])
78
+ answer = generate_response(context, query)
79
+
80
+ return context, answer
81
+
82
+ # Streamlit UI
83
+ st.set_page_config(page_title="Interactive QA Bot", layout="wide")
84
+
85
+ st.markdown(
86
+ """
87
+ <div style="text-align: center; font-size: 28px; font-weight: bold; margin-bottom: 20px; color: #2D3748;">
88
+ πŸ“„ Interactive QA Bot πŸ”
89
+ </div>
90
+ <p style="text-align: center; font-size: 16px; color: #4A5568;">
91
+ Upload a PDF document, ask questions, and receive answers based on the document content.
92
+ </p>
93
+ <hr style="border: 1px solid #CBD5E0; margin: 20px 0;">
94
+ """, unsafe_allow_html=True
95
+ )
96
+
97
+ col1, col2 = st.columns([1, 2])
98
+
99
+ with col1:
100
+ pdf_file = st.file_uploader("πŸ“ Upload PDF", type=["pdf"])
101
+ query = st.text_input("❓ Ask a Question", placeholder="Enter your question here...")
102
+ submit = st.button("πŸ” Submit")
103
+
104
+ with col2:
105
+ doc_segments_output = st.empty()
106
+ answer_output = st.empty()
107
+
108
+ if submit:
109
+ if not pdf_file:
110
+ st.warning("Please upload a PDF file first.")
111
+ elif not query.strip():
112
+ st.warning("Please enter a question.")
113
+ else:
114
+ with st.spinner("Processing..."):
115
+ context, answer = qa_pipeline(pdf_file, query)
116
+ doc_segments_output.text_area("πŸ“œ Retrieved Document Segments", context, height=200)
117
+ answer_output.text_area("πŸ’¬ Answer", answer, height=80)
118
+
119
+ # Optional custom CSS for styling
120
+ st.markdown(
121
+ """
122
+ <style>
123
+ body {
124
+ background-color: #EDF2F7;
125
+ }
126
+ .stFileUploader > div > div > input {
127
+ background-color: #3182CE !important;
128
+ color: white !important;
129
+ padding: 8px !important;
130
+ border-radius: 5px !important;
131
+ }
132
+ button {
133
+ background-color: #3182CE !important;
134
+ color: white !important;
135
+ padding: 10px !important;
136
+ font-size: 16px !important;
137
+ border-radius: 5px !important;
138
+ cursor: pointer;
139
+ border: none !important;
140
+ }
141
+ button:hover {
142
+ background-color: #2B6CB0 !important;
143
+ }
144
+ textarea {
145
+ border: 2px solid #CBD5E0 !important;
146
+ border-radius: 8px !important;
147
+ padding: 10px !important;
148
+ background-color: #FAFAFA !important;
149
+ }
150
+ </style>
151
+ """, unsafe_allow_html=True
152
+ )