Spaces:
Paused
Paused
import os | |
import json | |
import pickle | |
import textwrap | |
import logging | |
from typing import List, Optional, Dict, Any, Iterable, Tuple | |
import requests | |
import faiss | |
import numpy as np | |
from llama_index.core import VectorStoreIndex | |
from llama_index.core.schema import TextNode | |
from llama_index.vector_stores.faiss import FaissVectorStore | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from sentence_transformers.util import cos_sim | |
# === Logger configuration === | |
logger = logging.getLogger("RAGEngine") | |
logger.setLevel(logging.INFO) | |
handler = logging.StreamHandler() | |
formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s") | |
handler.setFormatter(formatter) | |
if not logger.handlers: | |
logger.addHandler(handler) | |
MAX_TOKENS = 64 # bornage court sur CPU-only | |
DEFAULT_STOPS = ["</s>", "\n\n", "\nQuestion:", "Question:"] | |
class OllamaClient: | |
""" | |
Minimal Ollama client for /api/generate (text completion) with streaming support. | |
""" | |
def __init__(self, model: str, host: Optional[str] = None, timeout: int = 300): | |
self.model = model | |
self.host = host or os.getenv("OLLAMA_HOST", "http://localhost:11434") | |
self.timeout = timeout | |
self._gen_url = self.host.rstrip("/") + "/api/generate" | |
def generate( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
max_tokens: Optional[int] = None, | |
stream: bool = False, | |
options: Optional[Dict[str, Any]] = None, | |
raw: bool = False, | |
) -> str | Iterable[str]: | |
payload: Dict[str, Any] = { | |
"model": self.model, | |
"prompt": prompt, | |
"stream": stream, | |
} | |
if raw: | |
payload["raw"] = True # IMPORTANT: désactive le template Modelfile | |
if stop: | |
payload["stop"] = stop | |
if max_tokens is not None: | |
payload["num_predict"] = int(max_tokens) # nommage Ollama | |
if options: | |
payload["options"] = options | |
logger.debug(f"POST {self._gen_url} (stream={stream})") | |
if stream: | |
with requests.post(self._gen_url, json=payload, stream=True, timeout=self.timeout) as r: | |
r.raise_for_status() | |
for line in r.iter_lines(decode_unicode=True): | |
if not line: | |
continue | |
try: | |
data = json.loads(line) | |
except Exception: | |
continue | |
# En stream, Ollama renvoie des morceaux dans "response" | |
if "response" in data and not data.get("done"): | |
yield data["response"] | |
if data.get("done"): | |
break | |
return | |
r = requests.post(self._gen_url, json=payload, timeout=self.timeout) | |
r.raise_for_status() | |
data = r.json() | |
return data.get("response", "") | |
class RAGEngine: | |
def __init__( | |
self, | |
model_name: str, | |
vector_path: str, | |
index_path: str, | |
model_threads: int = 4, | |
ollama_host: Optional[str] = None, | |
ollama_opts: Optional[Dict[str, Any]] = None, | |
): | |
""" | |
Args: | |
model_name: e.g. "noushermes_rag" | |
vector_path: pickle file with chunk texts list[str] | |
index_path: FAISS index path | |
model_threads: forwarded as a hint to Ollama options | |
ollama_host: override OLLAMA_HOST (default http://localhost:11434) | |
ollama_opts: extra Ollama options (temperature, num_ctx, num_batch, num_thread) | |
""" | |
logger.info(f"🔎 rag_model_ollama source: {__file__}") | |
logger.info("📦 Initialisation du moteur RAG (Ollama)...") | |
# Options Ollama (par défaut optimisées CPU) | |
opts = dict(ollama_opts or {}) | |
opts.setdefault("temperature", 0.0) | |
opts.setdefault("num_ctx", 512) | |
opts.setdefault("num_batch", 16) | |
if "num_thread" not in opts and model_threads: | |
opts["num_thread"] = int(model_threads) | |
self.llm = OllamaClient(model=model_name, host=ollama_host) | |
self.ollama_opts = opts | |
# Embedding model pour retrieval / rerank | |
self.embed_model = HuggingFaceEmbedding(model_name="intfloat/multilingual-e5-base") | |
logger.info(f"📂 Chargement des données vectorielles depuis {vector_path}") | |
with open(vector_path, "rb") as f: | |
chunk_texts: List[str] = pickle.load(f) | |
nodes = [TextNode(text=chunk) for chunk in chunk_texts] | |
faiss_index = faiss.read_index(index_path) | |
vector_store = FaissVectorStore(faiss_index=faiss_index) | |
self.index = VectorStoreIndex(nodes=nodes, embed_model=self.embed_model, vector_store=vector_store) | |
logger.info("✅ Moteur RAG (Ollama) initialisé avec succès.") | |
# ---------------- LLM helpers (via Ollama) ---------------- | |
def _complete( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
max_tokens: int = MAX_TOKENS, | |
raw: bool = True | |
) -> str: | |
text = self.llm.generate( | |
prompt=prompt, | |
stop=stop or DEFAULT_STOPS, | |
max_tokens=max_tokens, | |
stream=False, | |
options=self.ollama_opts, | |
raw=raw, # toujours True pour bypass Modelfile | |
) | |
# Par sécurité si un générateur se glisse quand stream=False | |
try: | |
if hasattr(text, "__iter__") and not isinstance(text, (str, bytes)): | |
chunks = [] | |
for t in text: | |
if not isinstance(t, (str, bytes)): | |
continue | |
chunks.append(t) | |
text = "".join(chunks) | |
except Exception: | |
pass | |
return (text or "").strip() | |
def _complete_stream( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
max_tokens: int = MAX_TOKENS, | |
raw: bool = True | |
) -> Iterable[str]: | |
return self.llm.generate( | |
prompt=prompt, | |
stop=stop or DEFAULT_STOPS, | |
max_tokens=max_tokens, | |
stream=True, | |
options=self.ollama_opts, | |
raw=raw, # toujours True pour bypass Modelfile | |
) | |
# ---------------- Utilities ---------------- | |
def _is_greeting(self, text: str) -> bool: | |
s = text.lower().strip() | |
return s in {"bonjour", "salut", "hello", "bonsoir", "hi", "coucou", "yo"} or len(s.split()) <= 2 | |
def _decide_mode(self, scores: List[float], tau: float = 0.32, is_greeting: bool = False) -> str: | |
if is_greeting: | |
return "llm" | |
top = scores[0] if scores else 0.0 | |
return "rag" if top >= tau else "llm" | |
def _stream_with_local_stops(self, tokens: Iterable[str], stops: List[str]) -> Iterable[str]: | |
""" | |
Coupe localement le stream si un stop apparaît, même si le serveur ne s'arrête pas. | |
""" | |
buffer = "" | |
for chunk in tokens: | |
buffer += chunk | |
# Check si un des stops est présent dans le buffer | |
hit = None | |
for s in stops: | |
idx = buffer.find(s) | |
if idx != -1: | |
hit = (s, idx) | |
break | |
if hit: | |
s, idx = hit | |
# Yield tout avant le stop, puis stoppe | |
yield buffer[:idx] | |
break | |
else: | |
# Si pas de stop, on envoie le chunk tel quel | |
yield chunk | |
# ---------------- Retrieval + (optional) rerank ---------------- | |
def get_adaptive_top_k(self, question: str) -> int: | |
q = question.lower() | |
if len(q.split()) <= 7: | |
top_k = 8 | |
elif any(w in q for w in ["liste", "résume", "quels sont", "explique", "comment"]): | |
top_k = 10 | |
else: | |
top_k = 8 | |
logger.info(f"🔢 top_k déterminé automatiquement : {top_k}") | |
return top_k | |
def rerank_nodes(self, question: str, retrieved_nodes, top_k: int = 3) -> Tuple[List[float], List[TextNode]]: | |
logger.info(f"🔍 Re-ranking des {len(retrieved_nodes)} chunks pour la question : « {question} »") | |
q_emb = self.embed_model.get_query_embedding(question) | |
scored_nodes: List[Tuple[float, TextNode]] = [] | |
for node in retrieved_nodes: | |
chunk_text = node.get_content() | |
chunk_emb = self.embed_model.get_text_embedding(chunk_text) | |
score = cos_sim(q_emb, chunk_emb).item() | |
scored_nodes.append((score, node)) | |
ranked = sorted(scored_nodes, key=lambda x: x[0], reverse=True) | |
logger.info("📊 Chunks les plus pertinents :") | |
for i, (score, node) in enumerate(ranked[:top_k]): | |
chunk_preview = textwrap.shorten(node.get_content().replace("\n", " "), width=100) | |
logger.info(f"#{i+1} | Score: {score:.4f} | {chunk_preview}") | |
top = ranked[:top_k] | |
scores = [s for s, _ in top] | |
nodes = [n for _, n in top] | |
return scores, nodes | |
def retrieve_context(self, question: str, top_k: int = 3) -> Tuple[str, List[TextNode], List[float]]: | |
logger.info("📥 Récupération du contexte...") | |
retriever = self.index.as_retriever(similarity_top_k=top_k) | |
retrieved_nodes = retriever.retrieve(question) | |
scores, nodes = self.rerank_nodes(question, retrieved_nodes, top_k) | |
context = "\n\n".join(n.get_content()[:500] for n in nodes) | |
return context, nodes, scores | |
# ---------------- Public API ---------------- | |
def ask(self, question_raw: str, allow_fallback: bool = True) -> str: | |
logger.info(f"💬 Question reçue : {question_raw}") | |
is_hello = self._is_greeting(question_raw) | |
# retrieval (sauf salutations) | |
context, scores = "", [] | |
if not is_hello: | |
top_k = self.get_adaptive_top_k(question_raw) | |
context, _, scores = self.retrieve_context(question_raw, top_k) | |
# router RAG vs LLM | |
mode = self._decide_mode(scores, tau=0.32, is_greeting=is_hello) | |
logger.info(f"🧭 Mode choisi : {mode}") | |
if mode == "rag": | |
prompt = ( | |
"Instruction: Réponds uniquement à partir du contexte. " | |
"Si la réponse n'est pas déductible, réponds exactement: \"Information non présente dans le contexte.\"" | |
"\n\nContexte :\n" | |
f"{context}\n\n" | |
f"Question : {question_raw}\n" | |
"Réponse :" | |
) | |
resp = self._complete( | |
prompt, | |
stop=DEFAULT_STOPS, | |
max_tokens=MAX_TOKENS, | |
raw=True, # ✅ bypass Modelfile/template | |
).strip() | |
# fallback LLM‑pur si le RAG n'a rien trouvé | |
if allow_fallback and "Information non présente" in resp: | |
logger.info("↪️ Fallback LLM‑pur (hors contexte)") | |
prompt_llm = ( | |
"Réponds brièvement et précisément en français.\n" | |
f"Question : {question_raw}\n" | |
"Réponse :" | |
) | |
resp = self._complete( | |
prompt_llm, | |
stop=DEFAULT_STOPS, | |
max_tokens=MAX_TOKENS, | |
raw=True | |
).strip() | |
ellipsis = "..." if len(resp) > 120 else "" | |
logger.info(f"🧠 Réponse générée : {resp[:120]}{ellipsis}") | |
return resp | |
# LLM pur (salutation ou score faible) | |
prompt_llm = ( | |
"Réponds brièvement et précisément en français.\n" | |
f"Question : {question_raw}\n" | |
"Réponse :" | |
) | |
resp = self._complete( | |
prompt_llm, | |
stop=DEFAULT_STOPS, | |
max_tokens=MAX_TOKENS, | |
raw=True | |
).strip() | |
ellipsis = "..." if len(resp) > 120 else "" | |
logger.info(f"🧠 Réponse générée : {resp[:120]}{ellipsis}") | |
return resp | |
def ask_stream(self, question: str, allow_fallback: bool = False) -> Iterable[str]: | |
logger.info(f"💬 [Stream] Question reçue : {question}") | |
is_hello = self._is_greeting(question) | |
context, scores = "", [] | |
if not is_hello: | |
top_k = self.get_adaptive_top_k(question) | |
context, _, scores = self.retrieve_context(question, top_k) | |
mode = self._decide_mode(scores, tau=0.32, is_greeting=is_hello) | |
logger.info(f"🧭 Mode choisi (stream) : {mode}") | |
stops = DEFAULT_STOPS | |
if mode == "rag": | |
prompt = ( | |
"Instruction: Réponds uniquement à partir du contexte. " | |
"Si la réponse n'est pas déductible, réponds exactement: \"Information non présente dans le contexte.\"" | |
"\n\nContexte :\n" | |
f"{context}\n\n" | |
f"Question : {question}\n" | |
"Réponse :" | |
) | |
logger.info("📡 Début du streaming de la réponse (RAG)...") | |
tokens = self._complete_stream( | |
prompt, | |
stop=stops, | |
max_tokens=MAX_TOKENS, | |
raw=True, | |
) | |
# Blindage local: coupe si un stop apparaît | |
for t in self._stream_with_local_stops(tokens, stops): | |
if t: | |
yield t | |
logger.info("📡 Fin du streaming de la réponse (RAG).") | |
return | |
# LLM pur en stream | |
prompt_llm = ( | |
"Réponds brièvement et précisément en français.\n" | |
f"Question : {question}\n" | |
"Réponse :" | |
) | |
logger.info("📡 Début du streaming de la réponse (LLM pur)...") | |
tokens = self._complete_stream( | |
prompt_llm, | |
stop=stops, | |
max_tokens=MAX_TOKENS, | |
raw=True, | |
) | |
for t in self._stream_with_local_stops(tokens, stops): | |
if t: | |
yield t | |
logger.info("📡 Fin du streaming de la réponse (LLM pur).") | |