abdull4h commited on
Commit
9df1e5f
·
verified ·
1 Parent(s): ae5e187

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -92
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import time
3
  import logging
@@ -6,16 +7,14 @@ import re
6
  from datetime import datetime
7
  import numpy as np
8
  import pandas as pd
9
- import matplotlib.pyplot as plt
10
  from sentence_transformers import SentenceTransformer, util
11
  import faiss
12
  import torch
13
- import spaces
14
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
15
  import PyPDF2
16
  import io
17
 
18
- # Configure logging for debugging and monitoring
19
  logging.basicConfig(
20
  level=logging.INFO,
21
  format='%(asctime)s - %(levelname)s - %(message)s',
@@ -27,36 +26,23 @@ logger = logging.getLogger('Vision2030Assistant')
27
  has_gpu = torch.cuda.is_available()
28
  logger.info(f"GPU available: {has_gpu}")
29
 
 
30
  class Vision2030Assistant:
31
  def __init__(self):
32
- """Initialize the assistant with enhanced features"""
33
  logger.info("Initializing Vision 2030 Assistant...")
34
-
35
- # Load models with error handling
36
  self.load_embedding_models()
37
  self.load_language_model()
38
-
39
- # Initialize knowledge base and indices
40
  self._create_knowledge_base()
41
  self._create_indices()
42
-
43
- # Sample evaluation data
44
  self._create_sample_eval_data()
45
-
46
- # Metrics storage
47
  self.metrics = {"response_times": [], "user_ratings": [], "factual_accuracy": []}
48
-
49
- # Session management
50
- self.session_history = {}
51
-
52
- # PDF content flag
53
- self.has_pdf_content = False
54
-
55
  logger.info("Assistant initialized successfully")
56
 
57
- @spaces.GPU
58
  def load_embedding_models(self):
59
- """Load embedding models with fallback"""
60
  try:
61
  self.arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca')
62
  self.english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
@@ -69,37 +55,37 @@ class Vision2030Assistant:
69
  self._fallback_embedding()
70
 
71
  def _fallback_embedding(self):
72
- """Fallback to simple embedding if model loading fails"""
73
  logger.warning("Using fallback embedding method")
74
- def simple_embed(text):
75
- import hashlib
76
- hash_obj = hashlib.md5(text.encode())
77
- np.random.seed(int(hash_obj.hexdigest(), 16) % 2**32)
78
- return np.random.randn(384).astype(np.float32)
79
-
80
  class SimpleEmbedder:
81
  def encode(self, text):
82
- return simple_embed(text)
83
-
 
 
84
  self.arabic_embedder = SimpleEmbedder()
85
  self.english_embedder = SimpleEmbedder()
86
 
87
- @spaces.GPU
88
  def load_language_model(self):
89
- """Load language model for advanced response generation"""
90
  try:
91
  self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
92
  self.model = AutoModelForCausalLM.from_pretrained("distilgpt2")
93
  if has_gpu:
94
  self.model = self.model.to('cuda')
95
- self.generator = pipeline('text-generation', model=self.model, tokenizer=self.tokenizer, device=0 if has_gpu else -1)
 
 
 
 
 
96
  logger.info("Language model loaded successfully")
97
  except Exception as e:
98
  logger.error(f"Failed to load language model: {e}")
99
  self.generator = None
100
 
101
  def _create_knowledge_base(self):
102
- """Create initial knowledge base"""
103
  self.english_texts = [
104
  "Vision 2030 is Saudi Arabia's strategic framework to reduce dependence on oil, diversify the economy, and develop public sectors.",
105
  "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation.",
@@ -113,11 +99,10 @@ class Vision2030Assistant:
113
  self.pdf_english_texts = []
114
  self.pdf_arabic_texts = []
115
 
116
- @spaces.GPU
117
  def _create_indices(self):
118
- """Create scalable FAISS indices"""
119
  try:
120
- # English index with IVF for scalability
121
  english_vectors = [self.english_embedder.encode(text) for text in self.english_texts]
122
  dim = len(english_vectors[0])
123
  nlist = max(1, len(english_vectors) // 10)
@@ -125,44 +110,50 @@ class Vision2030Assistant:
125
  self.english_index = faiss.IndexIVFFlat(quantizer, dim, nlist)
126
  self.english_index.train(np.array(english_vectors))
127
  self.english_index.add(np.array(english_vectors))
128
-
129
  # Arabic index
130
  arabic_vectors = [self.arabic_embedder.encode(text) for text in self.arabic_texts]
131
  self.arabic_index = faiss.IndexIVFFlat(quantizer, dim, nlist)
132
  self.arabic_index.train(np.array(arabic_vectors))
133
  self.arabic_index.add(np.array(arabic_vectors))
134
-
135
  logger.info("FAISS indices created successfully")
136
  except Exception as e:
137
  logger.error(f"Error creating indices: {e}")
138
 
139
  def _create_sample_eval_data(self):
140
- """Sample evaluation data"""
141
  self.eval_data = [
142
- {"question": "What are the key pillars of Vision 2030?", "lang": "en", "reference": "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation."},
143
- {"question": "ما هي الركائز الرئيسية لرؤية 2030؟", "lang": "ar", "reference": "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح."}
 
 
 
 
144
  ]
145
 
146
- @spaces.GPU
147
  def retrieve_context(self, query, lang, session_id):
148
- """Retrieve context with session history integration"""
149
  try:
150
- # Incorporate session history
151
  history = self.session_history.get(session_id, [])
152
- history_context = " ".join([f"Q: {q} A: {a}" for q, a in history[-2:]]) # Last 2 interactions
153
-
154
- # Embed query
155
  embedder = self.arabic_embedder if lang == "ar" else self.english_embedder
156
  query_vec = embedder.encode(query)
157
-
158
- # Search appropriate index
159
- index = self.pdf_arabic_index if (lang == "ar" and self.has_pdf_content) else \
160
- self.pdf_english_index if (lang == "en" and self.has_pdf_content) else \
161
- self.arabic_index if lang == "ar" else self.english_index
162
- texts = self.pdf_arabic_texts if (lang == "ar" and self.has_pdf_content) else \
163
- self.pdf_english_texts if (lang == "en" and self.has_pdf_content) else \
164
- self.arabic_texts if lang == "ar" else self.english_texts
165
-
 
 
 
 
 
 
 
166
  D, I = index.search(np.array([query_vec]), k=2)
167
  context = "\n".join([texts[i] for i in I[0] if i >= 0]) + f"\nHistory: {history_context}"
168
  return context if context.strip() else "No relevant information found."
@@ -170,9 +161,8 @@ class Vision2030Assistant:
170
  logger.error(f"Retrieval error: {e}")
171
  return "Error retrieving context."
172
 
173
- @spaces.GPU
174
  def generate_response(self, query, session_id):
175
- """Generate advanced responses with error handling"""
176
  if not query.strip():
177
  return "Please enter a valid question."
178
 
@@ -188,9 +178,8 @@ class Vision2030Assistant:
188
  response = self.generator(prompt, max_length=150, num_return_sequences=1, do_sample=True, temperature=0.7)
189
  reply = response[0]['generated_text'].split("Answer:")[-1].strip()
190
  else:
191
- reply = context # Fallback
192
 
193
- # Update session history
194
  self.session_history.setdefault(session_id, []).append((query, reply))
195
  self.metrics["response_times"].append(time.time() - start_time)
196
  return reply
@@ -199,9 +188,9 @@ class Vision2030Assistant:
199
  return "Sorry, an error occurred. Please try again."
200
 
201
  def evaluate_factual_accuracy(self, response, reference):
202
- """Evaluate using semantic similarity"""
203
  try:
204
- embedder = self.english_embedder # Assuming reference is in English; extend for Arabic if needed
205
  response_vec = embedder.encode(response)
206
  reference_vec = embedder.encode(reference)
207
  similarity = util.cos_sim(response_vec, reference_vec).item()
@@ -210,9 +199,8 @@ class Vision2030Assistant:
210
  logger.error(f"Evaluation error: {e}")
211
  return 0.0
212
 
213
- @spaces.GPU
214
  def process_pdf(self, file):
215
- """Process PDF with scalability and error handling"""
216
  if not file:
217
  return "Please upload a PDF file."
218
 
@@ -222,56 +210,61 @@ class Vision2030Assistant:
222
  if not text.strip():
223
  return "No extractable text found in PDF."
224
 
225
- # Chunk text for scalability
226
  chunks = [text[i:i+300] for i in range(0, len(text), 300)]
227
  self.pdf_english_texts = [c for c in chunks if not any('\u0600' <= char <= '\u06FF' for char in c)]
228
  self.pdf_arabic_texts = [c for c in chunks if any('\u0600' <= char <= '\u06FF' for char in c)]
229
-
230
- # Batch process embeddings
231
- batch_size = 32
232
- for lang, texts, embedder in [("en", self.pdf_english_texts, self.english_embedder),
233
- ("ar", self.pdf_arabic_texts, self.arabic_embedder)]:
234
- if texts:
235
- vectors = []
236
- for i in range(0, len(texts), batch_size):
237
- batch = texts[i:i+batch_size]
238
- vectors.extend(embedder.encode(batch))
239
- dim = len(vectors[0])
240
- nlist = max(1, len(vectors) // 10)
241
- quantizer = faiss.IndexFlatL2(dim)
242
- index = faiss.IndexIVFFlat(quantizer, dim, nlist)
243
- index.train(np.array(vectors))
244
- index.add(np.array(vectors))
245
- setattr(self, f"pdf_{lang}_index", index)
246
-
 
 
247
  self.has_pdf_content = True
248
  return f"PDF processed: {len(self.pdf_english_texts)} English, {len(self.pdf_arabic_texts)} Arabic chunks."
249
  except Exception as e:
250
  logger.error(f"PDF processing error: {e}")
251
  return f"Error processing PDF: {e}"
252
 
253
- # Gradio Interface
254
  def create_interface():
 
255
  assistant = Vision2030Assistant()
256
-
257
  def chat(query, history, session_id):
258
  reply = assistant.generate_response(query, session_id)
259
  history.append((query, reply))
260
  return history, ""
261
-
262
  with gr.Blocks() as demo:
263
  gr.Markdown("# Vision 2030 Virtual Assistant")
264
- session_id = gr.State(value="user1") # Simple session ID; enhance with authentication
265
  chatbot = gr.Chatbot()
266
  msg = gr.Textbox(label="Ask a question")
267
  submit = gr.Button("Submit")
268
  pdf_upload = gr.File(label="Upload PDF", type="binary")
269
  upload_status = gr.Textbox(label="Upload Status")
270
-
271
  submit.click(chat, [msg, chatbot, session_id], [chatbot, msg])
272
  pdf_upload.upload(assistant.process_pdf, pdf_upload, upload_status)
273
-
274
  return demo
275
 
276
- demo = create_interface()
277
- demo.launch()
 
 
 
1
+ # Import necessary libraries
2
  import gradio as gr
3
  import time
4
  import logging
 
7
  from datetime import datetime
8
  import numpy as np
9
  import pandas as pd
 
10
  from sentence_transformers import SentenceTransformer, util
11
  import faiss
12
  import torch
 
13
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
  import PyPDF2
15
  import io
16
 
17
+ # Set up logging
18
  logging.basicConfig(
19
  level=logging.INFO,
20
  format='%(asctime)s - %(levelname)s - %(message)s',
 
26
  has_gpu = torch.cuda.is_available()
27
  logger.info(f"GPU available: {has_gpu}")
28
 
29
+ # Define the Vision2030Assistant class
30
  class Vision2030Assistant:
31
  def __init__(self):
32
+ """Initialize the Vision 2030 Assistant with models, knowledge base, and indices."""
33
  logger.info("Initializing Vision 2030 Assistant...")
 
 
34
  self.load_embedding_models()
35
  self.load_language_model()
 
 
36
  self._create_knowledge_base()
37
  self._create_indices()
 
 
38
  self._create_sample_eval_data()
 
 
39
  self.metrics = {"response_times": [], "user_ratings": [], "factual_accuracy": []}
40
+ self.session_history = {} # Dictionary to store session history
41
+ self.has_pdf_content = False # Flag to indicate if PDF content is available
 
 
 
 
 
42
  logger.info("Assistant initialized successfully")
43
 
 
44
  def load_embedding_models(self):
45
+ """Load Arabic and English embedding models with fallback mechanism."""
46
  try:
47
  self.arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca')
48
  self.english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
55
  self._fallback_embedding()
56
 
57
  def _fallback_embedding(self):
58
+ """Fallback method for embedding models using a simple random vector approach."""
59
  logger.warning("Using fallback embedding method")
 
 
 
 
 
 
60
  class SimpleEmbedder:
61
  def encode(self, text):
62
+ import hashlib
63
+ hash_obj = hashlib.md5(text.encode())
64
+ np.random.seed(int(hash_obj.hexdigest(), 16) % 2**32)
65
+ return np.random.randn(384).astype(np.float32)
66
  self.arabic_embedder = SimpleEmbedder()
67
  self.english_embedder = SimpleEmbedder()
68
 
 
69
  def load_language_model(self):
70
+ """Load the DistilGPT-2 language model for response generation."""
71
  try:
72
  self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
73
  self.model = AutoModelForCausalLM.from_pretrained("distilgpt2")
74
  if has_gpu:
75
  self.model = self.model.to('cuda')
76
+ self.generator = pipeline(
77
+ 'text-generation',
78
+ model=self.model,
79
+ tokenizer=self.tokenizer,
80
+ device=0 if has_gpu else -1
81
+ )
82
  logger.info("Language model loaded successfully")
83
  except Exception as e:
84
  logger.error(f"Failed to load language model: {e}")
85
  self.generator = None
86
 
87
  def _create_knowledge_base(self):
88
+ """Initialize the knowledge base with basic Vision 2030 information."""
89
  self.english_texts = [
90
  "Vision 2030 is Saudi Arabia's strategic framework to reduce dependence on oil, diversify the economy, and develop public sectors.",
91
  "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation.",
 
99
  self.pdf_english_texts = []
100
  self.pdf_arabic_texts = []
101
 
 
102
  def _create_indices(self):
103
+ """Create FAISS indices for the initial knowledge base."""
104
  try:
105
+ # English index
106
  english_vectors = [self.english_embedder.encode(text) for text in self.english_texts]
107
  dim = len(english_vectors[0])
108
  nlist = max(1, len(english_vectors) // 10)
 
110
  self.english_index = faiss.IndexIVFFlat(quantizer, dim, nlist)
111
  self.english_index.train(np.array(english_vectors))
112
  self.english_index.add(np.array(english_vectors))
113
+
114
  # Arabic index
115
  arabic_vectors = [self.arabic_embedder.encode(text) for text in self.arabic_texts]
116
  self.arabic_index = faiss.IndexIVFFlat(quantizer, dim, nlist)
117
  self.arabic_index.train(np.array(arabic_vectors))
118
  self.arabic_index.add(np.array(arabic_vectors))
 
119
  logger.info("FAISS indices created successfully")
120
  except Exception as e:
121
  logger.error(f"Error creating indices: {e}")
122
 
123
  def _create_sample_eval_data(self):
124
+ """Create sample evaluation data for testing factual accuracy."""
125
  self.eval_data = [
126
+ {"question": "What are the key pillars of Vision 2030?",
127
+ "lang": "en",
128
+ "reference": "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation."},
129
+ {"question": "ما هي الركائز الرئيسية لرؤية 2030؟",
130
+ "lang": "ar",
131
+ "reference": "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح."}
132
  ]
133
 
 
134
  def retrieve_context(self, query, lang, session_id):
135
+ """Retrieve relevant context based on the query and session history."""
136
  try:
 
137
  history = self.session_history.get(session_id, [])
138
+ history_context = " ".join([f"Q: {q} A: {a}" for q, a in history[-2:]])
 
 
139
  embedder = self.arabic_embedder if lang == "ar" else self.english_embedder
140
  query_vec = embedder.encode(query)
141
+
142
+ if lang == "ar":
143
+ if self.has_pdf_content and self.pdf_arabic_texts:
144
+ index = self.pdf_arabic_index
145
+ texts = self.pdf_arabic_texts
146
+ else:
147
+ index = self.arabic_index
148
+ texts = self.arabic_texts
149
+ else:
150
+ if self.has_pdf_content and self.pdf_english_texts:
151
+ index = self.pdf_english_index
152
+ texts = self.pdf_english_texts
153
+ else:
154
+ index = self.english_index
155
+ texts = self.english_texts
156
+
157
  D, I = index.search(np.array([query_vec]), k=2)
158
  context = "\n".join([texts[i] for i in I[0] if i >= 0]) + f"\nHistory: {history_context}"
159
  return context if context.strip() else "No relevant information found."
 
161
  logger.error(f"Retrieval error: {e}")
162
  return "Error retrieving context."
163
 
 
164
  def generate_response(self, query, session_id):
165
+ """Generate a response to the user's query using context and session history."""
166
  if not query.strip():
167
  return "Please enter a valid question."
168
 
 
178
  response = self.generator(prompt, max_length=150, num_return_sequences=1, do_sample=True, temperature=0.7)
179
  reply = response[0]['generated_text'].split("Answer:")[-1].strip()
180
  else:
181
+ reply = context
182
 
 
183
  self.session_history.setdefault(session_id, []).append((query, reply))
184
  self.metrics["response_times"].append(time.time() - start_time)
185
  return reply
 
188
  return "Sorry, an error occurred. Please try again."
189
 
190
  def evaluate_factual_accuracy(self, response, reference):
191
+ """Evaluate the factual accuracy of a response using semantic similarity."""
192
  try:
193
+ embedder = self.english_embedder # Assuming reference is in English for simplicity
194
  response_vec = embedder.encode(response)
195
  reference_vec = embedder.encode(reference)
196
  similarity = util.cos_sim(response_vec, reference_vec).item()
 
199
  logger.error(f"Evaluation error: {e}")
200
  return 0.0
201
 
 
202
  def process_pdf(self, file):
203
+ """Process an uploaded PDF file and update the knowledge base."""
204
  if not file:
205
  return "Please upload a PDF file."
206
 
 
210
  if not text.strip():
211
  return "No extractable text found in PDF."
212
 
213
+ # Split text into chunks
214
  chunks = [text[i:i+300] for i in range(0, len(text), 300)]
215
  self.pdf_english_texts = [c for c in chunks if not any('\u0600' <= char <= '\u06FF' for char in c)]
216
  self.pdf_arabic_texts = [c for c in chunks if any('\u0600' <= char <= '\u06FF' for char in c)]
217
+
218
+ # Create indices for PDF content
219
+ if self.pdf_english_texts:
220
+ english_vectors = [self.english_embedder.encode(text) for text in self.pdf_english_texts]
221
+ dim = len(english_vectors[0])
222
+ nlist = max(1, len(english_vectors) // 10)
223
+ quantizer = faiss.IndexFlatL2(dim)
224
+ self.pdf_english_index = faiss.IndexIVFFlat(quantizer, dim, nlist)
225
+ self.pdf_english_index.train(np.array(english_vectors))
226
+ self.pdf_english_index.add(np.array(english_vectors))
227
+
228
+ if self.pdf_arabic_texts:
229
+ arabic_vectors = [self.arabic_embedder.encode(text) for text in self.pdf_arabic_texts]
230
+ dim = len(arabic_vectors[0])
231
+ nlist = max(1, len(arabic_vectors) // 10)
232
+ quantizer = faiss.IndexFlatL2(dim)
233
+ self.pdf_arabic_index = faiss.IndexIVFFlat(quantizer, dim, nlist)
234
+ self.pdf_arabic_index.train(np.array(arabic_vectors))
235
+ self.pdf_arabic_index.add(np.array(arabic_vectors))
236
+
237
  self.has_pdf_content = True
238
  return f"PDF processed: {len(self.pdf_english_texts)} English, {len(self.pdf_arabic_texts)} Arabic chunks."
239
  except Exception as e:
240
  logger.error(f"PDF processing error: {e}")
241
  return f"Error processing PDF: {e}"
242
 
243
+ # Create the Gradio interface
244
  def create_interface():
245
+ """Set up the Gradio interface for chatting and PDF uploading."""
246
  assistant = Vision2030Assistant()
247
+
248
  def chat(query, history, session_id):
249
  reply = assistant.generate_response(query, session_id)
250
  history.append((query, reply))
251
  return history, ""
252
+
253
  with gr.Blocks() as demo:
254
  gr.Markdown("# Vision 2030 Virtual Assistant")
255
+ session_id = gr.State(value="user1") # Fixed session ID for simplicity
256
  chatbot = gr.Chatbot()
257
  msg = gr.Textbox(label="Ask a question")
258
  submit = gr.Button("Submit")
259
  pdf_upload = gr.File(label="Upload PDF", type="binary")
260
  upload_status = gr.Textbox(label="Upload Status")
261
+
262
  submit.click(chat, [msg, chatbot, session_id], [chatbot, msg])
263
  pdf_upload.upload(assistant.process_pdf, pdf_upload, upload_status)
264
+
265
  return demo
266
 
267
+ # Launch the interface
268
+ if __name__ == "__main__":
269
+ demo = create_interface()
270
+ demo.launch()