Spaces:
Paused
Paused
Commit
·
3e1e43f
1
Parent(s):
0bd189c
updated
Browse files
app.py
CHANGED
@@ -117,132 +117,78 @@ def init_hf_model() -> None:
|
|
117 |
_chatbot_embedder = None
|
118 |
_chatbot_collection = None
|
119 |
|
120 |
-
def
|
121 |
-
"""Initialise the
|
122 |
-
|
123 |
-
|
124 |
-
initialisation steps once. Subsequent calls will return immediately if
|
125 |
-
the global variables are already populated. The knowledge base is read
|
126 |
-
from ``CHATBOT_TXT_PATH``, split into overlapping chunks and encoded
|
127 |
-
using a lightweight sentence transformer. The resulting embeddings are
|
128 |
-
stored in a Chroma collection located at ``CHATBOT_DB_DIR``. We set
|
129 |
-
``anonymized_telemetry=False`` to prevent any external network calls from
|
130 |
-
the Chroma client.
|
131 |
-
"""
|
132 |
-
global _chatbot_embedder, _chatbot_collection
|
133 |
-
if _chatbot_embedder is not None and _chatbot_collection is not None:
|
134 |
return
|
135 |
-
# Perform imports locally to avoid slowing down application startup. These
|
136 |
-
# libraries are heavy and only needed when the chatbot is used.
|
137 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
138 |
-
from sentence_transformers import SentenceTransformer
|
139 |
-
import chromadb
|
140 |
-
from chromadb.config import Settings
|
141 |
-
|
142 |
-
# Ensure the persist directory exists. Chroma will create it if missing,
|
143 |
-
# but explicitly creating it avoids permission errors on some platforms.
|
144 |
-
os.makedirs(CHATBOT_DB_DIR, exist_ok=True)
|
145 |
-
|
146 |
-
# Read the raw FAQ text and split into overlapping chunks to improve
|
147 |
-
# retrieval granularity. The chunk size and overlap are tuned to
|
148 |
-
# accommodate the relatively small knowledge base.
|
149 |
-
with open(CHATBOT_TXT_PATH, encoding='utf-8') as f:
|
150 |
-
text = f.read()
|
151 |
-
splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=100)
|
152 |
-
docs = [doc.strip() for doc in splitter.split_text(text)]
|
153 |
-
|
154 |
-
# Load the sentence transformer. This model is small and runs quickly on
|
155 |
-
# CPU. If you wish to change the model, update the name here.
|
156 |
-
embedder = SentenceTransformer('all-MiniLM-L6-v2')
|
157 |
-
embeddings = embedder.encode(docs, show_progress_bar=False, batch_size=32)
|
158 |
-
|
159 |
-
# Initialise Chroma with an on‑disk persistent store. If the collection
|
160 |
-
# already exists and contains all documents, the add operation below will
|
161 |
-
# silently merge duplicates.
|
162 |
-
client = chromadb.Client(Settings(persist_directory=CHATBOT_DB_DIR, anonymized_telemetry=False))
|
163 |
-
collection = client.get_or_create_collection('chatbot')
|
164 |
-
ids = [f'doc_{i}' for i in range(len(docs))]
|
165 |
-
try:
|
166 |
-
# Attempt to query an existing document to see if the collection is
|
167 |
-
# populated. If this fails, we'll proceed to add all documents.
|
168 |
-
existing = collection.get(ids=ids[:1])
|
169 |
-
if not existing.get('documents'):
|
170 |
-
raise ValueError('No documents in collection')
|
171 |
-
except Exception:
|
172 |
-
collection.add(documents=docs, embeddings=embeddings, ids=ids)
|
173 |
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
def get_chatbot_response(query: str) -> str:
|
178 |
-
"""Generate a reply to the user's query using
|
179 |
-
|
180 |
-
This function performs a two‑stage process to answer user questions. First
|
181 |
-
it ensures that the vector store and embedder are available via
|
182 |
-
``init_chatbot()``, then embeds the query to retrieve the most relevant
|
183 |
-
context chunks from ``chatbot.txt`` using Chroma. Second, it calls
|
184 |
-
``init_hf_model()`` to lazily load a conversational model from Hugging
|
185 |
-
Face. The retrieved context, together with a system instruction,
|
186 |
-
constitute the prompt for the model. The model is then run to
|
187 |
-
generate an answer. If the user asks a question unrelated to the
|
188 |
-
Codingo platform the system prompt instructs the model to refuse
|
189 |
-
politely.
|
190 |
-
|
191 |
-
Parameters
|
192 |
-
----------
|
193 |
-
query: str
|
194 |
-
The user's input message.
|
195 |
-
|
196 |
-
Returns
|
197 |
-
-------
|
198 |
-
str
|
199 |
-
The assistant's reply.
|
200 |
-
"""
|
201 |
-
# Ensure the embedding model and vector store are ready.
|
202 |
init_chatbot()
|
203 |
init_hf_model()
|
|
|
|
|
|
|
|
|
|
|
204 |
embedder = _chatbot_embedder
|
205 |
collection = _chatbot_collection
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
208 |
query_embedding = embedder.encode([query])[0]
|
209 |
results = collection.query(query_embeddings=[query_embedding], n_results=3)
|
210 |
-
retrieved_docs = results.get(
|
211 |
context = "\n".join(retrieved_docs)
|
212 |
-
|
213 |
-
#
|
214 |
system_prompt = (
|
215 |
"You are a helpful assistant for the Codingo website. "
|
216 |
-
"Only answer questions
|
217 |
-
"If
|
218 |
-
"\"I'm only trained to answer questions about the Codingo platform.\""
|
219 |
)
|
220 |
-
|
221 |
-
# the system prompt inline helps guide smaller conversational models.
|
222 |
prompt = f"{system_prompt}\n\nContext:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
223 |
-
|
224 |
-
#
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
|
|
246 |
|
247 |
# Initialize Flask app
|
248 |
app = Flask(
|
|
|
117 |
_chatbot_embedder = None
|
118 |
_chatbot_collection = None
|
119 |
|
120 |
+
def init_hf_model() -> None:
|
121 |
+
"""Initialise the Hugging Face conversational model and tokenizer."""
|
122 |
+
global _hf_model, _hf_tokenizer
|
123 |
+
if _hf_model is not None and _hf_tokenizer is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
127 |
+
import torch
|
128 |
+
|
129 |
+
model_name = "facebook/blenderbot-400M-distill"
|
130 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
131 |
+
|
132 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
133 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
|
134 |
+
|
135 |
+
_hf_model = model
|
136 |
+
_hf_tokenizer = tokenizer
|
137 |
+
|
138 |
|
139 |
def get_chatbot_response(query: str) -> str:
|
140 |
+
"""Generate a reply to the user's query using Chroma + Hugging Face model."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
init_chatbot()
|
142 |
init_hf_model()
|
143 |
+
|
144 |
+
# Safety: prevent empty input
|
145 |
+
if not query or not query.strip():
|
146 |
+
return "Please type a question about the Codingo platform."
|
147 |
+
|
148 |
embedder = _chatbot_embedder
|
149 |
collection = _chatbot_collection
|
150 |
+
model = _hf_model
|
151 |
+
tokenizer = _hf_tokenizer
|
152 |
+
device = model.device
|
153 |
+
|
154 |
+
# Retrieve context from Chroma
|
155 |
query_embedding = embedder.encode([query])[0]
|
156 |
results = collection.query(query_embeddings=[query_embedding], n_results=3)
|
157 |
+
retrieved_docs = results.get("documents", [[]])[0] if results else []
|
158 |
context = "\n".join(retrieved_docs)
|
159 |
+
|
160 |
+
# System instruction
|
161 |
system_prompt = (
|
162 |
"You are a helpful assistant for the Codingo website. "
|
163 |
+
"Only answer questions relevant to the context provided. "
|
164 |
+
"If unrelated, reply: 'I'm only trained to answer questions about the Codingo platform.'"
|
|
|
165 |
)
|
166 |
+
|
|
|
167 |
prompt = f"{system_prompt}\n\nContext:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
168 |
+
|
169 |
+
# ✅ Safe tokenization with truncation to avoid CUDA indexing issues
|
170 |
+
inputs = tokenizer(
|
171 |
+
prompt,
|
172 |
+
return_tensors="pt",
|
173 |
+
truncation=True,
|
174 |
+
max_length=256, # Prevents long inputs
|
175 |
+
padding=True
|
176 |
+
).to(device)
|
177 |
+
|
178 |
+
try:
|
179 |
+
output_ids = model.generate(
|
180 |
+
**inputs,
|
181 |
+
max_length=200,
|
182 |
+
num_beams=3,
|
183 |
+
do_sample=False,
|
184 |
+
early_stopping=True
|
185 |
+
)
|
186 |
+
reply = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
187 |
+
if reply.startswith(prompt):
|
188 |
+
reply = reply[len(prompt):]
|
189 |
+
return reply.strip()
|
190 |
+
except Exception as e:
|
191 |
+
return f"Error generating response: {str(e)}"
|
192 |
|
193 |
# Initialize Flask app
|
194 |
app = Flask(
|