chrisvoncsefalvay's picture
Implement streaming inference display
cdba9e2
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import spaces
# Model configuration
MODEL_ID = "yasserrmd/DentaInstruct-1.2B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize model and tokenizer
print(f"Loading model {MODEL_ID}...")
# Load tokenizer - try the fine-tuned model first, then base model
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print(f"Loaded tokenizer from {MODEL_ID}")
except Exception as e:
print(f"Failed to load tokenizer from {MODEL_ID}: {e}")
print("Using tokenizer from base LFM2 model...")
try:
tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-1.2B")
except Exception as e2:
print(f"Failed to load LFM2 tokenizer: {e2}")
print("Using fallback TinyLlama tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
# Load model with proper dtype for efficiency
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
if not torch.cuda.is_available():
model = model.to(DEVICE)
# Set padding token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def format_prompt(message, history):
"""Format the prompt for the model"""
messages = []
# Add conversation history
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Add current message
messages.append({"role": "user", "content": message})
# Apply chat template
if hasattr(tokenizer, 'apply_chat_template'):
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
# Fallback formatting
prompt = ""
for msg in messages:
if msg["role"] == "user":
prompt += f"User: {msg['content']}\n"
else:
prompt += f"Assistant: {msg['content']}\n"
prompt += "Assistant: "
return prompt
@spaces.GPU(duration=60)
def generate_response_streaming(
message,
history,
temperature=0.3,
max_new_tokens=512,
top_p=0.95,
repetition_penalty=1.05,
):
"""Generate response from the model with streaming"""
# Format the prompt
prompt = format_prompt(message, history)
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
# Move to device and filter out token_type_ids if present
model_inputs = {}
for k, v in inputs.items():
if k != 'token_type_ids': # Filter out token_type_ids
model_inputs[k] = v.to(model.device)
# Set up the streamer
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
timeout=30.0
)
# Generation parameters
generation_kwargs = dict(
**model_inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
streamer=streamer,
)
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream the response
partial_response = ""
for new_text in streamer:
partial_response += new_text
yield partial_response
thread.join()
# Question categories for the carousel
QUESTION_CATEGORIES = {
"Patient education": [
"What are the main types of dental cavities and how can I prevent them?",
"Explain the stages of gum disease from gingivitis to periodontitis",
"What should I expect during my first dental cleaning appointment?"
],
"Procedures": [
"Walk me through the steps of a root canal treatment",
"What's the difference between a crown and a veneer?",
"How does the dental implant process work from start to finish?"
],
"Preventative care advice": [
"What's the proper brushing technique for optimal plaque removal?",
"How does fluoride protect teeth and is it safe for children?",
"What foods should I avoid to maintain healthy teeth?"
],
"Anatomy and terms": [
"Explain the anatomy of a tooth from crown to root",
"What are the different types of teeth and their functions?",
"What is the difference between enamel, dentin, and pulp?"
]
}
# Custom CSS for the redesigned interface
custom_css = """
/* Reset and base styles */
* {
box-sizing: border-box;
}
/* Header with credits */
.header-container {
background: linear-gradient(135deg, #1e40af 0%, #3b82f6 50%, #60a5fa 100%);
border-radius: 16px;
padding: 32px;
margin-bottom: 24px;
color: white;
text-align: center;
box-shadow: 0 8px 32px rgba(30, 64, 175, 0.3);
}
.dark .header-container {
background: linear-gradient(135deg, #1e3a8a 0%, #3730a3 50%, #4338ca 100%);
}
.header-title {
font-size: 40px;
font-weight: 800;
margin-bottom: 8px;
text-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.header-subtitle {
font-size: 18px;
opacity: 0.95;
margin-bottom: 16px;
}
.header-credits {
font-size: 14px;
opacity: 0.9;
margin-bottom: 12px;
}
.header-credits a {
color: #fef3c7;
text-decoration: none;
font-weight: 500;
}
.header-credits a:hover {
color: #fde68a;
text-decoration: underline;
}
.social-icon {
display: inline-block;
margin-left: 8px;
text-decoration: none;
font-size: 18px;
opacity: 0.85;
transition: all 0.2s ease;
vertical-align: middle;
}
.social-icon:hover {
opacity: 1;
transform: translateY(-2px);
}
/* Mini model card - skeuomorphic design */
.model-card {
background: linear-gradient(145deg, #f8fafc 0%, #e2e8f0 100%);
border: 1px solid #cbd5e1;
border-radius: 16px;
padding: 20px;
margin-bottom: 24px;
box-shadow:
0 10px 25px rgba(0,0,0,0.1),
inset 0 1px 0 rgba(255,255,255,0.6);
position: relative;
overflow: hidden;
}
.dark .model-card {
background: linear-gradient(145deg, #374151 0%, #1f2937 100%);
border: 1px solid #4b5563;
box-shadow:
0 10px 25px rgba(0,0,0,0.3),
inset 0 1px 0 rgba(255,255,255,0.1);
}
.model-card::before {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
height: 2px;
background: linear-gradient(90deg, #3b82f6, #8b5cf6, #ef4444, #f59e0b);
}
.model-card-title {
font-size: 20px;
font-weight: 700;
color: #1e293b;
margin-bottom: 12px;
display: flex;
align-items: center;
gap: 8px;
}
.dark .model-card-title {
color: #f1f5f9;
}
.model-stats {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(120px, 1fr));
gap: 12px;
margin-bottom: 16px;
}
.model-stat {
background: rgba(59, 130, 246, 0.1);
border: 1px solid rgba(59, 130, 246, 0.2);
border-radius: 8px;
padding: 8px 12px;
text-align: center;
}
.dark .model-stat {
background: rgba(59, 130, 246, 0.15);
border: 1px solid rgba(59, 130, 246, 0.3);
}
.stat-value {
font-weight: 700;
font-size: 14px;
color: #3b82f6;
}
.dark .stat-value {
color: #60a5fa;
}
.stat-label {
font-size: 11px;
color: #64748b;
text-transform: uppercase;
letter-spacing: 0.5px;
margin-top: 2px;
}
.dark .stat-label {
color: #94a3b8;
}
.model-description {
color: #475569;
font-size: 14px;
line-height: 1.6;
}
.dark .model-description {
color: #cbd5e1;
}
/* Question carousel - right side */
.question-carousel {
background: linear-gradient(145deg, #ffffff 0%, #f1f5f9 100%);
border: 1px solid #e2e8f0;
border-radius: 16px;
padding: 20px;
box-shadow:
0 4px 16px rgba(0,0,0,0.08),
inset 0 1px 0 rgba(255,255,255,0.8);
height: fit-content;
position: sticky;
top: 20px;
}
.dark .question-carousel {
background: linear-gradient(145deg, #1f2937 0%, #111827 100%);
border: 1px solid #374151;
box-shadow:
0 4px 16px rgba(0,0,0,0.2),
inset 0 1px 0 rgba(255,255,255,0.05);
}
.carousel-title {
font-size: 18px;
font-weight: 700;
color: #1e293b;
margin-bottom: 16px;
text-align: center;
}
.dark .carousel-title {
color: #f1f5f9;
}
.carousel-card {
background: linear-gradient(135deg, #fafafa 0%, #f4f4f5 100%);
border: 1px solid #e4e4e7;
border-radius: 12px;
padding: 16px;
margin-bottom: 16px;
box-shadow:
0 2px 8px rgba(0,0,0,0.06),
inset 0 1px 0 rgba(255,255,255,0.7);
transition: transform 0.2s, box-shadow 0.2s;
}
.dark .carousel-card {
background: linear-gradient(135deg, #374151 0%, #2d3748 100%);
border: 1px solid #4b5563;
box-shadow:
0 2px 8px rgba(0,0,0,0.15),
inset 0 1px 0 rgba(255,255,255,0.05);
}
.carousel-card:hover {
transform: translateY(-2px);
box-shadow:
0 4px 16px rgba(0,0,0,0.12),
inset 0 1px 0 rgba(255,255,255,0.7);
}
.dark .carousel-card:hover {
box-shadow:
0 4px 16px rgba(0,0,0,0.25),
inset 0 1px 0 rgba(255,255,255,0.05);
}
.carousel-card-title {
font-weight: 600;
color: #3b82f6;
margin-bottom: 12px;
font-size: 15px;
}
.dark .carousel-card-title {
color: #60a5fa;
}
.question-button {
display: block;
width: 100%;
background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%);
border: 1px solid #cbd5e1;
border-radius: 8px;
padding: 8px 12px;
margin-bottom: 8px;
font-size: 13px;
color: #475569;
text-align: left;
cursor: pointer;
transition: all 0.2s;
box-shadow: 0 1px 3px rgba(0,0,0,0.05);
}
.dark .question-button {
background: linear-gradient(135deg, #4b5563 0%, #374151 100%);
border: 1px solid #6b7280;
color: #d1d5db;
}
.question-button:hover {
background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%);
color: white;
border-color: #3b82f6;
transform: translateY(-1px);
box-shadow: 0 2px 8px rgba(59, 130, 246, 0.3);
}
/* Loading animation */
@keyframes pulse {
0% { opacity: 1; }
50% { opacity: 0.5; }
100% { opacity: 1; }
}
.processing {
animation: pulse 1.5s ease-in-out infinite;
}
/* Typing indicator */
@keyframes typing {
0%, 60%, 100% { opacity: 0.3; }
30% { opacity: 1; }
}
.typing-indicator {
display: inline-block;
animation: typing 1.4s infinite;
}
.question-button:last-child {
margin-bottom: 0;
}
/* Main layout */
.main-layout {
display: grid;
grid-template-columns: 2fr 1fr;
gap: 24px;
margin-bottom: 24px;
}
@media (max-width: 1024px) {
.main-layout {
grid-template-columns: 1fr;
}
.question-carousel {
position: static;
}
}
/* Chat interface improvements */
.chat-container {
background: linear-gradient(145deg, #ffffff 0%, #f8fafc 100%);
border: 1px solid #e2e8f0;
border-radius: 16px;
padding: 20px;
box-shadow:
0 4px 16px rgba(0,0,0,0.08),
inset 0 1px 0 rgba(255,255,255,0.8);
}
.dark .chat-container {
background: linear-gradient(145deg, #1f2937 0%, #111827 100%);
border: 1px solid #374151;
box-shadow:
0 4px 16px rgba(0,0,0,0.2),
inset 0 1px 0 rgba(255,255,255,0.05);
}
/* Citation boxes */
.citation-section {
margin-top: 32px;
padding-top: 24px;
border-top: 2px solid #e2e8f0;
}
.dark .citation-section {
border-top: 2px solid #374151;
}
.citation-title {
font-size: 20px;
font-weight: 700;
color: #1e293b;
margin-bottom: 16px;
text-align: center;
}
.dark .citation-title {
color: #f1f5f9;
}
.citation-boxes {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
gap: 16px;
}
.citation-box {
background: linear-gradient(145deg, #f8fafc 0%, #e2e8f0 100%);
border: 1px solid #cbd5e1;
border-radius: 12px;
padding: 16px;
box-shadow:
0 4px 12px rgba(0,0,0,0.08),
inset 0 1px 0 rgba(255,255,255,0.6);
}
.dark .citation-box {
background: linear-gradient(145deg, #374151 0%, #1f2937 100%);
border: 1px solid #4b5563;
box-shadow:
0 4px 12px rgba(0,0,0,0.2),
inset 0 1px 0 rgba(255,255,255,0.1);
}
.citation-box h4 {
color: #3b82f6;
font-weight: 600;
margin-bottom: 8px;
font-size: 16px;
}
.dark .citation-box h4 {
color: #60a5fa;
}
.citation-content {
background: #1f2937;
color: #e5e7eb;
padding: 12px;
border-radius: 8px;
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
font-size: 12px;
line-height: 1.4;
overflow-x: auto;
white-space: pre-wrap;
word-break: break-all;
margin-top: 8px;
}
.dark .citation-content {
background: #111827;
border: 1px solid #374151;
}
/* Advanced settings styling */
.advanced-settings {
background: linear-gradient(145deg, #f1f5f9 0%, #e2e8f0 100%);
border: 1px solid #cbd5e1;
border-radius: 12px;
margin: 16px 0;
box-shadow:
0 2px 8px rgba(0,0,0,0.06),
inset 0 1px 0 rgba(255,255,255,0.7);
}
.dark .advanced-settings {
background: linear-gradient(145deg, #374151 0%, #1f2937 100%);
border: 1px solid #4b5563;
box-shadow:
0 2px 8px rgba(0,0,0,0.15),
inset 0 1px 0 rgba(255,255,255,0.05);
}
/* Disclaimer styling */
.disclaimer-box {
background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
border: 2px solid #f59e0b;
border-radius: 12px;
padding: 16px 20px;
margin: 20px 0;
box-shadow:
0 4px 12px rgba(245, 158, 11, 0.2),
inset 0 1px 0 rgba(255,255,255,0.5);
position: relative;
}
.dark .disclaimer-box {
background: linear-gradient(135deg, #92400e 0%, #78350f 100%);
border: 2px solid #f59e0b;
color: #fef3c7;
box-shadow:
0 4px 12px rgba(245, 158, 11, 0.3),
inset 0 1px 0 rgba(255,255,255,0.1);
}
.disclaimer-box::before {
content: '';
position: absolute;
left: 0;
top: 0;
bottom: 0;
width: 4px;
background: #f59e0b;
border-radius: 2px 0 0 2px;
}
.disclaimer-title {
font-weight: 600;
color: #92400e;
margin-bottom: 8px;
display: flex;
align-items: center;
gap: 8px;
font-size: 15px;
}
.dark .disclaimer-title {
color: #fbbf24;
}
.disclaimer-text {
color: #78350f;
font-size: 14px;
line-height: 1.5;
}
.dark .disclaimer-text {
color: #fef3c7;
}
/* Button improvements */
.gr-button {
border-radius: 8px !important;
font-weight: 600 !important;
transition: all 0.2s ease !important;
}
.gr-button-primary {
background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%) !important;
border: none !important;
color: white !important;
box-shadow: 0 2px 4px rgba(59, 130, 246, 0.3) !important;
}
.gr-button-primary:hover {
background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%) !important;
transform: translateY(-1px) !important;
box-shadow: 0 4px 12px rgba(59, 130, 246, 0.4) !important;
}
/* Responsive design */
@media (max-width: 768px) {
.header-title {
font-size: 28px;
}
.model-stats {
grid-template-columns: repeat(2, 1fr);
}
.social-icon {
font-size: 16px;
}
.citation-boxes {
grid-template-columns: 1fr;
}
}
"""
# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
# Header with credits and social links
gr.HTML(
"""
<div class="header-container">
<h1 class="header-title">🦷 DentaInstruct-1.2B Demo</h1>
<p class="header-subtitle">Advanced AI assistant for dental education and oral health information</p>
<div class="header-credits">
Model by <a href="https://huggingface.co/yasserrmd" target="_blank">yasserrmd</a>
<a href="https://github.com/YASSERRMD" target="_blank" class="social-icon" title="GitHub">🔗</a>
<a href="https://www.linkedin.com/in/moyasser" target="_blank" class="social-icon" title="LinkedIn">💼</a>
<span style="margin: 0 10px;">/</span>
Space by <a href="https://huggingface.co/chrisvoncsefalvay" target="_blank">Chris von Csefalvay</a>
<a href="https://github.com/chrisvoncsefalvay" target="_blank" class="social-icon" title="GitHub">🔗</a>
<a href="https://twitter.com/epichrisis" target="_blank" class="social-icon" title="X">𝕏</a>
<a href="https://chrisvoncsefalvay.com" target="_blank" class="social-icon" title="Website">🌐</a>
</div>
</div>
"""
)
# Mini model card with skeuomorphic design
gr.HTML(
"""
<div class="model-card">
<div class="model-card-title">
🧠 Model Information
</div>
<div class="model-stats">
<div class="model-stat">
<div class="stat-value">1.17B</div>
<div class="stat-label">Parameters</div>
</div>
<div class="model-stat">
<div class="stat-value">LFM2-1.2B</div>
<div class="stat-label">Base Model</div>
</div>
<div class="model-stat">
<div class="stat-value">MIRIAD</div>
<div class="stat-label">Dataset</div>
</div>
<div class="model-stat">
<div class="stat-value">2048</div>
<div class="stat-label">Context Length</div>
</div>
</div>
<div class="model-description">
Specialised language model fine-tuned for dental education and oral health information.
Built on efficient LFM2 architecture with supervised fine-tuning on comprehensive dental content.
</div>
</div>
"""
)
# Disclaimer box
gr.HTML(
"""
<div class="disclaimer-box">
<div class="disclaimer-title">
⚠️ Educational Use Only - Important Medical Disclaimer
</div>
<div class="disclaimer-text">
This AI model provides educational information about dental topics and is designed for learning purposes only.
It is <strong>NOT</strong> a substitute for professional dental or medical advice, diagnosis, or treatment.
Always seek the advice of your dentist or qualified healthcare provider with any questions about a medical condition or treatment.
</div>
</div>
"""
)
# Main layout with chat on left, carousel on right
with gr.Row(elem_classes="main-layout"):
# Left side - Chat interface
with gr.Column(scale=2, elem_classes="chat-container"):
chatbot = gr.Chatbot(
height=500,
label="Response",
show_label=True,
avatar_images=None,
bubble_full_width=False,
render_markdown=True,
)
with gr.Row():
msg = gr.Textbox(
label="Your dental question",
placeholder="Ask about dental procedures, oral health, treatment options, or any dental topic...",
lines=3,
scale=4,
container=False,
)
with gr.Row():
submit = gr.Button("Send Question", variant="primary", scale=1, elem_id="send-btn")
clear = gr.Button("Clear Chat", scale=1)
# Status indicator
status = gr.Textbox(value="", label="Status", visible=False)
# Right side - Question carousel
with gr.Column(scale=1):
gr.HTML("""
<div class="question-carousel">
<div class="carousel-title">💡 Quick Questions</div>
</div>
""")
# Create buttons for quick questions
question_buttons = []
for category, questions in QUESTION_CATEGORIES.items():
with gr.Group():
gr.HTML(f'<div class="carousel-card"><div class="carousel-card-title">{category}</div></div>')
for question in questions:
btn = gr.Button(
question,
variant="secondary",
size="sm",
elem_classes="question-button"
)
question_buttons.append((btn, question))
# Advanced settings in collapsible section
with gr.Accordion("⚙️ Advanced Settings", open=False, elem_classes="advanced-settings"):
with gr.Row():
with gr.Column(scale=1):
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.3,
step=0.1,
label="Temperature",
info="Lower values (0.1-0.3) for factual responses, higher (0.7-1.0) for creative explanations"
)
max_new_tokens = gr.Slider(
minimum=64,
maximum=1024,
value=512,
step=64,
label="Response Length",
info="Maximum number of tokens in the response"
)
with gr.Column(scale=1):
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (Nucleus Sampling)",
info="Controls diversity of word choices"
)
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=1.5,
value=1.05,
step=0.05,
label="Repetition Penalty",
info="Reduces repetitive phrases in responses"
)
# Citation boxes section
gr.HTML(
"""
<div class="citation-section">
<div class="citation-title">📚 Citations</div>
<div class="citation-boxes">
<div class="citation-box">
<h4>MIRIAD Dataset</h4>
<p>Training dataset used for fine-tuning the dental knowledge base.</p>
<div class="citation-content">@misc{miriad2024,
title={MIRIAD: A Multi-modal Instruction-following Dataset for Dentistry},
author={MIRIAD Team},
year={2024},
url={https://huggingface.co/datasets/miriad}
}</div>
</div>
<div class="citation-box">
<h4>DentaInstruct-1.2B Model</h4>
<p>The fine-tuned model used in this demonstration.</p>
<div class="citation-content">@misc{dentainstruct2024,
title={DentaInstruct-1.2B: A Dental Education Language Model},
author={yasserrmd},
year={2024},
url={https://huggingface.co/yasserrmd/DentaInstruct-1.2B}
}</div>
</div>
</div>
</div>
"""
)
# Event handlers
def respond(message, chat_history, temperature, max_new_tokens, top_p, repetition_penalty):
"""Handle user messages and generate responses"""
if not message.strip():
gr.Warning("Please enter a question")
return "", chat_history, gr.update(value="Send Question")
try:
# Show initial processing state
yield "", chat_history + [(message, "🔄 Starting...")], gr.update(value="⏳ Generating...")
# Stream the response
partial_response = ""
for chunk in generate_response_streaming(
message,
chat_history,
temperature,
max_new_tokens,
top_p,
repetition_penalty
):
partial_response = chunk
# Update chat with partial response and typing indicator
current_history = chat_history + [(message, partial_response + " ●")]
yield "", current_history, gr.update(value="⏳ Generating...")
# Final update with complete response
chat_history.append((message, partial_response))
yield "", chat_history, gr.update(value="Send Question")
except Exception as e:
gr.Error(f"An error occurred: {str(e)}")
yield message, chat_history, gr.update(value="Send Question")
# Connect event handlers with loading states
msg.submit(
respond,
[msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty],
[msg, chatbot, submit],
queue=True,
show_progress="full"
).then(
lambda: gr.update(interactive=True),
None,
[msg]
)
submit.click(
lambda: gr.update(interactive=False),
None,
[msg]
).then(
respond,
[msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty],
[msg, chatbot, submit],
queue=True,
show_progress="full"
).then(
lambda: gr.update(interactive=True),
None,
[msg]
)
clear.click(
lambda: (None, ""),
None,
[chatbot, msg],
queue=False
)
# Connect question button click handlers
for btn, question_text in question_buttons:
btn.click(
lambda q=question_text: q,
None,
msg,
queue=False
)
# Launch configuration
if __name__ == "__main__":
demo.queue(max_size=10)
demo.launch(
share=False,
show_error=True,
server_name="0.0.0.0",
server_port=7860
)