Spaces:
Paused
Paused
optims
Browse files- app.py +13 -2
- rag_model_optimise.py +144 -0
app.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import streamlit as st
|
2 |
from llama_cpp import Llama
|
3 |
import os
|
4 |
-
from rag_model import RAGEngine
|
|
|
|
|
5 |
import logging
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
import time
|
@@ -64,9 +66,18 @@ st.set_page_config(page_title="Chatbot RAG local",page_icon="🤖")
|
|
64 |
|
65 |
@st.cache_resource
|
66 |
def load_rag_engine():
|
67 |
-
rag = RAGEngine(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
return rag
|
69 |
|
|
|
70 |
rag=load_rag_engine()
|
71 |
|
72 |
st.title("🤖 Chatbot LLM Local (CPU)")
|
|
|
1 |
import streamlit as st
|
2 |
from llama_cpp import Llama
|
3 |
import os
|
4 |
+
#from rag_model import RAGEngine
|
5 |
+
|
6 |
+
from rag_model_optimise import RAGEngine
|
7 |
import logging
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
import time
|
|
|
66 |
|
67 |
@st.cache_resource
|
68 |
def load_rag_engine():
|
69 |
+
rag = RAGEngine(
|
70 |
+
model_path=model_path,
|
71 |
+
vector_path=vectors_path,
|
72 |
+
index_path=faiss_index_path,
|
73 |
+
model_threads=8 # ✅ plus rapide
|
74 |
+
)
|
75 |
+
|
76 |
+
# 🔥 Warmup pour éviter latence au 1er appel
|
77 |
+
rag.llm("Bonjour", max_tokens=1)
|
78 |
return rag
|
79 |
|
80 |
+
|
81 |
rag=load_rag_engine()
|
82 |
|
83 |
st.title("🤖 Chatbot LLM Local (CPU)")
|
rag_model_optimise.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import textwrap
|
4 |
+
import logging
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import faiss
|
8 |
+
import numpy as np
|
9 |
+
from llama_cpp import Llama
|
10 |
+
from llama_index.core import VectorStoreIndex
|
11 |
+
from llama_index.core.schema import TextNode
|
12 |
+
from llama_index.vector_stores.faiss import FaissVectorStore
|
13 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
14 |
+
from sentence_transformers.util import cos_sim
|
15 |
+
|
16 |
+
# === Logger configuration ===
|
17 |
+
logger = logging.getLogger("RAGEngine")
|
18 |
+
logger.setLevel(logging.INFO)
|
19 |
+
handler = logging.StreamHandler()
|
20 |
+
formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s")
|
21 |
+
handler.setFormatter(formatter)
|
22 |
+
logger.addHandler(handler)
|
23 |
+
|
24 |
+
MAX_TOKENS = 512
|
25 |
+
|
26 |
+
class RAGEngine:
|
27 |
+
def __init__(self, model_path: str, vector_path: str, index_path: str, model_threads: int = 4):
|
28 |
+
logger.info("📦 Initialisation du moteur RAG...")
|
29 |
+
self.llm = Llama(model_path=model_path, n_ctx=2048, n_threads=model_threads)
|
30 |
+
self.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
31 |
+
|
32 |
+
# Warmup pour éviter le temps de latence initial
|
33 |
+
try:
|
34 |
+
self.llm("Bonjour", max_tokens=1)
|
35 |
+
except Exception as e:
|
36 |
+
logger.warning(f"Warmup LLM échoué : {e}")
|
37 |
+
|
38 |
+
logger.info(f"📂 Chargement des données vectorielles depuis {vector_path}")
|
39 |
+
with open(vector_path, "rb") as f:
|
40 |
+
chunk_texts = pickle.load(f)
|
41 |
+
nodes = [TextNode(text=chunk) for chunk in chunk_texts]
|
42 |
+
|
43 |
+
faiss_index = faiss.read_index(index_path)
|
44 |
+
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
45 |
+
self.index = VectorStoreIndex(nodes=nodes, embed_model=self.embed_model, vector_store=vector_store)
|
46 |
+
|
47 |
+
logger.info("✅ Moteur RAG initialisé avec succès.")
|
48 |
+
|
49 |
+
def reformulate_with_context(self, question: str, context_sample: str) -> str:
|
50 |
+
logger.info("🔁 Reformulation de la question avec contexte...")
|
51 |
+
prompt = f"""Tu es un assistant expert en machine learning. Ton rôle est de reformuler les questions utilisateur en tenant compte du contexte ci-dessous, extrait d’un rapport technique sur un projet de reconnaissance de maladies de plantes.
|
52 |
+
|
53 |
+
Ta mission est de transformer une question vague ou floue en une question précise et adaptée au contenu du rapport. Ne donne pas une interprétation hors sujet. Ne reformule pas en termes de produits commerciaux.
|
54 |
+
|
55 |
+
Contexte :
|
56 |
+
{context_sample}
|
57 |
+
|
58 |
+
Question initiale : {question}
|
59 |
+
Question reformulée :"""
|
60 |
+
output = self.llm(prompt, max_tokens=128, stop=["
|
61 |
+
"], stream=False)
|
62 |
+
reformulated = output["choices"][0]["text"].strip()
|
63 |
+
logger.info(f"📝 Reformulée avec contexte : {reformulated}")
|
64 |
+
return reformulated
|
65 |
+
|
66 |
+
def get_adaptive_top_k(self, question: str) -> int:
|
67 |
+
q = question.lower()
|
68 |
+
if len(q.split()) <= 7:
|
69 |
+
return 8
|
70 |
+
elif any(w in q for w in ["liste", "résume", "quels sont", "explique", "comment"]):
|
71 |
+
return 10
|
72 |
+
return 8
|
73 |
+
|
74 |
+
def rerank_nodes(self, question: str, retrieved_nodes, top_k: int = 3):
|
75 |
+
logger.info(f"🔍 Re-ranking des {len(retrieved_nodes)} chunks pour la question : « {question} »")
|
76 |
+
q_emb = self.embed_model.get_query_embedding(question)
|
77 |
+
scored_nodes = []
|
78 |
+
|
79 |
+
for node in retrieved_nodes:
|
80 |
+
chunk_text = node.get_content()
|
81 |
+
chunk_emb = self.embed_model.get_text_embedding(chunk_text)
|
82 |
+
score = float(np.dot(q_emb, chunk_emb))
|
83 |
+
scored_nodes.append((score, node))
|
84 |
+
|
85 |
+
ranked_nodes = sorted(scored_nodes, key=lambda x: x[0], reverse=True)
|
86 |
+
|
87 |
+
logger.info("📊 Chunks les plus pertinents :")
|
88 |
+
for i, (score, node) in enumerate(ranked_nodes[:top_k]):
|
89 |
+
chunk_preview = textwrap.shorten(node.get_content().replace("\n", " "), width=100)
|
90 |
+
logger.info(f"#{i+1} | Score: {score:.4f} | {chunk_preview}")
|
91 |
+
|
92 |
+
return [n for _, n in ranked_nodes[:top_k]]
|
93 |
+
|
94 |
+
def retrieve_context(self, question: str, top_k: int = 3):
|
95 |
+
logger.info(f"📥 Récupération du contexte...")
|
96 |
+
retriever = self.index.as_retriever(similarity_top_k=top_k)
|
97 |
+
retrieved_nodes = retriever.retrieve(question)
|
98 |
+
reranked_nodes = self.rerank_nodes(question, retrieved_nodes, top_k)
|
99 |
+
context = "\n\n".join(n.get_content()[:500] for n in reranked_nodes)
|
100 |
+
return context, reranked_nodes
|
101 |
+
|
102 |
+
def ask(self, question_raw: str) -> str:
|
103 |
+
logger.info(f"💬 Question reçue : {question_raw}")
|
104 |
+
context_sample, _ = self.retrieve_context(question_raw, top_k=3)
|
105 |
+
reformulated = self.reformulate_with_context(question_raw, context_sample)
|
106 |
+
|
107 |
+
logger.info(f"📝 Question reformulée : {reformulated}")
|
108 |
+
top_k = self.get_adaptive_top_k(reformulated)
|
109 |
+
context, _ = self.retrieve_context(reformulated, top_k)
|
110 |
+
|
111 |
+
prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.
|
112 |
+
|
113 |
+
Si la réponse ne peut pas être déduite du contexte, indique : "Information non présente dans le contexte."
|
114 |
+
|
115 |
+
Contexte :
|
116 |
+
{context}
|
117 |
+
|
118 |
+
Question : {reformulated}
|
119 |
+
### Réponse:"""
|
120 |
+
|
121 |
+
output = self.llm(prompt, max_tokens=MAX_TOKENS, stop=["### Instruction:"], stream=False)
|
122 |
+
response = output["choices"][0]["text"].strip().split("###")[0]
|
123 |
+
logger.info(f"🧠 Réponse générée : {response[:120]}{'...' if len(response) > 120 else ''}")
|
124 |
+
return response
|
125 |
+
|
126 |
+
def ask_stream(self, question: str):
|
127 |
+
logger.info(f"💬 [Stream] Question reçue : {question}")
|
128 |
+
top_k = self.get_adaptive_top_k(question)
|
129 |
+
context, _ = self.retrieve_context(question, top_k)
|
130 |
+
|
131 |
+
prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.
|
132 |
+
|
133 |
+
Si la réponse ne peut pas être déduite du contexte, indique : "Information non présente dans le contexte."
|
134 |
+
|
135 |
+
Contexte :
|
136 |
+
{context}
|
137 |
+
|
138 |
+
Question : {question}
|
139 |
+
### Réponse:"""
|
140 |
+
|
141 |
+
logger.info("📡 Début du streaming de la réponse...")
|
142 |
+
stream = self.llm(prompt, max_tokens=MAX_TOKENS, stop=["### Instruction:"], stream=True)
|
143 |
+
for chunk in stream:
|
144 |
+
print(chunk["choices"][0]["text"], end="", flush=True)
|