ramysaidagieb commited on
Commit
b82ee60
·
verified ·
1 Parent(s): c90b40e

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +63 -51
rag_pipeline.py CHANGED
@@ -6,79 +6,91 @@ from typing import List, Dict
6
 
7
  class ArabicRAGSystem:
8
  def __init__(self):
9
- # Initialize models
10
- self.embedding_model = SentenceTransformer("aubmindlab/bert-base-arabertv2")
11
- self.cross_encoder = CrossEncoder("Arabic-Misc/roberta-base-arabic-camelbert-da-msa")
12
- self.tokenizer = AutoTokenizer.from_pretrained("inception-mbzuai/jais-13b-chat")
13
- self.llm = AutoModelForCausalLM.from_pretrained("inception-mbzuai/jais-13b-chat")
14
- self.index = faiss.IndexFlatL2(768)
15
-
16
- def _create_index(self, documents: List[Dict]):
17
- texts = [doc["text"] for doc in documents]
18
- embeddings = self.embedding_model.encode(texts)
19
- self.index.add(np.array(embeddings))
 
 
 
 
 
 
 
 
 
20
 
21
  def generate_answer(self, question: str, documents: List[Dict],
22
  top_k: int = 5, temperature: float = 0.7) -> tuple:
23
- # Indexing phase
24
- self._create_index(documents)
 
 
25
 
26
- # Two-stage retrieval
27
  query_embedding = self.embedding_model.encode([question])
28
  distances, indices = self.index.search(query_embedding, top_k*2)
29
 
30
- # Re-ranking with cross-encoder
31
- pairs = [[question, documents[idx]["text"]] for idx in indices[0]]
32
- scores = self.cross_encoder.predict(pairs)
33
- ranked_indices = np.argsort(scores)[::-1][:top_k]
 
 
 
34
 
35
- # Prepare context
36
  context = "\n\n".join([
37
- f"المصدر: {documents[idx]['source']}\n"
38
- f"الصفحة: {documents[idx]['page']}\n"
39
- f"النص: {documents[idx]['text']}"
40
- for idx in [indices[0][i] for i in ranked_indices]
41
  ])
42
 
43
- # Generate answer
44
  prompt = f"""
45
- أنت خبير في التحليل الديني. قم بالإجابة على السؤال التالي بناءً على السياق المقدم فقط:
46
-
47
  السياق:
48
  {context}
49
 
50
- السؤال:
51
- {question}
52
 
53
  التعليمات:
54
- - أجب باللغة العربية الفصحى
55
- - استخدم علامات التنسيق المناسبة
56
- - أشر إلى المصادر باستخدام التنسيق [المصدر: اسم الملف، الصفحة: رقم]
57
- - إذا لم توجد إجابة واضحة، قل "لا تتوفر معلومات كافية"
58
 
59
  الإجابة:
60
  """.strip()
61
 
62
- inputs = self.tokenizer(prompt, return_tensors="pt")
63
- outputs = self.llm.generate(
64
- inputs.input_ids,
65
- max_new_tokens=512,
66
- temperature=temperature,
67
- do_sample=True,
68
- pad_token_id=self.tokenizer.eos_token_id
69
- )
70
-
71
- answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
72
- answer = answer.split("الإجابة:")[-1].strip()
 
 
73
 
74
  # Prepare sources
75
- sources = []
76
- for idx in [indices[0][i] for i in ranked_indices]:
77
- sources.append({
78
- "text": documents[idx]["text"],
79
- "source": documents[idx]["source"],
80
- "page": documents[idx]["page"],
81
- "score": float(scores[idx])
82
- })
83
 
84
  return answer, sources
 
6
 
7
  class ArabicRAGSystem:
8
  def __init__(self):
9
+ """Initialize with fallback models for Arabic support"""
10
+ # Solution 1: Use reliable Arabic embedding model
11
+ self.embedding_model = SentenceTransformer("UBC-NLP/AraBERT")
12
+
13
+ # Solution 2: Fallback cross-encoder options
14
+ try:
15
+ self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") # Multilingual fallback
16
+ except:
17
+ self.cross_encoder = None # System will work without it
18
+
19
+ # Solution 3: Main Arabic LLM with error handling
20
+ try:
21
+ self.tokenizer = AutoTokenizer.from_pretrained("inception-mbzuai/jais-13b-chat")
22
+ self.llm = AutoModelForCausalLM.from_pretrained("inception-mbzuai/jais-13b-chat")
23
+ except:
24
+ # Fallback to smaller Arabic model
25
+ self.tokenizer = AutoTokenizer.from_pretrained("aubmindlab/aragpt2-base")
26
+ self.llm = AutoModelForCausalLM.from_pretrained("aubmindlab/aragpt2-base")
27
+
28
+ self.index = faiss.IndexFlatL2(768) # AraBERT uses 768-dim embeddings
29
 
30
  def generate_answer(self, question: str, documents: List[Dict],
31
  top_k: int = 5, temperature: float = 0.7) -> tuple:
32
+ """Enhanced with fallback retrieval methods"""
33
+ # Index documents
34
+ texts = [doc["text"] for doc in documents]
35
+ self.index.add(np.array(self.embedding_model.encode(texts)))
36
 
37
+ # Two-phase retrieval with fallback
38
  query_embedding = self.embedding_model.encode([question])
39
  distances, indices = self.index.search(query_embedding, top_k*2)
40
 
41
+ # Solution 4: Cross-encoder fallback logic
42
+ if self.cross_encoder:
43
+ pairs = [[question, documents[idx]["text"]] for idx in indices[0]]
44
+ scores = self.cross_encoder.predict(pairs)
45
+ top_indices = [indices[0][i] for i in np.argsort(scores)[-top_k:][::-1]]
46
+ else:
47
+ top_indices = indices[0][:top_k]
48
 
49
+ # Prepare context with metadata
50
  context = "\n\n".join([
51
+ f"المرجع: {documents[idx]['source']}\n"
52
+ f"الصفحة: {documents[idx].get('page', 'N/A')}\n"
53
+ f"النص: {documents[idx]['text']}\n"
54
+ for idx in top_indices
55
  ])
56
 
57
+ # Generation with error handling
58
  prompt = f"""
59
+ نظام التحليل الديني العربي:
 
60
  السياق:
61
  {context}
62
 
63
+ السؤال: {question}
 
64
 
65
  التعليمات:
66
+ - أجب باللغة العربية الفصحى فقط
67
+ - استخدم المعلومات من السياق فقط
68
+ - أشر إلى المصادر باستخدام [المرجع: اسم الملف، الصفحة]
69
+ - إذا لم تجد إجابة واضحة قل "لا تتوفر معلومات كافية"
70
 
71
  الإجابة:
72
  """.strip()
73
 
74
+ try:
75
+ inputs = self.tokenizer(prompt, return_tensors="pt")
76
+ outputs = self.llm.generate(
77
+ inputs.input_ids,
78
+ max_new_tokens=512,
79
+ temperature=temperature,
80
+ do_sample=True,
81
+ pad_token_id=self.tokenizer.eos_token_id
82
+ )
83
+ answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+ answer = answer.split("الإجابة:")[-1].strip()
85
+ except:
86
+ answer = "عذراً، حدث خطأ في معالجة السؤال. يرجى المحاولة مرة أخرى."
87
 
88
  # Prepare sources
89
+ sources = [{
90
+ "text": documents[idx]["text"],
91
+ "source": documents[idx]["source"],
92
+ "page": documents[idx].get("page", "N/A"),
93
+ "score": float(1 - distances[0][i]) if self.cross_encoder else 0.0
94
+ } for i, idx in enumerate(top_indices)]
 
 
95
 
96
  return answer, sources