|
|
|
""" |
|
Enhanced implementation of the Vision 2030 Virtual Assistant that meets all project requirements: |
|
1. Implements proper NLP task structure (bilingual QA system) |
|
2. Adds comprehensive evaluation framework for quantitative and qualitative assessment |
|
3. Improves RAG implementation with better retrieval and document processing |
|
4. Adds user feedback collection for continuous improvement |
|
5. Includes structured logging and performance monitoring |
|
""" |
|
|
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from langdetect import detect |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import numpy as np |
|
import json |
|
import time |
|
import logging |
|
import os |
|
import re |
|
from datetime import datetime |
|
from sklearn.metrics import precision_recall_fscore_support, accuracy_score |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import PyPDF2 |
|
import io |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler("vision2030_assistant.log"), |
|
logging.StreamHandler() |
|
] |
|
) |
|
logger = logging.getLogger('vision2030_assistant') |
|
|
|
class Vision2030Assistant: |
|
def __init__(self, pdf_path="vision2030.pdf", eval_data_path="evaluation_data.json"): |
|
""" |
|
Initialize the Vision 2030 Assistant with models, knowledge base, and evaluation framework |
|
|
|
Args: |
|
pdf_path: Path to the Vision 2030 PDF document |
|
eval_data_path: Path to evaluation dataset |
|
""" |
|
logger.info("Initializing Vision 2030 Assistant...") |
|
self.load_models() |
|
self.load_and_process_documents(pdf_path) |
|
self.setup_evaluation_framework(eval_data_path) |
|
self.response_history = [] |
|
logger.info("Vision 2030 Assistant initialized successfully") |
|
|
|
def load_models(self): |
|
"""Load language models and embedding models for both Arabic and English""" |
|
logger.info("Loading language and embedding models...") |
|
|
|
|
|
try: |
|
self.arabic_model_id = "ALLaM-AI/ALLaM-7B-Instruct-preview" |
|
self.arabic_tokenizer = AutoTokenizer.from_pretrained(self.arabic_model_id) |
|
self.arabic_model = AutoModelForCausalLM.from_pretrained(self.arabic_model_id, device_map="auto") |
|
self.arabic_pipe = pipeline("text-generation", model=self.arabic_model, tokenizer=self.arabic_tokenizer) |
|
logger.info("Arabic model loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Error loading Arabic model: {str(e)}") |
|
raise |
|
|
|
|
|
try: |
|
self.english_model_id = "mistralai/Mistral-7B-Instruct-v0.2" |
|
self.english_tokenizer = AutoTokenizer.from_pretrained(self.english_model_id) |
|
self.english_model = AutoModelForCausalLM.from_pretrained(self.english_model_id, device_map="auto") |
|
self.english_pipe = pipeline("text-generation", model=self.english_model, tokenizer=self.english_tokenizer) |
|
logger.info("English model loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Error loading English model: {str(e)}") |
|
raise |
|
|
|
|
|
try: |
|
self.arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca') |
|
self.english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
logger.info("Embedding models loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Error loading embedding models: {str(e)}") |
|
raise |
|
|
|
def load_and_process_documents(self, pdf_path): |
|
"""Load and process the Vision 2030 document from PDF""" |
|
logger.info(f"Processing Vision 2030 document from {pdf_path}") |
|
|
|
|
|
self.english_texts = [] |
|
self.arabic_texts = [] |
|
|
|
try: |
|
|
|
if os.path.exists(pdf_path): |
|
|
|
with open(pdf_path, 'rb') as file: |
|
reader = PyPDF2.PdfReader(file) |
|
full_text = "" |
|
for page_num in range(len(reader.pages)): |
|
page = reader.pages[page_num] |
|
full_text += page.extract_text() + "\n" |
|
|
|
|
|
chunks = [chunk.strip() for chunk in re.split(r'\n\s*\n', full_text) if chunk.strip()] |
|
|
|
|
|
for chunk in chunks: |
|
try: |
|
lang = detect(chunk) |
|
if lang == "ar": |
|
self.arabic_texts.append(chunk) |
|
else: |
|
self.english_texts.append(chunk) |
|
except: |
|
|
|
self.english_texts.append(chunk) |
|
|
|
logger.info(f"Processed {len(self.arabic_texts)} Arabic and {len(self.english_texts)} English chunks") |
|
else: |
|
logger.warning(f"PDF file not found at {pdf_path}. Using fallback sample data.") |
|
self._create_sample_data() |
|
except Exception as e: |
|
logger.error(f"Error processing PDF: {str(e)}") |
|
logger.info("Using fallback sample data") |
|
self._create_sample_data() |
|
|
|
|
|
self._create_indices() |
|
|
|
def _create_sample_data(self): |
|
"""Create sample Vision 2030 data if PDF processing fails""" |
|
logger.info("Creating sample Vision 2030 data") |
|
|
|
|
|
self.english_texts = [ |
|
"Vision 2030 is Saudi Arabia's strategic framework to reduce dependence on oil, diversify the economy, and develop public sectors.", |
|
"The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation.", |
|
"The Saudi Public Investment Fund (PIF) plays a crucial role in Vision 2030 by investing in strategic sectors.", |
|
"NEOM is a planned cross-border smart city in the Tabuk Province of northwestern Saudi Arabia, a key project of Vision 2030.", |
|
"Vision 2030 aims to increase women's participation in the workforce from 22% to 30%.", |
|
"The Red Sea Project is a Vision 2030 initiative to develop luxury tourism destinations across 50 islands off Saudi Arabia's Red Sea coast.", |
|
"Qiddiya is a entertainment mega-project being built in Riyadh as part of Vision 2030.", |
|
"Vision 2030 targets increasing the private sector's contribution to GDP from 40% to 65%.", |
|
"One goal of Vision 2030 is to increase foreign direct investment from 3.8% to 5.7% of GDP.", |
|
"Vision 2030 includes plans to develop the digital infrastructure and support for tech startups in Saudi Arabia." |
|
] |
|
|
|
|
|
self.arabic_texts = [ |
|
"رؤية 2030 هي الإطار الاستراتيجي للمملكة العربية السعودية للحد من الاعتماد على النفط وتنويع الاقتصاد وتطوير القطاعات العامة.", |
|
"الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح.", |
|
"يلعب صندوق الاستثمارات العامة السعودي دورًا محوريًا في رؤية 2030 من خلال الاستثمار في القطاعات الاستراتيجية.", |
|
"نيوم هي مدينة ذكية مخططة عبر الحدود في مقاطعة تبوك شمال غرب المملكة العربية السعودية، وهي مشروع رئيسي من رؤية 2030.", |
|
"تهدف رؤية 2030 إلى زيادة مشاركة المرأة في القوى العاملة من 22٪ إلى 30٪.", |
|
"مشروع البحر الأحمر هو مبادرة رؤية 2030 لتطوير وجهات سياحية فاخرة عبر 50 جزيرة قبالة ساحل البحر الأحمر السعودي.", |
|
"القدية هي مشروع ترفيهي ضخم يتم بناؤه في الرياض كجزء من رؤية 2030.", |
|
"تستهدف رؤية 2030 زيادة مساهمة القطاع الخاص في الناتج المحلي الإجمالي من 40٪ إلى 65٪.", |
|
"أحد أهداف رؤية 2030 هو زيادة الاستثمار الأجنبي المباشر من 3.8٪ إلى 5.7٪ من الناتج المحلي الإجمالي.", |
|
"تتضمن رؤية 2030 خططًا لتطوير البنية التحتية الرقمية والدعم للشركات الناشئة التكنولوجية في المملكة العربية السعودية." |
|
] |
|
|
|
def _create_indices(self): |
|
"""Create FAISS indices for fast text retrieval""" |
|
logger.info("Creating FAISS indices for text retrieval") |
|
|
|
try: |
|
|
|
self.english_vectors = [] |
|
for text in self.english_texts: |
|
vec = self.english_embedder.encode(text) |
|
self.english_vectors.append(vec) |
|
|
|
|
|
if self.english_vectors: |
|
self.english_index = faiss.IndexFlatL2(len(self.english_vectors[0])) |
|
self.english_index.add(np.array(self.english_vectors)) |
|
logger.info(f"Created English index with {len(self.english_vectors)} vectors") |
|
else: |
|
logger.warning("No English texts to index") |
|
|
|
|
|
self.arabic_vectors = [] |
|
for text in self.arabic_texts: |
|
vec = self.arabic_embedder.encode(text) |
|
self.arabic_vectors.append(vec) |
|
|
|
|
|
if self.arabic_vectors: |
|
self.arabic_index = faiss.IndexFlatL2(len(self.arabic_vectors[0])) |
|
self.arabic_index.add(np.array(self.arabic_vectors)) |
|
logger.info(f"Created Arabic index with {len(self.arabic_vectors)} vectors") |
|
else: |
|
logger.warning("No Arabic texts to index") |
|
|
|
except Exception as e: |
|
logger.error(f"Error creating FAISS indices: {str(e)}") |
|
raise |
|
|
|
def setup_evaluation_framework(self, eval_data_path): |
|
"""Set up the evaluation framework with test data and metrics""" |
|
logger.info("Setting up evaluation framework") |
|
|
|
|
|
self.metrics = { |
|
"response_times": [], |
|
"user_ratings": [], |
|
"retrieval_precision": [], |
|
"factual_accuracy": [] |
|
} |
|
|
|
|
|
try: |
|
if os.path.exists(eval_data_path): |
|
with open(eval_data_path, 'r', encoding='utf-8') as f: |
|
self.eval_data = json.load(f) |
|
logger.info(f"Loaded {len(self.eval_data)} evaluation examples from {eval_data_path}") |
|
else: |
|
logger.warning(f"Evaluation data not found at {eval_data_path}. Creating sample evaluation data.") |
|
self._create_sample_eval_data() |
|
except Exception as e: |
|
logger.error(f"Error loading evaluation data: {str(e)}") |
|
self._create_sample_eval_data() |
|
|
|
def _create_sample_eval_data(self): |
|
"""Create sample evaluation data with ground truth""" |
|
self.eval_data = [ |
|
{ |
|
"question": "What are the key pillars of Vision 2030?", |
|
"lang": "en", |
|
"reference_answer": "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation." |
|
}, |
|
{ |
|
"question": "ما هي الركائز الرئيسية لرؤية 2030؟", |
|
"lang": "ar", |
|
"reference_answer": "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح." |
|
}, |
|
{ |
|
"question": "What is NEOM?", |
|
"lang": "en", |
|
"reference_answer": "NEOM is a planned cross-border smart city in the Tabuk Province of northwestern Saudi Arabia, a key project of Vision 2030." |
|
}, |
|
{ |
|
"question": "ما هو مشروع البحر الأحمر؟", |
|
"lang": "ar", |
|
"reference_answer": "مشروع البحر الأحمر هو مبادرة رؤية 2030 لتطوير وجهات سياحية فاخرة عبر 50 جزيرة قبالة ساحل البحر الأحمر السعودي." |
|
} |
|
] |
|
logger.info(f"Created {len(self.eval_data)} sample evaluation examples") |
|
|
|
def retrieve_context(self, query, lang): |
|
"""Retrieve relevant context for a query based on language""" |
|
start_time = time.time() |
|
|
|
try: |
|
if lang == "ar": |
|
query_vec = self.arabic_embedder.encode(query) |
|
D, I = self.arabic_index.search(np.array([query_vec]), k=2) |
|
context = "\n".join([self.arabic_texts[i] for i in I[0] if i < len(self.arabic_texts) and i >= 0]) |
|
else: |
|
query_vec = self.english_embedder.encode(query) |
|
D, I = self.english_index.search(np.array([query_vec]), k=2) |
|
context = "\n".join([self.english_texts[i] for i in I[0] if i < len(self.english_texts) and i >= 0]) |
|
|
|
retrieval_time = time.time() - start_time |
|
logger.info(f"Retrieved context in {retrieval_time:.2f}s") |
|
|
|
return context |
|
except Exception as e: |
|
logger.error(f"Error retrieving context: {str(e)}") |
|
return "" |
|
|
|
def generate_response(self, user_input): |
|
"""Generate a response to user input using the appropriate model and retrieval system""" |
|
start_time = time.time() |
|
|
|
|
|
default_response = { |
|
"en": "I apologize, but I couldn't process your request properly. Please try again.", |
|
"ar": "أعتذر، لم أتمكن من معالجة طلبك بشكل صحيح. الرجاء المحاولة مرة أخرى." |
|
} |
|
|
|
try: |
|
|
|
try: |
|
lang = detect(user_input) |
|
if lang != "ar": |
|
lang = "en" |
|
except: |
|
lang = "en" |
|
|
|
logger.info(f"Detected language: {lang}") |
|
|
|
|
|
context = self.retrieve_context(user_input, lang) |
|
|
|
if lang == "ar": |
|
|
|
input_text = ( |
|
f"أنت خبير في رؤية السعودية 2030.\n" |
|
f"إليك بعض المعلومات المهمة:\n{context}\n\n" |
|
f"مثال:\n" |
|
f"السؤال: ما هي ركائز رؤية 2030؟\n" |
|
f"الإجابة: ركائز رؤية 2030 هي مجتمع حيوي، اقتصاد مزدهر، ووطن طموح.\n\n" |
|
f"أجب عن سؤال المستخدم بشكل واضح ودقيق، مستندًا إلى المعلومات المقدمة. إذا لم تكن المعلومات متوفرة، أوضح ذلك.\n" |
|
f"السؤال: {user_input}\n" |
|
f"الإجابة:" |
|
) |
|
|
|
response = self.arabic_pipe(input_text, max_new_tokens=256, do_sample=True, temperature=0.7) |
|
full_text = response[0]['generated_text'] |
|
|
|
|
|
answer_pattern = r"الإجابة:(.*?)(?:$)" |
|
match = re.search(answer_pattern, full_text, re.DOTALL) |
|
if match: |
|
reply = match.group(1).strip() |
|
else: |
|
reply = full_text |
|
else: |
|
|
|
input_text = ( |
|
f"You are an expert on Saudi Arabia's Vision 2030.\n" |
|
f"Here is some relevant information:\n{context}\n\n" |
|
f"Example:\n" |
|
f"Question: What are the key pillars of Vision 2030?\n" |
|
f"Answer: The key pillars are a vibrant society, a thriving economy, and an ambitious nation.\n\n" |
|
f"Answer the user's question clearly and accurately based on the provided information. If information is not available, make that clear.\n" |
|
f"Question: {user_input}\n" |
|
f"Answer:" |
|
) |
|
|
|
response = self.english_pipe(input_text, max_new_tokens=256, do_sample=True, temperature=0.7) |
|
full_text = response[0]['generated_text'] |
|
|
|
|
|
answer_pattern = r"Answer:(.*?)(?:$)" |
|
match = re.search(answer_pattern, full_text, re.DOTALL) |
|
if match: |
|
reply = match.group(1).strip() |
|
else: |
|
reply = full_text |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating response: {str(e)}") |
|
reply = default_response.get(lang, default_response["en"]) |
|
|
|
|
|
response_time = time.time() - start_time |
|
self.metrics["response_times"].append(response_time) |
|
|
|
logger.info(f"Generated response in {response_time:.2f}s") |
|
|
|
|
|
interaction = { |
|
"timestamp": datetime.now().isoformat(), |
|
"user_input": user_input, |
|
"response": reply, |
|
"language": lang, |
|
"response_time": response_time |
|
} |
|
self.response_history.append(interaction) |
|
|
|
return reply |
|
|
|
def evaluate_factual_accuracy(self, response, reference): |
|
"""Simple evaluation of factual accuracy by keyword matching""" |
|
|
|
keywords_reference = set(re.findall(r'\b\w+\b', reference.lower())) |
|
keywords_response = set(re.findall(r'\b\w+\b', response.lower())) |
|
|
|
common_keywords = keywords_reference.intersection(keywords_response) |
|
|
|
if len(keywords_reference) > 0: |
|
accuracy = len(common_keywords) / len(keywords_reference) |
|
else: |
|
accuracy = 0 |
|
|
|
return accuracy |
|
|
|
def evaluate_on_test_set(self): |
|
"""Evaluate the assistant on the test set""" |
|
logger.info("Running evaluation on test set") |
|
|
|
eval_results = [] |
|
|
|
for example in self.eval_data: |
|
|
|
response = self.generate_response(example["question"]) |
|
|
|
|
|
accuracy = self.evaluate_factual_accuracy(response, example["reference_answer"]) |
|
|
|
eval_results.append({ |
|
"question": example["question"], |
|
"reference": example["reference_answer"], |
|
"response": response, |
|
"factual_accuracy": accuracy |
|
}) |
|
|
|
self.metrics["factual_accuracy"].append(accuracy) |
|
|
|
|
|
avg_accuracy = sum(self.metrics["factual_accuracy"]) / len(self.metrics["factual_accuracy"]) if self.metrics["factual_accuracy"] else 0 |
|
avg_response_time = sum(self.metrics["response_times"]) / len(self.metrics["response_times"]) if self.metrics["response_times"] else 0 |
|
|
|
results = { |
|
"average_factual_accuracy": avg_accuracy, |
|
"average_response_time": avg_response_time, |
|
"detailed_results": eval_results |
|
} |
|
|
|
logger.info(f"Evaluation results: Factual accuracy = {avg_accuracy:.2f}, Avg response time = {avg_response_time:.2f}s") |
|
|
|
return results |
|
|
|
def record_user_feedback(self, user_input, response, rating, feedback_text=""): |
|
"""Record user feedback for a response""" |
|
feedback = { |
|
"timestamp": datetime.now().isoformat(), |
|
"user_input": user_input, |
|
"response": response, |
|
"rating": rating, |
|
"feedback_text": feedback_text |
|
} |
|
|
|
self.metrics["user_ratings"].append(rating) |
|
|
|
|
|
logger.info(f"Recorded user feedback: rating={rating}") |
|
|
|
return True |
|
|
|
def save_evaluation_metrics(self, output_path="evaluation_metrics.json"): |
|
"""Save evaluation metrics to a file""" |
|
try: |
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
json.dump({ |
|
"response_times": self.metrics["response_times"], |
|
"user_ratings": self.metrics["user_ratings"], |
|
"factual_accuracy": self.metrics["factual_accuracy"], |
|
"average_factual_accuracy": sum(self.metrics["factual_accuracy"]) / len(self.metrics["factual_accuracy"]) if self.metrics["factual_accuracy"] else 0, |
|
"average_response_time": sum(self.metrics["response_times"]) / len(self.metrics["response_times"]) if self.metrics["response_times"] else 0, |
|
"average_user_rating": sum(self.metrics["user_ratings"]) / len(self.metrics["user_ratings"]) if self.metrics["user_ratings"] else 0, |
|
"timestamp": datetime.now().isoformat() |
|
}, f, indent=2) |
|
|
|
logger.info(f"Saved evaluation metrics to {output_path}") |
|
return True |
|
except Exception as e: |
|
logger.error(f"Error saving evaluation metrics: {str(e)}") |
|
return False |
|
|
|
|
|
def create_gradio_interface(): |
|
|
|
assistant = Vision2030Assistant() |
|
|
|
|
|
conversation_history = [] |
|
|
|
def chat(message, history): |
|
if not message: |
|
return history, "" |
|
|
|
|
|
reply = assistant.generate_response(message) |
|
|
|
|
|
history.append((message, reply)) |
|
|
|
return history, "" |
|
|
|
def provide_feedback(message, rating, feedback_text): |
|
|
|
if conversation_history: |
|
last_interaction = conversation_history[-1] |
|
assistant.record_user_feedback(last_interaction[0], last_interaction[1], rating, feedback_text) |
|
return f"Thank you for your feedback! (Rating: {rating}/5)" |
|
return "No conversation found to rate." |
|
|
|
def clear_history(): |
|
conversation_history.clear() |
|
return [] |
|
|
|
def download_metrics(): |
|
assistant.save_evaluation_metrics() |
|
return "evaluation_metrics.json" |
|
|
|
def run_evaluation(): |
|
results = assistant.evaluate_on_test_set() |
|
return f"Evaluation Results:\nFactual Accuracy: {results['average_factual_accuracy']:.2f}\nAverage Response Time: {results['average_response_time']:.2f}s" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Vision 2030 Virtual Assistant 🌍\n\nAsk questions about Saudi Vision 2030 in Arabic or English") |
|
|
|
with gr.Tab("Chat"): |
|
chatbot = gr.Chatbot(show_label=False) |
|
msg = gr.Textbox(label="Ask me anything about Vision 2030", placeholder="Type your question here...") |
|
clear = gr.Button("Clear Conversation") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
feedback_text = gr.Textbox(label="Provide additional feedback (optional)") |
|
with gr.Column(scale=1): |
|
rating = gr.Slider(label="Rate Response (1-5)", minimum=1, maximum=5, step=1, value=3) |
|
|
|
submit_feedback = gr.Button("Submit Feedback") |
|
feedback_result = gr.Textbox(label="Feedback Status") |
|
|
|
|
|
msg.submit(chat, [msg, chatbot], [chatbot, msg]) |
|
clear.click(clear_history, None, chatbot) |
|
submit_feedback.click(provide_feedback, [msg, rating, feedback_text], feedback_result) |
|
|
|
with gr.Tab("Evaluation"): |
|
eval_button = gr.Button("Run Evaluation on Test Set") |
|
eval_results = gr.Textbox(label="Evaluation Results") |
|
download_button = gr.Button("Download Metrics") |
|
download_file = gr.File(label="Download evaluation metrics as JSON") |
|
|
|
|
|
eval_button.click(run_evaluation, None, eval_results) |
|
download_button.click(download_metrics, None, download_file) |
|
|
|
with gr.Tab("About"): |
|
gr.Markdown(""" |
|
## About Vision 2030 Virtual Assistant |
|
|
|
This assistant uses a combination of state-of-the-art language models to answer questions about Saudi Arabia's Vision 2030 strategic framework in both Arabic and English. |
|
|
|
### Features: |
|
- Bilingual support (Arabic and English) |
|
- Retrieval-Augmented Generation (RAG) for factual accuracy |
|
- Evaluation framework for measuring performance |
|
- User feedback collection for continuous improvement |
|
|
|
### Models Used: |
|
- Arabic: ALLaM-7B-Instruct-preview |
|
- English: Mistral-7B-Instruct-v0.2 |
|
- Embeddings: CAMeL-Lab/bert-base-arabic-camelbert-ca and sentence-transformers/all-MiniLM-L6-v2 |
|
|
|
This project demonstrates the application of advanced NLP techniques for multilingual question answering, particularly for Arabic language support. |
|
""") |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_gradio_interface() |
|
demo.launch() |