Spaces:
Sleeping
Sleeping
Update rag_pipeline.py
Browse files- rag_pipeline.py +41 -25
rag_pipeline.py
CHANGED
@@ -1,31 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from sentence_transformers import SentenceTransformer
|
2 |
-
|
3 |
-
from
|
4 |
-
|
5 |
-
|
6 |
-
from utils import extract_text_from_files
|
7 |
|
8 |
class RAGPipeline:
|
9 |
def __init__(self):
|
10 |
-
|
11 |
-
self.embedding_model =
|
12 |
-
self.
|
13 |
-
self.
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
def
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
return f"[RAG] تم بناء الفهرس لـ {len(chunks)} مقاطع."
|
22 |
|
23 |
-
def
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
answer =
|
31 |
-
return answer,
|
|
|
1 |
+
# rag_pipeline.py
|
2 |
+
import time
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
|
7 |
from sentence_transformers import SentenceTransformer
|
8 |
+
import chromadb
|
9 |
+
from chromadb.config import Settings
|
10 |
+
|
11 |
+
logger = logging.getLogger("RAG")
|
|
|
12 |
|
13 |
class RAGPipeline:
|
14 |
def __init__(self):
|
15 |
+
logger.info("[RAG] جاري تحميل النموذج والمحول...")
|
16 |
+
self.embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
|
17 |
+
self.chunk_embeddings = []
|
18 |
+
self.chunks = []
|
19 |
+
self.client = chromadb.Client(Settings(chroma_db_impl="memory", persist_directory=None))
|
20 |
+
self.collection = self.client.create_collection(name="rag_collection")
|
21 |
+
self.tokenizer = AutoTokenizer.from_pretrained("aubmindlab/aragpt2-mega", trust_remote_code=True)
|
22 |
+
self.lm = AutoModelForCausalLM.from_pretrained("aubmindlab/aragpt2-mega", trust_remote_code=True)
|
23 |
+
logger.info("[RAG] تم التحميل بنجاح.")
|
24 |
+
|
25 |
+
def build_index(self, chunks):
|
26 |
+
start_time = time.time()
|
27 |
+
self.chunks = chunks
|
28 |
+
self.chunk_embeddings = self.embedding_model.encode(chunks, show_progress_bar=True)
|
29 |
+
logger.info(f"[RAG] تم بناء الفهرس بأبعاد {self.chunk_embeddings.shape[1]} في {time.time() - start_time:.2f} ثانية.")
|
30 |
+
for i, chunk in enumerate(chunks):
|
31 |
+
self.collection.add(documents=[chunk], ids=[str(i)], embeddings=[self.chunk_embeddings[i].tolist()])
|
32 |
|
33 |
+
def retrieve(self, query, k=5):
|
34 |
+
logger.info("[RAG] استرجاع المقاطع الأكثر صلة بالسؤال...")
|
35 |
+
query_embedding = self.embedding_model.encode([query])[0].tolist()
|
36 |
+
results = self.collection.query(query_embeddings=[query_embedding], n_results=k)
|
37 |
+
return results["documents"][0], results["ids"][0]
|
|
|
38 |
|
39 |
+
def generate_answer(self, query):
|
40 |
+
docs, ids = self.retrieve(query)
|
41 |
+
context = "\n\n".join(docs)
|
42 |
+
prompt = f"السياق:\n{context}\n\nالسؤال: {query}\nالإجابة:"
|
43 |
+
inputs = self.tokenizer(prompt, return_tensors="pt")
|
44 |
+
with torch.no_grad():
|
45 |
+
outputs = self.lm.generate(**inputs, max_new_tokens=200)
|
46 |
+
answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
47 |
+
return answer, context
|