ramysaidagieb commited on
Commit
6a5866e
·
verified ·
1 Parent(s): faa82c9

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +26 -72
rag_pipeline.py CHANGED
@@ -6,80 +6,34 @@ from typing import List, Dict
6
 
7
  class ArabicRAGSystem:
8
  def __init__(self):
9
- """Initialize with guaranteed-accessible Arabic models"""
10
- # Verified embedding models (publicly available)
11
- self.embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
12
-
13
- # Main Arabic LLM with local fallback
14
- try:
15
- self.tokenizer = AutoTokenizer.from_pretrained("inception-mbzuai/jais-13b-chat")
16
- self.llm = AutoModelForCausalLM.from_pretrained("inception-mbzuai/jais-13b-chat")
17
- except:
18
- try:
19
- self.tokenizer = AutoTokenizer.from_pretrained("aubmindlab/aragpt2-base")
20
- self.llm = AutoModelForCausalLM.from_pretrained("aubmindlab/aragpt2-base")
21
- except:
22
- raise Exception("Failed to load any Arabic language model. Please check internet connection and try again.")
23
-
24
- self.index = faiss.IndexFlatL2(384) # Multilingual MiniLM uses 384-dim
 
 
 
 
25
 
26
  def generate_answer(self, question: str, documents: List[Dict],
27
  top_k: int = 3, temperature: float = 0.7) -> tuple:
28
- """Robust generation with guaranteed fallbacks"""
29
- # Index documents
30
  texts = [doc["text"] for doc in documents]
31
- self.index.add(np.array(self.embedding_model.encode(texts)))
32
-
33
- # Simple semantic search (no cross-encoder dependency)
34
- query_embedding = self.embedding_model.encode([question])
35
- distances, indices = self.index.search(query_embedding, top_k)
36
-
37
- # Prepare context with metadata
38
- context = "\n\n".join([
39
- f"المرجع: {documents[idx]['source']}\n"
40
- f"الصفحة: {documents[idx].get('page', 'N/A')}\n"
41
- f"النص: {documents[idx]['text']}\n"
42
- for idx in indices[0]
43
- ])
44
-
45
- # Generation with bulletproof prompt
46
- prompt = f"""
47
- أنت مساعد ذكي متخصص في النصوص الدينية العربية. أجب على السؤال بناءً على السياق التالي فقط:
48
-
49
- السياق:
50
- {context}
51
-
52
- السؤال: {question}
53
-
54
- التعليمات:
55
- 1. استخدم المعلومات من السياق فقط
56
- 2. أجب باللغة العربية الفصحى
57
- 3. أشر إلى المصادر بهذا الشكل: [المرجع: اسم الملف، الصفحة]
58
- 4. إذا لم تجد إجابة واضحة قل "لا توجد معلومات كافية في النصوص المقدمة"
59
-
60
- الإجابة:
61
- """.strip()
62
-
63
- try:
64
- inputs = self.tokenizer(prompt, return_tensors="pt")
65
- outputs = self.llm.generate(
66
- inputs.input_ids,
67
- max_new_tokens=512,
68
- temperature=temperature,
69
- do_sample=True,
70
- pad_token_id=self.tokenizer.eos_token_id
71
- )
72
- answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
73
- answer = answer.split("الإجابة:")[-1].strip()
74
- except Exception as e:
75
- answer = f"عذراً، حدث خطأ في معالجة السؤال. التفاصيل: {str(e)}"
76
-
77
- # Prepare sources
78
- sources = [{
79
- "text": documents[idx]["text"],
80
- "source": documents[idx]["source"],
81
- "page": documents[idx].get("page", "N/A"),
82
- "score": float(1 - distances[0][i])
83
- } for i, idx in enumerate(indices[0])]
84
 
85
- return answer, sources
 
 
6
 
7
  class ArabicRAGSystem:
8
  def __init__(self):
9
+ """Initialize with dependency-safe models"""
10
+ # Verified working embedding model
11
+ self.embedding_model = SentenceTransformer(
12
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
13
+ device="cpu"
14
+ )
15
+
16
+ # Load Arabic LLM with safe tokenizer settings
17
+ self.tokenizer = AutoTokenizer.from_pretrained(
18
+ "aubmindlab/aragpt2-base",
19
+ use_safetensors=True
20
+ )
21
+ self.llm = AutoModelForCausalLM.from_pretrained(
22
+ "aubmindlab/aragpt2-base",
23
+ use_safetensors=True,
24
+ device_map="auto",
25
+ torch_dtype="auto"
26
+ )
27
+
28
+ self.index = faiss.IndexFlatL2(384) # Matching embedding dim
29
 
30
  def generate_answer(self, question: str, documents: List[Dict],
31
  top_k: int = 3, temperature: float = 0.7) -> tuple:
32
+ """Optimized generation with memory safety"""
33
+ # Convert documents to embeddings
34
  texts = [doc["text"] for doc in documents]
35
+ embeddings = self.embedding_model.encode(texts, convert_to_numpy=True)
36
+ self.index.add(embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # Semantic search
39
+ query_embedding = self.embedding_model