chat22GV2 / rag_pipeline.py
ramysaidagieb's picture
Update rag_pipeline.py
d05dffb verified
raw
history blame
2.36 kB
# rag_pipeline.py
import time
import logging
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
logger = logging.getLogger("RAG")
class RAGPipeline:
def __init__(self):
logger.info("[RAG] جاري تحميل النموذج والمحول...")
self.embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
self.chunk_embeddings = []
self.chunks = []
self.client = chromadb.Client(Settings(chroma_db_impl="memory", persist_directory=None))
self.collection = self.client.create_collection(name="rag_collection")
self.tokenizer = AutoTokenizer.from_pretrained("aubmindlab/aragpt2-mega", trust_remote_code=True)
self.lm = AutoModelForCausalLM.from_pretrained("aubmindlab/aragpt2-mega", trust_remote_code=True)
logger.info("[RAG] تم التحميل بنجاح.")
def build_index(self, chunks):
start_time = time.time()
self.chunks = chunks
self.chunk_embeddings = self.embedding_model.encode(chunks, show_progress_bar=True)
logger.info(f"[RAG] تم بناء الفهرس بأبعاد {self.chunk_embeddings.shape[1]} في {time.time() - start_time:.2f} ثانية.")
for i, chunk in enumerate(chunks):
self.collection.add(documents=[chunk], ids=[str(i)], embeddings=[self.chunk_embeddings[i].tolist()])
def retrieve(self, query, k=5):
logger.info("[RAG] استرجاع المقاطع الأكثر صلة بالسؤال...")
query_embedding = self.embedding_model.encode([query])[0].tolist()
results = self.collection.query(query_embeddings=[query_embedding], n_results=k)
return results["documents"][0], results["ids"][0]
def generate_answer(self, query):
docs, ids = self.retrieve(query)
context = "\n\n".join(docs)
prompt = f"السياق:\n{context}\n\nالسؤال: {query}\nالإجابة:"
inputs = self.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = self.lm.generate(**inputs, max_new_tokens=200)
answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer, context