|
|
|
|
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from langdetect import detect |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
print("Loading ALLaM-7B-Instruct-preview for Arabic...") |
|
arabic_model_id = "ALLaM-AI/ALLaM-7B-Instruct-preview" |
|
arabic_tokenizer = AutoTokenizer.from_pretrained(arabic_model_id) |
|
arabic_model = AutoModelForCausalLM.from_pretrained(arabic_model_id, device_map="auto") |
|
arabic_pipe = pipeline("text-generation", model=arabic_model, tokenizer=arabic_tokenizer) |
|
|
|
|
|
|
|
|
|
print("Loading Mistral-7B-Instruct-v0.2 for English...") |
|
english_model_id = "mistralai/Mistral-7B-Instruct-v0.2" |
|
english_tokenizer = AutoTokenizer.from_pretrained(english_model_id) |
|
english_model = AutoModelForCausalLM.from_pretrained(english_model_id, device_map="auto") |
|
english_pipe = pipeline("text-generation", model=english_model, tokenizer=english_tokenizer) |
|
|
|
|
|
|
|
|
|
print("Loading Embedding Models for Retrieval...") |
|
arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca') |
|
english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
documents = [ |
|
{"text": "Vision 2030 aims to diversify the Saudi economy.", "lang": "en"}, |
|
{"text": "رؤية 2030 تهدف إلى تنويع الاقتصاد السعودي.", "lang": "ar"} |
|
] |
|
|
|
|
|
english_vectors = [] |
|
arabic_vectors = [] |
|
english_texts = [] |
|
arabic_texts = [] |
|
|
|
for doc in documents: |
|
if doc["lang"] == "en": |
|
vec = english_embedder.encode(doc["text"]) |
|
english_vectors.append(vec) |
|
english_texts.append(doc["text"]) |
|
else: |
|
vec = arabic_embedder.encode(doc["text"]) |
|
arabic_vectors.append(vec) |
|
arabic_texts.append(doc["text"]) |
|
|
|
|
|
english_index = faiss.IndexFlatL2(len(english_vectors[0])) |
|
english_index.add(np.array(english_vectors)) |
|
|
|
arabic_index = faiss.IndexFlatL2(len(arabic_vectors[0])) |
|
arabic_index.add(np.array(arabic_vectors)) |
|
|
|
|
|
|
|
|
|
def retrieve_and_generate(user_input): |
|
try: |
|
lang = detect(user_input) |
|
except: |
|
lang = "en" |
|
|
|
if lang == "ar": |
|
print("Detected Arabic input") |
|
query_vec = arabic_embedder.encode(user_input) |
|
D, I = arabic_index.search(np.array([query_vec]), k=1) |
|
context = arabic_texts[I[0][0]] if I[0][0] >= 0 else "" |
|
|
|
input_text = ( |
|
f"أنت خبير في رؤية السعودية 2030.\n" |
|
f"إليك بعض المعلومات المهمة:\n{context}\n\n" |
|
f"مثال:\n" |
|
f"السؤال: ما هي ركائز رؤية 2030؟\n" |
|
f"الإجابة: ركائز رؤية 2030 هي مجتمع حيوي، اقتصاد مزدهر، ووطن طموح.\n\n" |
|
f"أجب عن سؤال المستخدم بشكل واضح ودقيق.\n" |
|
f"السؤال: {user_input}\n" |
|
f"الإجابة:" |
|
) |
|
response = arabic_pipe(input_text, max_new_tokens=256, do_sample=True, temperature=0.7) |
|
reply = response[0]['generated_text'] |
|
|
|
else: |
|
print("Detected English input") |
|
query_vec = english_embedder.encode(user_input) |
|
D, I = english_index.search(np.array([query_vec]), k=1) |
|
context = english_texts[I[0][0]] if I[0][0] >= 0 else "" |
|
|
|
input_text = ( |
|
f"You are an expert on Saudi Arabia's Vision 2030.\n" |
|
f"Here is some relevant information:\n{context}\n\n" |
|
f"Example:\n" |
|
f"Question: What are the key pillars of Vision 2030?\n" |
|
f"Answer: The key pillars are a vibrant society, a thriving economy, and an ambitious nation.\n\n" |
|
f"Answer the user's question clearly and accurately.\n" |
|
f"Question: {user_input}\n" |
|
f"Answer:" |
|
) |
|
response = english_pipe(input_text, max_new_tokens=256, do_sample=True, temperature=0.7) |
|
reply = response[0]['generated_text'] |
|
|
|
return reply |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Vision 2030 Virtual Assistant 🌍\n\nSupports Arabic & English queries about Vision 2030 (with RAG retrieval and improved prompting).") |
|
chatbot = gr.Chatbot() |
|
msg = gr.Textbox(label="Ask me anything about Vision 2030") |
|
clear = gr.Button("Clear") |
|
|
|
def chat(message, history): |
|
reply = retrieve_and_generate(message) |
|
history.append((message, reply)) |
|
return history, "" |
|
|
|
msg.submit(chat, [msg, chatbot], [chatbot, msg]) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
demo.launch() |
|
|