chrisvoncsefalvay's picture
Major UI improvements: Enhanced design, fixed dark mode, better examples
0570cd5
raw
history blame
20.9 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
# Model configuration
MODEL_ID = "yasserrmd/DentaInstruct-1.2B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize model and tokenizer
print(f"Loading model {MODEL_ID}...")
# Load tokenizer - try the fine-tuned model first, then base model
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print(f"Loaded tokenizer from {MODEL_ID}")
except Exception as e:
print(f"Failed to load tokenizer from {MODEL_ID}: {e}")
print("Using tokenizer from base LFM2 model...")
try:
tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-1.2B")
except Exception as e2:
print(f"Failed to load LFM2 tokenizer: {e2}")
print("Using fallback TinyLlama tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
# Load model with proper dtype for efficiency
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
if not torch.cuda.is_available():
model = model.to(DEVICE)
# Set padding token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def format_prompt(message, history):
"""Format the prompt for the model"""
messages = []
# Add conversation history
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Add current message
messages.append({"role": "user", "content": message})
# Apply chat template
if hasattr(tokenizer, 'apply_chat_template'):
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
# Fallback formatting
prompt = ""
for msg in messages:
if msg["role"] == "user":
prompt += f"User: {msg['content']}\n"
else:
prompt += f"Assistant: {msg['content']}\n"
prompt += "Assistant: "
return prompt
@spaces.GPU(duration=60)
def generate_response(
message,
history,
temperature=0.3,
max_new_tokens=512,
top_p=0.95,
repetition_penalty=1.05,
):
"""Generate response from the model"""
# Format the prompt
prompt = format_prompt(message, history)
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
# Move to device and filter out token_type_ids if present
model_inputs = {}
for k, v in inputs.items():
if k != 'token_type_ids': # Filter out token_type_ids
model_inputs[k] = v.to(model.device)
# Generate response
with torch.no_grad():
outputs = model.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode response
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return response
# Categorised example questions for better showcase
EXAMPLE_CATEGORIES = {
"Patient Education": [
"What are the main types of dental cavities and how can I prevent them?",
"Explain the stages of gum disease from gingivitis to periodontitis",
"What should I expect during my first dental cleaning appointment?",
],
"Treatment Procedures": [
"Walk me through the steps of a root canal treatment",
"What's the difference between a crown and a veneer?",
"How does the dental implant process work from start to finish?",
],
"Oral Health & Prevention": [
"What's the proper brushing technique for optimal plaque removal?",
"How does fluoride protect teeth and is it safe for children?",
"What foods should I avoid to maintain healthy teeth?",
],
"Paediatric Dentistry": [
"When should a child have their first dental visit?",
"Explain the tooth eruption timeline in children",
"How can parents help prevent early childhood cavities?",
],
"Emergency & Post-Care": [
"What should I do if I knock out a permanent tooth?",
"How should I care for my mouth after wisdom tooth extraction?",
"What are signs of a dental infection that needs immediate attention?",
]
}
# Flatten examples for the Examples component
EXAMPLES = []
for category, questions in EXAMPLE_CATEGORIES.items():
for question in questions:
EXAMPLES.append([question])
# Custom CSS for improved styling with proper dark mode support
custom_css = """
/* Improved disclaimer box with proper dark mode support */
.disclaimer-box {
background: linear-gradient(135deg, #fff9e6 0%, #fff3cd 100%);
border: 2px solid #f0ad4e;
border-radius: 10px;
padding: 16px 20px;
margin: 20px 0;
font-size: 14px;
line-height: 1.6;
position: relative;
overflow: hidden;
}
/* Dark mode disclaimer */
.dark .disclaimer-box {
background: linear-gradient(135deg, #3d2f1f 0%, #4a3a28 100%);
border: 2px solid #d4a574;
color: #ffd9b3;
}
.disclaimer-box::before {
content: '';
position: absolute;
left: 0;
top: 0;
bottom: 0;
width: 4px;
background: #f0ad4e;
}
.dark .disclaimer-box::before {
background: #d4a574;
}
.disclaimer-title {
font-weight: 600;
color: #d58512;
margin-bottom: 8px;
display: flex;
align-items: center;
gap: 8px;
}
.dark .disclaimer-title {
color: #ffa500;
}
.disclaimer-text {
color: #856404;
}
.dark .disclaimer-text {
color: #ffd9b3;
}
/* Model capabilities cards */
.capability-cards {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
gap: 16px;
margin: 20px 0;
}
.capability-card {
background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
border: 1px solid #dee2e6;
border-radius: 8px;
padding: 16px;
transition: transform 0.2s, box-shadow 0.2s;
}
.dark .capability-card {
background: linear-gradient(135deg, #2b2b2b 0%, #1f1f1f 100%);
border: 1px solid #404040;
}
.capability-card:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
}
.dark .capability-card:hover {
box-shadow: 0 4px 12px rgba(255,255,255,0.1);
}
.capability-title {
font-weight: 600;
color: #495057;
margin-bottom: 8px;
font-size: 16px;
}
.dark .capability-title {
color: #e9ecef;
}
.capability-description {
color: #6c757d;
font-size: 14px;
line-height: 1.5;
}
.dark .capability-description {
color: #adb5bd;
}
/* Stats badges */
.stats-container {
display: flex;
gap: 16px;
flex-wrap: wrap;
margin: 16px 0;
}
.stat-badge {
background: linear-gradient(135deg, #e7f3ff 0%, #cfe2ff 100%);
border: 1px solid #b6d4fe;
border-radius: 20px;
padding: 8px 16px;
display: flex;
align-items: center;
gap: 8px;
}
.dark .stat-badge {
background: linear-gradient(135deg, #1a3a52 0%, #0f2940 100%);
border: 1px solid #2563eb;
}
.stat-label {
color: #0066cc;
font-weight: 500;
font-size: 12px;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.dark .stat-label {
color: #60a5fa;
}
.stat-value {
color: #004099;
font-weight: 700;
font-size: 14px;
}
.dark .stat-value {
color: #93bbfc;
}
/* Improved button styling */
.gr-button-primary {
background: linear-gradient(135deg, #0066cc 0%, #0052a3 100%) !important;
border: none !important;
color: white !important;
font-weight: 600 !important;
transition: all 0.3s ease !important;
}
.gr-button-primary:hover {
background: linear-gradient(135deg, #0052a3 0%, #003d7a 100%) !important;
transform: translateY(-1px);
box-shadow: 0 4px 12px rgba(0, 102, 204, 0.3);
}
/* Chat improvements */
.gr-chatbot {
border-radius: 12px !important;
border: 1px solid #dee2e6 !important;
}
.dark .gr-chatbot {
border: 1px solid #404040 !important;
}
/* Example section styling */
.example-category {
margin-bottom: 12px;
padding: 12px;
background: #f8f9fa;
border-radius: 8px;
}
.dark .example-category {
background: #1f1f1f;
}
.example-category-title {
font-weight: 600;
color: #495057;
margin-bottom: 8px;
font-size: 14px;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.dark .example-category-title {
color: #e9ecef;
}
/* Header styling */
.main-header {
background: linear-gradient(135deg, #0066cc 0%, #0052a3 100%);
color: white;
padding: 32px;
border-radius: 12px;
margin-bottom: 24px;
text-align: center;
}
.dark .main-header {
background: linear-gradient(135deg, #1e3a8a 0%, #1e40af 100%);
}
.header-title {
font-size: 36px;
font-weight: 700;
margin-bottom: 12px;
}
.header-subtitle {
font-size: 18px;
opacity: 0.95;
font-weight: 400;
}
/* Mobile responsiveness */
@media (max-width: 768px) {
.capability-cards {
grid-template-columns: 1fr;
}
.stats-container {
flex-direction: column;
}
.stat-badge {
width: 100%;
justify-content: center;
}
.header-title {
font-size: 28px;
}
.header-subtitle {
font-size: 16px;
}
}
"""
# Create Gradio interface with improved design
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
# Professional header with gradient
gr.HTML(
"""
<div class="main-header">
<h1 class="header-title">🦷 DentaInstruct-1.2B Demo</h1>
<p class="header-subtitle">Advanced AI assistant for dental education and oral health information</p>
</div>
"""
)
# Model statistics and capabilities
gr.HTML(
"""
<div class="stats-container">
<div class="stat-badge">
<span class="stat-label">Model Size</span>
<span class="stat-value">1.17B params</span>
</div>
<div class="stat-badge">
<span class="stat-label">Base Model</span>
<span class="stat-value">LFM2-1.2B</span>
</div>
<div class="stat-badge">
<span class="stat-label">Training Data</span>
<span class="stat-value">MIRIAD Dental</span>
</div>
<div class="stat-badge">
<span class="stat-label">Response Time</span>
<span class="stat-value">< 2 seconds</span>
</div>
</div>
"""
)
# Improved disclaimer with better visibility
gr.HTML(
"""
<div class="disclaimer-box">
<div class="disclaimer-title">
⚠️ Educational Use Only - Important Medical Disclaimer
</div>
<div class="disclaimer-text">
This AI model provides educational information about dental topics and is designed for learning purposes only.
It is <strong>NOT</strong> a substitute for professional dental or medical advice, diagnosis, or treatment.
Always seek the advice of your dentist or qualified healthcare provider with any questions about a medical condition or treatment.
</div>
</div>
"""
)
# Model capabilities showcase
gr.HTML(
"""
<h2 style="margin-top: 24px; margin-bottom: 16px;">What can DentaInstruct help you with?</h2>
<div class="capability-cards">
<div class="capability-card">
<div class="capability-title">πŸ“š Patient Education</div>
<div class="capability-description">Clear explanations of dental conditions, treatments, and procedures in patient-friendly language</div>
</div>
<div class="capability-card">
<div class="capability-title">πŸ” Procedure Details</div>
<div class="capability-description">Step-by-step breakdowns of common dental procedures from cleanings to complex treatments</div>
</div>
<div class="capability-card">
<div class="capability-title">πŸ›‘οΈ Prevention Tips</div>
<div class="capability-description">Evidence-based oral hygiene guidance and preventive care recommendations</div>
</div>
<div class="capability-card">
<div class="capability-title">πŸ‘Ά Paediatric Dentistry</div>
<div class="capability-description">Specialised information about children's dental development and care</div>
</div>
<div class="capability-card">
<div class="capability-title">🚨 Emergency Guidance</div>
<div class="capability-description">Educational information about dental emergencies and post-treatment care</div>
</div>
<div class="capability-card">
<div class="capability-title">🦷 Anatomy & Terms</div>
<div class="capability-description">Detailed explanations of dental anatomy and professional terminology</div>
</div>
</div>
"""
)
# Main chat interface
with gr.Row():
with gr.Column(scale=1):
chatbot = gr.Chatbot(
height=500,
label="Dental Education Assistant",
show_label=True,
avatar_images=None,
bubble_full_width=False,
render_markdown=True,
)
with gr.Row():
msg = gr.Textbox(
label="Your dental question",
placeholder="Ask about dental procedures, oral health, treatment options, or any dental topic...",
lines=3,
scale=4,
container=False,
)
with gr.Row():
submit = gr.Button("Send Question", variant="primary", scale=1)
clear = gr.Button("Clear Chat", scale=1)
# Advanced settings in a collapsible section
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
with gr.Row():
with gr.Column(scale=1):
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.3,
step=0.1,
label="Temperature",
info="Lower values (0.1-0.3) for factual responses, higher (0.7-1.0) for creative explanations"
)
max_new_tokens = gr.Slider(
minimum=64,
maximum=1024,
value=512,
step=64,
label="Response Length",
info="Maximum number of tokens in the response"
)
with gr.Column(scale=1):
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (Nucleus Sampling)",
info="Controls diversity of word choices"
)
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=1.5,
value=1.05,
step=0.05,
label="Repetition Penalty",
info="Reduces repetitive phrases in responses"
)
# Example questions organised by category
with gr.Accordion("πŸ’‘ Example Questions by Category", open=True):
gr.Examples(
examples=EXAMPLES[:8], # Show first 8 examples
inputs=msg,
label="Quick Start Examples",
)
# Additional categorised examples
gr.Markdown(
"""
### More Example Categories:
- **Patient Education**: Understanding conditions, prevention, and treatment basics
- **Treatment Procedures**: Detailed explanations of dental procedures
- **Oral Health & Prevention**: Daily care and preventive measures
- **Paediatric Dentistry**: Children's dental health and development
- **Emergency & Post-Care**: Urgent situations and aftercare instructions
"""
)
# About section with professional information
gr.Markdown(
"""
---
## About DentaInstruct-1.2B
DentaInstruct-1.2B is a specialised language model fine-tuned specifically for dental education and oral health information.
Built on the efficient LFM2-1.2B architecture, it combines compact size with domain expertise to provide accurate,
educational content about dentistry.
### Key Features:
- **Specialised Training**: Fine-tuned on comprehensive dental educational content from the MIRIAD dataset
- **Efficient Architecture**: 1.17B parameters optimised for fast response times
- **Broad Coverage**: Knowledgeable about general dentistry, orthodontics, periodontics, endodontics, and more
- **Educational Focus**: Designed to explain complex dental concepts in accessible language
- **Multi-context Support**: Can handle patient education, professional discussions, and academic queries
### Technical Specifications:
- **Architecture**: Transformer-based language model
- **Base Model**: LiquidAI LFM2-1.2B
- **Training Method**: Supervised fine-tuning on dental domain data
- **Context Length**: 2048 tokens
- **Inference**: Optimised for GPU acceleration with bfloat16 precision
### Use Cases:
- Patient education materials and explanations
- Dental student study assistance
- Quick reference for dental terminology
- Understanding treatment options and procedures
- Oral health and hygiene guidance
### Important Considerations:
- This model is for educational purposes only
- Not intended for clinical decision-making
- Information should be verified with professional sources
- Always consult qualified dental professionals for personal health concerns
---
**Model Creator**: [@yasserrmd](https://huggingface.co/yasserrmd) |
**Space Developer**: [@chrisvoncsefalvay](https://huggingface.co/chrisvoncsefalvay) |
**License**: Apache 2.0
πŸ”— [Model Card](https://huggingface.co/yasserrmd/DentaInstruct-1.2B) |
πŸ“Š [MIRIAD Dataset](https://huggingface.co/datasets/miriad) |
πŸ’¬ [Report Issues](https://huggingface.co/spaces/chrisvoncsefalvay/dental-vqa-comparison/discussions)
"""
)
# Event handlers
def respond(message, chat_history, temperature, max_new_tokens, top_p, repetition_penalty):
"""Handle user messages and generate responses"""
if not message.strip():
gr.Warning("Please enter a question")
return "", chat_history
try:
response = generate_response(
message,
chat_history,
temperature,
max_new_tokens,
top_p,
repetition_penalty
)
chat_history.append((message, response))
return "", chat_history
except Exception as e:
gr.Error(f"An error occurred: {str(e)}")
return message, chat_history
# Connect event handlers
msg.submit(
respond,
[msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty],
[msg, chatbot],
queue=True
)
submit.click(
respond,
[msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty],
[msg, chatbot],
queue=True
)
clear.click(
lambda: (None, ""),
None,
[chatbot, msg],
queue=False
)
# Launch configuration
if __name__ == "__main__":
demo.queue(max_size=10)
demo.launch(
share=False,
show_error=True,
server_name="0.0.0.0",
server_port=7860
)