rkonan commited on
Commit
d8f78dc
·
1 Parent(s): fb1ee83

gestion des modèles

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +53 -2
  3. rag_model.py +22 -33
.gitignore CHANGED
@@ -1,6 +1,8 @@
1
  # Fichiers et dossiers à ignorer
2
  llamavenv/
3
  models/
 
4
  *.gguf
5
  __pycache__/
6
  *.pyc
 
 
1
  # Fichiers et dossiers à ignorer
2
  llamavenv/
3
  models/
4
+ chatbot-models/
5
  *.gguf
6
  __pycache__/
7
  *.pyc
8
+ vectordb/
app.py CHANGED
@@ -2,12 +2,63 @@ import streamlit as st
2
  from llama_cpp import Llama
3
  import os
4
  from rag_model import RAGEngine
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  st.set_page_config(page_title="Chatbot RAG local",page_icon="🤖")
7
 
 
 
 
 
8
  @st.cache_resource
9
  def load_rag_engine():
10
- rag = RAGEngine(model_path="models/Nous-Hermes-2-Mistral-7B-DPO.Q4_K_M.gguf")
11
  return rag
12
 
13
  rag=load_rag_engine()
@@ -18,6 +69,6 @@ user_input=st.text_area("Posez votre question :", height=100)
18
 
19
  if st.button("Envoyer") and user_input.strip():
20
  with st.spinner("Génération en cours..."):
21
- response = rag.ask(user_input,mode="docling")
22
  st.markdown("**Réponse :**")
23
  st.success(response)
 
2
  from llama_cpp import Llama
3
  import os
4
  from rag_model import RAGEngine
5
+ import logging
6
+ from huggingface_hub import hf_hub_download
7
+ import time
8
+
9
+
10
+ ENV = os.getenv("ENV", "space")
11
+
12
+ logger = logging.getLogger("Streamlit")
13
+ logger.setLevel(logging.INFO)
14
+ handler = logging.StreamHandler()
15
+ formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s")
16
+ handler.setFormatter(formatter)
17
+ logger.addHandler(handler)
18
+
19
+
20
+ logger.info(f"ENV :{ENV}")
21
+
22
+ #time.sleep(5)
23
+
24
+ if ENV == "local":
25
+ model_path = "chatbot-models/Nous-Hermes-2-Mistral-7B-DPO.Q4_K_M.gguf"
26
+ faiss_index_path="chatbot-models/vectordb_docling/index.faiss"
27
+ vectors_path="chatbot-models/vectordb_docling/chunks.pkl"
28
+
29
+ else:
30
+ # Télécharger le modèle GGUF
31
+ model_path = hf_hub_download(
32
+ repo_id="rkonan/chatbot-models",
33
+ filename="chatbot-models/Nous-Hermes-2-Mistral-7B-DPO.Q4_K_M.gguf",
34
+ repo_type="dataset"
35
+ )
36
+
37
+ # Télécharger les fichiers FAISS
38
+ faiss_index_path = hf_hub_download(
39
+ repo_id="rkonan/chatbot-models",
40
+ filename="chatbot-models/vectordb_docling/index.faiss",
41
+ repo_type="dataset"
42
+ )
43
+
44
+ vectors_path = hf_hub_download(
45
+ repo_id="rkonan/chatbot-models",
46
+ filename="chatbot-models/vectordb_docling/chunks.pkl",
47
+ repo_type="dataset"
48
+ )
49
+
50
+
51
+
52
 
53
  st.set_page_config(page_title="Chatbot RAG local",page_icon="🤖")
54
 
55
+
56
+
57
+
58
+
59
  @st.cache_resource
60
  def load_rag_engine():
61
+ rag = RAGEngine(model_path,vectors_path,faiss_index_path)
62
  return rag
63
 
64
  rag=load_rag_engine()
 
69
 
70
  if st.button("Envoyer") and user_input.strip():
71
  with st.spinner("Génération en cours..."):
72
+ response = rag.ask(user_input)
73
  st.markdown("**Réponse :**")
74
  st.success(response)
rag_model.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import pickle
3
  import textwrap
4
  import logging
5
- from typing import Dict, List
6
 
7
  import faiss
8
  import numpy as np
@@ -24,30 +24,20 @@ logger.addHandler(handler)
24
  MAX_TOKENS = 512
25
 
26
  class RAGEngine:
27
- def __init__(self, model_path: str, vector_modes: List[str] = ["docling"], model_threads: int = 4):
28
  logger.info("📦 Initialisation du moteur RAG...")
29
  self.llm = Llama(model_path=model_path, n_ctx=2048, n_threads=model_threads)
30
  self.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
31
- self.indexes: Dict[str, Dict] = {}
32
 
33
- for mode in vector_modes:
34
- vectordir = f"vectordb_{mode}" if mode != "sentence" else "vectordb"
35
- index_file = os.path.join(vectordir, "index.faiss")
36
- chunks_file = os.path.join(vectordir, "chunks.pkl")
37
 
38
- logger.info(f"📂 Chargement des données vectorielles pour le mode '{mode}' depuis {vectordir}")
39
- with open(chunks_file, "rb") as f:
40
- chunk_texts = pickle.load(f)
41
- nodes = [TextNode(text=chunk) for chunk in chunk_texts]
42
 
43
- faiss_index = faiss.read_index(index_file)
44
- vector_store = FaissVectorStore(faiss_index=faiss_index)
45
- index = VectorStoreIndex(nodes=nodes, embed_model=self.embed_model, vector_store=vector_store)
46
-
47
- self.indexes[mode] = {
48
- "nodes": nodes,
49
- "index": index
50
- }
51
 
52
  logger.info("✅ Moteur RAG initialisé avec succès.")
53
 
@@ -111,26 +101,25 @@ Question reformulée :"""
111
 
112
  return [n for _, n in ranked_nodes[:top_k]]
113
 
114
- def retrieve_context(self, question: str, mode: str, top_k: int = 3):
115
- logger.info(f"📥 Récupération du contexte pour le mode « {mode} »...")
116
- retriever = self.indexes[mode]["index"].as_retriever(similarity_top_k=top_k)
117
  retrieved_nodes = retriever.retrieve(question)
118
  reranked_nodes = self.rerank_nodes(question, retrieved_nodes, top_k)
119
  context = "\n\n".join(n.get_content()[:500] for n in reranked_nodes)
120
  return context, reranked_nodes
121
 
122
- def ask(self, question_raw: str, mode: str = "docling") -> str:
123
  logger.info(f"💬 Question reçue : {question_raw}")
124
  if len(question_raw.split()) <= 3:
125
- context_sample, _ = self.retrieve_context(question_raw, mode, 3)
126
- reformulated = self.reformulate_with_context( question_raw, context_sample)
127
  else:
128
- reformulated = self.reformulate_question( question_raw)
129
 
130
- print(f"📝 Question reformulée : {reformulated}")
131
- question = reformulated
132
- top_k = self.get_adaptive_top_k(question)
133
- context, _ = self.retrieve_context(question, mode, top_k)
134
 
135
  prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.
136
 
@@ -139,7 +128,7 @@ Si la réponse ne peut pas être déduite du contexte, indique : "Information no
139
  Contexte :
140
  {context}
141
 
142
- Question : {question}
143
  ### Réponse:"""
144
 
145
  output = self.llm(prompt, max_tokens=MAX_TOKENS, stop=["### Instruction:"], stream=False)
@@ -147,10 +136,10 @@ Question : {question}
147
  logger.info(f"🧠 Réponse générée : {response[:120]}{'...' if len(response) > 120 else ''}")
148
  return response
149
 
150
- def ask_stream(self, question: str, mode: str = "docling"):
151
  logger.info(f"💬 [Stream] Question reçue : {question}")
152
  top_k = self.get_adaptive_top_k(question)
153
- context, _ = self.retrieve_context(question, mode, top_k)
154
 
155
  prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.
156
 
 
2
  import pickle
3
  import textwrap
4
  import logging
5
+ from typing import List
6
 
7
  import faiss
8
  import numpy as np
 
24
  MAX_TOKENS = 512
25
 
26
  class RAGEngine:
27
+ def __init__(self, model_path: str, vector_path: str, index_path: str, model_threads: int = 4):
28
  logger.info("📦 Initialisation du moteur RAG...")
29
  self.llm = Llama(model_path=model_path, n_ctx=2048, n_threads=model_threads)
30
  self.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
31
 
 
 
 
 
32
 
33
+ logger.info(f"📂 Chargement des données vectorielles depuis {vector_path}")
34
+ with open(vector_path, "rb") as f:
35
+ chunk_texts = pickle.load(f)
36
+ nodes = [TextNode(text=chunk) for chunk in chunk_texts]
37
 
38
+ faiss_index = faiss.read_index(index_path)
39
+ vector_store = FaissVectorStore(faiss_index=faiss_index)
40
+ self.index = VectorStoreIndex(nodes=nodes, embed_model=self.embed_model, vector_store=vector_store)
 
 
 
 
 
41
 
42
  logger.info("✅ Moteur RAG initialisé avec succès.")
43
 
 
101
 
102
  return [n for _, n in ranked_nodes[:top_k]]
103
 
104
+ def retrieve_context(self, question: str, top_k: int = 3):
105
+ logger.info(f"📥 Récupération du contexte...")
106
+ retriever = self.index.as_retriever(similarity_top_k=top_k)
107
  retrieved_nodes = retriever.retrieve(question)
108
  reranked_nodes = self.rerank_nodes(question, retrieved_nodes, top_k)
109
  context = "\n\n".join(n.get_content()[:500] for n in reranked_nodes)
110
  return context, reranked_nodes
111
 
112
+ def ask(self, question_raw: str) -> str:
113
  logger.info(f"💬 Question reçue : {question_raw}")
114
  if len(question_raw.split()) <= 3:
115
+ context_sample, _ = self.retrieve_context(question_raw, top_k=3)
116
+ reformulated = self.reformulate_with_context(question_raw, context_sample)
117
  else:
118
+ reformulated = self.reformulate_question(question_raw)
119
 
120
+ logger.info(f"📝 Question reformulée : {reformulated}")
121
+ top_k = self.get_adaptive_top_k(reformulated)
122
+ context, _ = self.retrieve_context(reformulated, top_k)
 
123
 
124
  prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.
125
 
 
128
  Contexte :
129
  {context}
130
 
131
+ Question : {reformulated}
132
  ### Réponse:"""
133
 
134
  output = self.llm(prompt, max_tokens=MAX_TOKENS, stop=["### Instruction:"], stream=False)
 
136
  logger.info(f"🧠 Réponse générée : {response[:120]}{'...' if len(response) > 120 else ''}")
137
  return response
138
 
139
+ def ask_stream(self, question: str):
140
  logger.info(f"💬 [Stream] Question reçue : {question}")
141
  top_k = self.get_adaptive_top_k(question)
142
+ context, _ = self.retrieve_context(question, top_k)
143
 
144
  prompt = f"""### Instruction: En te basant uniquement sur le contexte ci-dessous, réponds à la question de manière précise et en français.
145