rag22V1 / rag_pipeline.py
ramysaidagieb's picture
Upload 6 files
74e2822 verified
raw
history blame
1.65 kB
from sentence_transformers import SentenceTransformer
from transformers import pipeline
import faiss
import numpy as np
from config import MODEL_CONFIG
class ArabicRAGSystem:
def __init__(self):
self.embedder = SentenceTransformer(MODEL_CONFIG["embedding_model"])
self.llm = pipeline("text-generation", model=MODEL_CONFIG["llm"])
self.index = None
self.documents = []
def build_index(self, chunks: List[str]):
"""Create FAISS index from document chunks"""
self.documents = chunks
embeddings = self.embedder.encode(chunks, show_progress_bar=True)
self.index = faiss.IndexFlatIP(embeddings.shape[1])
self.index.add(embeddings)
def retrieve(self, query: str, k: int = 3) -> List[str]:
"""Retrieve relevant document chunks"""
query_embedding = self.embedder.encode([query])
distances, indices = self.index.search(query_embedding, k)
return [self.documents[i] for i in indices[0]]
def generate_answer(self, question: str, context: List[str]) -> str:
"""Generate answer using LLM with retrieved context"""
prompt = f"""استخدم المعلومات التالية للإجابة على السؤال:
السياق:
{'\n'.join(context)}
السؤال: {question}
الإجابة:"""
result = self.llm(
prompt,
max_new_tokens=256,
temperature=0.7,
do_sample=True
)
return result[0]["generated_text"].replace(prompt, "")