ramysaidagieb commited on
Commit
faa82c9
·
verified ·
1 Parent(s): e6f0575

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +28 -39
rag_pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- from sentence_transformers import CrossEncoder, SentenceTransformer
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import faiss
4
  import numpy as np
@@ -6,67 +6,56 @@ from typing import List, Dict
6
 
7
  class ArabicRAGSystem:
8
  def __init__(self):
9
- """Initialize with fallback models for Arabic support"""
10
- # Solution 1: Use reliable Arabic embedding model
11
- self.embedding_model = SentenceTransformer("UBC-NLP/AraBERT")
12
 
13
- # Solution 2: Fallback cross-encoder options
14
- try:
15
- self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") # Multilingual fallback
16
- except:
17
- self.cross_encoder = None # System will work without it
18
-
19
- # Solution 3: Main Arabic LLM with error handling
20
  try:
21
  self.tokenizer = AutoTokenizer.from_pretrained("inception-mbzuai/jais-13b-chat")
22
  self.llm = AutoModelForCausalLM.from_pretrained("inception-mbzuai/jais-13b-chat")
23
  except:
24
- # Fallback to smaller Arabic model
25
- self.tokenizer = AutoTokenizer.from_pretrained("aubmindlab/aragpt2-base")
26
- self.llm = AutoModelForCausalLM.from_pretrained("aubmindlab/aragpt2-base")
27
-
28
- self.index = faiss.IndexFlatL2(768) # AraBERT uses 768-dim embeddings
 
 
29
 
30
  def generate_answer(self, question: str, documents: List[Dict],
31
- top_k: int = 5, temperature: float = 0.7) -> tuple:
32
- """Enhanced with fallback retrieval methods"""
33
  # Index documents
34
  texts = [doc["text"] for doc in documents]
35
  self.index.add(np.array(self.embedding_model.encode(texts)))
36
 
37
- # Two-phase retrieval with fallback
38
  query_embedding = self.embedding_model.encode([question])
39
- distances, indices = self.index.search(query_embedding, top_k*2)
40
-
41
- # Solution 4: Cross-encoder fallback logic
42
- if self.cross_encoder:
43
- pairs = [[question, documents[idx]["text"]] for idx in indices[0]]
44
- scores = self.cross_encoder.predict(pairs)
45
- top_indices = [indices[0][i] for i in np.argsort(scores)[-top_k:][::-1]]
46
- else:
47
- top_indices = indices[0][:top_k]
48
 
49
  # Prepare context with metadata
50
  context = "\n\n".join([
51
  f"المرجع: {documents[idx]['source']}\n"
52
  f"الصفحة: {documents[idx].get('page', 'N/A')}\n"
53
  f"النص: {documents[idx]['text']}\n"
54
- for idx in top_indices
55
  ])
56
 
57
- # Generation with error handling
58
  prompt = f"""
59
- نظام التحليل الديني العربي:
 
60
  السياق:
61
  {context}
62
 
63
  السؤال: {question}
64
 
65
  التعليمات:
66
- - أجب باللغة العربية الفصحى فقط
67
- - استخدم المعلومات من السياق فقط
68
- - أشر إلى المصادر باستخدام [المرجع: اسم الملف، الصفحة]
69
- - إذا لم تجد إجابة واضحة قل "لا تتوفر معلومات كافية"
70
 
71
  الإجابة:
72
  """.strip()
@@ -82,15 +71,15 @@ class ArabicRAGSystem:
82
  )
83
  answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
84
  answer = answer.split("الإجابة:")[-1].strip()
85
- except:
86
- answer = "عذراً، حدث خطأ في معالجة السؤال. يرجى المحاولة مرة أخرى."
87
 
88
  # Prepare sources
89
  sources = [{
90
  "text": documents[idx]["text"],
91
  "source": documents[idx]["source"],
92
  "page": documents[idx].get("page", "N/A"),
93
- "score": float(1 - distances[0][i]) if self.cross_encoder else 0.0
94
- } for i, idx in enumerate(top_indices)]
95
 
96
  return answer, sources
 
1
+ from sentence_transformers import SentenceTransformer
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import faiss
4
  import numpy as np
 
6
 
7
  class ArabicRAGSystem:
8
  def __init__(self):
9
+ """Initialize with guaranteed-accessible Arabic models"""
10
+ # Verified embedding models (publicly available)
11
+ self.embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
12
 
13
+ # Main Arabic LLM with local fallback
 
 
 
 
 
 
14
  try:
15
  self.tokenizer = AutoTokenizer.from_pretrained("inception-mbzuai/jais-13b-chat")
16
  self.llm = AutoModelForCausalLM.from_pretrained("inception-mbzuai/jais-13b-chat")
17
  except:
18
+ try:
19
+ self.tokenizer = AutoTokenizer.from_pretrained("aubmindlab/aragpt2-base")
20
+ self.llm = AutoModelForCausalLM.from_pretrained("aubmindlab/aragpt2-base")
21
+ except:
22
+ raise Exception("Failed to load any Arabic language model. Please check internet connection and try again.")
23
+
24
+ self.index = faiss.IndexFlatL2(384) # Multilingual MiniLM uses 384-dim
25
 
26
  def generate_answer(self, question: str, documents: List[Dict],
27
+ top_k: int = 3, temperature: float = 0.7) -> tuple:
28
+ """Robust generation with guaranteed fallbacks"""
29
  # Index documents
30
  texts = [doc["text"] for doc in documents]
31
  self.index.add(np.array(self.embedding_model.encode(texts)))
32
 
33
+ # Simple semantic search (no cross-encoder dependency)
34
  query_embedding = self.embedding_model.encode([question])
35
+ distances, indices = self.index.search(query_embedding, top_k)
 
 
 
 
 
 
 
 
36
 
37
  # Prepare context with metadata
38
  context = "\n\n".join([
39
  f"المرجع: {documents[idx]['source']}\n"
40
  f"الصفحة: {documents[idx].get('page', 'N/A')}\n"
41
  f"النص: {documents[idx]['text']}\n"
42
+ for idx in indices[0]
43
  ])
44
 
45
+ # Generation with bulletproof prompt
46
  prompt = f"""
47
+ أنت مساعد ذكي متخصص في النصوص الدينية العربية. أجب على السؤال بناءً على السياق التالي فقط:
48
+
49
  السياق:
50
  {context}
51
 
52
  السؤال: {question}
53
 
54
  التعليمات:
55
+ 1. استخدم المعلومات من السياق فقط
56
+ 2. أجب باللغة العربية الفصحى
57
+ 3. أشر إلى المصادر بهذا الشكل: [المرجع: اسم الملف، الصفحة]
58
+ 4. إذا لم تجد إجابة واضحة قل "لا توجد معلومات كافية في النصوص المقدمة"
59
 
60
  الإجابة:
61
  """.strip()
 
71
  )
72
  answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
73
  answer = answer.split("الإجابة:")[-1].strip()
74
+ except Exception as e:
75
+ answer = f"عذراً، حدث خطأ في معالجة السؤال. التفاصيل: {str(e)}"
76
 
77
  # Prepare sources
78
  sources = [{
79
  "text": documents[idx]["text"],
80
  "source": documents[idx]["source"],
81
  "page": documents[idx].get("page", "N/A"),
82
+ "score": float(1 - distances[0][i])
83
+ } for i, idx in enumerate(indices[0])]
84
 
85
  return answer, sources