|
|
|
import gradio as gr |
|
import time |
|
import logging |
|
import os |
|
import re |
|
from datetime import datetime |
|
import numpy as np |
|
import pandas as pd |
|
from sentence_transformers import SentenceTransformer, util |
|
import faiss |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
import PyPDF2 |
|
import io |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[logging.StreamHandler()] |
|
) |
|
logger = logging.getLogger('Vision2030Assistant') |
|
|
|
|
|
has_gpu = torch.cuda.is_available() |
|
logger.info(f"GPU available: {has_gpu}") |
|
|
|
|
|
class Vision2030Assistant: |
|
def __init__(self): |
|
"""Initialize the Vision 2030 Assistant with models, knowledge base, and indices.""" |
|
logger.info("Initializing Vision 2030 Assistant...") |
|
self.load_embedding_models() |
|
self.load_language_model() |
|
self._create_knowledge_base() |
|
self._create_indices() |
|
self._create_sample_eval_data() |
|
self.metrics = {"response_times": [], "user_ratings": [], "factual_accuracy": []} |
|
self.session_history = {} |
|
self.has_pdf_content = False |
|
logger.info("Assistant initialized successfully") |
|
|
|
def load_embedding_models(self): |
|
"""Load Arabic and English embedding models with fallback mechanism.""" |
|
try: |
|
self.arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca') |
|
self.english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
if has_gpu: |
|
self.arabic_embedder = self.arabic_embedder.to('cuda') |
|
self.english_embedder = self.english_embedder.to('cuda') |
|
logger.info("Embedding models loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Failed to load embedding models: {e}") |
|
self._fallback_embedding() |
|
|
|
def _fallback_embedding(self): |
|
"""Fallback method for embedding models using a simple random vector approach.""" |
|
logger.warning("Using fallback embedding method") |
|
class SimpleEmbedder: |
|
def encode(self, text): |
|
import hashlib |
|
hash_obj = hashlib.md5(text.encode()) |
|
np.random.seed(int(hash_obj.hexdigest(), 16) % 2**32) |
|
return np.random.randn(384).astype(np.float32) |
|
self.arabic_embedder = SimpleEmbedder() |
|
self.english_embedder = SimpleEmbedder() |
|
|
|
def load_language_model(self): |
|
"""Load the DistilGPT-2 language model for response generation.""" |
|
try: |
|
self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2") |
|
self.model = AutoModelForCausalLM.from_pretrained("distilgpt2") |
|
if has_gpu: |
|
self.model = self.model.to('cuda') |
|
self.generator = pipeline( |
|
'text-generation', |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
device=0 if has_gpu else -1 |
|
) |
|
logger.info("Language model loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Failed to load language model: {e}") |
|
self.generator = None |
|
|
|
def _create_knowledge_base(self): |
|
"""Initialize the knowledge base with basic Vision 2030 information.""" |
|
self.english_texts = [ |
|
"Vision 2030 is Saudi Arabia's strategic framework to reduce dependence on oil, diversify the economy, and develop public sectors.", |
|
"The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation.", |
|
"NEOM is a planned smart city in Tabuk Province, a key Vision 2030 project." |
|
] |
|
self.arabic_texts = [ |
|
"رؤية 2030 هي إطار استراتيجي لتقليل الاعتماد على النفط وتنويع الاقتصاد.", |
|
"الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح.", |
|
"نيوم مدينة ذكية مخططة في تبوك، مشروع رئيسي لرؤية 2030." |
|
] |
|
self.pdf_english_texts = [] |
|
self.pdf_arabic_texts = [] |
|
|
|
def _create_indices(self): |
|
"""Create FAISS indices for the initial knowledge base.""" |
|
try: |
|
|
|
english_vectors = [self.english_embedder.encode(text) for text in self.english_texts] |
|
dim = len(english_vectors[0]) |
|
nlist = max(1, len(english_vectors) // 10) |
|
quantizer = faiss.IndexFlatL2(dim) |
|
self.english_index = faiss.IndexIVFFlat(quantizer, dim, nlist) |
|
self.english_index.train(np.array(english_vectors)) |
|
self.english_index.add(np.array(english_vectors)) |
|
|
|
|
|
arabic_vectors = [self.arabic_embedder.encode(text) for text in self.arabic_texts] |
|
self.arabic_index = faiss.IndexIVFFlat(quantizer, dim, nlist) |
|
self.arabic_index.train(np.array(arabic_vectors)) |
|
self.arabic_index.add(np.array(arabic_vectors)) |
|
logger.info("FAISS indices created successfully") |
|
except Exception as e: |
|
logger.error(f"Error creating indices: {e}") |
|
|
|
def _create_sample_eval_data(self): |
|
"""Create sample evaluation data for testing factual accuracy.""" |
|
self.eval_data = [ |
|
{"question": "What are the key pillars of Vision 2030?", |
|
"lang": "en", |
|
"reference": "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation."}, |
|
{"question": "ما هي الركائز الرئيسية لرؤية 2030؟", |
|
"lang": "ar", |
|
"reference": "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح."} |
|
] |
|
|
|
def retrieve_context(self, query, lang, session_id): |
|
"""Retrieve relevant context based on the query and session history.""" |
|
try: |
|
history = self.session_history.get(session_id, []) |
|
history_context = " ".join([f"Q: {q} A: {a}" for q, a in history[-2:]]) |
|
embedder = self.arabic_embedder if lang == "ar" else self.english_embedder |
|
query_vec = embedder.encode(query) |
|
|
|
if lang == "ar": |
|
if self.has_pdf_content and self.pdf_arabic_texts: |
|
index = self.pdf_arabic_index |
|
texts = self.pdf_arabic_texts |
|
else: |
|
index = self.arabic_index |
|
texts = self.arabic_texts |
|
else: |
|
if self.has_pdf_content and self.pdf_english_texts: |
|
index = self.pdf_english_index |
|
texts = self.pdf_english_texts |
|
else: |
|
index = self.english_index |
|
texts = self.english_texts |
|
|
|
D, I = index.search(np.array([query_vec]), k=2) |
|
context = "\n".join([texts[i] for i in I[0] if i >= 0]) + f"\nHistory: {history_context}" |
|
return context if context.strip() else "No relevant information found." |
|
except Exception as e: |
|
logger.error(f"Retrieval error: {e}") |
|
return "Error retrieving context." |
|
|
|
def generate_response(self, query, session_id): |
|
"""Generate a response to the user's query using context and session history.""" |
|
if not query.strip(): |
|
return "Please enter a valid question." |
|
|
|
start_time = time.time() |
|
try: |
|
lang = "ar" if any('\u0600' <= c <= '\u06FF' for c in query) else "en" |
|
context = self.retrieve_context(query, lang, session_id) |
|
|
|
if "Error" in context or "No relevant" in context: |
|
reply = context |
|
elif self.generator: |
|
prompt = f"Context: {context}\nQuestion: {query}\nAnswer:" |
|
response = self.generator(prompt, max_length=150, num_return_sequences=1, do_sample=True, temperature=0.7) |
|
reply = response[0]['generated_text'].split("Answer:")[-1].strip() |
|
else: |
|
reply = context |
|
|
|
self.session_history.setdefault(session_id, []).append((query, reply)) |
|
self.metrics["response_times"].append(time.time() - start_time) |
|
return reply |
|
except Exception as e: |
|
logger.error(f"Response generation error: {e}") |
|
return "Sorry, an error occurred. Please try again." |
|
|
|
def evaluate_factual_accuracy(self, response, reference): |
|
"""Evaluate the factual accuracy of a response using semantic similarity.""" |
|
try: |
|
embedder = self.english_embedder |
|
response_vec = embedder.encode(response) |
|
reference_vec = embedder.encode(reference) |
|
similarity = util.cos_sim(response_vec, reference_vec).item() |
|
return similarity |
|
except Exception as e: |
|
logger.error(f"Evaluation error: {e}") |
|
return 0.0 |
|
|
|
def process_pdf(self, file): |
|
"""Process an uploaded PDF file and update the knowledge base.""" |
|
if not file: |
|
return "Please upload a PDF file." |
|
|
|
try: |
|
pdf_reader = PyPDF2.PdfReader(io.BytesIO(file)) |
|
text = "".join([page.extract_text() or "" for page in pdf_reader.pages]) |
|
if not text.strip(): |
|
return "No extractable text found in PDF." |
|
|
|
|
|
chunks = [text[i:i+300] for i in range(0, len(text), 300)] |
|
self.pdf_english_texts = [c for c in chunks if not any('\u0600' <= char <= '\u06FF' for char in c)] |
|
self.pdf_arabic_texts = [c for c in chunks if any('\u0600' <= char <= '\u06FF' for char in c)] |
|
|
|
|
|
if self.pdf_english_texts: |
|
english_vectors = [self.english_embedder.encode(text) for text in self.pdf_english_texts] |
|
dim = len(english_vectors[0]) |
|
nlist = max(1, len(english_vectors) // 10) |
|
quantizer = faiss.IndexFlatL2(dim) |
|
self.pdf_english_index = faiss.IndexIVFFlat(quantizer, dim, nlist) |
|
self.pdf_english_index.train(np.array(english_vectors)) |
|
self.pdf_english_index.add(np.array(english_vectors)) |
|
|
|
if self.pdf_arabic_texts: |
|
arabic_vectors = [self.arabic_embedder.encode(text) for text in self.pdf_arabic_texts] |
|
dim = len(arabic_vectors[0]) |
|
nlist = max(1, len(arabic_vectors) // 10) |
|
quantizer = faiss.IndexFlatL2(dim) |
|
self.pdf_arabic_index = faiss.IndexIVFFlat(quantizer, dim, nlist) |
|
self.pdf_arabic_index.train(np.array(arabic_vectors)) |
|
self.pdf_arabic_index.add(np.array(arabic_vectors)) |
|
|
|
self.has_pdf_content = True |
|
return f"PDF processed: {len(self.pdf_english_texts)} English, {len(self.pdf_arabic_texts)} Arabic chunks." |
|
except Exception as e: |
|
logger.error(f"PDF processing error: {e}") |
|
return f"Error processing PDF: {e}" |
|
|
|
|
|
def create_interface(): |
|
"""Set up the Gradio interface for chatting and PDF uploading.""" |
|
assistant = Vision2030Assistant() |
|
|
|
def chat(query, history, session_id): |
|
reply = assistant.generate_response(query, session_id) |
|
history.append((query, reply)) |
|
return history, "" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Vision 2030 Virtual Assistant") |
|
session_id = gr.State(value="user1") |
|
chatbot = gr.Chatbot() |
|
msg = gr.Textbox(label="Ask a question") |
|
submit = gr.Button("Submit") |
|
pdf_upload = gr.File(label="Upload PDF", type="binary") |
|
upload_status = gr.Textbox(label="Upload Status") |
|
|
|
submit.click(chat, [msg, chatbot, session_id], [chatbot, msg]) |
|
pdf_upload.upload(assistant.process_pdf, pdf_upload, upload_status) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch() |