|
import gradio as gr |
|
import os |
|
import re |
|
import torch |
|
import numpy as np |
|
from pathlib import Path |
|
import PyPDF2 |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from sentence_transformers import SentenceTransformer |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.schema import Document |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
import spaces |
|
|
|
|
|
class Vision2030Assistant: |
|
def __init__(self, model, tokenizer, vector_store): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.vector_store = vector_store |
|
self.conversation_history = [] |
|
|
|
def answer(self, user_query): |
|
|
|
language = detect_language(user_query) |
|
|
|
|
|
self.conversation_history.append({"role": "user", "content": user_query}) |
|
|
|
|
|
conversation_context = "\n".join([ |
|
f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}" |
|
for msg in self.conversation_history[-6:] |
|
]) |
|
|
|
|
|
enhanced_query = f"{conversation_context}\n{user_query}" |
|
|
|
|
|
contexts = retrieve_context(enhanced_query, self.vector_store, top_k=5) |
|
|
|
|
|
response = generate_response(user_query, contexts, self.model, self.tokenizer, language) |
|
|
|
|
|
self.conversation_history.append({"role": "assistant", "content": response}) |
|
|
|
|
|
sources = [ctx.get("source", "Unknown") for ctx in contexts] |
|
unique_sources = list(set(sources)) |
|
|
|
|
|
if unique_sources: |
|
source_text = "\n\nSources: " + ", ".join([os.path.basename(src) for src in unique_sources]) |
|
response_with_sources = response + source_text |
|
else: |
|
response_with_sources = response |
|
|
|
return response_with_sources |
|
|
|
def reset_conversation(self): |
|
"""Reset the conversation history""" |
|
self.conversation_history = [] |
|
return "Conversation has been reset." |
|
|
|
|
|
def detect_language(text): |
|
"""Detect if text is primarily Arabic or English""" |
|
arabic_chars = re.findall(r'[\u0600-\u06FF]', text) |
|
is_arabic = len(arabic_chars) > len(text) * 0.5 |
|
return "arabic" if is_arabic else "english" |
|
|
|
def retrieve_context(query, vector_store, top_k=5): |
|
"""Retrieve most relevant document chunks for a given query""" |
|
|
|
results = vector_store.similarity_search_with_score(query, k=top_k) |
|
|
|
|
|
contexts = [] |
|
for doc, score in results: |
|
contexts.append({ |
|
"content": doc.page_content, |
|
"source": doc.metadata.get("source", "Unknown"), |
|
"relevance_score": score |
|
}) |
|
|
|
return contexts |
|
|
|
@spaces.GPU |
|
def generate_response(query, contexts, model, tokenizer, language="auto"): |
|
"""Generate a response using retrieved contexts with ALLaM-specific formatting""" |
|
|
|
if language == "auto": |
|
language = detect_language(query) |
|
|
|
|
|
if language == "arabic": |
|
instruction = ( |
|
"أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. " |
|
"إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف." |
|
) |
|
else: |
|
instruction = ( |
|
"You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. " |
|
"If you don't know the answer, honestly say you don't know." |
|
) |
|
|
|
|
|
context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts]) |
|
|
|
|
|
prompt = f"""<s>[INST] {instruction} |
|
|
|
Context: |
|
{context_text} |
|
|
|
Question: {query} [/INST]</s>""" |
|
|
|
try: |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
outputs = model.generate( |
|
inputs.input_ids, |
|
attention_mask=inputs.attention_mask, |
|
max_new_tokens=512, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
|
|
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
response = full_output.split("[/INST]")[-1].strip() |
|
|
|
|
|
if not response: |
|
response = full_output |
|
|
|
return response |
|
|
|
except Exception as e: |
|
print(f"Error during generation: {e}") |
|
|
|
return "I apologize, but I encountered an error while generating a response." |
|
|
|
def process_pdf_files(pdf_files): |
|
"""Process PDF files and create documents""" |
|
documents = [] |
|
|
|
for pdf_file in pdf_files: |
|
try: |
|
|
|
temp_path = f"temp_{pdf_file.name}" |
|
with open(temp_path, "wb") as f: |
|
f.write(pdf_file.read()) |
|
|
|
|
|
text = "" |
|
with open(temp_path, 'rb') as file: |
|
reader = PyPDF2.PdfReader(file) |
|
for page in reader.pages: |
|
page_text = page.extract_text() |
|
if page_text: |
|
text += page_text + "\n\n" |
|
|
|
|
|
os.remove(temp_path) |
|
|
|
if text.strip(): |
|
doc = Document( |
|
page_content=text, |
|
metadata={"source": pdf_file.name, "filename": pdf_file.name} |
|
) |
|
documents.append(doc) |
|
print(f"Successfully processed: {pdf_file.name}") |
|
else: |
|
print(f"Warning: No text extracted from {pdf_file.name}") |
|
except Exception as e: |
|
print(f"Error processing {pdf_file.name}: {e}") |
|
|
|
print(f"Processed {len(documents)} PDF documents") |
|
return documents |
|
|
|
def create_vector_store(documents): |
|
"""Create a vector store from documents""" |
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=500, |
|
chunk_overlap=50, |
|
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""] |
|
) |
|
|
|
|
|
chunks = [] |
|
for doc in documents: |
|
doc_chunks = text_splitter.split_text(doc.page_content) |
|
|
|
chunks.extend([ |
|
Document(page_content=chunk, metadata=doc.metadata) |
|
for chunk in doc_chunks |
|
]) |
|
|
|
print(f"Created {len(chunks)} chunks from {len(documents)} documents") |
|
|
|
|
|
embedding_function = HuggingFaceEmbeddings( |
|
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" |
|
) |
|
|
|
|
|
vector_store = FAISS.from_documents(chunks, embedding_function) |
|
return vector_store |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
assistant = None |
|
|
|
|
|
@spaces.GPU |
|
def load_model_and_tokenizer(): |
|
global model, tokenizer |
|
|
|
if model is not None and tokenizer is not None: |
|
return "Model already loaded" |
|
|
|
model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview" |
|
print(f"Loading model: {model_name}") |
|
|
|
try: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
use_fast=False |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
) |
|
|
|
return "Model loaded successfully with AutoTokenizer!" |
|
|
|
except Exception as e: |
|
error_msg = f"First loading attempt failed: {e}" |
|
print(error_msg) |
|
|
|
try: |
|
|
|
from transformers import LlamaTokenizer |
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
) |
|
|
|
return "Model loaded successfully with LlamaTokenizer!" |
|
except Exception as e2: |
|
return f"Both loading attempts failed. Error 1: {e}. Error 2: {e2}" |
|
|
|
|
|
def process_pdfs(pdf_files): |
|
if not pdf_files: |
|
return "No files uploaded. Please upload PDF documents about Vision 2030." |
|
|
|
documents = process_pdf_files(pdf_files) |
|
|
|
if not documents: |
|
return "Failed to extract text from the uploaded PDFs." |
|
|
|
global assistant, model, tokenizer |
|
|
|
|
|
if model is None or tokenizer is None: |
|
load_status = load_model_and_tokenizer() |
|
if "successfully" not in load_status.lower(): |
|
return f"Model loading failed: {load_status}" |
|
|
|
|
|
vector_store = create_vector_store(documents) |
|
|
|
|
|
assistant = Vision2030Assistant(model, tokenizer, vector_store) |
|
|
|
return f"Successfully processed {len(documents)} documents. The assistant is ready to use!" |
|
|
|
@spaces.GPU |
|
def answer_query(message, history): |
|
global assistant |
|
|
|
if assistant is None: |
|
return "Please upload and process Vision 2030 PDF documents first." |
|
|
|
response = assistant.answer(message) |
|
return response |
|
|
|
def reset_chat(): |
|
global assistant |
|
|
|
if assistant is None: |
|
return "No active conversation to reset." |
|
|
|
reset_message = assistant.reset_conversation() |
|
return reset_message |
|
|
|
|
|
with gr.Blocks(title="Vision 2030 Virtual Assistant") as demo: |
|
gr.Markdown("# Vision 2030 Virtual Assistant") |
|
gr.Markdown("Ask questions about Saudi Vision 2030 goals, projects, and progress in Arabic or English.") |
|
|
|
with gr.Tab("Setup"): |
|
gr.Markdown("## Step 1: Load the Model") |
|
load_btn = gr.Button("Load ALLaM-7B Model", variant="primary") |
|
load_output = gr.Textbox(label="Load Status") |
|
load_btn.click(load_model_and_tokenizer, inputs=[], outputs=load_output) |
|
|
|
gr.Markdown("## Step 2: Upload Vision 2030 Documents") |
|
pdf_files = gr.File(file_types=[".pdf"], file_count="multiple", label="Upload PDF Documents") |
|
process_btn = gr.Button("Process Documents", variant="primary") |
|
process_output = gr.Textbox(label="Processing Status") |
|
process_btn.click(process_pdfs, inputs=[pdf_files], outputs=process_output) |
|
|
|
with gr.Tab("Chat"): |
|
chatbot = gr.Chatbot(label="Conversation") |
|
message = gr.Textbox( |
|
label="Ask a question about Vision 2030 (in Arabic or English)", |
|
placeholder="What are the main goals of Vision 2030?", |
|
lines=2 |
|
) |
|
submit_btn = gr.Button("Submit", variant="primary") |
|
reset_btn = gr.Button("Reset Conversation") |
|
|
|
gr.Markdown("### Example Questions") |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("**English Questions:**") |
|
en_examples = gr.Examples( |
|
examples=[ |
|
"What is Saudi Vision 2030?", |
|
"What are the economic goals of Vision 2030?", |
|
"How does Vision 2030 support women's empowerment?", |
|
"What environmental initiatives are part of Vision 2030?", |
|
"What is the role of the Public Investment Fund in Vision 2030?" |
|
], |
|
inputs=message |
|
) |
|
|
|
with gr.Column(): |
|
gr.Markdown("**Arabic Questions:**") |
|
ar_examples = gr.Examples( |
|
examples=[ |
|
"ما هي رؤية السعودية 2030؟", |
|
"ما هي الأهداف الاقتصادية لرؤية 2030؟", |
|
"كيف تدعم رؤية 2030 تمكين المرأة السعودية؟", |
|
"ما هي مبادرات رؤية 2030 للحفاظ على البيئة؟", |
|
"ما هي استراتيجية صندوق الاستثمارات العامة في رؤية 2030؟" |
|
], |
|
inputs=message |
|
) |
|
|
|
reset_output = gr.Textbox(label="Reset Status", visible=False) |
|
submit_btn.click(answer_query, inputs=[message, chatbot], outputs=[chatbot]) |
|
message.submit(answer_query, inputs=[message, chatbot], outputs=[chatbot]) |
|
reset_btn.click(reset_chat, inputs=[], outputs=[reset_output]) |
|
reset_btn.click(lambda: None, inputs=[], outputs=[chatbot], postprocess=False) |
|
|
|
|
|
demo.launch() |