Spaces:
Paused
Paused
gestion des modèles
Browse files- .gitignore +2 -0
- app.py +53 -2
- rag_model.py +22 -33
.gitignore
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
# Fichiers et dossiers à ignorer
|
2 |
llamavenv/
|
3 |
models/
|
|
|
4 |
*.gguf
|
5 |
__pycache__/
|
6 |
*.pyc
|
|
|
|
1 |
# Fichiers et dossiers à ignorer
|
2 |
llamavenv/
|
3 |
models/
|
4 |
+
chatbot-models/
|
5 |
*.gguf
|
6 |
__pycache__/
|
7 |
*.pyc
|
8 |
+
vectordb/
|
app.py
CHANGED
@@ -2,12 +2,63 @@ import streamlit as st
|
|
2 |
from llama_cpp import Llama
|
3 |
import os
|
4 |
from rag_model import RAGEngine
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
st.set_page_config(page_title="Chatbot RAG local",page_icon="🤖")
|
7 |
|
|
|
|
|
|
|
|
|
8 |
@st.cache_resource
|
9 |
def load_rag_engine():
|
10 |
-
rag = RAGEngine(model_path
|
11 |
return rag
|
12 |
|
13 |
rag=load_rag_engine()
|
@@ -18,6 +69,6 @@ user_input=st.text_area("Posez votre question :", height=100)
|
|
18 |
|
19 |
if st.button("Envoyer") and user_input.strip():
|
20 |
with st.spinner("Génération en cours..."):
|
21 |
-
response = rag.ask(user_input
|
22 |
st.markdown("**Réponse :**")
|
23 |
st.success(response)
|
|
|
2 |
from llama_cpp import Llama
|
3 |
import os
|
4 |
from rag_model import RAGEngine
|
5 |
+
import logging
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
import time
|
8 |
+
|
9 |
+
|
10 |
+
ENV = os.getenv("ENV", "space")
|
11 |
+
|
12 |
+
logger = logging.getLogger("Streamlit")
|
13 |
+
logger.setLevel(logging.INFO)
|
14 |
+
handler = logging.StreamHandler()
|
15 |
+
formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s")
|
16 |
+
handler.setFormatter(formatter)
|
17 |
+
logger.addHandler(handler)
|
18 |
+
|
19 |
+
|
20 |
+
logger.info(f"ENV :{ENV}")
|
21 |
+
|
22 |
+
#time.sleep(5)
|
23 |
+
|
24 |
+
if ENV == "local":
|
25 |
+
model_path = "chatbot-models/Nous-Hermes-2-Mistral-7B-DPO.Q4_K_M.gguf"
|
26 |
+
faiss_index_path="chatbot-models/vectordb_docling/index.faiss"
|
27 |
+
vectors_path="chatbot-models/vectordb_docling/chunks.pkl"
|
28 |
+
|
29 |
+
else:
|
30 |
+
# Télécharger le modèle GGUF
|
31 |
+
model_path = hf_hub_download(
|
32 |
+
repo_id="rkonan/chatbot-models",
|
33 |
+
filename="chatbot-models/Nous-Hermes-2-Mistral-7B-DPO.Q4_K_M.gguf",
|
34 |
+
repo_type="dataset"
|
35 |
+
)
|
36 |
+
|
37 |
+
# Télécharger les fichiers FAISS
|
38 |
+
faiss_index_path = hf_hub_download(
|
39 |
+
repo_id="rkonan/chatbot-models",
|
40 |
+
filename="chatbot-models/vectordb_docling/index.faiss",
|
41 |
+
repo_type="dataset"
|
42 |
+
)
|
43 |
+
|
44 |
+
vectors_path = hf_hub_download(
|
45 |
+
repo_id="rkonan/chatbot-models",
|
46 |
+
filename="chatbot-models/vectordb_docling/chunks.pkl",
|
47 |
+
repo_type="dataset"
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
|
53 |
st.set_page_config(page_title="Chatbot RAG local",page_icon="🤖")
|
54 |
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
@st.cache_resource
|
60 |
def load_rag_engine():
|
61 |
+
rag = RAGEngine(model_path,vectors_path,faiss_index_path)
|
62 |
return rag
|
63 |
|
64 |
rag=load_rag_engine()
|
|
|
69 |
|
70 |
if st.button("Envoyer") and user_input.strip():
|
71 |
with st.spinner("Génération en cours..."):
|
72 |
+
response = rag.ask(user_input)
|
73 |
st.markdown("**Réponse :**")
|
74 |
st.success(response)
|
rag_model.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
import pickle
|
3 |
import textwrap
|
4 |
import logging
|
5 |
-
from typing import
|
6 |
|
7 |
import faiss
|
8 |
import numpy as np
|
@@ -24,30 +24,20 @@ logger.addHandler(handler)
|
|
24 |
MAX_TOKENS = 512
|
25 |
|
26 |
class RAGEngine:
|
27 |
-
def __init__(self, model_path: str,
|
28 |
logger.info("📦 Initialisation du moteur RAG...")
|
29 |
self.llm = Llama(model_path=model_path, n_ctx=2048, n_threads=model_threads)
|
30 |
self.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
31 |
-
self.indexes: Dict[str, Dict] = {}
|
32 |
|
33 |
-
for mode in vector_modes:
|
34 |
-
vectordir = f"vectordb_{mode}" if mode != "sentence" else "vectordb"
|
35 |
-
index_file = os.path.join(vectordir, "index.faiss")
|
36 |
-
chunks_file = os.path.join(vectordir, "chunks.pkl")
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
self.indexes[mode] = {
|
48 |
-
"nodes": nodes,
|
49 |
-
"index": index
|
50 |
-
}
|
51 |
|
52 |
logger.info("✅ Moteur RAG initialisé avec succès.")
|
53 |
|
@@ -111,26 +101,25 @@ Question reformulée :"""
|
|
111 |
|
112 |
return [n for _, n in ranked_nodes[:top_k]]
|
113 |
|
114 |
-
def retrieve_context(self, question: str,
|
115 |
-
logger.info(f"📥 Récupération du contexte
|
116 |
-
retriever = self.
|
117 |
retrieved_nodes = retriever.retrieve(question)
|
118 |
reranked_nodes = self.rerank_nodes(question, retrieved_nodes, top_k)
|
119 |
context = "\n\n".join(n.get_content()[:500] for n in reranked_nodes)
|
120 |
return context, reranked_nodes
|
121 |
|
122 |
-
def ask(self, question_raw: str
|
123 |
logger.info(f"💬 Question reçue : {question_raw}")
|
124 |
if len(question_raw.split()) <= 3:
|
125 |
-
context_sample, _ = self.retrieve_context(question_raw,
|
126 |
-
reformulated = self.reformulate_with_context(
|
127 |
else:
|
128 |
-
reformulated = self.reformulate_question(
|
129 |
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
context, _ = self.retrieve_context(question, mode, top_k)
|
134 |
|
135 |
prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.
|
136 |
|
@@ -139,7 +128,7 @@ Si la réponse ne peut pas être déduite du contexte, indique : "Information no
|
|
139 |
Contexte :
|
140 |
{context}
|
141 |
|
142 |
-
Question : {
|
143 |
### Réponse:"""
|
144 |
|
145 |
output = self.llm(prompt, max_tokens=MAX_TOKENS, stop=["### Instruction:"], stream=False)
|
@@ -147,10 +136,10 @@ Question : {question}
|
|
147 |
logger.info(f"🧠 Réponse générée : {response[:120]}{'...' if len(response) > 120 else ''}")
|
148 |
return response
|
149 |
|
150 |
-
def ask_stream(self, question: str
|
151 |
logger.info(f"💬 [Stream] Question reçue : {question}")
|
152 |
top_k = self.get_adaptive_top_k(question)
|
153 |
-
context, _ = self.retrieve_context(question,
|
154 |
|
155 |
prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.
|
156 |
|
|
|
2 |
import pickle
|
3 |
import textwrap
|
4 |
import logging
|
5 |
+
from typing import List
|
6 |
|
7 |
import faiss
|
8 |
import numpy as np
|
|
|
24 |
MAX_TOKENS = 512
|
25 |
|
26 |
class RAGEngine:
|
27 |
+
def __init__(self, model_path: str, vector_path: str, index_path: str, model_threads: int = 4):
|
28 |
logger.info("📦 Initialisation du moteur RAG...")
|
29 |
self.llm = Llama(model_path=model_path, n_ctx=2048, n_threads=model_threads)
|
30 |
self.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
|
|
31 |
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
logger.info(f"📂 Chargement des données vectorielles depuis {vector_path}")
|
34 |
+
with open(vector_path, "rb") as f:
|
35 |
+
chunk_texts = pickle.load(f)
|
36 |
+
nodes = [TextNode(text=chunk) for chunk in chunk_texts]
|
37 |
|
38 |
+
faiss_index = faiss.read_index(index_path)
|
39 |
+
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
40 |
+
self.index = VectorStoreIndex(nodes=nodes, embed_model=self.embed_model, vector_store=vector_store)
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
logger.info("✅ Moteur RAG initialisé avec succès.")
|
43 |
|
|
|
101 |
|
102 |
return [n for _, n in ranked_nodes[:top_k]]
|
103 |
|
104 |
+
def retrieve_context(self, question: str, top_k: int = 3):
|
105 |
+
logger.info(f"📥 Récupération du contexte...")
|
106 |
+
retriever = self.index.as_retriever(similarity_top_k=top_k)
|
107 |
retrieved_nodes = retriever.retrieve(question)
|
108 |
reranked_nodes = self.rerank_nodes(question, retrieved_nodes, top_k)
|
109 |
context = "\n\n".join(n.get_content()[:500] for n in reranked_nodes)
|
110 |
return context, reranked_nodes
|
111 |
|
112 |
+
def ask(self, question_raw: str) -> str:
|
113 |
logger.info(f"💬 Question reçue : {question_raw}")
|
114 |
if len(question_raw.split()) <= 3:
|
115 |
+
context_sample, _ = self.retrieve_context(question_raw, top_k=3)
|
116 |
+
reformulated = self.reformulate_with_context(question_raw, context_sample)
|
117 |
else:
|
118 |
+
reformulated = self.reformulate_question(question_raw)
|
119 |
|
120 |
+
logger.info(f"📝 Question reformulée : {reformulated}")
|
121 |
+
top_k = self.get_adaptive_top_k(reformulated)
|
122 |
+
context, _ = self.retrieve_context(reformulated, top_k)
|
|
|
123 |
|
124 |
prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.
|
125 |
|
|
|
128 |
Contexte :
|
129 |
{context}
|
130 |
|
131 |
+
Question : {reformulated}
|
132 |
### Réponse:"""
|
133 |
|
134 |
output = self.llm(prompt, max_tokens=MAX_TOKENS, stop=["### Instruction:"], stream=False)
|
|
|
136 |
logger.info(f"🧠 Réponse générée : {response[:120]}{'...' if len(response) > 120 else ''}")
|
137 |
return response
|
138 |
|
139 |
+
def ask_stream(self, question: str):
|
140 |
logger.info(f"💬 [Stream] Question reçue : {question}")
|
141 |
top_k = self.get_adaptive_top_k(question)
|
142 |
+
context, _ = self.retrieve_context(question, top_k)
|
143 |
|
144 |
prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.
|
145 |
|