husseinelsaadi commited on
Commit
0c4a8eb
·
1 Parent(s): 72f831c

chatbot updated

Browse files
Files changed (1) hide show
  1. chatbot/chatbot.py +154 -61
chatbot/chatbot.py CHANGED
@@ -36,15 +36,37 @@ def _init_hf_model() -> None:
36
 
37
  model_name = os.getenv("HF_CHATBOT_MODEL", DEFAULT_MODEL_NAME)
38
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
39
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
40
  try:
41
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
42
  except Exception:
43
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
 
 
 
 
 
44
  model = model.to(device)
 
 
 
45
  if tokenizer.pad_token is None:
46
- tokenizer.pad_token = tokenizer.eos_token
47
-
 
 
 
 
 
 
 
48
  _hf_model = model
49
  _hf_tokenizer = tokenizer
50
 
@@ -58,8 +80,10 @@ def _init_vector_store() -> None:
58
  import chromadb
59
  from chromadb.config import Settings
60
 
61
- shutil.rmtree("/app/chatbot/chroma_db", ignore_errors=True)
 
62
  os.makedirs(_chroma_db_dir, exist_ok=True)
 
63
  try:
64
  with open(_knowledge_base_path, encoding="utf-8") as f:
65
  raw_text = f.read()
@@ -73,74 +97,143 @@ def _init_vector_store() -> None:
73
 
74
  splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=100)
75
  docs: List[str] = [doc.strip() for doc in splitter.split_text(raw_text) if doc.strip()]
 
 
76
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
77
  embeddings = embedder.encode(docs, show_progress_bar=False, batch_size=32)
 
 
78
  client = chromadb.Client(Settings(
79
  persist_directory=_chroma_db_dir,
80
  anonymized_telemetry=False,
81
  is_persistent=True,
82
  ))
83
- collection = client.get_or_create_collection("chatbot")
 
84
  try:
85
- existing = collection.get(limit=1)
86
- if not existing.get("documents"):
87
- raise ValueError("Empty Chroma DB")
88
- except Exception:
89
- ids = [f"doc_{i}" for i in range(len(docs))]
90
- collection.add(documents=docs, embeddings=embeddings.tolist(), ids=ids)
 
 
 
91
 
92
  _chatbot_embedder = embedder
93
  _chatbot_collection = collection
94
 
95
  def get_chatbot_response(query: str) -> str:
96
- if not query or not query.strip():
97
- return "Please type a question about the Codingo platform."
98
-
99
- _init_vector_store()
100
- _init_hf_model()
101
- embedder = _chatbot_embedder
102
- collection = _chatbot_collection
103
- model = _hf_model
104
- tokenizer = _hf_tokenizer
105
-
106
- import torch
107
-
108
- query_embedding = embedder.encode([query])[0]
109
- results = collection.query(query_embeddings=[query_embedding.tolist()], n_results=3)
110
- retrieved_docs = results.get("documents", [[]])[0] if results else []
111
- context = "\n".join(retrieved_docs[:3])
112
-
113
- system_instruction = (
114
- "You are LUNA AI, a helpful assistant for the Codingo recruitment "
115
- "platform. Use the provided context to answer questions about "
116
- "Codingo. If the question is not related to Codingo, politely "
117
- "redirect the conversation. Keep responses concise and friendly."
118
- )
119
- prompt = f"{system_instruction}\n\nContext:\n{context}\n\nUser: {query}\nLUNA AI:"
120
- inputs = tokenizer.encode(
121
- prompt, return_tensors="pt", truncation=True, max_length=512, padding=True
122
- ).to(model.device)
123
-
124
- with torch.no_grad():
125
- output_ids = model.generate(
126
- inputs,
127
- max_new_tokens=150,
128
- num_beams=3,
129
- do_sample=True,
130
- temperature=0.7,
131
- pad_token_id=tokenizer.eos_token_id,
132
- eos_token_id=tokenizer.eos_token_id,
133
- early_stopping=True,
 
 
 
 
 
 
 
 
 
134
  )
135
-
136
- response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
137
- if "LUNA AI:" in response:
138
- response = response.split("LUNA AI:")[-1].strip()
139
- elif prompt in response:
140
- response = response.replace(prompt, "").strip()
141
-
142
- return (
143
- response
144
- if response
145
- else "I'm here to help you with questions about the Codingo platform. What would you like to know?"
146
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  model_name = os.getenv("HF_CHATBOT_MODEL", DEFAULT_MODEL_NAME)
38
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+ # Initialize tokenizer with proper configuration
41
  tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+
43
+ # Try loading the model with proper error handling
44
  try:
45
  model = AutoModelForCausalLM.from_pretrained(model_name)
46
+ model_type = "causal"
47
  except Exception:
48
+ try:
49
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
50
+ model_type = "seq2seq"
51
+ except Exception as e:
52
+ print(f"Error loading model: {e}")
53
+ raise
54
+
55
+ # Move model to device
56
  model = model.to(device)
57
+ model.eval() # Set to evaluation mode
58
+
59
+ # Ensure proper padding token configuration
60
  if tokenizer.pad_token is None:
61
+ if tokenizer.eos_token is not None:
62
+ tokenizer.pad_token = tokenizer.eos_token
63
+ else:
64
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
65
+ model.resize_token_embeddings(len(tokenizer))
66
+
67
+ # Store model type for later use
68
+ model.model_type = model_type
69
+
70
  _hf_model = model
71
  _hf_tokenizer = tokenizer
72
 
 
80
  import chromadb
81
  from chromadb.config import Settings
82
 
83
+ # Clean up old database
84
+ shutil.rmtree(_chroma_db_dir, ignore_errors=True)
85
  os.makedirs(_chroma_db_dir, exist_ok=True)
86
+
87
  try:
88
  with open(_knowledge_base_path, encoding="utf-8") as f:
89
  raw_text = f.read()
 
97
 
98
  splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=100)
99
  docs: List[str] = [doc.strip() for doc in splitter.split_text(raw_text) if doc.strip()]
100
+
101
+ # Initialize embedder
102
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
103
  embeddings = embedder.encode(docs, show_progress_bar=False, batch_size=32)
104
+
105
+ # Initialize ChromaDB
106
  client = chromadb.Client(Settings(
107
  persist_directory=_chroma_db_dir,
108
  anonymized_telemetry=False,
109
  is_persistent=True,
110
  ))
111
+
112
+ # Create or recreate collection
113
  try:
114
+ client.delete_collection("chatbot")
115
+ except:
116
+ pass
117
+
118
+ collection = client.create_collection("chatbot")
119
+
120
+ # Add documents
121
+ ids = [f"doc_{i}" for i in range(len(docs))]
122
+ collection.add(documents=docs, embeddings=embeddings.tolist(), ids=ids)
123
 
124
  _chatbot_embedder = embedder
125
  _chatbot_collection = collection
126
 
127
  def get_chatbot_response(query: str) -> str:
128
+ try:
129
+ if not query or not query.strip():
130
+ return "Please type a question about the Codingo platform."
131
+
132
+ # Clear GPU cache before processing
133
+ import torch
134
+ if torch.cuda.is_available():
135
+ torch.cuda.empty_cache()
136
+
137
+ _init_vector_store()
138
+ _init_hf_model()
139
+
140
+ embedder = _chatbot_embedder
141
+ collection = _chatbot_collection
142
+ model = _hf_model
143
+ tokenizer = _hf_tokenizer
144
+
145
+ import torch
146
+
147
+ # Get relevant documents
148
+ query_embedding = embedder.encode([query])[0]
149
+ results = collection.query(query_embeddings=[query_embedding.tolist()], n_results=3)
150
+ retrieved_docs = results.get("documents", [[]])[0] if results else []
151
+ context = "\n".join(retrieved_docs[:3])
152
+
153
+ # Prepare the prompt based on model type
154
+ if hasattr(model, 'model_type') and model.model_type == "seq2seq":
155
+ # For seq2seq models like BlenderBot
156
+ prompt = f"Context: {context}\n\nUser: {query}\nAssistant:"
157
+ else:
158
+ # For causal models
159
+ system_instruction = (
160
+ "You are LUNA AI, a helpful assistant for the Codingo recruitment "
161
+ "platform. Use the provided context to answer questions about "
162
+ "Codingo. If the question is not related to Codingo, politely "
163
+ "redirect the conversation. Keep responses concise and friendly."
164
+ )
165
+ prompt = f"{system_instruction}\n\nContext:\n{context}\n\nUser: {query}\nLUNA AI:"
166
+
167
+ # Tokenize with proper handling
168
+ inputs = tokenizer(
169
+ prompt,
170
+ return_tensors="pt",
171
+ truncation=True,
172
+ max_length=512,
173
+ padding=True,
174
+ return_attention_mask=True
175
  )
176
+
177
+ # Move all tensors to the same device
178
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
179
+
180
+ # Generate response with error handling
181
+ with torch.no_grad():
182
+ try:
183
+ # Use different generation parameters based on model type
184
+ if hasattr(model, 'model_type') and model.model_type == "seq2seq":
185
+ output_ids = model.generate(
186
+ input_ids=inputs['input_ids'],
187
+ attention_mask=inputs['attention_mask'],
188
+ max_new_tokens=150,
189
+ min_length=10,
190
+ num_beams=3,
191
+ do_sample=True,
192
+ temperature=0.7,
193
+ top_p=0.9,
194
+ pad_token_id=tokenizer.pad_token_id,
195
+ eos_token_id=tokenizer.eos_token_id,
196
+ early_stopping=True,
197
+ )
198
+ else:
199
+ output_ids = model.generate(
200
+ input_ids=inputs['input_ids'],
201
+ attention_mask=inputs['attention_mask'],
202
+ max_new_tokens=150,
203
+ num_beams=3,
204
+ do_sample=True,
205
+ temperature=0.7,
206
+ pad_token_id=tokenizer.pad_token_id,
207
+ eos_token_id=tokenizer.eos_token_id,
208
+ )
209
+ except Exception as e:
210
+ print(f"Generation error: {e}")
211
+ # Fallback to a simple response
212
+ return "I'm here to help you with questions about the Codingo platform. Could you please rephrase your question?"
213
+
214
+ # Decode the response
215
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
216
+
217
+ # Clean up the response
218
+ if "Assistant:" in response:
219
+ response = response.split("Assistant:")[-1].strip()
220
+ elif "LUNA AI:" in response:
221
+ response = response.split("LUNA AI:")[-1].strip()
222
+ elif prompt in response:
223
+ response = response.replace(prompt, "").strip()
224
+
225
+ # Remove the input prompt if it's still in the response
226
+ if query in response:
227
+ response = response.split(query)[-1].strip()
228
+
229
+ return (
230
+ response
231
+ if response and len(response) > 5
232
+ else "I'm here to help you with questions about the Codingo platform. What would you like to know?"
233
+ )
234
+
235
+ except Exception as e:
236
+ print(f"Chatbot error: {e}")
237
+ import traceback
238
+ traceback.print_exc()
239
+ return "I apologize, but I'm having trouble processing your request. Please try again with a different question about Codingo."