Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from threading import Thread | |
import spaces | |
# Model configuration | |
MODEL_ID = "yasserrmd/DentaInstruct-1.2B" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Initialize model and tokenizer | |
print(f"Loading model {MODEL_ID}...") | |
# Load tokenizer - try the fine-tuned model first, then base model | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
print(f"Loaded tokenizer from {MODEL_ID}") | |
except Exception as e: | |
print(f"Failed to load tokenizer from {MODEL_ID}: {e}") | |
print("Using tokenizer from base LFM2 model...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-1.2B") | |
except Exception as e2: | |
print(f"Failed to load LFM2 tokenizer: {e2}") | |
print("Using fallback TinyLlama tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
# Load model with proper dtype for efficiency | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" if torch.cuda.is_available() else None | |
) | |
if not torch.cuda.is_available(): | |
model = model.to(DEVICE) | |
# Set padding token if not set | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
def format_prompt(message, history): | |
"""Format the prompt for the model""" | |
messages = [] | |
# Add conversation history | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Add current message | |
messages.append({"role": "user", "content": message}) | |
# Apply chat template | |
if hasattr(tokenizer, 'apply_chat_template'): | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
else: | |
# Fallback formatting | |
prompt = "" | |
for msg in messages: | |
if msg["role"] == "user": | |
prompt += f"User: {msg['content']}\n" | |
else: | |
prompt += f"Assistant: {msg['content']}\n" | |
prompt += "Assistant: " | |
return prompt | |
def generate_response_streaming( | |
message, | |
history, | |
temperature=0.3, | |
max_new_tokens=512, | |
top_p=0.95, | |
repetition_penalty=1.05, | |
): | |
"""Generate response from the model with streaming""" | |
# Format the prompt | |
prompt = format_prompt(message, history) | |
# Tokenize input | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
# Move to device and filter out token_type_ids if present | |
model_inputs = {} | |
for k, v in inputs.items(): | |
if k != 'token_type_ids': # Filter out token_type_ids | |
model_inputs[k] = v.to(model.device) | |
# Set up the streamer | |
streamer = TextIteratorStreamer( | |
tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=True, | |
timeout=30.0 | |
) | |
# Generation parameters | |
generation_kwargs = dict( | |
**model_inputs, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
streamer=streamer, | |
) | |
# Start generation in a separate thread | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Stream the response | |
partial_response = "" | |
for new_text in streamer: | |
partial_response += new_text | |
yield partial_response | |
thread.join() | |
# Question categories for the carousel | |
QUESTION_CATEGORIES = { | |
"Patient education": [ | |
"What are the main types of dental cavities and how can I prevent them?", | |
"Explain the stages of gum disease from gingivitis to periodontitis", | |
"What should I expect during my first dental cleaning appointment?" | |
], | |
"Procedures": [ | |
"Walk me through the steps of a root canal treatment", | |
"What's the difference between a crown and a veneer?", | |
"How does the dental implant process work from start to finish?" | |
], | |
"Preventative care advice": [ | |
"What's the proper brushing technique for optimal plaque removal?", | |
"How does fluoride protect teeth and is it safe for children?", | |
"What foods should I avoid to maintain healthy teeth?" | |
], | |
"Anatomy and terms": [ | |
"Explain the anatomy of a tooth from crown to root", | |
"What are the different types of teeth and their functions?", | |
"What is the difference between enamel, dentin, and pulp?" | |
] | |
} | |
# Custom CSS for the redesigned interface | |
custom_css = """ | |
/* Reset and base styles */ | |
* { | |
box-sizing: border-box; | |
} | |
/* Header with credits */ | |
.header-container { | |
background: linear-gradient(135deg, #1e40af 0%, #3b82f6 50%, #60a5fa 100%); | |
border-radius: 16px; | |
padding: 32px; | |
margin-bottom: 24px; | |
color: white; | |
text-align: center; | |
box-shadow: 0 8px 32px rgba(30, 64, 175, 0.3); | |
} | |
.dark .header-container { | |
background: linear-gradient(135deg, #1e3a8a 0%, #3730a3 50%, #4338ca 100%); | |
} | |
.header-title { | |
font-size: 40px; | |
font-weight: 800; | |
margin-bottom: 8px; | |
text-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
} | |
.header-subtitle { | |
font-size: 18px; | |
opacity: 0.95; | |
margin-bottom: 16px; | |
} | |
.header-credits { | |
font-size: 14px; | |
opacity: 0.9; | |
margin-bottom: 12px; | |
} | |
.header-credits a { | |
color: #fef3c7; | |
text-decoration: none; | |
font-weight: 500; | |
} | |
.header-credits a:hover { | |
color: #fde68a; | |
text-decoration: underline; | |
} | |
.social-icon { | |
display: inline-block; | |
margin-left: 8px; | |
text-decoration: none; | |
font-size: 18px; | |
opacity: 0.85; | |
transition: all 0.2s ease; | |
vertical-align: middle; | |
} | |
.social-icon:hover { | |
opacity: 1; | |
transform: translateY(-2px); | |
} | |
/* Mini model card - skeuomorphic design */ | |
.model-card { | |
background: linear-gradient(145deg, #f8fafc 0%, #e2e8f0 100%); | |
border: 1px solid #cbd5e1; | |
border-radius: 16px; | |
padding: 20px; | |
margin-bottom: 24px; | |
box-shadow: | |
0 10px 25px rgba(0,0,0,0.1), | |
inset 0 1px 0 rgba(255,255,255,0.6); | |
position: relative; | |
overflow: hidden; | |
} | |
.dark .model-card { | |
background: linear-gradient(145deg, #374151 0%, #1f2937 100%); | |
border: 1px solid #4b5563; | |
box-shadow: | |
0 10px 25px rgba(0,0,0,0.3), | |
inset 0 1px 0 rgba(255,255,255,0.1); | |
} | |
.model-card::before { | |
content: ''; | |
position: absolute; | |
top: 0; | |
left: 0; | |
right: 0; | |
height: 2px; | |
background: linear-gradient(90deg, #3b82f6, #8b5cf6, #ef4444, #f59e0b); | |
} | |
.model-card-title { | |
font-size: 20px; | |
font-weight: 700; | |
color: #1e293b; | |
margin-bottom: 12px; | |
display: flex; | |
align-items: center; | |
gap: 8px; | |
} | |
.dark .model-card-title { | |
color: #f1f5f9; | |
} | |
.model-stats { | |
display: grid; | |
grid-template-columns: repeat(auto-fit, minmax(120px, 1fr)); | |
gap: 12px; | |
margin-bottom: 16px; | |
} | |
.model-stat { | |
background: rgba(59, 130, 246, 0.1); | |
border: 1px solid rgba(59, 130, 246, 0.2); | |
border-radius: 8px; | |
padding: 8px 12px; | |
text-align: center; | |
} | |
.dark .model-stat { | |
background: rgba(59, 130, 246, 0.15); | |
border: 1px solid rgba(59, 130, 246, 0.3); | |
} | |
.stat-value { | |
font-weight: 700; | |
font-size: 14px; | |
color: #3b82f6; | |
} | |
.dark .stat-value { | |
color: #60a5fa; | |
} | |
.stat-label { | |
font-size: 11px; | |
color: #64748b; | |
text-transform: uppercase; | |
letter-spacing: 0.5px; | |
margin-top: 2px; | |
} | |
.dark .stat-label { | |
color: #94a3b8; | |
} | |
.model-description { | |
color: #475569; | |
font-size: 14px; | |
line-height: 1.6; | |
} | |
.dark .model-description { | |
color: #cbd5e1; | |
} | |
/* Question carousel - right side */ | |
.question-carousel { | |
background: linear-gradient(145deg, #ffffff 0%, #f1f5f9 100%); | |
border: 1px solid #e2e8f0; | |
border-radius: 16px; | |
padding: 20px; | |
box-shadow: | |
0 4px 16px rgba(0,0,0,0.08), | |
inset 0 1px 0 rgba(255,255,255,0.8); | |
height: fit-content; | |
position: sticky; | |
top: 20px; | |
} | |
.dark .question-carousel { | |
background: linear-gradient(145deg, #1f2937 0%, #111827 100%); | |
border: 1px solid #374151; | |
box-shadow: | |
0 4px 16px rgba(0,0,0,0.2), | |
inset 0 1px 0 rgba(255,255,255,0.05); | |
} | |
.carousel-title { | |
font-size: 18px; | |
font-weight: 700; | |
color: #1e293b; | |
margin-bottom: 16px; | |
text-align: center; | |
} | |
.dark .carousel-title { | |
color: #f1f5f9; | |
} | |
.carousel-card { | |
background: linear-gradient(135deg, #fafafa 0%, #f4f4f5 100%); | |
border: 1px solid #e4e4e7; | |
border-radius: 12px; | |
padding: 16px; | |
margin-bottom: 16px; | |
box-shadow: | |
0 2px 8px rgba(0,0,0,0.06), | |
inset 0 1px 0 rgba(255,255,255,0.7); | |
transition: transform 0.2s, box-shadow 0.2s; | |
} | |
.dark .carousel-card { | |
background: linear-gradient(135deg, #374151 0%, #2d3748 100%); | |
border: 1px solid #4b5563; | |
box-shadow: | |
0 2px 8px rgba(0,0,0,0.15), | |
inset 0 1px 0 rgba(255,255,255,0.05); | |
} | |
.carousel-card:hover { | |
transform: translateY(-2px); | |
box-shadow: | |
0 4px 16px rgba(0,0,0,0.12), | |
inset 0 1px 0 rgba(255,255,255,0.7); | |
} | |
.dark .carousel-card:hover { | |
box-shadow: | |
0 4px 16px rgba(0,0,0,0.25), | |
inset 0 1px 0 rgba(255,255,255,0.05); | |
} | |
.carousel-card-title { | |
font-weight: 600; | |
color: #3b82f6; | |
margin-bottom: 12px; | |
font-size: 15px; | |
} | |
.dark .carousel-card-title { | |
color: #60a5fa; | |
} | |
.question-button { | |
display: block; | |
width: 100%; | |
background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%); | |
border: 1px solid #cbd5e1; | |
border-radius: 8px; | |
padding: 8px 12px; | |
margin-bottom: 8px; | |
font-size: 13px; | |
color: #475569; | |
text-align: left; | |
cursor: pointer; | |
transition: all 0.2s; | |
box-shadow: 0 1px 3px rgba(0,0,0,0.05); | |
} | |
.dark .question-button { | |
background: linear-gradient(135deg, #4b5563 0%, #374151 100%); | |
border: 1px solid #6b7280; | |
color: #d1d5db; | |
} | |
.question-button:hover { | |
background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%); | |
color: white; | |
border-color: #3b82f6; | |
transform: translateY(-1px); | |
box-shadow: 0 2px 8px rgba(59, 130, 246, 0.3); | |
} | |
/* Loading animation */ | |
@keyframes pulse { | |
0% { opacity: 1; } | |
50% { opacity: 0.5; } | |
100% { opacity: 1; } | |
} | |
.processing { | |
animation: pulse 1.5s ease-in-out infinite; | |
} | |
/* Typing indicator */ | |
@keyframes typing { | |
0%, 60%, 100% { opacity: 0.3; } | |
30% { opacity: 1; } | |
} | |
.typing-indicator { | |
display: inline-block; | |
animation: typing 1.4s infinite; | |
} | |
.question-button:last-child { | |
margin-bottom: 0; | |
} | |
/* Main layout */ | |
.main-layout { | |
display: grid; | |
grid-template-columns: 2fr 1fr; | |
gap: 24px; | |
margin-bottom: 24px; | |
} | |
@media (max-width: 1024px) { | |
.main-layout { | |
grid-template-columns: 1fr; | |
} | |
.question-carousel { | |
position: static; | |
} | |
} | |
/* Chat interface improvements */ | |
.chat-container { | |
background: linear-gradient(145deg, #ffffff 0%, #f8fafc 100%); | |
border: 1px solid #e2e8f0; | |
border-radius: 16px; | |
padding: 20px; | |
box-shadow: | |
0 4px 16px rgba(0,0,0,0.08), | |
inset 0 1px 0 rgba(255,255,255,0.8); | |
} | |
.dark .chat-container { | |
background: linear-gradient(145deg, #1f2937 0%, #111827 100%); | |
border: 1px solid #374151; | |
box-shadow: | |
0 4px 16px rgba(0,0,0,0.2), | |
inset 0 1px 0 rgba(255,255,255,0.05); | |
} | |
/* Citation boxes */ | |
.citation-section { | |
margin-top: 32px; | |
padding-top: 24px; | |
border-top: 2px solid #e2e8f0; | |
} | |
.dark .citation-section { | |
border-top: 2px solid #374151; | |
} | |
.citation-title { | |
font-size: 20px; | |
font-weight: 700; | |
color: #1e293b; | |
margin-bottom: 16px; | |
text-align: center; | |
} | |
.dark .citation-title { | |
color: #f1f5f9; | |
} | |
.citation-boxes { | |
display: grid; | |
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); | |
gap: 16px; | |
} | |
.citation-box { | |
background: linear-gradient(145deg, #f8fafc 0%, #e2e8f0 100%); | |
border: 1px solid #cbd5e1; | |
border-radius: 12px; | |
padding: 16px; | |
box-shadow: | |
0 4px 12px rgba(0,0,0,0.08), | |
inset 0 1px 0 rgba(255,255,255,0.6); | |
} | |
.dark .citation-box { | |
background: linear-gradient(145deg, #374151 0%, #1f2937 100%); | |
border: 1px solid #4b5563; | |
box-shadow: | |
0 4px 12px rgba(0,0,0,0.2), | |
inset 0 1px 0 rgba(255,255,255,0.1); | |
} | |
.citation-box h4 { | |
color: #3b82f6; | |
font-weight: 600; | |
margin-bottom: 8px; | |
font-size: 16px; | |
} | |
.dark .citation-box h4 { | |
color: #60a5fa; | |
} | |
.citation-content { | |
background: #1f2937; | |
color: #e5e7eb; | |
padding: 12px; | |
border-radius: 8px; | |
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; | |
font-size: 12px; | |
line-height: 1.4; | |
overflow-x: auto; | |
white-space: pre-wrap; | |
word-break: break-all; | |
margin-top: 8px; | |
} | |
.dark .citation-content { | |
background: #111827; | |
border: 1px solid #374151; | |
} | |
/* Advanced settings styling */ | |
.advanced-settings { | |
background: linear-gradient(145deg, #f1f5f9 0%, #e2e8f0 100%); | |
border: 1px solid #cbd5e1; | |
border-radius: 12px; | |
margin: 16px 0; | |
box-shadow: | |
0 2px 8px rgba(0,0,0,0.06), | |
inset 0 1px 0 rgba(255,255,255,0.7); | |
} | |
.dark .advanced-settings { | |
background: linear-gradient(145deg, #374151 0%, #1f2937 100%); | |
border: 1px solid #4b5563; | |
box-shadow: | |
0 2px 8px rgba(0,0,0,0.15), | |
inset 0 1px 0 rgba(255,255,255,0.05); | |
} | |
/* Disclaimer styling */ | |
.disclaimer-box { | |
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%); | |
border: 2px solid #f59e0b; | |
border-radius: 12px; | |
padding: 16px 20px; | |
margin: 20px 0; | |
box-shadow: | |
0 4px 12px rgba(245, 158, 11, 0.2), | |
inset 0 1px 0 rgba(255,255,255,0.5); | |
position: relative; | |
} | |
.dark .disclaimer-box { | |
background: linear-gradient(135deg, #92400e 0%, #78350f 100%); | |
border: 2px solid #f59e0b; | |
color: #fef3c7; | |
box-shadow: | |
0 4px 12px rgba(245, 158, 11, 0.3), | |
inset 0 1px 0 rgba(255,255,255,0.1); | |
} | |
.disclaimer-box::before { | |
content: ''; | |
position: absolute; | |
left: 0; | |
top: 0; | |
bottom: 0; | |
width: 4px; | |
background: #f59e0b; | |
border-radius: 2px 0 0 2px; | |
} | |
.disclaimer-title { | |
font-weight: 600; | |
color: #92400e; | |
margin-bottom: 8px; | |
display: flex; | |
align-items: center; | |
gap: 8px; | |
font-size: 15px; | |
} | |
.dark .disclaimer-title { | |
color: #fbbf24; | |
} | |
.disclaimer-text { | |
color: #78350f; | |
font-size: 14px; | |
line-height: 1.5; | |
} | |
.dark .disclaimer-text { | |
color: #fef3c7; | |
} | |
/* Button improvements */ | |
.gr-button { | |
border-radius: 8px !important; | |
font-weight: 600 !important; | |
transition: all 0.2s ease !important; | |
} | |
.gr-button-primary { | |
background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%) !important; | |
border: none !important; | |
color: white !important; | |
box-shadow: 0 2px 4px rgba(59, 130, 246, 0.3) !important; | |
} | |
.gr-button-primary:hover { | |
background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%) !important; | |
transform: translateY(-1px) !important; | |
box-shadow: 0 4px 12px rgba(59, 130, 246, 0.4) !important; | |
} | |
/* Responsive design */ | |
@media (max-width: 768px) { | |
.header-title { | |
font-size: 28px; | |
} | |
.model-stats { | |
grid-template-columns: repeat(2, 1fr); | |
} | |
.social-icon { | |
font-size: 16px; | |
} | |
.citation-boxes { | |
grid-template-columns: 1fr; | |
} | |
} | |
""" | |
# Create the Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
# Header with credits and social links | |
gr.HTML( | |
""" | |
<div class="header-container"> | |
<h1 class="header-title">🦷 DentaInstruct-1.2B Demo</h1> | |
<p class="header-subtitle">Advanced AI assistant for dental education and oral health information</p> | |
<div class="header-credits"> | |
Model by <a href="https://huggingface.co/yasserrmd" target="_blank">yasserrmd</a> | |
<a href="https://github.com/YASSERRMD" target="_blank" class="social-icon" title="GitHub">🔗</a> | |
<a href="https://www.linkedin.com/in/moyasser" target="_blank" class="social-icon" title="LinkedIn">💼</a> | |
<span style="margin: 0 10px;">/</span> | |
Space by <a href="https://huggingface.co/chrisvoncsefalvay" target="_blank">Chris von Csefalvay</a> | |
<a href="https://github.com/chrisvoncsefalvay" target="_blank" class="social-icon" title="GitHub">🔗</a> | |
<a href="https://twitter.com/epichrisis" target="_blank" class="social-icon" title="X">𝕏</a> | |
<a href="https://chrisvoncsefalvay.com" target="_blank" class="social-icon" title="Website">🌐</a> | |
</div> | |
</div> | |
""" | |
) | |
# Mini model card with skeuomorphic design | |
gr.HTML( | |
""" | |
<div class="model-card"> | |
<div class="model-card-title"> | |
🧠 Model Information | |
</div> | |
<div class="model-stats"> | |
<div class="model-stat"> | |
<div class="stat-value">1.17B</div> | |
<div class="stat-label">Parameters</div> | |
</div> | |
<div class="model-stat"> | |
<div class="stat-value">LFM2-1.2B</div> | |
<div class="stat-label">Base Model</div> | |
</div> | |
<div class="model-stat"> | |
<div class="stat-value">MIRIAD</div> | |
<div class="stat-label">Dataset</div> | |
</div> | |
<div class="model-stat"> | |
<div class="stat-value">2048</div> | |
<div class="stat-label">Context Length</div> | |
</div> | |
</div> | |
<div class="model-description"> | |
Specialised language model fine-tuned for dental education and oral health information. | |
Built on efficient LFM2 architecture with supervised fine-tuning on comprehensive dental content. | |
</div> | |
</div> | |
""" | |
) | |
# Disclaimer box | |
gr.HTML( | |
""" | |
<div class="disclaimer-box"> | |
<div class="disclaimer-title"> | |
⚠️ Educational Use Only - Important Medical Disclaimer | |
</div> | |
<div class="disclaimer-text"> | |
This AI model provides educational information about dental topics and is designed for learning purposes only. | |
It is <strong>NOT</strong> a substitute for professional dental or medical advice, diagnosis, or treatment. | |
Always seek the advice of your dentist or qualified healthcare provider with any questions about a medical condition or treatment. | |
</div> | |
</div> | |
""" | |
) | |
# Main layout with chat on left, carousel on right | |
with gr.Row(elem_classes="main-layout"): | |
# Left side - Chat interface | |
with gr.Column(scale=2, elem_classes="chat-container"): | |
chatbot = gr.Chatbot( | |
height=500, | |
label="Response", | |
show_label=True, | |
avatar_images=None, | |
bubble_full_width=False, | |
render_markdown=True, | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
label="Your dental question", | |
placeholder="Ask about dental procedures, oral health, treatment options, or any dental topic...", | |
lines=3, | |
scale=4, | |
container=False, | |
) | |
with gr.Row(): | |
submit = gr.Button("Send Question", variant="primary", scale=1, elem_id="send-btn") | |
clear = gr.Button("Clear Chat", scale=1) | |
# Status indicator | |
status = gr.Textbox(value="", label="Status", visible=False) | |
# Right side - Question carousel | |
with gr.Column(scale=1): | |
gr.HTML(""" | |
<div class="question-carousel"> | |
<div class="carousel-title">💡 Quick Questions</div> | |
</div> | |
""") | |
# Create buttons for quick questions | |
question_buttons = [] | |
for category, questions in QUESTION_CATEGORIES.items(): | |
with gr.Group(): | |
gr.HTML(f'<div class="carousel-card"><div class="carousel-card-title">{category}</div></div>') | |
for question in questions: | |
btn = gr.Button( | |
question, | |
variant="secondary", | |
size="sm", | |
elem_classes="question-button" | |
) | |
question_buttons.append((btn, question)) | |
# Advanced settings in collapsible section | |
with gr.Accordion("⚙️ Advanced Settings", open=False, elem_classes="advanced-settings"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.3, | |
step=0.1, | |
label="Temperature", | |
info="Lower values (0.1-0.3) for factual responses, higher (0.7-1.0) for creative explanations" | |
) | |
max_new_tokens = gr.Slider( | |
minimum=64, | |
maximum=1024, | |
value=512, | |
step=64, | |
label="Response Length", | |
info="Maximum number of tokens in the response" | |
) | |
with gr.Column(scale=1): | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (Nucleus Sampling)", | |
info="Controls diversity of word choices" | |
) | |
repetition_penalty = gr.Slider( | |
minimum=1.0, | |
maximum=1.5, | |
value=1.05, | |
step=0.05, | |
label="Repetition Penalty", | |
info="Reduces repetitive phrases in responses" | |
) | |
# Citation boxes section | |
gr.HTML( | |
""" | |
<div class="citation-section"> | |
<div class="citation-title">📚 Citations</div> | |
<div class="citation-boxes"> | |
<div class="citation-box"> | |
<h4>MIRIAD Dataset</h4> | |
<p>Training dataset used for fine-tuning the dental knowledge base.</p> | |
<div class="citation-content">@misc{miriad2024, | |
title={MIRIAD: A Multi-modal Instruction-following Dataset for Dentistry}, | |
author={MIRIAD Team}, | |
year={2024}, | |
url={https://huggingface.co/datasets/miriad} | |
}</div> | |
</div> | |
<div class="citation-box"> | |
<h4>DentaInstruct-1.2B Model</h4> | |
<p>The fine-tuned model used in this demonstration.</p> | |
<div class="citation-content">@misc{dentainstruct2024, | |
title={DentaInstruct-1.2B: A Dental Education Language Model}, | |
author={yasserrmd}, | |
year={2024}, | |
url={https://huggingface.co/yasserrmd/DentaInstruct-1.2B} | |
}</div> | |
</div> | |
</div> | |
</div> | |
""" | |
) | |
# Event handlers | |
def respond(message, chat_history, temperature, max_new_tokens, top_p, repetition_penalty): | |
"""Handle user messages and generate responses""" | |
if not message.strip(): | |
gr.Warning("Please enter a question") | |
return "", chat_history, gr.update(value="Send Question") | |
try: | |
# Show initial processing state | |
yield "", chat_history + [(message, "🔄 Starting...")], gr.update(value="⏳ Generating...") | |
# Stream the response | |
partial_response = "" | |
for chunk in generate_response_streaming( | |
message, | |
chat_history, | |
temperature, | |
max_new_tokens, | |
top_p, | |
repetition_penalty | |
): | |
partial_response = chunk | |
# Update chat with partial response and typing indicator | |
current_history = chat_history + [(message, partial_response + " ●")] | |
yield "", current_history, gr.update(value="⏳ Generating...") | |
# Final update with complete response | |
chat_history.append((message, partial_response)) | |
yield "", chat_history, gr.update(value="Send Question") | |
except Exception as e: | |
gr.Error(f"An error occurred: {str(e)}") | |
yield message, chat_history, gr.update(value="Send Question") | |
# Connect event handlers with loading states | |
msg.submit( | |
respond, | |
[msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty], | |
[msg, chatbot, submit], | |
queue=True, | |
show_progress="full" | |
).then( | |
lambda: gr.update(interactive=True), | |
None, | |
[msg] | |
) | |
submit.click( | |
lambda: gr.update(interactive=False), | |
None, | |
[msg] | |
).then( | |
respond, | |
[msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty], | |
[msg, chatbot, submit], | |
queue=True, | |
show_progress="full" | |
).then( | |
lambda: gr.update(interactive=True), | |
None, | |
[msg] | |
) | |
clear.click( | |
lambda: (None, ""), | |
None, | |
[chatbot, msg], | |
queue=False | |
) | |
# Connect question button click handlers | |
for btn, question_text in question_buttons: | |
btn.click( | |
lambda q=question_text: q, | |
None, | |
msg, | |
queue=False | |
) | |
# Launch configuration | |
if __name__ == "__main__": | |
demo.queue(max_size=10) | |
demo.launch( | |
share=False, | |
show_error=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |