rag22v2 / rag_pipeline.py
ramysaidagieb's picture
Update rag_pipeline.py
6a5866e verified
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import faiss
import numpy as np
from typing import List, Dict
class ArabicRAGSystem:
def __init__(self):
"""Initialize with dependency-safe models"""
# Verified working embedding model
self.embedding_model = SentenceTransformer(
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
device="cpu"
)
# Load Arabic LLM with safe tokenizer settings
self.tokenizer = AutoTokenizer.from_pretrained(
"aubmindlab/aragpt2-base",
use_safetensors=True
)
self.llm = AutoModelForCausalLM.from_pretrained(
"aubmindlab/aragpt2-base",
use_safetensors=True,
device_map="auto",
torch_dtype="auto"
)
self.index = faiss.IndexFlatL2(384) # Matching embedding dim
def generate_answer(self, question: str, documents: List[Dict],
top_k: int = 3, temperature: float = 0.7) -> tuple:
"""Optimized generation with memory safety"""
# Convert documents to embeddings
texts = [doc["text"] for doc in documents]
embeddings = self.embedding_model.encode(texts, convert_to_numpy=True)
self.index.add(embeddings)
# Semantic search
query_embedding = self.embedding_model