|
|
|
import sys |
|
import subprocess |
|
|
|
def install_package(package): |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) |
|
|
|
try: |
|
import sentencepiece |
|
print("SentencePiece is already installed") |
|
except ImportError: |
|
print("Installing SentencePiece...") |
|
install_package("sentencepiece==0.1.99") |
|
print("SentencePiece installed successfully") |
|
|
|
|
|
import gradio as gr |
|
import os |
|
import re |
|
import torch |
|
import numpy as np |
|
from pathlib import Path |
|
import PyPDF2 |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM |
|
from sentence_transformers import SentenceTransformer |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.schema import Document |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
import spaces |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
assistant = None |
|
model_type = "primary" |
|
|
|
|
|
class Vision2030Assistant: |
|
def __init__(self, model, tokenizer, vector_store, model_type="primary"): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.vector_store = vector_store |
|
self.model_type = model_type |
|
self.conversation_history = [] |
|
|
|
def answer(self, user_query): |
|
|
|
language = detect_language(user_query) |
|
|
|
|
|
self.conversation_history.append({"role": "user", "content": user_query}) |
|
|
|
|
|
conversation_context = "\n".join([ |
|
f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}" |
|
for msg in self.conversation_history[-6:] |
|
]) |
|
|
|
|
|
enhanced_query = f"{conversation_context}\n{user_query}" |
|
|
|
|
|
contexts = retrieve_context(enhanced_query, self.vector_store, top_k=5) |
|
|
|
|
|
if self.model_type == "primary": |
|
response = generate_response_primary(user_query, contexts, self.model, self.tokenizer, language) |
|
else: |
|
response = generate_response_fallback(user_query, contexts, self.model, self.tokenizer, language) |
|
|
|
|
|
self.conversation_history.append({"role": "assistant", "content": response}) |
|
|
|
|
|
sources = [ctx.get("source", "Unknown") for ctx in contexts] |
|
unique_sources = list(set(sources)) |
|
|
|
|
|
if unique_sources: |
|
source_text = "\n\nSources: " + ", ".join([os.path.basename(src) for src in unique_sources]) |
|
response_with_sources = response + source_text |
|
else: |
|
response_with_sources = response |
|
|
|
return response_with_sources |
|
|
|
def reset_conversation(self): |
|
"""Reset the conversation history""" |
|
self.conversation_history = [] |
|
return "Conversation has been reset." |
|
|
|
|
|
def detect_language(text): |
|
"""Detect if text is primarily Arabic or English""" |
|
arabic_chars = re.findall(r'[\u0600-\u06FF]', text) |
|
is_arabic = len(arabic_chars) > len(text) * 0.5 |
|
return "arabic" if is_arabic else "english" |
|
|
|
def retrieve_context(query, vector_store, top_k=5): |
|
"""Retrieve most relevant document chunks for a given query""" |
|
|
|
results = vector_store.similarity_search_with_score(query, k=top_k) |
|
|
|
|
|
contexts = [] |
|
for doc, score in results: |
|
contexts.append({ |
|
"content": doc.page_content, |
|
"source": doc.metadata.get("source", "Unknown"), |
|
"relevance_score": score |
|
}) |
|
|
|
return contexts |
|
|
|
@spaces.GPU |
|
def generate_response_primary(query, contexts, model, tokenizer, language="auto"): |
|
"""Generate a response using ALLaM model""" |
|
|
|
if language == "auto": |
|
language = detect_language(query) |
|
|
|
|
|
if language == "arabic": |
|
instruction = ( |
|
"أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. " |
|
"إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف." |
|
) |
|
else: |
|
instruction = ( |
|
"You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. " |
|
"If you don't know the answer, honestly say you don't know." |
|
) |
|
|
|
|
|
context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts]) |
|
|
|
|
|
prompt = f"""<s>[INST] {instruction} |
|
|
|
Context: |
|
{context_text} |
|
|
|
Question: {query} [/INST]</s>""" |
|
|
|
try: |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
outputs = model.generate( |
|
inputs.input_ids, |
|
attention_mask=inputs.attention_mask, |
|
max_new_tokens=512, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
|
|
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
response = full_output.split("[/INST]")[-1].strip() |
|
|
|
|
|
if not response: |
|
response = full_output |
|
|
|
return response |
|
|
|
except Exception as e: |
|
print(f"Error during generation: {e}") |
|
|
|
return "I apologize, but I encountered an error while generating a response." |
|
|
|
@spaces.GPU |
|
def generate_response_fallback(query, contexts, model, tokenizer, language="auto"): |
|
"""Generate a response using the fallback model (BLOOM or mBART)""" |
|
|
|
if language == "auto": |
|
language = detect_language(query) |
|
|
|
|
|
if language == "arabic": |
|
system_prompt = ( |
|
"أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم السياق التالي للإجابة على السؤال: " |
|
) |
|
else: |
|
system_prompt = ( |
|
"You are a virtual assistant for Saudi Vision 2030. Use the following context to answer the question: " |
|
) |
|
|
|
|
|
context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts]) |
|
|
|
|
|
prompt = f"{system_prompt}\n\nContext:\n{context_text}\n\nQuestion: {query}\n\nAnswer:" |
|
|
|
try: |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(model.device) |
|
|
|
outputs = model.generate( |
|
inputs.input_ids, |
|
attention_mask=inputs.attention_mask, |
|
max_length=inputs.input_ids.shape[1] + 512, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
|
|
|
|
|
return response.strip() |
|
|
|
except Exception as e: |
|
print(f"Error during fallback generation: {e}") |
|
return "I apologize, but I encountered an error while generating a response with the fallback model." |
|
|
|
def process_pdf_files(pdf_files): |
|
"""Process PDF files and create documents""" |
|
documents = [] |
|
|
|
for pdf_file in pdf_files: |
|
try: |
|
|
|
temp_path = f"temp_{pdf_file.name}" |
|
with open(temp_path, "wb") as f: |
|
f.write(pdf_file.read()) |
|
|
|
|
|
text = "" |
|
with open(temp_path, 'rb') as file: |
|
reader = PyPDF2.PdfReader(file) |
|
for page in reader.pages: |
|
page_text = page.extract_text() |
|
if page_text: |
|
text += page_text + "\n\n" |
|
|
|
|
|
os.remove(temp_path) |
|
|
|
if text.strip(): |
|
doc = Document( |
|
page_content=text, |
|
metadata={"source": pdf_file.name, "filename": pdf_file.name} |
|
) |
|
documents.append(doc) |
|
print(f"Successfully processed: {pdf_file.name}") |
|
else: |
|
print(f"Warning: No text extracted from {pdf_file.name}") |
|
except Exception as e: |
|
print(f"Error processing {pdf_file.name}: {e}") |
|
|
|
print(f"Processed {len(documents)} PDF documents") |
|
return documents |
|
|
|
def create_vector_store(documents): |
|
"""Create a vector store from documents""" |
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=500, |
|
chunk_overlap=50, |
|
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""] |
|
) |
|
|
|
|
|
chunks = [] |
|
for doc in documents: |
|
doc_chunks = text_splitter.split_text(doc.page_content) |
|
|
|
chunks.extend([ |
|
Document(page_content=chunk, metadata=doc.metadata) |
|
for chunk in doc_chunks |
|
]) |
|
|
|
print(f"Created {len(chunks)} chunks from {len(documents)} documents") |
|
|
|
|
|
embedding_function = HuggingFaceEmbeddings( |
|
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" |
|
) |
|
|
|
|
|
vector_store = FAISS.from_documents(chunks, embedding_function) |
|
return vector_store |
|
|
|
|
|
def create_mock_documents(): |
|
"""Create mock documents about Vision 2030""" |
|
documents = [] |
|
|
|
|
|
samples = [ |
|
{ |
|
"content": "رؤية السعودية 2030 هي خطة استراتيجية تهدف إلى تنويع الاقتصاد السعودي وتقليل الاعتماد على النفط مع تطوير قطاعات مختلفة مثل الصحة والتعليم والسياحة.", |
|
"source": "vision2030_overview_ar.txt" |
|
}, |
|
{ |
|
"content": "Saudi Vision 2030 is a strategic framework aiming to diversify Saudi Arabia's economy and reduce dependence on oil, while developing sectors like health, education, and tourism.", |
|
"source": "vision2030_overview_en.txt" |
|
}, |
|
{ |
|
"content": "تشمل الأهداف الاقتصادية لرؤية 2030 زيادة مساهمة القطاع الخاص من 40% إلى 65% من الناتج المحلي الإجمالي، ورفع نسبة الصادرات غير النفطية من 16% إلى 50% من الناتج المحلي الإجمالي غير النفطي، وخفض البطالة إلى 7%.", |
|
"source": "economic_goals_ar.txt" |
|
}, |
|
{ |
|
"content": "The economic goals of Vision 2030 include increasing private sector contribution from 40% to 65% of GDP, raising non-oil exports from 16% to 50%, and reducing unemployment from 11.6% to 7%.", |
|
"source": "economic_goals_en.txt" |
|
}, |
|
{ |
|
"content": "تركز رؤية 2030 على زيادة مشاركة المرأة في سوق العمل من 22% إلى 30% بحلول عام 2030، مع توفير فرص متساوية في التعليم والعمل.", |
|
"source": "women_empowerment_ar.txt" |
|
}, |
|
{ |
|
"content": "Vision 2030 emphasizes increasing women's participation in the workforce from 22% to 30% by 2030, while providing equal opportunities in education and employment.", |
|
"source": "women_empowerment_en.txt" |
|
} |
|
] |
|
|
|
|
|
for sample in samples: |
|
doc = Document( |
|
page_content=sample["content"], |
|
metadata={"source": sample["source"], "filename": sample["source"]} |
|
) |
|
documents.append(doc) |
|
|
|
print(f"Created {len(documents)} mock documents") |
|
return documents |
|
|
|
@spaces.GPU |
|
def load_primary_model(): |
|
"""Load the ALLaM-7B model with error handling""" |
|
global model, tokenizer, model_type |
|
|
|
if model is not None and tokenizer is not None and model_type == "primary": |
|
return "Primary model (ALLaM-7B) already loaded" |
|
|
|
model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview" |
|
print(f"Loading primary model: {model_name}") |
|
|
|
try: |
|
|
|
import sentencepiece as spm |
|
print("SentencePiece imported successfully") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
use_fast=False |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
) |
|
|
|
model_type = "primary" |
|
return "Primary model (ALLaM-7B) loaded successfully!" |
|
|
|
except Exception as e: |
|
error_msg = f"Primary model loading failed: {e}" |
|
print(error_msg) |
|
return error_msg |
|
|
|
@spaces.GPU |
|
def load_fallback_model(): |
|
"""Load the fallback model (BLOOM-7B1) when ALLaM fails""" |
|
global model, tokenizer, model_type |
|
|
|
if model is not None and tokenizer is not None and model_type == "fallback": |
|
return "Fallback model already loaded" |
|
|
|
try: |
|
print("Loading fallback model: BLOOM-7B1...") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-7b1") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"bigscience/bloom-7b1", |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
load_in_8bit=True |
|
) |
|
|
|
model_type = "fallback" |
|
return "Fallback model (BLOOM-7B1) loaded successfully!" |
|
except Exception as e: |
|
return f"Fallback model loading failed: {e}" |
|
|
|
def load_mbart_model(): |
|
"""Load mBART as a second fallback option""" |
|
global model, tokenizer, model_type |
|
|
|
try: |
|
print("Loading mBART multilingual model...") |
|
|
|
model_name = "facebook/mbart-large-50-many-to-many-mmt" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
load_in_8bit=True |
|
) |
|
|
|
model_type = "mbart" |
|
return "mBART multilingual model loaded successfully!" |
|
except Exception as e: |
|
return f"mBART model loading failed: {e}" |
|
|
|
|
|
def process_pdfs(pdf_files): |
|
if not pdf_files: |
|
return "No files uploaded. Please upload PDF documents about Vision 2030." |
|
|
|
documents = process_pdf_files(pdf_files) |
|
|
|
if not documents: |
|
return "Failed to extract text from the uploaded PDFs." |
|
|
|
global assistant, model, tokenizer |
|
|
|
|
|
if model is None or tokenizer is None: |
|
return "Please load a model first (primary or fallback) before processing documents." |
|
|
|
|
|
vector_store = create_vector_store(documents) |
|
|
|
|
|
assistant = Vision2030Assistant(model, tokenizer, vector_store, model_type) |
|
|
|
return f"Successfully processed {len(documents)} documents. The assistant is ready to use!" |
|
|
|
def use_mock_documents(): |
|
"""Use mock documents when no PDFs are available""" |
|
documents = create_mock_documents() |
|
|
|
global assistant, model, tokenizer |
|
|
|
|
|
if model is None or tokenizer is None: |
|
return "Please load a model first (primary or fallback) before using mock documents." |
|
|
|
|
|
vector_store = create_vector_store(documents) |
|
|
|
|
|
assistant = Vision2030Assistant(model, tokenizer, vector_store, model_type) |
|
|
|
return "Successfully initialized with mock Vision 2030 documents. The assistant is ready for testing!" |
|
|
|
@spaces.GPU |
|
def answer_query(message, history): |
|
global assistant |
|
|
|
if assistant is None: |
|
return [(message, "Please load a model and process documents first (or use mock documents for testing).")] |
|
|
|
response = assistant.answer(message) |
|
history.append((message, response)) |
|
return history |
|
|
|
def reset_chat(): |
|
global assistant |
|
|
|
if assistant is None: |
|
return "No active conversation to reset." |
|
|
|
reset_message = assistant.reset_conversation() |
|
return reset_message |
|
|
|
def restart_factory(): |
|
return "Restarting the application... Please reload the page in a few seconds." |
|
|
|
|
|
with gr.Blocks(title="Vision 2030 Virtual Assistant") as demo: |
|
gr.Markdown("# Vision 2030 Virtual Assistant") |
|
gr.Markdown("Ask questions about Saudi Vision 2030 goals, projects, and progress in Arabic or English.") |
|
|
|
with gr.Tab("Setup"): |
|
gr.Markdown("## Step 1: Load a Model") |
|
with gr.Row(): |
|
with gr.Column(): |
|
primary_btn = gr.Button("Load ALLaM-7B Model (Primary)", variant="primary") |
|
primary_output = gr.Textbox(label="Primary Model Status") |
|
primary_btn.click(load_primary_model, inputs=[], outputs=primary_output) |
|
|
|
with gr.Column(): |
|
fallback_btn = gr.Button("Load BLOOM-7B1 (Fallback)", variant="secondary") |
|
fallback_output = gr.Textbox(label="Fallback Model Status") |
|
fallback_btn.click(load_fallback_model, inputs=[], outputs=fallback_output) |
|
|
|
with gr.Column(): |
|
mbart_btn = gr.Button("Load mBART (Alternative)", variant="secondary") |
|
mbart_output = gr.Textbox(label="mBART Model Status") |
|
mbart_btn.click(load_mbart_model, inputs=[], outputs=mbart_output) |
|
|
|
gr.Markdown("## Step 2: Prepare Documents") |
|
with gr.Row(): |
|
with gr.Column(): |
|
pdf_files = gr.File(file_types=[".pdf"], file_count="multiple", label="Upload PDF Documents") |
|
process_btn = gr.Button("Process Documents", variant="primary") |
|
process_output = gr.Textbox(label="Processing Status") |
|
process_btn.click(process_pdfs, inputs=[pdf_files], outputs=process_output) |
|
|
|
with gr.Column(): |
|
mock_btn = gr.Button("Use Mock Documents (for testing)", variant="secondary") |
|
mock_output = gr.Textbox(label="Mock Documents Status") |
|
mock_btn.click(use_mock_documents, inputs=[], outputs=mock_output) |
|
|
|
gr.Markdown("## Troubleshooting") |
|
restart_btn = gr.Button("Restart Application", variant="secondary") |
|
restart_output = gr.Textbox(label="Restart Status") |
|
restart_btn.click(restart_factory, inputs=[], outputs=restart_output) |
|
restart_btn.click(None, [], None, _js="() => {setTimeout(() => {location.reload()}, 5000)}") |
|
|
|
with gr.Tab("Chat"): |
|
chatbot = gr.Chatbot(label="Conversation", height=500) |
|
|
|
with gr.Row(): |
|
message = gr.Textbox( |
|
label="Ask a question about Vision 2030 (in Arabic or English)", |
|
placeholder="What are the main goals of Vision 2030?", |
|
lines=2 |
|
) |
|
submit_btn = gr.Button("Submit", variant="primary") |
|
|
|
reset_btn = gr.Button("Reset Conversation") |
|
|
|
gr.Markdown("### Example Questions") |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("**English Questions:**") |
|
en_examples = gr.Examples( |
|
examples=[ |
|
"What is Saudi Vision 2030?", |
|
"What are the economic goals of Vision 2030?", |
|
"How does Vision 2030 support women's empowerment?", |
|
"What environmental initiatives are part of Vision 2030?", |
|
"What is the role of the Public Investment Fund in Vision 2030?" |
|
], |
|
inputs=message |
|
) |
|
|
|
with gr.Column(): |
|
gr.Markdown("**Arabic Questions:**") |
|
ar_examples = gr.Examples( |
|
examples=[ |
|
"ما هي رؤية السعودية 2030؟", |
|
"ما هي الأهداف الاقتصادية لرؤية 2030؟", |
|
"كيف تدعم رؤية 2030 تمكين المرأة السعودية؟", |
|
"ما هي مبادرات رؤية 2030 للحفاظ على البيئة؟", |
|
"ما هي استراتيجية صندوق الاستثمارات العامة في رؤية 2030؟" |
|
], |
|
inputs=message |
|
) |
|
|
|
reset_output = gr.Textbox(label="Reset Status", visible=False) |
|
submit_btn.click(answer_query, inputs=[message, chatbot], outputs=[chatbot]) |
|
message.submit(answer_query, inputs=[message, chatbot], outputs=[chatbot]) |
|
reset_btn.click(reset_chat, inputs=[], outputs=[reset_output]) |
|
reset_btn.click(lambda: None, inputs=[], outputs=[chatbot], postprocess=lambda: []) |
|
|
|
|
|
demo.launch() |