|
import os |
|
import re |
|
import json |
|
from tqdm import tqdm |
|
from pathlib import Path |
|
import spaces |
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
def safe_tokenize(text): |
|
"""Pure regex tokenizer with no NLTK dependency""" |
|
if not text: |
|
return [] |
|
|
|
text = re.sub(r'([.,!?;:()\[\]{}"\'/\\])', r' \1 ', text) |
|
|
|
return [token for token in re.split(r'\s+', text.lower()) if token] |
|
|
|
def detect_language(text): |
|
"""Detect if text is primarily Arabic or English""" |
|
|
|
arabic_chars = re.findall(r'[\u0600-\u06FF]', text) |
|
is_arabic = len(arabic_chars) > len(text) * 0.5 |
|
return "arabic" if is_arabic else "english" |
|
|
|
|
|
comprehensive_evaluation_data = [ |
|
|
|
{ |
|
"query": "ما هي رؤية السعودية 2030؟", |
|
"reference": "رؤية السعودية 2030 هي خطة استراتيجية تهدف إلى تنويع الاقتصاد السعودي وتقليل الاعتماد على النفط مع تطوير قطاعات مختلفة مثل الصحة والتعليم والسياحة.", |
|
"category": "overview", |
|
"language": "arabic" |
|
}, |
|
{ |
|
"query": "What is Saudi Vision 2030?", |
|
"reference": "Saudi Vision 2030 is a strategic framework aiming to diversify Saudi Arabia's economy and reduce dependence on oil, while developing sectors like health, education, and tourism.", |
|
"category": "overview", |
|
"language": "english" |
|
}, |
|
|
|
|
|
{ |
|
"query": "ما هي الأهداف الاقتصادية لرؤية 2030؟", |
|
"reference": "تشمل الأهداف الاقتصادية زيادة مساهمة القطاع الخاص إلى 65%، وزيادة الصادرات غير النفطية إلى 50% من الناتج المحلي غير النفطي، وخفض البطالة إلى 7%.", |
|
"category": "economic", |
|
"language": "arabic" |
|
}, |
|
{ |
|
"query": "What are the economic goals of Vision 2030?", |
|
"reference": "The economic goals of Vision 2030 include increasing private sector contribution from 40% to 65% of GDP, raising non-oil exports from 16% to 50%, reducing unemployment from 11.6% to 7%.", |
|
"category": "economic", |
|
"language": "english" |
|
}, |
|
|
|
|
|
{ |
|
"query": "كيف تعزز رؤية 2030 الإرث الثقافي السعودي؟", |
|
"reference": "تتضمن رؤية 2030 الحفاظ على الهوية الوطنية، تسجيل مواقع أثرية في اليونسكو، وتعزيز الفعاليات الثقافية.", |
|
"category": "social", |
|
"language": "arabic" |
|
}, |
|
{ |
|
"query": "How does Vision 2030 aim to improve quality of life?", |
|
"reference": "Vision 2030 plans to enhance quality of life by expanding sports facilities, promoting cultural activities, and boosting tourism and entertainment sectors.", |
|
"category": "social", |
|
"language": "english" |
|
} |
|
] |
|
|
|
|
|
class Vision2030Service: |
|
def __init__(self): |
|
self.initialized = False |
|
self.model = None |
|
self.tokenizer = None |
|
self.vector_store = None |
|
self.conversation_history = [] |
|
|
|
@spaces.GPU |
|
def initialize(self): |
|
"""Initialize the system - ALL GPU operations must happen here""" |
|
if self.initialized: |
|
return True |
|
|
|
try: |
|
|
|
import torch |
|
import PyPDF2 |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from sentence_transformers import SentenceTransformer |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.schema import Document |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
|
|
|
pdf_files = ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"] |
|
|
|
|
|
vector_store_dir = "vector_stores" |
|
os.makedirs(vector_store_dir, exist_ok=True) |
|
|
|
if os.path.exists(os.path.join(vector_store_dir, "index.faiss")): |
|
print("Loading existing vector store...") |
|
embedding_function = HuggingFaceEmbeddings( |
|
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" |
|
) |
|
self.vector_store = FAISS.load_local(vector_store_dir, embedding_function) |
|
else: |
|
print("Creating new vector store...") |
|
|
|
documents = [] |
|
for pdf_path in pdf_files: |
|
if not os.path.exists(pdf_path): |
|
print(f"Warning: {pdf_path} does not exist") |
|
continue |
|
|
|
print(f"Processing {pdf_path}...") |
|
text = "" |
|
with open(pdf_path, 'rb') as file: |
|
reader = PyPDF2.PdfReader(file) |
|
for page in reader.pages: |
|
page_text = page.extract_text() |
|
if page_text: |
|
text += page_text + "\n\n" |
|
|
|
if text.strip(): |
|
doc = Document( |
|
page_content=text, |
|
metadata={"source": pdf_path, "filename": os.path.basename(pdf_path)} |
|
) |
|
documents.append(doc) |
|
|
|
if not documents: |
|
raise ValueError("No documents were processed successfully.") |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=500, |
|
chunk_overlap=50, |
|
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""] |
|
) |
|
|
|
chunks = [] |
|
for doc in documents: |
|
doc_chunks = text_splitter.split_text(doc.page_content) |
|
chunks.extend([ |
|
Document(page_content=chunk, metadata=doc.metadata) |
|
for chunk in doc_chunks |
|
]) |
|
|
|
|
|
embedding_function = HuggingFaceEmbeddings( |
|
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" |
|
) |
|
self.vector_store = FAISS.from_documents(chunks, embedding_function) |
|
self.vector_store.save_local(vector_store_dir) |
|
|
|
|
|
model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview" |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
use_fast=False |
|
) |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
) |
|
|
|
self.initialized = True |
|
return True |
|
|
|
except Exception as e: |
|
import traceback |
|
print(f"Initialization error: {e}") |
|
print(traceback.format_exc()) |
|
return False |
|
|
|
@spaces.GPU |
|
def retrieve_context(self, query, top_k=5): |
|
"""Retrieve contexts from vector store""" |
|
|
|
|
|
if not self.initialized: |
|
return [] |
|
|
|
try: |
|
results = self.vector_store.similarity_search_with_score(query, k=top_k) |
|
|
|
contexts = [] |
|
for doc, score in results: |
|
contexts.append({ |
|
"content": doc.page_content, |
|
"source": doc.metadata.get("source", "Unknown"), |
|
"relevance_score": score |
|
}) |
|
|
|
return contexts |
|
except Exception as e: |
|
print(f"Error retrieving context: {e}") |
|
return [] |
|
|
|
@spaces.GPU |
|
def generate_response(self, query, contexts, language="auto"): |
|
"""Generate response using the model""" |
|
|
|
import torch |
|
|
|
if not self.initialized or self.model is None or self.tokenizer is None: |
|
return "I'm still initializing. Please try again in a moment." |
|
|
|
try: |
|
|
|
if language == "auto": |
|
language = detect_language(query) |
|
|
|
|
|
if language == "arabic": |
|
instruction = ( |
|
"أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. " |
|
"إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف." |
|
) |
|
else: |
|
instruction = ( |
|
"You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. " |
|
"If you don't know the answer, honestly say you don't know." |
|
) |
|
|
|
|
|
context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts]) |
|
|
|
|
|
prompt = f"""<s>[INST] {instruction} |
|
|
|
Context: |
|
{context_text} |
|
|
|
Question: {query} [/INST]</s>""" |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
|
|
outputs = self.model.generate( |
|
inputs.input_ids, |
|
attention_mask=inputs.attention_mask, |
|
max_new_tokens=512, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
|
|
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
response = full_output.split("[/INST]")[-1].strip() |
|
|
|
|
|
if not response: |
|
response = full_output |
|
|
|
return response |
|
|
|
except Exception as e: |
|
import traceback |
|
print(f"Error generating response: {e}") |
|
print(traceback.format_exc()) |
|
return f"Sorry, I encountered an error while generating a response." |
|
|
|
@spaces.GPU |
|
def answer_question(self, query): |
|
"""Process a user query and return a response with sources""" |
|
if not self.initialized: |
|
if not self.initialize(): |
|
return "System initialization failed. Please check the logs.", [] |
|
|
|
try: |
|
|
|
self.conversation_history.append({"role": "user", "content": query}) |
|
|
|
|
|
conversation_context = "\n".join([ |
|
f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}" |
|
for msg in self.conversation_history[-6:] |
|
]) |
|
|
|
|
|
enhanced_query = f"{conversation_context}\n{query}" |
|
|
|
|
|
contexts = self.retrieve_context(enhanced_query, top_k=5) |
|
|
|
|
|
response = self.generate_response(query, contexts) |
|
|
|
|
|
self.conversation_history.append({"role": "assistant", "content": response}) |
|
|
|
|
|
sources = [ctx.get("source", "Unknown") for ctx in contexts] |
|
unique_sources = list(set(sources)) |
|
|
|
return response, unique_sources |
|
except Exception as e: |
|
import traceback |
|
print(f"Error answering question: {e}") |
|
print(traceback.format_exc()) |
|
return f"Sorry, I encountered an error: {str(e)}", [] |
|
|
|
def reset_conversation(self): |
|
"""Reset the conversation history""" |
|
self.conversation_history = [] |
|
return "Conversation has been reset." |
|
|
|
|
|
def main(): |
|
|
|
service = Vision2030Service() |
|
|
|
|
|
with gr.Blocks(title="Vision 2030 Assistant") as demo: |
|
gr.Markdown("# Vision 2030 Assistant") |
|
gr.Markdown("Ask questions about Saudi Vision 2030 in English or Arabic") |
|
|
|
with gr.Tab("Chat"): |
|
chatbot = gr.Chatbot() |
|
msg = gr.Textbox(label="Your question", placeholder="Ask about Vision 2030...") |
|
clear = gr.Button("Clear History") |
|
|
|
@spaces.GPU |
|
def respond(message, history): |
|
if not message: |
|
return history, "" |
|
|
|
response, sources = service.answer_question(message) |
|
sources_text = ", ".join(sources) if sources else "No specific sources" |
|
|
|
|
|
full_response = f"{response}\n\nSources: {sources_text}" |
|
|
|
return history + [[message, full_response]], "" |
|
|
|
def reset_chat(): |
|
service.reset_conversation() |
|
return [], "Conversation history has been reset." |
|
|
|
msg.submit(respond, [msg, chatbot], [chatbot, msg]) |
|
clear.click(reset_chat, None, [chatbot, msg]) |
|
|
|
with gr.Tab("System Status"): |
|
init_btn = gr.Button("Initialize System") |
|
status_box = gr.Textbox(label="Status", value="System not initialized") |
|
|
|
@spaces.GPU |
|
def initialize_system(): |
|
success = service.initialize() |
|
if success: |
|
return "System initialized successfully!" |
|
else: |
|
return "System initialization failed. Check logs for details." |
|
|
|
init_btn.click(initialize_system, None, status_box) |
|
|
|
|
|
gr.Markdown("### PDF Status") |
|
pdf_btn = gr.Button("Check PDF Files") |
|
pdf_status = gr.Textbox(label="PDF Files") |
|
|
|
def check_pdfs(): |
|
result = [] |
|
for pdf_file in ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"]: |
|
if os.path.exists(pdf_file): |
|
size = os.path.getsize(pdf_file) / (1024 * 1024) |
|
result.append(f"{pdf_file}: Found ({size:.2f} MB)") |
|
else: |
|
result.append(f"{pdf_file}: Not found") |
|
return "\n".join(result) |
|
|
|
pdf_btn.click(check_pdfs, None, pdf_status) |
|
|
|
|
|
gr.Markdown("### Dependencies") |
|
sys_btn = gr.Button("Check Dependencies") |
|
sys_status = gr.Textbox(label="Dependencies Status") |
|
|
|
@spaces.GPU |
|
def check_dependencies(): |
|
result = [] |
|
|
|
|
|
try: |
|
import torch |
|
result.append(f"✓ PyTorch: {torch.__version__}") |
|
except ImportError: |
|
result.append("✗ PyTorch: Not installed") |
|
|
|
try: |
|
import transformers |
|
result.append(f"✓ Transformers: {transformers.__version__}") |
|
except ImportError: |
|
result.append("✗ Transformers: Not installed") |
|
|
|
try: |
|
import sentencepiece |
|
result.append("✓ SentencePiece: Installed") |
|
except ImportError: |
|
result.append("✗ SentencePiece: Not installed") |
|
|
|
try: |
|
import accelerate |
|
result.append(f"✓ Accelerate: {accelerate.__version__}") |
|
except ImportError: |
|
result.append("✗ Accelerate: Not installed") |
|
|
|
try: |
|
import langchain |
|
result.append(f"✓ LangChain: {langchain.__version__}") |
|
except ImportError: |
|
result.append("✗ LangChain: Not installed") |
|
|
|
try: |
|
import langchain_community |
|
result.append(f"✓ LangChain Community: {langchain_community.__version__}") |
|
except ImportError: |
|
result.append("✗ LangChain Community: Not installed") |
|
|
|
return "\n".join(result) |
|
|
|
sys_btn.click(check_dependencies, None, sys_status) |
|
|
|
with gr.Tab("Sample Questions"): |
|
gr.Markdown("### Sample Questions to Try") |
|
|
|
sample_questions = [] |
|
|
|
for item in comprehensive_evaluation_data: |
|
sample_questions.append(item["query"]) |
|
|
|
questions_md = "\n".join([f"- {q}" for q in sample_questions]) |
|
gr.Markdown(questions_md) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = main() |
|
demo.queue() |
|
demo.launch() |