|
|
|
import gradio as gr |
|
import time |
|
import logging |
|
import os |
|
import re |
|
from datetime import datetime |
|
import numpy as np |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from sklearn.metrics import precision_recall_fscore_support, accuracy_score |
|
import PyPDF2 |
|
import io |
|
import json |
|
from langdetect import detect |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import torch |
|
import spaces |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[logging.StreamHandler()] |
|
) |
|
logger = logging.getLogger('vision2030_assistant') |
|
|
|
|
|
has_gpu = torch.cuda.is_available() |
|
logger.info(f"GPU available: {has_gpu}") |
|
|
|
class Vision2030Assistant: |
|
def __init__(self): |
|
"""Initialize the Vision 2030 Assistant with embedding models and enhanced knowledge base""" |
|
logger.info("Initializing Vision 2030 Assistant...") |
|
|
|
|
|
self.load_embedding_models() |
|
|
|
|
|
self._create_enhanced_knowledge_base() |
|
|
|
|
|
self._create_sample_data() |
|
self._create_indices() |
|
|
|
|
|
self._create_sample_eval_data() |
|
|
|
|
|
self.metrics = { |
|
"response_times": [], |
|
"user_ratings": [], |
|
"factual_accuracy": [] |
|
} |
|
self.response_history = [] |
|
|
|
|
|
self.original_generate_response = self._basic_generate_response |
|
|
|
logger.info("Vision 2030 Assistant initialized successfully") |
|
|
|
def _create_enhanced_knowledge_base(self): |
|
"""Create an enhanced knowledge base with detailed information about Vision 2030""" |
|
logger.info("Creating enhanced Vision 2030 knowledge base") |
|
|
|
|
|
self.vision2030_knowledge = { |
|
"general": { |
|
"en": [ |
|
"Vision 2030 is Saudi Arabia's strategic framework to reduce dependence on oil, diversify the economy, and develop public sectors.", |
|
"The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation.", |
|
"Vision 2030 was announced by Crown Prince Mohammed bin Salman in April 2016.", |
|
"The true wealth of Saudi Arabia, as mentioned in Vision 2030, is its people and their potential." |
|
], |
|
"ar": [ |
|
"رؤية 2030 هي الإطار الاستراتيجي للمملكة العربية السعودية للحد من الاعتماد على النفط وتنويع الاقتصاد وتطوير القطاعات العامة.", |
|
"الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح.", |
|
"تم الإعلان عن رؤية 2030 من قبل ولي العهد محمد بن سلمان في أبريل 2016.", |
|
"الثروة الحقيقية للمملكة العربية السعودية، كما ذكر في رؤية 2030، هي شعبها وإمكاناته." |
|
] |
|
}, |
|
"real_wealth": { |
|
"en": [ |
|
"The real wealth of Saudi Arabia, as emphasized in Vision 2030, is its people, particularly the youth.", |
|
"Vision 2030 recognizes that the Saudi people, with their strong values and capabilities, are the true wealth of the nation.", |
|
"The document states: 'Our people are our most valuable asset and the enablers of our success'." |
|
], |
|
"ar": [ |
|
"الثروة الحقيقية للمملكة العربية السعودية، كما أكدت رؤية 2030، هي شعبها، وخاصة الشباب.", |
|
"تعترف رؤية 2030 بأن الشعب السعودي، بقيمه وقدراته القوية، هو الثروة الحقيقية للأمة.", |
|
"تنص الوثيقة على: 'شعبنا هو أثمن أصولنا وأساس نجاحنا'." |
|
] |
|
}, |
|
"global_gateway": { |
|
"en": [ |
|
"Saudi Arabia aims to strengthen its position as a global gateway by leveraging its strategic location between Asia, Europe, and Africa.", |
|
"The Kingdom plans to build a unique logistical hub connecting three continents and improve infrastructure to facilitate trade.", |
|
"Vision 2030 intends to establish special economic zones with competitive regulations to attract international investors.", |
|
"The plan includes enhancing seaports, building regional connectivity through railways, and expanding airports." |
|
], |
|
"ar": [ |
|
"تهدف المملكة العربية السعودية إلى تعزيز مكانتها كبوابة عالمية من خلال الاستفادة من موقعها الاستراتيجي بين آسيا وأوروبا وأفريقيا.", |
|
"تخطط المملكة لبناء مركز لوجستي فريد يربط بين ثلاث قارات وتحسين البنية التحتية لتسهيل التجارة.", |
|
"تعتزم رؤية 2030 إنشاء مناطق اقتصادية خاصة ذات لوائح تنافسية لجذب المستثمرين الدوليين.", |
|
"تتضمن الخطة تعزيز الموانئ البحرية، وبناء الربط الإقليمي من خلال السكك الحديدية، وتوسيع المطارات." |
|
] |
|
}, |
|
"tourism": { |
|
"en": [ |
|
"Vision 2030 aims to develop tourism as a key non-oil sector, including religious, cultural, and leisure tourism.", |
|
"The plan includes developing the Red Sea as a world-class luxury tourist destination, with a focus on sustainability.", |
|
"Vision 2030 targets increasing tourism's contribution to GDP from 3% to 10% and hosting 100 million tourists annually by 2030.", |
|
"The Al-Ula region is being developed as a major archaeological and cultural tourism destination." |
|
], |
|
"ar": [ |
|
"تهدف رؤية 2030 إلى تطوير السياحة كقطاع غير نفطي رئيسي، بما في ذلك السياحة الدينية والثقافية والترفيهية.", |
|
"تتضمن الخطة تطوير البحر الأحمر كوجهة سياحية فاخرة على مستوى عالمي، مع التركيز على الاستدامة.", |
|
"تستهدف رؤية 2030 زيادة مساهمة السياحة في الناتج المحلي الإجمالي من 3٪ إلى 10٪ واستضافة 100 مليون سائح سنويًا بحلول عام 2030.", |
|
"يتم تطوير منطقة العلا كوجهة سياحية أثرية وثقافية رئيسية." |
|
] |
|
}, |
|
"youth": { |
|
"en": [ |
|
"Vision 2030 recognizes youth as the Kingdom's most valuable resource, with 60% of the population under 30 years old.", |
|
"The plan aims to reduce youth unemployment from 30% to 7% through education reform and economic growth.", |
|
"Vision 2030 includes building a culture of entrepreneurship to harness the creative energy of Saudi youth.", |
|
"The plan supports youth development programs, sports initiatives, and enhanced educational opportunities." |
|
], |
|
"ar": [ |
|
"تعترف رؤية 2030 بالشباب كأثمن موارد المملكة، حيث يشكلون 60٪ من السكان تحت سن 30 عامًا.", |
|
"تهدف الخطة إلى خفض بطالة الشباب من 30٪ إلى 7٪ من خلال إصلاح التعليم والنمو الاقتصادي.", |
|
"تتضمن رؤية 2030 بناء ثقافة ريادة الأعمال للاستفادة من الطاقة الإبداعية للشباب السعودي.", |
|
"تدعم الخطة برامج تنمية الشباب، والمبادرات الرياضية، وتعزيز الفرص التعليمية." |
|
] |
|
}, |
|
"women": { |
|
"en": [ |
|
"Vision 2030 aims to increase women's participation in the workforce from 22% to 30%.", |
|
"The plan supports women's rights and empowerment across economic, social, and political spheres.", |
|
"Vision 2030 has already resulted in policy changes allowing women to drive, travel independently, and participate more fully in public life.", |
|
"The plan includes initiatives to increase female leadership positions in both public and private sectors." |
|
], |
|
"ar": [ |
|
"تهدف رؤية 2030 إلى زيادة مشاركة المرأة في القوى العاملة من 22٪ إلى 30٪.", |
|
"تدعم الخطة حقوق المرأة وتمكينها في المجالات الاقتصادية والاجتماعية والسياسية.", |
|
"أدت رؤية 2030 بالفعل إلى تغييرات في السياسات تسمح للمرأة بالقيادة، والسفر بشكل مستقل، والمشاركة بشكل أكبر في الحياة العامة.", |
|
"تتضمن الخطة مبادرات لزيادة المناصب القيادية النسائية في القطاعين العام والخاص." |
|
] |
|
}, |
|
"projects": { |
|
"en": [ |
|
"NEOM is a planned cross-border smart city in the Tabuk Province of northwestern Saudi Arabia, a key project of Vision 2030.", |
|
"The Red Sea Project is a Vision 2030 initiative to develop luxury tourism destinations across 50 islands off Saudi Arabia's Red Sea coast.", |
|
"Qiddiya is an entertainment mega-project being built in Riyadh as part of Vision 2030, intended to be the world's largest entertainment city.", |
|
"The Line is a revolutionary urban development project within NEOM featuring a 170 km-long linear city without cars or streets.", |
|
"AMAALA is an ultra-luxury tourism project on the Red Sea that focuses on wellness, healthy living, and meditation." |
|
], |
|
"ar": [ |
|
"نيوم هي مدينة ذكية مخططة عبر الحدود في مقاطعة تبوك شمال غرب المملكة العربية السعودية، وهي مشروع رئيسي من رؤية 2030.", |
|
"مشروع البحر الأحمر هو مبادرة رؤية 2030 لتطوير وجهات سياحية فاخرة عبر 50 جزيرة قبالة ساحل البحر الأحمر السعودي.", |
|
"القدية هي مشروع ترفيهي ضخم يتم بناؤه في الرياض كجزء من رؤية 2030، ويهدف إلى أن يكون أكبر مدينة ترفيهية في العالم.", |
|
"ذا لاين هو مشروع تطوير حضري ثوري ضمن نيوم يتميز بمدينة خطية طولها 170 كم بدون سيارات أو شوارع.", |
|
"أمالا هو مشروع سياحي فائق الفخامة على البحر الأحمر يركز على العافية والحياة الصحية والتأمل." |
|
] |
|
}, |
|
"economic_goals": { |
|
"en": [ |
|
"Vision 2030 targets increasing the private sector's contribution to GDP from 40% to 65%.", |
|
"One goal of Vision 2030 is to increase foreign direct investment from 3.8% to 5.7% of GDP.", |
|
"Vision 2030 aims to raise the share of non-oil exports in non-oil GDP from 16% to 50%.", |
|
"The plan targets increasing SME contribution to GDP from 20% to 35%.", |
|
"Vision 2030 aims to lower the unemployment rate from 11.6% to 7%." |
|
], |
|
"ar": [ |
|
"تستهدف رؤية 2030 زيادة مساهمة القطاع الخاص في الناتج المحلي الإجمالي من 40٪ إلى 65٪.", |
|
"أحد أهداف رؤية 2030 هو زيادة الاستثمار الأجنبي المباشر من 3.8٪ إلى 5.7٪ من الناتج المحلي الإجمالي.", |
|
"تهدف رؤية 2030 إلى رفع حصة الصادرات غير النفطية في الناتج المحلي الإجمالي غير النفطي من 16٪ إلى 50٪.", |
|
"تستهدف الخطة زيادة مساهمة المنشآت الصغيرة والمتوسطة في الناتج المحلي الإجمالي من 20٪ إلى 35٪.", |
|
"تهدف رؤية 2030 إلى خفض معدل البطالة من 11.6٪ إلى 7٪." |
|
] |
|
}, |
|
"digital_transformation": { |
|
"en": [ |
|
"Vision 2030 includes plans to develop the digital infrastructure and support for tech startups in Saudi Arabia.", |
|
"The plan aims to increase internet penetration to 95% of households in urban areas and 65% in rural areas.", |
|
"Vision 2030 focuses on building a digital economy, enhancing e-government services, and developing digital skills.", |
|
"The plan includes initiatives to position Saudi Arabia as a leader in the Fourth Industrial Revolution technologies." |
|
], |
|
"ar": [ |
|
"تتضمن رؤية 2030 خططًا لتطوير البنية التحتية الرقمية ودعم الشركات الناشئة التكنولوجية في المملكة العربية السعودية.", |
|
"تهدف الخطة إلى زيادة انتشار الإنترنت إلى 95٪ من الأسر في المناطق الحضرية و 65٪ في المناطق الريفية.", |
|
"تركز رؤية 2030 على بناء اقتصاد رقمي، وتعزيز خدمات الحكومة الإلكترونية، وتطوير المهارات الرقمية.", |
|
"تتضمن الخطة مبادرات لوضع المملكة العربية السعودية كرائدة في تقنيات الثورة الصناعية الرابعة." |
|
] |
|
} |
|
} |
|
|
|
|
|
self.english_texts = [] |
|
self.arabic_texts = [] |
|
|
|
|
|
for category in self.vision2030_knowledge: |
|
self.english_texts.extend(self.vision2030_knowledge[category]["en"]) |
|
self.arabic_texts.extend(self.vision2030_knowledge[category]["ar"]) |
|
|
|
logger.info(f"Created enhanced knowledge base: {len(self.english_texts)} English, {len(self.arabic_texts)} Arabic texts") |
|
|
|
@spaces.GPU |
|
def load_embedding_models(self): |
|
"""Load embedding models for retrieval with GPU support""" |
|
logger.info("Loading embedding models...") |
|
|
|
try: |
|
|
|
self.arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca') |
|
self.english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
|
|
|
|
if has_gpu: |
|
self.arabic_embedder = self.arabic_embedder.to('cuda') |
|
self.english_embedder = self.english_embedder.to('cuda') |
|
logger.info("Models moved to GPU") |
|
|
|
logger.info("Embedding models loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Error loading embedding models: {str(e)}") |
|
|
|
self._create_fallback_embedders() |
|
|
|
def _create_fallback_embedders(self): |
|
"""Create fallback embedding methods if model loading fails""" |
|
logger.warning("Using fallback embedding methods") |
|
|
|
|
|
def simple_encode(text, dim=384): |
|
import hashlib |
|
|
|
hash_object = hashlib.md5(text.encode()) |
|
|
|
np.random.seed(int(hash_object.hexdigest(), 16) % 2**32) |
|
|
|
return np.random.randn(dim).astype(np.float32) |
|
|
|
|
|
class SimpleEmbedder: |
|
def __init__(self, dim=384): |
|
self.dim = dim |
|
|
|
def encode(self, text): |
|
return simple_encode(text, self.dim) |
|
|
|
self.arabic_embedder = SimpleEmbedder() |
|
self.english_embedder = SimpleEmbedder() |
|
|
|
def _create_sample_data(self): |
|
"""Create sample Vision 2030 data""" |
|
logger.info("Creating additional sample data") |
|
|
|
|
|
|
|
pass |
|
|
|
@spaces.GPU |
|
def _create_indices(self): |
|
"""Create FAISS indices for fast text retrieval""" |
|
logger.info("Creating FAISS indices for text retrieval") |
|
|
|
try: |
|
|
|
self.english_vectors = [] |
|
for text in self.english_texts: |
|
try: |
|
if has_gpu and hasattr(self.english_embedder, 'to') and callable(getattr(self.english_embedder, 'to')): |
|
|
|
with torch.no_grad(): |
|
vec = self.english_embedder.encode(text) |
|
else: |
|
|
|
vec = self.english_embedder.encode(text) |
|
self.english_vectors.append(vec) |
|
except Exception as e: |
|
logger.error(f"Error encoding English text: {str(e)}") |
|
|
|
self.english_vectors.append(np.random.randn(384).astype(np.float32)) |
|
|
|
|
|
if self.english_vectors: |
|
self.english_index = faiss.IndexFlatL2(len(self.english_vectors[0])) |
|
self.english_index.add(np.array(self.english_vectors)) |
|
logger.info(f"Created English index with {len(self.english_vectors)} vectors") |
|
else: |
|
logger.warning("No English texts to index") |
|
|
|
|
|
self.arabic_vectors = [] |
|
for text in self.arabic_texts: |
|
try: |
|
if has_gpu and hasattr(self.arabic_embedder, 'to') and callable(getattr(self.arabic_embedder, 'to')): |
|
|
|
with torch.no_grad(): |
|
vec = self.arabic_embedder.encode(text) |
|
else: |
|
|
|
vec = self.arabic_embedder.encode(text) |
|
self.arabic_vectors.append(vec) |
|
except Exception as e: |
|
logger.error(f"Error encoding Arabic text: {str(e)}") |
|
|
|
self.arabic_vectors.append(np.random.randn(384).astype(np.float32)) |
|
|
|
|
|
if self.arabic_vectors: |
|
self.arabic_index = faiss.IndexFlatL2(len(self.arabic_vectors[0])) |
|
self.arabic_index.add(np.array(self.arabic_vectors)) |
|
logger.info(f"Created Arabic index with {len(self.arabic_vectors)} vectors") |
|
else: |
|
logger.warning("No Arabic texts to index") |
|
|
|
except Exception as e: |
|
logger.error(f"Error creating FAISS indices: {str(e)}") |
|
raise |
|
|
|
def _create_sample_eval_data(self): |
|
"""Create sample evaluation data with ground truth""" |
|
self.eval_data = [ |
|
{ |
|
"question": "What are the key pillars of Vision 2030?", |
|
"lang": "en", |
|
"reference_answer": "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation." |
|
}, |
|
{ |
|
"question": "ما هي الركائز الرئيسية لرؤية 2030؟", |
|
"lang": "ar", |
|
"reference_answer": "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح." |
|
}, |
|
{ |
|
"question": "What is NEOM?", |
|
"lang": "en", |
|
"reference_answer": "NEOM is a planned cross-border smart city in the Tabuk Province of northwestern Saudi Arabia, a key project of Vision 2030." |
|
}, |
|
{ |
|
"question": "ما هو مشروع البحر الأحمر؟", |
|
"lang": "ar", |
|
"reference_answer": "مشروع البحر الأحمر هو مبادرة رؤية 2030 لتطوير وجهات سياحية فاخرة عبر 50 جزيرة قبالة ساحل البحر الأحمر السعودي." |
|
}, |
|
{ |
|
"question": "ما هي الثروة الحقيقية التي تعتز بها المملكة كما وردت في الرؤية؟", |
|
"lang": "ar", |
|
"reference_answer": "الثروة الحقيقية للمملكة العربية السعودية، كما أكدت رؤية 2030، هي شعبها، وخاصة الشباب." |
|
}, |
|
{ |
|
"question": "كيف تسعى المملكة إلى تعزيز مكانتها كبوابة للعالم؟", |
|
"lang": "ar", |
|
"reference_answer": "تهدف المملكة العربية السعودية إلى تعزيز مكانتها كبوابة عالمية من خلال الاستفادة من موقعها الاستراتيجي بين آسيا وأوروبا وأفريقيا." |
|
} |
|
] |
|
logger.info(f"Created {len(self.eval_data)} sample evaluation examples") |
|
|
|
@spaces.GPU |
|
def retrieve_context(self, query, lang): |
|
"""Retrieve relevant context for a query based on language""" |
|
start_time = time.time() |
|
|
|
try: |
|
if lang == "ar": |
|
if has_gpu and hasattr(self.arabic_embedder, 'to') and callable(getattr(self.arabic_embedder, 'to')): |
|
with torch.no_grad(): |
|
query_vec = self.arabic_embedder.encode(query) |
|
else: |
|
query_vec = self.arabic_embedder.encode(query) |
|
|
|
D, I = self.arabic_index.search(np.array([query_vec]), k=2) |
|
context = "\n".join([self.arabic_texts[i] for i in I[0] if i < len(self.arabic_texts) and i >= 0]) |
|
else: |
|
if has_gpu and hasattr(self.english_embedder, 'to') and callable(getattr(self.english_embedder, 'to')): |
|
with torch.no_grad(): |
|
query_vec = self.english_embedder.encode(query) |
|
else: |
|
query_vec = self.english_embedder.encode(query) |
|
|
|
D, I = self.english_index.search(np.array([query_vec]), k=2) |
|
context = "\n".join([self.english_texts[i] for i in I[0] if i < len(self.english_texts) and i >= 0]) |
|
|
|
retrieval_time = time.time() - start_time |
|
logger.info(f"Retrieved context in {retrieval_time:.2f}s") |
|
|
|
return context |
|
except Exception as e: |
|
logger.error(f"Error retrieving context: {str(e)}") |
|
return "" |
|
|
|
def _basic_generate_response(self, user_input): |
|
"""Basic response generation with retrieval-based approach""" |
|
if not user_input or user_input.strip() == "": |
|
return "" |
|
|
|
start_time = time.time() |
|
|
|
|
|
default_response = { |
|
"en": "I apologize, but I couldn't process your request properly. Please try again.", |
|
"ar": "أعتذر، لم أتمكن من معالجة طلبك بشكل صحيح. الرجاء المحاولة مرة أخرى." |
|
} |
|
|
|
try: |
|
|
|
try: |
|
lang = detect(user_input) |
|
if lang != "ar": |
|
lang = "en" |
|
except: |
|
lang = "en" |
|
|
|
logger.info(f"Detected language: {lang}") |
|
|
|
|
|
context = self.retrieve_context(user_input, lang) |
|
|
|
|
|
if lang == "ar": |
|
if "ركائز" in user_input or "اركان" in user_input: |
|
reply = "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح." |
|
elif "نيوم" in user_input: |
|
reply = "نيوم هي مدينة ذكية مخططة عبر الحدود في مقاطعة تبوك شمال غرب المملكة العربية السعودية، وهي مشروع رئيسي من رؤية 2030." |
|
elif "البحر الأحمر" in user_input or "البحر الاحمر" in user_input: |
|
reply = "مشروع البحر الأحمر هو مبادرة رؤية 2030 لتطوير وجهات سياحية فاخرة عبر 50 جزيرة قبالة ساحل البحر الأحمر السعودي." |
|
elif "المرأة" in user_input or "النساء" in user_input: |
|
reply = "تهدف رؤية 2030 إلى زيادة مشاركة المرأة في القوى العاملة من 22٪ إلى 30٪." |
|
elif "القدية" in user_input: |
|
reply = "القدية هي مشروع ترفيهي ضخم يتم بناؤه في الرياض كجزء من رؤية 2030، ويهدف إلى أن يكون أكبر مدينة ترفيهية في العالم." |
|
elif "ماهي" in user_input or "ما هي" in user_input: |
|
reply = "رؤية 2030 هي الإطار الاستراتيجي للمملكة العربية السعودية للحد من الاعتماد على النفط وتنويع الاقتصاد وتطوير القطاعات العامة. الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح." |
|
else: |
|
|
|
reply = context if context else "لم أتمكن من العثور على معلومات كافية حول هذا السؤال." |
|
else: |
|
if "pillar" in user_input.lower() or "key" in user_input.lower(): |
|
reply = "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation." |
|
elif "neom" in user_input.lower(): |
|
reply = "NEOM is a planned cross-border smart city in the Tabuk Province of northwestern Saudi Arabia, a key project of Vision 2030." |
|
elif "red sea" in user_input.lower(): |
|
reply = "The Red Sea Project is a Vision 2030 initiative to develop luxury tourism destinations across 50 islands off Saudi Arabia's Red Sea coast." |
|
elif "women" in user_input.lower() or "female" in user_input.lower(): |
|
reply = "Vision 2030 aims to increase women's participation in the workforce from 22% to 30%." |
|
elif "qiddiya" in user_input.lower(): |
|
reply = "Qiddiya is an entertainment mega-project being built in Riyadh as part of Vision 2030, intended to be the world's largest entertainment city." |
|
elif "what is" in user_input.lower(): |
|
reply = "Vision 2030 is Saudi Arabia's strategic framework to reduce dependence on oil, diversify the economy, and develop public sectors. The key pillars are a vibrant society, a thriving economy, and an ambitious nation." |
|
else: |
|
|
|
reply = context if context else "I couldn't find enough information about this question." |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating response: {str(e)}") |
|
reply = default_response.get(lang, default_response["en"]) |
|
|
|
|
|
response_time = time.time() - start_time |
|
self.metrics["response_times"].append(response_time) |
|
|
|
logger.info(f"Generated response in {response_time:.2f}s") |
|
|
|
|
|
interaction = { |
|
"timestamp": datetime.now().isoformat(), |
|
"user_input": user_input, |
|
"response": reply, |
|
"language": lang, |
|
"response_time": response_time |
|
} |
|
self.response_history.append(interaction) |
|
|
|
return reply |
|
|
|
def generate_response(self, user_input): |
|
"""Enhanced response generation with specific question handling""" |
|
if not user_input or user_input.strip() == "": |
|
return "" |
|
|
|
start_time = time.time() |
|
|
|
try: |
|
|
|
try: |
|
lang = detect(user_input) |
|
if lang != "ar": |
|
lang = "en" |
|
except: |
|
lang = "en" |
|
|
|
logger.info(f"Detected language: {lang}") |
|
|
|
|
|
if lang == "ar": |
|
|
|
if "الثروة الحقيقية" in user_input or "أثمن" in user_input or "ثروة" in user_input: |
|
response = self.vision2030_knowledge["real_wealth"]["ar"][0] |
|
|
|
self._record_metrics(user_input, response, lang, start_time) |
|
return response |
|
|
|
|
|
if "بوابة للعالم" in user_input or "مكانتها" in user_input or "موقعها الاستراتيجي" in user_input or "تعزيز مكانتها" in user_input: |
|
response = self.vision2030_knowledge["global_gateway"]["ar"][0] |
|
|
|
self._record_metrics(user_input, response, lang, start_time) |
|
return response |
|
|
|
|
|
if "الشباب" in user_input: |
|
response = self.vision2030_knowledge["youth"]["ar"][0] |
|
|
|
self._record_metrics(user_input, response, lang, start_time) |
|
return response |
|
|
|
|
|
if "المرأة" in user_input or "النساء" in user_input: |
|
response = self.vision2030_knowledge["women"]["ar"][0] |
|
|
|
self._record_metrics(user_input, response, lang, start_time) |
|
return response |
|
|
|
|
|
if "سياحة" in user_input or "السياحة" in user_input: |
|
response = self.vision2030_knowledge["tourism"]["ar"][0] |
|
|
|
self._record_metrics(user_input, response, lang, start_time) |
|
return response |
|
else: |
|
|
|
|
|
if "real wealth" in user_input.lower() or "valuable asset" in user_input.lower(): |
|
response = self.vision2030_knowledge["real_wealth"]["en"][0] |
|
|
|
self._record_metrics(user_input, response, lang, start_time) |
|
return response |
|
|
|
|
|
if "global gateway" in user_input.lower() or "strategic location" in user_input.lower(): |
|
response = self.vision2030_knowledge["global_gateway"]["en"][0] |
|
|
|
self._record_metrics(user_input, response, lang, start_time) |
|
return response |
|
|
|
|
|
if "youth" in user_input.lower() or "young" in user_input.lower(): |
|
response = self.vision2030_knowledge["youth"]["en"][0] |
|
|
|
self._record_metrics(user_input, response, lang, start_time) |
|
return response |
|
|
|
|
|
if "women" in user_input.lower() or "female" in user_input.lower(): |
|
response = self.vision2030_knowledge["women"]["en"][0] |
|
|
|
self._record_metrics(user_input, response, lang, start_time) |
|
return response |
|
|
|
|
|
if "tourism" in user_input.lower() or "tourist" in user_input.lower(): |
|
response = self.vision2030_knowledge["tourism"]["en"][0] |
|
|
|
self._record_metrics(user_input, response, lang, start_time) |
|
return response |
|
|
|
|
|
response = self.original_generate_response(user_input) |
|
return response |
|
|
|
except Exception as e: |
|
logger.error(f"Error in enhanced generation: {str(e)}") |
|
|
|
return self.original_generate_response(user_input) |
|
|
|
def _record_metrics(self, user_input, response, lang, start_time): |
|
"""Record metrics for a generated response""" |
|
|
|
response_time = time.time() - start_time |
|
self.metrics["response_times"].append(response_time) |
|
|
|
logger.info(f"Generated response in {response_time:.2f}s") |
|
|
|
|
|
interaction = { |
|
"timestamp": datetime.now().isoformat(), |
|
"user_input": user_input, |
|
"response": response, |
|
"language": lang, |
|
"response_time": response_time |
|
} |
|
self.response_history.append(interaction) |
|
|
|
def evaluate_factual_accuracy(self, response, reference): |
|
"""Simple evaluation of factual accuracy by keyword matching""" |
|
|
|
keywords_reference = set(re.findall(r'\b\w+\b', reference.lower())) |
|
keywords_response = set(re.findall(r'\b\w+\b', response.lower())) |
|
|
|
|
|
english_stopwords = {"the", "is", "a", "an", "and", "or", "of", "to", "in", "for", "with", "by", "on", "at"} |
|
arabic_stopwords = {"في", "من", "إلى", "على", "و", "هي", "هو", "عن", "مع"} |
|
|
|
keywords_reference = {w for w in keywords_reference if w not in english_stopwords and w not in arabic_stopwords} |
|
keywords_response = {w for w in keywords_response if w not in english_stopwords and w not in arabic_stopwords} |
|
|
|
common_keywords = keywords_reference.intersection(keywords_response) |
|
|
|
if len(keywords_reference) > 0: |
|
accuracy = len(common_keywords) / len(keywords_reference) |
|
else: |
|
accuracy = 0 |
|
|
|
return accuracy |
|
|
|
@spaces.GPU |
|
def evaluate_on_test_set(self): |
|
"""Evaluate the assistant on the test set""" |
|
logger.info("Running evaluation on test set") |
|
|
|
eval_results = [] |
|
|
|
for example in self.eval_data: |
|
|
|
response = self.generate_response(example["question"]) |
|
|
|
|
|
accuracy = self.evaluate_factual_accuracy(response, example["reference_answer"]) |
|
|
|
eval_results.append({ |
|
"question": example["question"], |
|
"reference": example["reference_answer"], |
|
"response": response, |
|
"factual_accuracy": accuracy |
|
}) |
|
|
|
self.metrics["factual_accuracy"].append(accuracy) |
|
|
|
|
|
avg_accuracy = sum(self.metrics["factual_accuracy"]) / len(self.metrics["factual_accuracy"]) if self.metrics["factual_accuracy"] else 0 |
|
avg_response_time = sum(self.metrics["response_times"]) / len(self.metrics["response_times"]) if self.metrics["response_times"] else 0 |
|
|
|
results = { |
|
"average_factual_accuracy": avg_accuracy, |
|
"average_response_time": avg_response_time, |
|
"detailed_results": eval_results |
|
} |
|
|
|
logger.info(f"Evaluation results: Factual accuracy = {avg_accuracy:.2f}, Avg response time = {avg_response_time:.2f}s") |
|
|
|
return results |
|
|
|
def visualize_evaluation_results(self, results): |
|
"""Generate visualization of evaluation results""" |
|
|
|
df = pd.DataFrame(results["detailed_results"]) |
|
|
|
|
|
fig = plt.figure(figsize=(12, 8)) |
|
|
|
|
|
plt.subplot(2, 1, 1) |
|
bars = plt.bar(range(len(df)), df["factual_accuracy"], color="skyblue") |
|
plt.axhline(y=results["average_factual_accuracy"], color='r', linestyle='-', |
|
label=f"Avg: {results['average_factual_accuracy']:.2f}") |
|
plt.xlabel("Question Index") |
|
plt.ylabel("Factual Accuracy") |
|
plt.title("Factual Accuracy by Question") |
|
plt.ylim(0, 1.1) |
|
plt.legend() |
|
|
|
|
|
df["language"] = df["question"].apply(lambda x: "Arabic" if detect(x) == "ar" else "English") |
|
|
|
|
|
lang_accuracy = df.groupby("language")["factual_accuracy"].mean() |
|
|
|
|
|
plt.subplot(2, 1, 2) |
|
lang_bars = plt.bar(lang_accuracy.index, lang_accuracy.values, color=["lightblue", "lightgreen"]) |
|
plt.axhline(y=results["average_factual_accuracy"], color='r', linestyle='-', |
|
label=f"Overall: {results['average_factual_accuracy']:.2f}") |
|
plt.xlabel("Language") |
|
plt.ylabel("Average Factual Accuracy") |
|
plt.title("Factual Accuracy by Language") |
|
plt.ylim(0, 1.1) |
|
|
|
|
|
for i, v in enumerate(lang_accuracy): |
|
plt.text(i, v + 0.05, f"{v:.2f}", ha='center') |
|
|
|
plt.tight_layout() |
|
return fig |
|
|
|
def record_user_feedback(self, user_input, response, rating, feedback_text=""): |
|
"""Record user feedback for a response""" |
|
feedback = { |
|
"timestamp": datetime.now().isoformat(), |
|
"user_input": user_input, |
|
"response": response, |
|
"rating": rating, |
|
"feedback_text": feedback_text |
|
} |
|
|
|
self.metrics["user_ratings"].append(rating) |
|
|
|
|
|
logger.info(f"Recorded user feedback: rating={rating}") |
|
|
|
return True |
|
|
|
@spaces.GPU |
|
def process_uploaded_pdf(self, file): |
|
"""Process uploaded PDF and extract text content""" |
|
if file is None: |
|
return "No file uploaded. Please select a PDF file." |
|
|
|
try: |
|
logger.info(f"Processing uploaded file") |
|
|
|
|
|
file_stream = io.BytesIO(file) |
|
|
|
|
|
reader = PyPDF2.PdfReader(file_stream) |
|
|
|
|
|
full_text = "" |
|
for page_num in range(len(reader.pages)): |
|
page = reader.pages[page_num] |
|
extracted_text = page.extract_text() |
|
if extracted_text: |
|
full_text += extracted_text + "\n" |
|
|
|
if not full_text.strip(): |
|
return "The uploaded PDF doesn't contain extractable text. Please try another file." |
|
|
|
|
|
chunks = [chunk.strip() for chunk in re.split(r'\n\s*\n', full_text) if chunk.strip()] |
|
|
|
|
|
english_chunks = [] |
|
arabic_chunks = [] |
|
|
|
for chunk in chunks: |
|
try: |
|
lang = detect(chunk) |
|
if lang == "ar": |
|
arabic_chunks.append(chunk) |
|
else: |
|
english_chunks.append(chunk) |
|
except: |
|
|
|
english_chunks.append(chunk) |
|
|
|
|
|
self.english_texts.extend(english_chunks) |
|
self.arabic_texts.extend(arabic_chunks) |
|
|
|
|
|
self._create_indices() |
|
|
|
logger.info(f"Successfully processed PDF: {len(arabic_chunks)} Arabic chunks, {len(english_chunks)} English chunks") |
|
|
|
return f"✅ Successfully processed the PDF! Found {len(arabic_chunks)} Arabic and {len(english_chunks)} English text segments." |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing PDF: {str(e)}") |
|
return f"❌ Error processing the PDF: {str(e)}. Please try another file." |
|
|
|
|
|
def create_interface(): |
|
|
|
assistant = Vision2030Assistant() |
|
|
|
def chat(message, history): |
|
if not message or message.strip() == "": |
|
return history, "" |
|
|
|
|
|
reply = assistant.generate_response(message) |
|
|
|
|
|
history.append((message, reply)) |
|
|
|
return history, "" |
|
|
|
def provide_feedback(history, rating, feedback_text): |
|
|
|
if history and len(history) > 0: |
|
last_interaction = history[-1] |
|
assistant.record_user_feedback(last_interaction[0], last_interaction[1], rating, feedback_text) |
|
return f"Thank you for your feedback! (Rating: {rating}/5)" |
|
return "No conversation found to rate." |
|
|
|
@spaces.GPU |
|
def run_evaluation(): |
|
results = assistant.evaluate_on_test_set() |
|
|
|
|
|
summary = f""" |
|
Evaluation Results: |
|
------------------ |
|
Total questions evaluated: {len(results['detailed_results'])} |
|
Overall factual accuracy: {results['average_factual_accuracy']:.2f} |
|
Average response time: {results['average_response_time']:.4f} seconds |
|
|
|
Detailed Results: |
|
""" |
|
|
|
for i, result in enumerate(results['detailed_results']): |
|
summary += f"\nQ{i+1}: {result['question']}\n" |
|
summary += f"Reference: {result['reference']}\n" |
|
summary += f"Response: {result['response']}\n" |
|
summary += f"Accuracy: {result['factual_accuracy']:.2f}\n" |
|
summary += "-" * 40 + "\n" |
|
|
|
|
|
fig = assistant.visualize_evaluation_results(results) |
|
|
|
return summary, fig |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Vision 2030 Virtual Assistant 🌟") |
|
gr.Markdown("Ask questions about Saudi Arabia's Vision 2030 in both Arabic and English") |
|
|
|
with gr.Tab("Chat"): |
|
chatbot = gr.Chatbot(height=400) |
|
msg = gr.Textbox(label="Your Question", placeholder="Ask about Vision 2030...") |
|
with gr.Row(): |
|
submit_btn = gr.Button("Submit") |
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
gr.Markdown("### Provide Feedback") |
|
with gr.Row(): |
|
rating = gr.Slider(minimum=1, maximum=5, step=1, value=3, label="Rate the Response (1-5)") |
|
feedback_text = gr.Textbox(label="Additional Comments (Optional)") |
|
feedback_btn = gr.Button("Submit Feedback") |
|
feedback_result = gr.Textbox(label="Feedback Status") |
|
|
|
with gr.Tab("Evaluation"): |
|
evaluate_btn = gr.Button("Run Evaluation on Test Set") |
|
eval_output = gr.Textbox(label="Evaluation Results", lines=20) |
|
eval_chart = gr.Plot(label="Evaluation Metrics") |
|
|
|
with gr.Tab("Upload PDF"): |
|
gr.Markdown(""" |
|
### Upload a Vision 2030 PDF Document |
|
Upload a PDF document to enhance the assistant's knowledge base. |
|
""") |
|
|
|
with gr.Row(): |
|
file_input = gr.File( |
|
label="Select PDF File", |
|
file_types=[".pdf"], |
|
type="binary" |
|
) |
|
|
|
with gr.Row(): |
|
upload_btn = gr.Button("Process PDF", variant="primary") |
|
|
|
with gr.Row(): |
|
upload_status = gr.Textbox( |
|
label="Upload Status", |
|
placeholder="Upload status will appear here...", |
|
interactive=False |
|
) |
|
|
|
gr.Markdown(""" |
|
### Notes: |
|
- The PDF should contain text that can be extracted (not scanned images) |
|
- After uploading, return to the Chat tab to ask questions about the uploaded content |
|
""") |
|
|
|
|
|
msg.submit(chat, [msg, chatbot], [chatbot, msg]) |
|
submit_btn.click(chat, [msg, chatbot], [chatbot, msg]) |
|
clear_btn.click(lambda: [], None, chatbot) |
|
feedback_btn.click(provide_feedback, [chatbot, rating, feedback_text], feedback_result) |
|
evaluate_btn.click(run_evaluation, None, [eval_output, eval_chart]) |
|
upload_btn.click(assistant.process_uploaded_pdf, [file_input], [upload_status]) |
|
|
|
return demo |
|
|
|
|
|
demo = create_interface() |
|
demo.launch() |