File size: 11,982 Bytes
e7a5765
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

import os
import pickle
import textwrap
import logging
from typing import List, Optional, Dict, Any, Iterable

import requests
import faiss
import numpy as np
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import TextNode
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from sentence_transformers.util import cos_sim


# === Logger configuration ===
logger = logging.getLogger("RAGEngine")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s")
handler.setFormatter(formatter)
if not logger.handlers:
    logger.addHandler(handler)

#MAX_TOKENS = 512
MAX_TOKENS = 64


class OllamaClient:
    """
    Minimal Ollama client for /api/generate (text completion) with streaming support.
    Docs: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
    """
    def __init__(self, model: str, host: Optional[str] = None, timeout: int = 300):
        self.model = model
        self.host = host or os.getenv("OLLAMA_HOST", "http://localhost:11434")
        self.timeout = timeout
        self._gen_url = self.host.rstrip("/") + "/api/generate"

    def generate(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        max_tokens: Optional[int] = None,
        stream: bool = False,
        options: Optional[Dict[str, Any]] = None,
        raw: bool = False
    ) -> str | Iterable[str]:
        payload = {
            "model": self.model,
            "prompt": prompt,
            "stream": stream,
        }
        if raw:
            payload["raw"]=True
        if stop:
            payload["stop"] = stop
        if max_tokens is not None:
            # Ollama uses "num_predict" for max new tokens
            payload["num_predict"] = int(max_tokens)
        if options:
            payload["options"] = options

        logger.debug(f"POST {self._gen_url} (stream={stream})")

        if stream:
            with requests.post(self._gen_url, json=payload, stream=True, timeout=self.timeout) as r:
                r.raise_for_status()
                for line in r.iter_lines(decode_unicode=True):
                    if not line:
                        continue
                    try:
                        data = json.loads(line)
                    except Exception:
                        # In case a broken line appears
                        continue
                    if "response" in data and data.get("done") is not True:
                        yield data["response"]
                    if data.get("done"):
                        break
            return

        # Non-streaming
        r = requests.post(self._gen_url, json=payload, timeout=self.timeout)
        r.raise_for_status()
        data = r.json()
        return data.get("response", "")


# Lazy import json to keep top clean
import json


class RAGEngine:
    def __init__(
        self,
        model_name: str,
        vector_path: str,
        index_path: str,
        model_threads: int = 4,
        ollama_host: Optional[str] = None,
        ollama_opts: Optional[Dict[str, Any]] = None,
    ):
        """
        Args:
            model_name: e.g. "nous-hermes2:Q4_K_M" or "llama3.1:8b-instruct-q4_K_M"
            vector_path: pickle file with chunk texts list[str]
            index_path: FAISS index path
            model_threads: forwarded to Ollama via options.n_threads (if supported by the model)
            ollama_host: override OLLAMA_HOST (default http://localhost:11434)
            ollama_opts: extra Ollama options (e.g., temperature, top_p, num_gpu, num_thread)
        """
        logger.info(f"🔎 rag_model_ollama source: {__file__}")
        logger.info("📦 Initialisation du moteur RAG (Ollama)...")
        # Build options
        opts = dict(ollama_opts or {})
        # Common low-latency defaults; user can override via ollama_opts
        opts.setdefault("temperature", 0.1)
        # Try to pass thread hint if supported by the backend
        if "num_thread" not in opts and model_threads:
            opts["num_thread"] = int(model_threads)

        self.llm = OllamaClient(model=model_name, host=ollama_host)
        self.ollama_opts = opts

        #self.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")

        self.embed_model = HuggingFaceEmbedding(model_name="intfloat/multilingual-e5-base")
        logger.info(f"📂 Chargement des données vectorielles depuis {vector_path}")
        with open(vector_path, "rb") as f:
            chunk_texts = pickle.load(f)
        nodes = [TextNode(text=chunk) for chunk in chunk_texts]

        faiss_index = faiss.read_index(index_path)
        vector_store = FaissVectorStore(faiss_index=faiss_index)
        self.index = VectorStoreIndex(nodes=nodes, embed_model=self.embed_model, vector_store=vector_store)

        logger.info("✅ Moteur RAG (Ollama) initialisé avec succès.")

    # ---------------- LLM helpers (via Ollama) ----------------

    def _complete(self, prompt: str, stop: Optional[List[str]] = None, max_tokens: int = 128,raw:bool=True) -> str:
        text = self.llm.generate(
            prompt=prompt,
            stop=stop,
            max_tokens=max_tokens,
            stream=False,
            options=self.ollama_opts,
            raw=raw
        )
        # Some Ollama setups may stream even when stream=False. Coerce generators to string.
        try:
            if hasattr(text, "__iter__") and not isinstance(text, (str, bytes)):
                chunks = []
                for t in text:
                    if not isinstance(t, (str, bytes)):
                        continue
                    chunks.append(t)
                text = "".join(chunks)
        except Exception:
            pass
        return (text or "").strip()

    def _complete_stream(self, prompt: str, stop: Optional[List[str]] = None, max_tokens: int = MAX_TOKENS,raw : bool =True):
        return self.llm.generate(
            prompt=prompt,
            stop=stop,
            max_tokens=max_tokens,
            stream=True,
            options=self.ollama_opts,
            raw=raw
        )

    # ---------------- Reformulation ----------------

    def reformulate_question(self, question: str) -> str:
        logger.info("🔁 Reformulation de la question (sans contexte)...")
        prompt = f"""Tu es un assistant expert chargé de clarifier des questions floues.

Transforme la question suivante en une question claire, explicite et complète, sans ajouter d'informations extérieures.

Question floue : {question}
Question reformulée :"""
        reformulated = self._complete(prompt, stop=["### Réponse:", "\n\n", "###"], max_tokens=128)
        logger.info(f"📝 Reformulée : {reformulated}")
        return reformulated.strip().split("###")[0]

    def reformulate_with_context(self, question: str, context_sample: str) -> str:
        logger.info("🔁 Reformulation de la question avec contexte...")
        prompt = f"""Tu es un assistant expert en machine learning. Ton rôle est de reformuler les questions utilisateur en tenant compte du contexte ci-dessous, extrait d’un rapport technique sur un projet de reconnaissance de maladies de plantes.

Ta mission est de transformer une question vague ou floue en une question précise et adaptée au contenu du rapport. Ne donne pas une interprétation hors sujet. Ne reformule pas en termes de produits commerciaux.

Contexte :
{context_sample}

Question initiale : {question}
Question reformulée :"""
        reformulated = self._complete(prompt, stop=["### Réponse:", "\n\n", "###"], max_tokens=128)
        logger.info(f"📝 Reformulée avec contexte : {reformulated}")
        return reformulated

    # ---------------- Retrieval ----------------

    def get_adaptive_top_k(self, question: str) -> int:
        q = question.lower()
        if len(q.split()) <= 7:
            top_k = 8
        elif any(w in q for w in ["liste", "résume", "quels sont", "explique", "comment"]):
            top_k = 10
        else:
            top_k = 8
        logger.info(f"🔢 top_k déterminé automatiquement : {top_k}")
        return top_k

    def rerank_nodes(self, question: str, retrieved_nodes, top_k: int = 3):
        logger.info(f"🔍 Re-ranking des {len(retrieved_nodes)} chunks pour la question : « {question} »")
        q_emb = self.embed_model.get_query_embedding(question)
        scored_nodes = []

        for node in retrieved_nodes:
            chunk_text = node.get_content()
            chunk_emb = self.embed_model.get_text_embedding(chunk_text)
            score = cos_sim(q_emb, chunk_emb).item()
            scored_nodes.append((score, node))

        ranked_nodes = sorted(scored_nodes, key=lambda x: x[0], reverse=True)

        logger.info("📊 Chunks les plus pertinents :")
        for i, (score, node) in enumerate(ranked_nodes[:top_k]):
            chunk_preview = textwrap.shorten(node.get_content().replace("\n", " "), width=100)
            logger.info(f"#{i+1} | Score: {score:.4f} | {chunk_preview}")

        return [n for _, n in ranked_nodes[:top_k]]

    def retrieve_context(self, question: str, top_k: int = 3):
        logger.info(f"📥 Récupération du contexte...")
        retriever = self.index.as_retriever(similarity_top_k=top_k)
        retrieved_nodes = retriever.retrieve(question)
        reranked_nodes = self.rerank_nodes(question, retrieved_nodes, top_k)
        context = "\n\n".join(n.get_content()[:500] for n in reranked_nodes)
        return context, reranked_nodes

    # ---------------- Public API ----------------

    def ask(self, question_raw: str) -> str:
        logger.info(f"💬 Question reçue : {question_raw}")
        context=""
        reformulate=False
        if reformulate :
            if len(question_raw.split()) <= 2:
                context_sample, _ = self.retrieve_context(question_raw, top_k=3)
                reformulated = self.reformulate_with_context(question_raw, context_sample)
            else:
                reformulated = self.reformulate_question(question_raw)

            logger.info(f"📝 Question reformulée : {reformulated}")
            top_k = self.get_adaptive_top_k(reformulated)
            context, _ = self.retrieve_context(reformulated, top_k)
        else:
            reformulated=question_raw


        prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.

Si la réponse ne peut pas être déduite du contexte, indique : "Information non présente dans le contexte."

Contexte :
{context}

Question : {reformulated}
### Réponse:"""

        response = self._complete(prompt, stop=["### Réponse:", "\n\n", "###"], max_tokens=MAX_TOKENS)
        response = response.strip().split("###")[0]
        ellipsis = "..." if len(response) > 120 else ""
        logger.info(f"🧠 Réponse générée : {response[:120]}{ellipsis}")
        return response

    def ask_stream(self, question: str):
        logger.info(f"💬 [Stream] Question reçue : {question}")
        top_k = self.get_adaptive_top_k(question)
        context, _ = self.retrieve_context(question, top_k)
        context="" #for test purpose

        prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.

Si la réponse ne peut pas être déduite du contexte, indique : "Information non présente dans le contexte."

Contexte :
{context}

Question : {question}
### Réponse:"""

        logger.info("📡 Début du streaming de la réponse...")
        for token in self._complete_stream(prompt,  stop=["### Réponse:", "\n\n", "###"], max_tokens=MAX_TOKENS,raw=False):
            yield token

        logger.info("📡 Fin du streaming de la réponse...")