Blaiseboy commited on
Commit
79cca78
·
verified ·
1 Parent(s): a8da25c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +98 -0
  2. medical_chatbot.py +608 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from medical_chatbot import ColabBioGPTChatbot
4
+
5
+ # Instantiate the chatbot with CPU settings for HF Spaces
6
+ chatbot = ColabBioGPTChatbot(use_gpu=False, use_8bit=False)
7
+
8
+ medical_file_uploaded = False
9
+
10
+ def upload_and_initialize(file):
11
+ global medical_file_uploaded
12
+ if file is None:
13
+ return (
14
+ "❌ Please upload a medical .txt file.",
15
+ gr.Chatbot(visible=False),
16
+ gr.Textbox(visible=False),
17
+ gr.Button(visible=False)
18
+ )
19
+
20
+ # Handle the file path correctly for Gradio
21
+ file_path = file.name if hasattr(file, 'name') else file
22
+ success = chatbot.load_medical_data(file_path)
23
+
24
+ if success:
25
+ medical_file_uploaded = True
26
+ model_name = type(chatbot.model).__name__ if chatbot.model else "Fallback Model"
27
+ status = f"✅ Medical data processed successfully!\n📦 Model in use: {model_name}"
28
+ return (
29
+ status,
30
+ gr.Chatbot(visible=True),
31
+ gr.Textbox(visible=True),
32
+ gr.Button(visible=True)
33
+ )
34
+ else:
35
+ return (
36
+ "❌ Failed to process uploaded file.",
37
+ gr.Chatbot(visible=False),
38
+ gr.Textbox(visible=False),
39
+ gr.Button(visible=False)
40
+ )
41
+
42
+ def generate_response(user_input):
43
+ if not medical_file_uploaded:
44
+ return "⚠️ Please upload and initialize medical data first."
45
+ return chatbot.chat(user_input)
46
+
47
+ # Create the Gradio interface
48
+ with gr.Blocks(title="🩺 Pediatric Medical Assistant") as demo:
49
+ gr.Markdown("## 🩺 Pediatric Medical Assistant\nUpload a medical .txt file and start chatting.")
50
+
51
+ with gr.Row():
52
+ file_input = gr.File(label="📁 Upload Medical File", file_types=[".txt"])
53
+ upload_btn = gr.Button("📤 Upload and Initialize")
54
+
55
+ upload_output = gr.Textbox(label="System Status", interactive=False)
56
+
57
+ chatbot_ui = gr.Chatbot(label="🧠 Chat History", visible=False)
58
+ user_input = gr.Textbox(
59
+ placeholder="Ask a pediatric health question...",
60
+ lines=2,
61
+ show_label=False,
62
+ visible=False
63
+ )
64
+ submit_btn = gr.Button("Send", visible=False)
65
+
66
+ upload_btn.click(
67
+ fn=upload_and_initialize,
68
+ inputs=[file_input],
69
+ outputs=[upload_output, chatbot_ui, user_input, submit_btn]
70
+ )
71
+
72
+ def on_submit(user_message, chat_history):
73
+ if not user_message.strip():
74
+ return "", chat_history
75
+
76
+ bot_response = generate_response(user_message)
77
+ chat_history.append((user_message, bot_response))
78
+ return "", chat_history
79
+
80
+ user_input.submit(
81
+ fn=on_submit,
82
+ inputs=[user_input, chatbot_ui],
83
+ outputs=[user_input, chatbot_ui]
84
+ )
85
+ submit_btn.click(
86
+ fn=on_submit,
87
+ inputs=[user_input, chatbot_ui],
88
+ outputs=[user_input, chatbot_ui]
89
+ )
90
+
91
+ # Launch with proper settings for Hugging Face Spaces
92
+ if __name__ == "__main__":
93
+ demo.launch(
94
+ share=False, # Don't need share=True on HF Spaces
95
+ server_name="0.0.0.0", # Listen on all interfaces for HF Spaces
96
+ server_port=7860, # Standard port for HF Spaces
97
+ show_error=True # Show detailed errors for debugging
98
+ )
medical_chatbot.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Setup and Installation
2
+
3
+ import torch
4
+ print("🖥️ System Check:")
5
+ print(f"CUDA available: {torch.cuda.is_available()}")
6
+ if torch.cuda.is_available():
7
+ print(f"GPU device: {torch.cuda.get_device_name(0)}")
8
+ print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
9
+ else:
10
+ print("⚠️ No GPU detected - BioGPT will run on CPU")
11
+
12
+ print("\n🔧 Loading required packages...")
13
+
14
+ # Import Libraries
15
+
16
+ import os
17
+ import re
18
+ import torch
19
+ import warnings
20
+ import numpy as np
21
+ import faiss # FAISS for vector search
22
+ from transformers import (
23
+ AutoTokenizer,
24
+ AutoModelForCausalLM,
25
+ pipeline,
26
+ BitsAndBytesConfig
27
+ )
28
+ from sentence_transformers import SentenceTransformer
29
+ from typing import List, Dict, Optional
30
+ import time
31
+ from datetime import datetime
32
+ import json
33
+ import pickle
34
+
35
+ # Suppress warnings for cleaner output
36
+ warnings.filterwarnings('ignore')
37
+
38
+ print("📚 Libraries imported successfully!")
39
+ print(f"🔍 FAISS version: {faiss.__version__}")
40
+ print("🎯 Using FAISS for vector search")
41
+
42
+ # BioGPT Medical Chatbot Class
43
+
44
+ class ColabBioGPTChatbot:
45
+ def __init__(self, use_gpu=True, use_8bit=True):
46
+ """Initialize BioGPT chatbot optimized for deployment"""
47
+ print("🏥 Initializing Professional BioGPT Medical Chatbot...")
48
+
49
+ # Force CPU for HF Spaces if needed
50
+ self.device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu"
51
+ self.use_8bit = use_8bit and torch.cuda.is_available()
52
+
53
+ print(f"🖥️ Using device: {self.device}")
54
+ if self.use_8bit:
55
+ print("💾 Using 8-bit quantization for memory efficiency")
56
+
57
+ # Setup components
58
+ self.setup_embeddings()
59
+ self.setup_faiss_index()
60
+ self.setup_biogpt()
61
+
62
+ # Conversation tracking
63
+ self.conversation_history = []
64
+ self.knowledge_chunks = []
65
+
66
+ print("✅ BioGPT Medical Chatbot ready for professional medical assistance!")
67
+
68
+ def setup_embeddings(self):
69
+ """Setup medical-optimized embeddings"""
70
+ print("🔧 Loading medical embeddings...")
71
+ try:
72
+ # Use a smaller, more efficient model for deployment
73
+ self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
74
+ self.embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
75
+ print(f"✅ Embeddings loaded (dimension: {self.embedding_dim})")
76
+ self.use_embeddings = True
77
+ except Exception as e:
78
+ print(f"⚠️ Embeddings failed: {e}")
79
+ self.embedding_model = None
80
+ self.embedding_dim = 384
81
+ self.use_embeddings = False
82
+
83
+ def setup_faiss_index(self):
84
+ """Setup faiss for CPU-based vector search"""
85
+ print("🔧 Setting up FAISS vector database...")
86
+ try:
87
+ print('Using CPU FAISS index for maximum compatibility')
88
+ self.faiss_index = faiss.IndexFlatIP(self.embedding_dim)
89
+ self.use_gpu_faiss = False
90
+ self.faiss_ready = True
91
+ self.collection = self.faiss_index
92
+ print("✅ FAISS CPU index initialized successfully")
93
+ except Exception as e:
94
+ print(f"❌ FAISS setup failed: {e}")
95
+ self.faiss_index = None
96
+ self.faiss_ready = False
97
+ self.collection = None
98
+
99
+ def setup_biogpt(self):
100
+ """Setup BioGPT model with optimizations for deployment"""
101
+ print("🧠 Loading BioGPT model...")
102
+
103
+ # Try BioGPT first, fallback to smaller models if needed
104
+ model_options = [
105
+ "microsoft/BioGPT-Large",
106
+ "microsoft/BioGPT", # Smaller version
107
+ "microsoft/DialoGPT-medium", # Fallback
108
+ "gpt2" # Final fallback
109
+ ]
110
+
111
+ for model_name in model_options:
112
+ try:
113
+ print(f" Attempting to load: {model_name}")
114
+
115
+ # Setup quantization config for memory efficiency
116
+ if self.use_8bit and "BioGPT" in model_name:
117
+ quantization_config = BitsAndBytesConfig(
118
+ load_in_8bit=True,
119
+ llm_int8_threshold=6.0,
120
+ llm_int8_has_fp16_weight=False,
121
+ )
122
+ else:
123
+ quantization_config = None
124
+
125
+ # Load tokenizer
126
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
127
+
128
+ # Set padding token
129
+ if self.tokenizer.pad_token is None:
130
+ self.tokenizer.pad_token = self.tokenizer.eos_token
131
+
132
+ # Load model with proper settings for deployment
133
+ start_time = time.time()
134
+
135
+ model_kwargs = {
136
+ "torch_dtype": torch.float16 if self.device == "cuda" else torch.float32,
137
+ "trust_remote_code": True,
138
+ "low_cpu_mem_usage": True, # Important for deployment
139
+ }
140
+
141
+ if quantization_config:
142
+ model_kwargs["quantization_config"] = quantization_config
143
+ model_kwargs["device_map"] = "auto"
144
+
145
+ self.model = AutoModelForCausalLM.from_pretrained(
146
+ model_name,
147
+ **model_kwargs
148
+ )
149
+
150
+ # Move to device if not using device_map
151
+ if self.device == "cuda" and quantization_config is None:
152
+ self.model = self.model.to(self.device)
153
+
154
+ load_time = time.time() - start_time
155
+ print(f"✅ {model_name} loaded successfully! ({load_time:.1f} seconds)")
156
+
157
+ # Test the model
158
+ self.test_model()
159
+ break # Success, exit the loop
160
+
161
+ except Exception as e:
162
+ print(f"❌ {model_name} loading failed: {e}")
163
+ if model_name == model_options[-1]: # Last option failed
164
+ print("❌ All models failed to load")
165
+ self.model = None
166
+ self.tokenizer = None
167
+ continue
168
+
169
+ def test_model(self):
170
+ """Test the loaded model with a simple query"""
171
+ print("🧪 Testing model...")
172
+ try:
173
+ test_prompt = "Fever in children can be caused by"
174
+ inputs = self.tokenizer(test_prompt, return_tensors="pt")
175
+
176
+ if self.device == "cuda":
177
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
178
+
179
+ with torch.no_grad():
180
+ outputs = self.model.generate(
181
+ **inputs,
182
+ max_new_tokens=20,
183
+ do_sample=True,
184
+ temperature=0.7,
185
+ pad_token_id=self.tokenizer.eos_token_id
186
+ )
187
+
188
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
189
+ print(f"✅ Model test successful!")
190
+ print(f" Test response: {response}")
191
+
192
+ except Exception as e:
193
+ print(f"⚠️ Model test failed: {e}")
194
+
195
+ def load_medical_data(self, file_path: str):
196
+ """Load and process medical data with progress tracking"""
197
+ print(f"📖 Loading medical data from {file_path}...")
198
+
199
+ try:
200
+ with open(file_path, 'r', encoding='utf-8') as f:
201
+ text = f.read()
202
+ print(f"📄 File loaded: {len(text):,} characters")
203
+ except FileNotFoundError:
204
+ print(f"❌ File {file_path} not found!")
205
+ return False
206
+ except Exception as e:
207
+ print(f"❌ Error loading file: {e}")
208
+ return False
209
+
210
+ # Create chunks optimized for medical content
211
+ print("📝 Creating medical-optimized chunks...")
212
+ chunks = self.create_medical_chunks(text)
213
+ print(f"📋 Created {len(chunks)} medical chunks")
214
+
215
+ self.knowledge_chunks = chunks
216
+
217
+ # Generate embeddings with progress and add to FAISS index
218
+ if self.use_embeddings and self.embedding_model and self.faiss_ready:
219
+ return self.generate_embeddings_with_progress(chunks)
220
+
221
+ print("✅ Medical data loaded (text search mode)")
222
+ return True
223
+
224
+ def create_medical_chunks(self, text: str, chunk_size: int = 400) -> List[Dict]:
225
+ """Create medically-optimized text chunks"""
226
+ chunks = []
227
+
228
+ # Split by medical sections first
229
+ medical_sections = self.split_by_medical_sections(text)
230
+
231
+ chunk_id = 0
232
+ for section in medical_sections:
233
+ if len(section.split()) > chunk_size:
234
+ # Split large sections by sentences
235
+ sentences = re.split(r'[.!?]+', section)
236
+ current_chunk = ""
237
+
238
+ for sentence in sentences:
239
+ sentence = sentence.strip()
240
+ if not sentence:
241
+ continue
242
+
243
+ if len(current_chunk.split()) + len(sentence.split()) < chunk_size:
244
+ current_chunk += sentence + ". "
245
+ else:
246
+ if current_chunk.strip():
247
+ chunks.append({
248
+ 'id': chunk_id,
249
+ 'text': current_chunk.strip(),
250
+ 'medical_focus': self.identify_medical_focus(current_chunk)
251
+ })
252
+ chunk_id += 1
253
+ current_chunk = sentence + ". "
254
+
255
+ if current_chunk.strip():
256
+ chunks.append({
257
+ 'id': chunk_id,
258
+ 'text': current_chunk.strip(),
259
+ 'medical_focus': self.identify_medical_focus(current_chunk)
260
+ })
261
+ chunk_id += 1
262
+ else:
263
+ chunks.append({
264
+ 'id': chunk_id,
265
+ 'text': section,
266
+ 'medical_focus': self.identify_medical_focus(section)
267
+ })
268
+ chunk_id += 1
269
+
270
+ return chunks
271
+
272
+ def split_by_medical_sections(self, text: str) -> List[str]:
273
+ """Split text by medical sections"""
274
+ # Look for medical section headers
275
+ section_patterns = [
276
+ r'\n\s*(?:SYMPTOMS?|TREATMENT|DIAGNOSIS|CAUSES?|PREVENTION|MANAGEMENT).*?\n',
277
+ r'\n\s*\d+\.\s+', # Numbered sections
278
+ r'\n\n+' # Paragraph breaks
279
+ ]
280
+
281
+ sections = [text]
282
+ for pattern in section_patterns:
283
+ new_sections = []
284
+ for section in sections:
285
+ splits = re.split(pattern, section, flags=re.IGNORECASE)
286
+ new_sections.extend([s.strip() for s in splits if len(s.strip()) > 100])
287
+ sections = new_sections
288
+
289
+ return sections
290
+
291
+ def identify_medical_focus(self, text: str) -> str:
292
+ """Identify the medical focus of a text chunk"""
293
+ text_lower = text.lower()
294
+
295
+ # Medical categories
296
+ categories = {
297
+ 'pediatric_symptoms': ['fever', 'cough', 'rash', 'vomiting', 'diarrhea'],
298
+ 'treatments': ['treatment', 'therapy', 'medication', 'antibiotics'],
299
+ 'diagnosis': ['diagnosis', 'diagnostic', 'symptoms', 'signs'],
300
+ 'emergency': ['emergency', 'urgent', 'serious', 'hospital'],
301
+ 'prevention': ['prevention', 'vaccine', 'immunization', 'avoid']
302
+ }
303
+
304
+ for category, keywords in categories.items():
305
+ if any(keyword in text_lower for keyword in keywords):
306
+ return category
307
+
308
+ return 'general_medical'
309
+
310
+ def generate_embeddings_with_progress(self, chunks: List[Dict]) -> bool:
311
+ """Generate embeddings with progress tracking and add to FAISS index"""
312
+ print("🔮 Generating medical embeddings and adding to FAISS index...")
313
+
314
+ if not self.embedding_model or not self.faiss_index:
315
+ print("❌ Embedding model or FAISS index not available.")
316
+ return False
317
+
318
+ try:
319
+ texts = [chunk['text'] for chunk in chunks]
320
+
321
+ # Generate embeddings in batches with progress
322
+ batch_size = 32
323
+ all_embeddings = []
324
+
325
+ for i in range(0, len(texts), batch_size):
326
+ batch_texts = texts[i:i+batch_size]
327
+ batch_embeddings = self.embedding_model.encode(batch_texts, show_progress_bar=False)
328
+ all_embeddings.extend(batch_embeddings)
329
+
330
+ # Show progress
331
+ progress = min(i + batch_size, len(texts))
332
+ print(f" Progress: {progress}/{len(texts)} chunks processed", end='\r')
333
+
334
+ print(f"\n ✅ Generated embeddings for {len(texts)} chunks")
335
+
336
+ # Add embeddings to FAISS index
337
+ print("💾 Adding embeddings to FAISS index...")
338
+ self.faiss_index.add(np.array(all_embeddings))
339
+
340
+ print("✅ Medical embeddings added to FAISS index successfully!")
341
+ return True
342
+
343
+ except Exception as e:
344
+ print(f"❌ Embedding generation or FAISS add failed: {e}")
345
+ return False
346
+
347
+ def retrieve_medical_context(self, query: str, n_results: int = 3) -> List[str]:
348
+ """Retrieve relevant medical context using embeddings or keyword search"""
349
+ if self.use_embeddings and self.embedding_model and self.faiss_ready:
350
+ try:
351
+ # Generate query embedding
352
+ query_embedding = self.embedding_model.encode([query])
353
+
354
+ # Search for similar content in FAISS index
355
+ distances, indices = self.faiss_index.search(np.array(query_embedding), n_results)
356
+
357
+ # Retrieve the corresponding chunks
358
+ context_chunks = [self.knowledge_chunks[i]['text'] for i in indices[0] if i != -1]
359
+
360
+ if context_chunks:
361
+ return context_chunks
362
+
363
+ except Exception as e:
364
+ print(f"⚠️ Embedding search failed: {e}")
365
+
366
+ # Fallback to keyword search
367
+ print("⚠️ Falling back to keyword search.")
368
+ return self.keyword_search_medical(query, n_results)
369
+
370
+ def keyword_search_medical(self, query: str, n_results: int) -> List[str]:
371
+ """Medical-focused keyword search"""
372
+ if not self.knowledge_chunks:
373
+ return []
374
+
375
+ query_words = set(query.lower().split())
376
+ chunk_scores = []
377
+
378
+ for chunk_info in self.knowledge_chunks:
379
+ chunk_text = chunk_info['text']
380
+ chunk_words = set(chunk_text.lower().split())
381
+
382
+ # Calculate relevance score
383
+ word_overlap = len(query_words.intersection(chunk_words))
384
+ base_score = word_overlap / len(query_words) if query_words else 0
385
+
386
+ # Boost medical content
387
+ medical_boost = 0
388
+ if chunk_info.get('medical_focus') in ['pediatric_symptoms', 'treatments', 'diagnosis']:
389
+ medical_boost = 0.5
390
+
391
+ final_score = base_score + medical_boost
392
+
393
+ if final_score > 0:
394
+ chunk_scores.append((final_score, chunk_text))
395
+
396
+ # Return top matches
397
+ chunk_scores.sort(reverse=True)
398
+ return [chunk for _, chunk in chunk_scores[:n_results]]
399
+
400
+ def generate_biogpt_response(self, context: str, query: str) -> str:
401
+ """Generate medical response using BioGPT only"""
402
+ if not self.model or not self.tokenizer:
403
+ return "⚠️ Medical AI model not available. This chatbot requires BioGPT for accurate medical information. Please check the setup or try restarting."
404
+
405
+ try:
406
+ # Create medical-focused prompt
407
+ prompt = f"""Medical Context: {context[:800]}
408
+
409
+ Question: {query}
410
+
411
+ Medical Answer:"""
412
+
413
+ # Tokenize input
414
+ inputs = self.tokenizer(
415
+ prompt,
416
+ return_tensors="pt",
417
+ truncation=True,
418
+ max_length=1024
419
+ )
420
+
421
+ # Move inputs to the correct device
422
+ if self.device == "cuda":
423
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
424
+
425
+ # Generate response
426
+ with torch.no_grad():
427
+ outputs = self.model.generate(
428
+ **inputs,
429
+ max_new_tokens=150,
430
+ do_sample=True,
431
+ temperature=0.7,
432
+ top_p=0.9,
433
+ pad_token_id=self.tokenizer.eos_token_id,
434
+ repetition_penalty=1.1
435
+ )
436
+
437
+ # Decode response
438
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
439
+
440
+ # Extract just the generated part
441
+ if "Medical Answer:" in full_response:
442
+ generated_response = full_response.split("Medical Answer:")[-1].strip()
443
+ else:
444
+ generated_response = full_response[len(prompt):].strip()
445
+
446
+ # Clean up response
447
+ cleaned_response = self.clean_medical_response(generated_response)
448
+
449
+ return cleaned_response
450
+
451
+ except Exception as e:
452
+ print(f"⚠️ BioGPT generation failed: {e}")
453
+ return "⚠️ Unable to generate medical response. The medical AI model encountered an error. Please try rephrasing your question or contact support."
454
+
455
+ def clean_medical_response(self, response: str) -> str:
456
+ """Clean and format medical response"""
457
+ # Remove incomplete sentences and limit length
458
+ sentences = re.split(r'[.!?]+', response)
459
+ clean_sentences = []
460
+
461
+ for sentence in sentences:
462
+ sentence = sentence.strip()
463
+ if len(sentence) > 10 and not sentence.endswith(('and', 'or', 'but', 'however')):
464
+ clean_sentences.append(sentence)
465
+ if len(clean_sentences) >= 3: # Limit to 3 sentences
466
+ break
467
+
468
+ if clean_sentences:
469
+ cleaned = '. '.join(clean_sentences) + '.'
470
+ else:
471
+ cleaned = response[:200] + '...' if len(response) > 200 else response
472
+
473
+ return cleaned
474
+
475
+ def fallback_response(self, context: str, query: str) -> str:
476
+ """Fallback response when BioGPT fails"""
477
+ # Extract key sentences from context
478
+ sentences = [s.strip() for s in context.split('.') if len(s.strip()) > 20]
479
+
480
+ if sentences:
481
+ response = sentences[0] + '.'
482
+ if len(sentences) > 1:
483
+ response += ' ' + sentences[1] + '.'
484
+ else:
485
+ response = context[:300] + '...'
486
+
487
+ return response
488
+
489
+ def handle_conversational_interactions(self, query: str) -> Optional[str]:
490
+ """Handle comprehensive conversational interactions"""
491
+ query_lower = query.lower().strip()
492
+
493
+ # Use more specific patterns for greetings
494
+ greeting_patterns = [
495
+ r'^\s*(hello|hi|hey|hiya|howdy)\s*$',
496
+ r'^\s*(good morning|good afternoon|good evening|good day)\s*$',
497
+ r'^\s*(what\'s up|whats up|sup|yo)\s*$',
498
+ r'^\s*(greetings|salutations)\s*$',
499
+ r'^\s*(how are you|how are you doing|how\'s it going|hows it going)\s*$',
500
+ r'^\s*(good to meet you|nice to meet you|pleased to meet you)\s*$'
501
+ ]
502
+
503
+ for pattern in greeting_patterns:
504
+ if re.match(pattern, query_lower):
505
+ responses = [
506
+ "👋 Hello! I'm BioGPT, your professional medical AI assistant specialized in pediatric medicine. I'm here to provide evidence-based medical information. What health concern can I help you with today?",
507
+ "🏥 Hi there! I'm a medical AI assistant powered by BioGPT, trained on medical literature. I can help answer questions about children's health and medical conditions. How can I assist you?",
508
+ "👋 Greetings! I'm your AI medical consultant, ready to help with pediatric health questions using the latest medical knowledge. What would you like to know about?"
509
+ ]
510
+ return np.random.choice(responses)
511
+
512
+ # Handle thanks and other conversational patterns...
513
+ # (keeping the rest of the conversational handling as before)
514
+
515
+ # Return None if no conversational pattern matches
516
+ return None
517
+
518
+ def chat(self, query: str) -> str:
519
+ """Main chat function with BioGPT medical-only responses"""
520
+ if not query.strip():
521
+ return "Hello! I'm BioGPT, your professional medical AI assistant. How can I help you with pediatric medical questions today?"
522
+
523
+ # Handle comprehensive conversational interactions first
524
+ conversational_response = self.handle_conversational_interactions(query)
525
+ if conversational_response:
526
+ # Add to conversation history
527
+ self.conversation_history.append({
528
+ 'query': query,
529
+ 'response': conversational_response,
530
+ 'timestamp': datetime.now().isoformat(),
531
+ 'type': 'conversational'
532
+ })
533
+ return conversational_response
534
+
535
+ # Check if medical model is available
536
+ if not self.model or not self.tokenizer:
537
+ return "⚠️ **Medical AI Unavailable**: This chatbot requires BioGPT for accurate medical information. The medical model failed to load. Please contact support or try restarting the application."
538
+
539
+ if not self.knowledge_chunks:
540
+ return "Please load medical data first to access the medical knowledge base."
541
+
542
+ print(f"🔍 Processing medical query: {query}")
543
+
544
+ # Retrieve relevant medical context using FAISS or keyword search
545
+ start_time = time.time()
546
+ context = self.retrieve_medical_context(query)
547
+ retrieval_time = time.time() - start_time
548
+
549
+ if not context:
550
+ return "I don't have specific information about this topic in my medical database. Please consult with a healthcare professional for personalized medical advice."
551
+
552
+ print(f" 📚 Context retrieved ({retrieval_time:.2f}s)")
553
+
554
+ # Generate response with BioGPT
555
+ start_time = time.time()
556
+ main_context = '\n\n'.join(context)
557
+ response = self.generate_biogpt_response(main_context, query)
558
+ generation_time = time.time() - start_time
559
+
560
+ print(f" 🧠 Response generated ({generation_time:.2f}s)")
561
+
562
+ # Format final response
563
+ final_response = f"🩺 **Medical Information:** {response}\n\n⚠️ **Important:** This information is for educational purposes only. Always consult with qualified healthcare professionals for medical diagnosis, treatment, and personalized advice."
564
+
565
+ # Add to conversation history
566
+ self.conversation_history.append({
567
+ 'query': query,
568
+ 'response': final_response,
569
+ 'timestamp': datetime.now().isoformat(),
570
+ 'retrieval_time': retrieval_time,
571
+ 'generation_time': generation_time,
572
+ 'type': 'medical'
573
+ })
574
+
575
+ return final_response
576
+
577
+ def get_conversation_summary(self) -> Dict:
578
+ """Get conversation statistics"""
579
+ if not self.conversation_history:
580
+ return {"message": "No conversations yet"}
581
+
582
+ # Filter medical conversations for performance stats
583
+ medical_conversations = [h for h in self.conversation_history if h.get('type') == 'medical']
584
+
585
+ if not medical_conversations:
586
+ return {
587
+ "total_conversations": len(self.conversation_history),
588
+ "medical_conversations": 0,
589
+ "conversational_interactions": len(self.conversation_history),
590
+ "model_info": "BioGPT" if self.model and "BioGPT" in str(type(self.model)) else "Fallback Model",
591
+ "vector_search": "FAISS CPU" if self.faiss_ready else "Keyword Search",
592
+ "device": self.device
593
+ }
594
+
595
+ avg_retrieval_time = sum(h.get('retrieval_time', 0) for h in medical_conversations) / len(medical_conversations)
596
+ avg_generation_time = sum(h.get('generation_time', 0) for h in medical_conversations) / len(medical_conversations)
597
+
598
+ return {
599
+ "total_conversations": len(self.conversation_history),
600
+ "medical_conversations": len(medical_conversations),
601
+ "conversational_interactions": len(self.conversation_history) - len(medical_conversations),
602
+ "avg_retrieval_time": f"{avg_retrieval_time:.2f}s",
603
+ "avg_generation_time": f"{avg_generation_time:.2f}s",
604
+ "model_info": "BioGPT" if self.model and "BioGPT" in str(type(self.model)) else "Fallback Model",
605
+ "vector_search": "FAISS CPU" if self.faiss_ready else "Keyword Search",
606
+ "device": self.device,
607
+ "quantization": "8-bit" if self.use_8bit else "16-bit/32-bit"
608
+ }