Spaces:
Runtime error
Runtime error
from sentence_transformers import CrossEncoder, SentenceTransformer | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import faiss | |
import numpy as np | |
from typing import List, Dict | |
class ArabicRAGSystem: | |
def __init__(self): | |
"""Initialize with fallback models for Arabic support""" | |
# Solution 1: Use reliable Arabic embedding model | |
self.embedding_model = SentenceTransformer("UBC-NLP/AraBERT") | |
# Solution 2: Fallback cross-encoder options | |
try: | |
self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") # Multilingual fallback | |
except: | |
self.cross_encoder = None # System will work without it | |
# Solution 3: Main Arabic LLM with error handling | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained("inception-mbzuai/jais-13b-chat") | |
self.llm = AutoModelForCausalLM.from_pretrained("inception-mbzuai/jais-13b-chat") | |
except: | |
# Fallback to smaller Arabic model | |
self.tokenizer = AutoTokenizer.from_pretrained("aubmindlab/aragpt2-base") | |
self.llm = AutoModelForCausalLM.from_pretrained("aubmindlab/aragpt2-base") | |
self.index = faiss.IndexFlatL2(768) # AraBERT uses 768-dim embeddings | |
def generate_answer(self, question: str, documents: List[Dict], | |
top_k: int = 5, temperature: float = 0.7) -> tuple: | |
"""Enhanced with fallback retrieval methods""" | |
# Index documents | |
texts = [doc["text"] for doc in documents] | |
self.index.add(np.array(self.embedding_model.encode(texts))) | |
# Two-phase retrieval with fallback | |
query_embedding = self.embedding_model.encode([question]) | |
distances, indices = self.index.search(query_embedding, top_k*2) | |
# Solution 4: Cross-encoder fallback logic | |
if self.cross_encoder: | |
pairs = [[question, documents[idx]["text"]] for idx in indices[0]] | |
scores = self.cross_encoder.predict(pairs) | |
top_indices = [indices[0][i] for i in np.argsort(scores)[-top_k:][::-1]] | |
else: | |
top_indices = indices[0][:top_k] | |
# Prepare context with metadata | |
context = "\n\n".join([ | |
f"المرجع: {documents[idx]['source']}\n" | |
f"الصفحة: {documents[idx].get('page', 'N/A')}\n" | |
f"النص: {documents[idx]['text']}\n" | |
for idx in top_indices | |
]) | |
# Generation with error handling | |
prompt = f""" | |
نظام التحليل الديني العربي: | |
السياق: | |
{context} | |
السؤال: {question} | |
التعليمات: | |
- أجب باللغة العربية الفصحى فقط | |
- استخدم المعلومات من السياق فقط | |
- أشر إلى المصادر باستخدام [المرجع: اسم الملف، الصفحة] | |
- إذا لم تجد إجابة واضحة قل "لا تتوفر معلومات كافية" | |
الإجابة: | |
""".strip() | |
try: | |
inputs = self.tokenizer(prompt, return_tensors="pt") | |
outputs = self.llm.generate( | |
inputs.input_ids, | |
max_new_tokens=512, | |
temperature=temperature, | |
do_sample=True, | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
answer = answer.split("الإجابة:")[-1].strip() | |
except: | |
answer = "عذراً، حدث خطأ في معالجة السؤال. يرجى المحاولة مرة أخرى." | |
# Prepare sources | |
sources = [{ | |
"text": documents[idx]["text"], | |
"source": documents[idx]["source"], | |
"page": documents[idx].get("page", "N/A"), | |
"score": float(1 - distances[0][i]) if self.cross_encoder else 0.0 | |
} for i, idx in enumerate(top_indices)] | |
return answer, sources |