|
import os |
|
import re |
|
import json |
|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
from tqdm import tqdm |
|
from pathlib import Path |
|
|
|
|
|
import PyPDF2 |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.schema import Document |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
|
|
|
import arabic_reshaper |
|
from bidi.algorithm import get_display |
|
|
|
|
|
from rouge_score import rouge_scorer |
|
import sacrebleu |
|
from sklearn.metrics import accuracy_score, precision_recall_fscore_support |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from collections import defaultdict |
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
def safe_tokenize(text): |
|
"""Pure regex tokenizer with no NLTK dependency""" |
|
if not text: |
|
return [] |
|
|
|
text = re.sub(r'([.,!?;:()\[\]{}"\'/\\])', r' \1 ', text) |
|
|
|
return [token for token in re.split(r'\s+', text.lower()) if token] |
|
|
|
def detect_language(text): |
|
"""Detect if text is primarily Arabic or English""" |
|
|
|
arabic_chars = re.findall(r'[\u0600-\u06FF]', text) |
|
is_arabic = len(arabic_chars) > len(text) * 0.5 |
|
return "arabic" if is_arabic else "english" |
|
|
|
|
|
def calculate_bleu(prediction, reference): |
|
"""Calculate BLEU score without any NLTK dependency""" |
|
|
|
pred_tokens = safe_tokenize(prediction.lower()) |
|
ref_tokens = [safe_tokenize(reference.lower())] |
|
|
|
|
|
if not pred_tokens or not ref_tokens[0]: |
|
return {"bleu_1": 0, "bleu_2": 0, "bleu_4": 0} |
|
|
|
|
|
def get_ngrams(tokens, n): |
|
return [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)] |
|
|
|
|
|
precisions = [] |
|
for n in range(1, 5): |
|
if len(pred_tokens) < n: |
|
precisions.append(0) |
|
continue |
|
|
|
pred_ngrams = get_ngrams(pred_tokens, n) |
|
ref_ngrams = get_ngrams(ref_tokens[0], n) |
|
|
|
|
|
matches = sum(1 for ng in pred_ngrams if ng in ref_ngrams) |
|
|
|
|
|
if pred_ngrams: |
|
precisions.append(matches / len(pred_ngrams)) |
|
else: |
|
precisions.append(0) |
|
|
|
|
|
return { |
|
"bleu_1": precisions[0], |
|
"bleu_2": (precisions[0] * precisions[1]) ** 0.5 if len(precisions) > 1 else 0, |
|
"bleu_4": (precisions[0] * precisions[1] * precisions[2] * precisions[3]) ** 0.25 if len(precisions) > 3 else 0 |
|
} |
|
|
|
def calculate_meteor(prediction, reference): |
|
"""Simple word overlap metric as METEOR alternative""" |
|
|
|
pred_tokens = set(safe_tokenize(prediction.lower())) |
|
ref_tokens = set(safe_tokenize(reference.lower())) |
|
|
|
|
|
if not pred_tokens or not ref_tokens: |
|
return 0 |
|
|
|
intersection = len(pred_tokens.intersection(ref_tokens)) |
|
union = len(pred_tokens.union(ref_tokens)) |
|
|
|
return intersection / union if union > 0 else 0 |
|
|
|
def calculate_f1_precision_recall(prediction, reference): |
|
"""Calculate word-level F1, precision, and recall with custom tokenizer""" |
|
|
|
pred_tokens = set(safe_tokenize(prediction.lower())) |
|
ref_tokens = set(safe_tokenize(reference.lower())) |
|
|
|
|
|
common = pred_tokens.intersection(ref_tokens) |
|
|
|
|
|
precision = len(common) / len(pred_tokens) if pred_tokens else 0 |
|
recall = len(common) / len(ref_tokens) if ref_tokens else 0 |
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0 |
|
|
|
return {'precision': precision, 'recall': recall, 'f1': f1} |
|
|
|
def evaluate_retrieval_quality(contexts, query, language): |
|
"""Evaluate the quality of retrieved contexts""" |
|
|
|
|
|
return { |
|
'language_match_ratio': 1.0, |
|
'source_diversity': len(set([ctx.get('source', '') for ctx in contexts])) / max(1, len(contexts)), |
|
'mrr': 1.0 |
|
} |
|
|
|
|
|
def simple_process_pdfs(pdf_paths): |
|
"""Process PDF documents and return document objects""" |
|
documents = [] |
|
|
|
for pdf_path in pdf_paths: |
|
try: |
|
text = "" |
|
with open(pdf_path, 'rb') as file: |
|
reader = PyPDF2.PdfReader(file) |
|
for page in reader.pages: |
|
page_text = page.extract_text() |
|
if page_text: |
|
text += page_text + "\n\n" |
|
|
|
if text.strip(): |
|
doc = Document( |
|
page_content=text, |
|
metadata={"source": pdf_path, "filename": os.path.basename(pdf_path)} |
|
) |
|
documents.append(doc) |
|
print(f"Successfully processed: {pdf_path}") |
|
else: |
|
print(f"Warning: No text extracted from {pdf_path}") |
|
except Exception as e: |
|
print(f"Error processing {pdf_path}: {e}") |
|
|
|
print(f"Processed {len(documents)} PDF documents") |
|
return documents |
|
|
|
def create_vector_store(documents): |
|
"""Split documents into chunks and create a FAISS vector store""" |
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=500, |
|
chunk_overlap=50, |
|
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""] |
|
) |
|
|
|
|
|
chunks = [] |
|
for doc in documents: |
|
doc_chunks = text_splitter.split_text(doc.page_content) |
|
|
|
chunks.extend([ |
|
Document(page_content=chunk, metadata=doc.metadata) |
|
for chunk in doc_chunks |
|
]) |
|
|
|
print(f"Created {len(chunks)} chunks from {len(documents)} documents") |
|
|
|
|
|
embedding_function = HuggingFaceEmbeddings( |
|
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" |
|
) |
|
|
|
|
|
vector_store = FAISS.from_documents( |
|
chunks, |
|
embedding_function |
|
) |
|
|
|
return vector_store |
|
|
|
|
|
def load_model_and_tokenizer(): |
|
"""Load the ALLaM-7B model and tokenizer with error handling""" |
|
model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview" |
|
print(f"Loading model: {model_name}") |
|
|
|
try: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
use_fast=False |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
) |
|
|
|
print("Model loaded successfully with AutoTokenizer!") |
|
|
|
except Exception as e: |
|
print(f"First loading attempt failed: {e}") |
|
print("Trying alternative loading approach...") |
|
|
|
|
|
from transformers import LlamaTokenizer |
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
) |
|
|
|
print("Model loaded successfully with LlamaTokenizer!") |
|
|
|
return model, tokenizer |
|
|
|
def retrieve_context(query, vector_store, top_k=5): |
|
"""Retrieve most relevant document chunks for a given query""" |
|
|
|
results = vector_store.similarity_search_with_score(query, k=top_k) |
|
|
|
|
|
contexts = [] |
|
for doc, score in results: |
|
contexts.append({ |
|
"content": doc.page_content, |
|
"source": doc.metadata.get("source", "Unknown"), |
|
"relevance_score": score |
|
}) |
|
|
|
return contexts |
|
|
|
def generate_response(query, contexts, model, tokenizer, language="auto"): |
|
"""Generate a response using retrieved contexts with ALLaM-specific formatting""" |
|
|
|
if language == "auto": |
|
language = detect_language(query) |
|
|
|
|
|
if language == "arabic": |
|
instruction = ( |
|
"أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. " |
|
"إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف." |
|
) |
|
else: |
|
instruction = ( |
|
"You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. " |
|
"If you don't know the answer, honestly say you don't know." |
|
) |
|
|
|
|
|
context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts]) |
|
|
|
|
|
prompt = f"""<s>[INST] {instruction} |
|
|
|
Context: |
|
{context_text} |
|
|
|
Question: {query} [/INST]</s>""" |
|
|
|
try: |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
outputs = model.generate( |
|
inputs.input_ids, |
|
attention_mask=inputs.attention_mask, |
|
max_new_tokens=512, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
|
|
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
response = full_output.split("[/INST]")[-1].strip() |
|
|
|
|
|
if not response: |
|
response = full_output |
|
|
|
return response |
|
|
|
except Exception as e: |
|
print(f"Error during generation: {e}") |
|
|
|
return "I apologize, but I encountered an error while generating a response." |
|
|
|
|
|
class Vision2030Assistant: |
|
def __init__(self, model, tokenizer, vector_store): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.vector_store = vector_store |
|
self.conversation_history = [] |
|
|
|
def answer(self, user_query): |
|
"""Process a user query and return a response with sources""" |
|
|
|
language = detect_language(user_query) |
|
|
|
|
|
self.conversation_history.append({"role": "user", "content": user_query}) |
|
|
|
|
|
conversation_context = "\n".join([ |
|
f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}" |
|
for msg in self.conversation_history[-6:] |
|
]) |
|
|
|
|
|
enhanced_query = f"{conversation_context}\n{user_query}" |
|
|
|
|
|
contexts = retrieve_context(enhanced_query, self.vector_store, top_k=5) |
|
|
|
|
|
response = generate_response(user_query, contexts, self.model, self.tokenizer, language) |
|
|
|
|
|
self.conversation_history.append({"role": "assistant", "content": response}) |
|
|
|
|
|
sources = [ctx.get("source", "Unknown") for ctx in contexts] |
|
unique_sources = list(set(sources)) |
|
|
|
return response, unique_sources, contexts |
|
|
|
def reset_conversation(self): |
|
"""Reset the conversation history""" |
|
self.conversation_history = [] |
|
return "Conversation has been reset." |
|
|
|
|
|
comprehensive_evaluation_data = [ |
|
|
|
{ |
|
"query": "ما هي رؤية السعودية 2030؟", |
|
"reference": "رؤية السعودية 2030 هي خطة استراتيجية تهدف إلى تنويع الاقتصاد السعودي وتقليل الاعتماد على النفط مع تطوير قطاعات مختلفة مثل الصحة والتعليم والسياحة.", |
|
"category": "overview", |
|
"language": "arabic" |
|
}, |
|
{ |
|
"query": "What is Saudi Vision 2030?", |
|
"reference": "Saudi Vision 2030 is a strategic framework aiming to diversify Saudi Arabia's economy and reduce dependence on oil, while developing sectors like health, education, and tourism.", |
|
"category": "overview", |
|
"language": "english" |
|
}, |
|
|
|
|
|
{ |
|
"query": "ما هي الأهداف الاقتصادية لرؤية 2030؟", |
|
"reference": "تشمل الأهداف الاقتصادية زيادة مساهمة القطاع الخاص إلى 65%، وزيادة الصادرات غير النفطية إلى 50% من الناتج المحلي غير النفطي، وخفض البطالة إلى 7%.", |
|
"category": "economic", |
|
"language": "arabic" |
|
}, |
|
{ |
|
"query": "What are the economic goals of Vision 2030?", |
|
"reference": "The economic goals of Vision 2030 include increasing private sector contribution from 40% to 65% of GDP, raising non-oil exports from 16% to 50%, reducing unemployment from 11.6% to 7%.", |
|
"category": "economic", |
|
"language": "english" |
|
}, |
|
|
|
|
|
{ |
|
"query": "كيف تعزز رؤية 2030 الإرث الثقافي السعودي؟", |
|
"reference": "تتضمن رؤية 2030 الحفاظ على الهوية الوطنية، تسجيل مواقع أثرية في اليونسكو، وتعزيز الفعاليات الثقافية.", |
|
"category": "social", |
|
"language": "arabic" |
|
}, |
|
{ |
|
"query": "How does Vision 2030 aim to improve quality of life?", |
|
"reference": "Vision 2030 plans to enhance quality of life by expanding sports facilities, promoting cultural activities, and boosting tourism and entertainment sectors.", |
|
"category": "social", |
|
"language": "english" |
|
} |
|
] |
|
|
|
|
|
def initialize_system(): |
|
"""Initialize the Vision 2030 Assistant system""" |
|
|
|
|
|
|
|
|
|
|
|
model_dir = "models" |
|
vector_store_dir = "vector_stores" |
|
pdf_dir = "pdf_data" |
|
|
|
os.makedirs(model_dir, exist_ok=True) |
|
os.makedirs(vector_store_dir, exist_ok=True) |
|
os.makedirs(pdf_dir, exist_ok=True) |
|
|
|
|
|
pdf_files = ["vision2030_docs/saudi_vision203.pdf", "vision2030_docs/saudi_vision2030_ar.pdf"] |
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists(os.path.join(vector_store_dir, "index.faiss")): |
|
print("Loading existing vector store...") |
|
embedding_function = HuggingFaceEmbeddings( |
|
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" |
|
) |
|
vector_store = FAISS.load_local(vector_store_dir, embedding_function) |
|
else: |
|
print("Creating new vector store...") |
|
documents = simple_process_pdfs(pdf_files) |
|
vector_store = create_vector_store(documents) |
|
vector_store.save_local(vector_store_dir) |
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer() |
|
|
|
|
|
assistant = Vision2030Assistant(model, tokenizer, vector_store) |
|
|
|
return assistant |
|
|
|
def evaluate_response(query, response, reference): |
|
"""Evaluate a single response against a reference""" |
|
|
|
rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) |
|
rouge_scores = rouge.score(response, reference) |
|
|
|
bleu_scores = calculate_bleu(response, reference) |
|
meteor = calculate_meteor(response, reference) |
|
word_metrics = calculate_f1_precision_recall(response, reference) |
|
|
|
|
|
evaluation_results = { |
|
"ROUGE-1": f"{rouge_scores['rouge1'].fmeasure:.4f}", |
|
"ROUGE-2": f"{rouge_scores['rouge2'].fmeasure:.4f}", |
|
"ROUGE-L": f"{rouge_scores['rougeL'].fmeasure:.4f}", |
|
"BLEU-1": f"{bleu_scores['bleu_1']:.4f}", |
|
"BLEU-4": f"{bleu_scores['bleu_4']:.4f}", |
|
"METEOR": f"{meteor:.4f}", |
|
"Word Precision": f"{word_metrics['precision']:.4f}", |
|
"Word Recall": f"{word_metrics['recall']:.4f}", |
|
"Word F1": f"{word_metrics['f1']:.4f}" |
|
} |
|
|
|
return evaluation_results |
|
|
|
def run_conversation(assistant, query): |
|
"""Run a query through the assistant and return the response""" |
|
response, sources, contexts = assistant.answer(query) |
|
return response, sources, contexts |
|
|
|
def run_evaluation_on_sample(assistant, sample_index=0): |
|
"""Run evaluation on a selected sample from the evaluation dataset""" |
|
if sample_index < 0 or sample_index >= len(comprehensive_evaluation_data): |
|
return "Invalid sample index", "", "", {} |
|
|
|
|
|
sample = comprehensive_evaluation_data[sample_index] |
|
query = sample["query"] |
|
reference = sample["reference"] |
|
category = sample["category"] |
|
language = sample["language"] |
|
|
|
|
|
assistant.reset_conversation() |
|
response, sources, contexts = assistant.answer(query) |
|
|
|
|
|
evaluation_results = evaluate_response(query, response, reference) |
|
|
|
|
|
metrics_str = "\n".join([f"{k}: {v}" for k, v in evaluation_results.items()]) |
|
|
|
return query, response, reference, evaluation_results, sources, category, language |
|
|
|
def qualitative_evaluation_interface(assistant): |
|
"""Create a Gradio interface for qualitative evaluation""" |
|
|
|
sample_options = [f"{i+1}. {item['query'][:50]}..." for i, item in enumerate(comprehensive_evaluation_data)] |
|
|
|
with gr.Blocks(title="Vision 2030 Assistant - Qualitative Evaluation") as interface: |
|
gr.Markdown("# Vision 2030 Assistant - Qualitative Evaluation") |
|
gr.Markdown("This interface allows you to evaluate the Vision 2030 Assistant on predefined samples or your own queries.") |
|
|
|
with gr.Tab("Sample Evaluation"): |
|
gr.Markdown("### Evaluate the assistant on predefined samples") |
|
|
|
sample_dropdown = gr.Dropdown( |
|
choices=sample_options, |
|
label="Select a sample query", |
|
value=sample_options[0] if sample_options else None |
|
) |
|
|
|
eval_button = gr.Button("Evaluate Sample") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
sample_query = gr.Textbox(label="Query") |
|
sample_category = gr.Textbox(label="Category") |
|
sample_language = gr.Textbox(label="Language") |
|
|
|
with gr.Column(): |
|
sample_response = gr.Textbox(label="Assistant Response") |
|
sample_reference = gr.Textbox(label="Reference Answer") |
|
sample_sources = gr.Textbox(label="Sources Used") |
|
|
|
with gr.Row(): |
|
metrics_display = gr.JSON(label="Evaluation Metrics") |
|
|
|
with gr.Tab("Custom Evaluation"): |
|
gr.Markdown("### Evaluate the assistant on your own query") |
|
|
|
custom_query = gr.Textbox( |
|
lines=3, |
|
placeholder="Enter your question about Saudi Vision 2030...", |
|
label="Your Query" |
|
) |
|
|
|
custom_reference = gr.Textbox( |
|
lines=3, |
|
placeholder="Enter a reference answer (optional)...", |
|
label="Reference Answer (Optional)" |
|
) |
|
|
|
custom_eval_button = gr.Button("Get Response and Evaluate") |
|
|
|
custom_response = gr.Textbox(label="Assistant Response") |
|
custom_sources = gr.Textbox(label="Sources Used") |
|
|
|
custom_metrics = gr.JSON( |
|
label="Evaluation Metrics (if reference provided)", |
|
visible=True |
|
) |
|
|
|
with gr.Tab("Conversation Mode"): |
|
gr.Markdown("### Have a conversation with the Vision 2030 Assistant") |
|
|
|
chatbot = gr.Chatbot(label="Conversation") |
|
|
|
conv_input = gr.Textbox( |
|
placeholder="Ask about Saudi Vision 2030...", |
|
label="Your message" |
|
) |
|
|
|
with gr.Row(): |
|
conv_button = gr.Button("Send") |
|
reset_button = gr.Button("Reset Conversation") |
|
|
|
conv_sources = gr.Textbox(label="Sources Used") |
|
|
|
|
|
def handle_sample_selection(selection): |
|
if not selection: |
|
return "", "", "", "", "", "", "" |
|
|
|
|
|
try: |
|
index = int(selection.split(".")[0]) - 1 |
|
query, response, reference, metrics, sources, category, language = run_evaluation_on_sample(assistant, index) |
|
sources_str = ", ".join(sources) |
|
return query, response, reference, metrics, sources_str, category, language |
|
except: |
|
return "Error processing selection", "", "", {}, "", "", "" |
|
|
|
eval_button.click( |
|
handle_sample_selection, |
|
inputs=[sample_dropdown], |
|
outputs=[sample_query, sample_response, sample_reference, metrics_display, |
|
sample_sources, sample_category, sample_language] |
|
) |
|
|
|
sample_dropdown.change( |
|
handle_sample_selection, |
|
inputs=[sample_dropdown], |
|
outputs=[sample_query, sample_response, sample_reference, metrics_display, |
|
sample_sources, sample_category, sample_language] |
|
) |
|
|
|
|
|
def handle_custom_evaluation(query, reference): |
|
if not query: |
|
return "Please enter a query", "", {} |
|
|
|
|
|
assistant.reset_conversation() |
|
|
|
|
|
response, sources, _ = assistant.answer(query) |
|
sources_str = ", ".join(sources) |
|
|
|
|
|
metrics = {} |
|
if reference: |
|
metrics = evaluate_response(query, response, reference) |
|
|
|
return response, sources_str, metrics |
|
|
|
custom_eval_button.click( |
|
handle_custom_evaluation, |
|
inputs=[custom_query, custom_reference], |
|
outputs=[custom_response, custom_sources, custom_metrics] |
|
) |
|
|
|
|
|
def handle_conversation(message, history): |
|
if not message: |
|
return history, "", "" |
|
|
|
|
|
response, sources, _ = assistant.answer(message) |
|
sources_str = ", ".join(sources) |
|
|
|
|
|
history = history + [[message, response]] |
|
|
|
return history, "", sources_str |
|
|
|
def reset_conv(): |
|
result = assistant.reset_conversation() |
|
return [], result, "" |
|
|
|
conv_button.click( |
|
handle_conversation, |
|
inputs=[conv_input, chatbot], |
|
outputs=[chatbot, conv_input, conv_sources] |
|
) |
|
|
|
reset_button.click( |
|
reset_conv, |
|
inputs=[], |
|
outputs=[chatbot, conv_input, conv_sources] |
|
) |
|
|
|
return interface |
|
|
|
|
|
def main(): |
|
|
|
try: |
|
assistant = initialize_system() |
|
interface = qualitative_evaluation_interface(assistant) |
|
interface.launch() |
|
except Exception as e: |
|
print(f"Error initializing system: {e}") |
|
|
|
gr.Interface( |
|
fn=lambda x: f"System initialization failed: {str(e)}", |
|
inputs=gr.Textbox(placeholder="System failed to initialize"), |
|
outputs=gr.Textbox() |
|
).launch() |
|
|
|
if __name__ == "__main__": |
|
main() |