Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import spaces | |
# Model configuration | |
MODEL_ID = "yasserrmd/DentaInstruct-1.2B" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Initialize model and tokenizer | |
print(f"Loading model {MODEL_ID}...") | |
# Load tokenizer - try the fine-tuned model first, then base model | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
print(f"Loaded tokenizer from {MODEL_ID}") | |
except Exception as e: | |
print(f"Failed to load tokenizer from {MODEL_ID}: {e}") | |
print("Using tokenizer from base LFM2 model...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-1.2B") | |
except Exception as e2: | |
print(f"Failed to load LFM2 tokenizer: {e2}") | |
print("Using fallback TinyLlama tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
# Load model with proper dtype for efficiency | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" if torch.cuda.is_available() else None | |
) | |
if not torch.cuda.is_available(): | |
model = model.to(DEVICE) | |
# Set padding token if not set | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
def format_prompt(message, history): | |
"""Format the prompt for the model""" | |
messages = [] | |
# Add conversation history | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Add current message | |
messages.append({"role": "user", "content": message}) | |
# Apply chat template | |
if hasattr(tokenizer, 'apply_chat_template'): | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
else: | |
# Fallback formatting | |
prompt = "" | |
for msg in messages: | |
if msg["role"] == "user": | |
prompt += f"User: {msg['content']}\n" | |
else: | |
prompt += f"Assistant: {msg['content']}\n" | |
prompt += "Assistant: " | |
return prompt | |
def generate_response( | |
message, | |
history, | |
temperature=0.3, | |
max_new_tokens=512, | |
top_p=0.95, | |
repetition_penalty=1.05, | |
): | |
"""Generate response from the model""" | |
# Format the prompt | |
prompt = format_prompt(message, history) | |
# Tokenize input | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
# Move to device and filter out token_type_ids if present | |
model_inputs = {} | |
for k, v in inputs.items(): | |
if k != 'token_type_ids': # Filter out token_type_ids | |
model_inputs[k] = v.to(model.device) | |
# Generate response | |
with torch.no_grad(): | |
outputs = model.generate( | |
**model_inputs, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
# Decode response | |
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
return response | |
# Categorised example questions for better showcase | |
EXAMPLE_CATEGORIES = { | |
"Patient Education": [ | |
"What are the main types of dental cavities and how can I prevent them?", | |
"Explain the stages of gum disease from gingivitis to periodontitis", | |
"What should I expect during my first dental cleaning appointment?", | |
], | |
"Treatment Procedures": [ | |
"Walk me through the steps of a root canal treatment", | |
"What's the difference between a crown and a veneer?", | |
"How does the dental implant process work from start to finish?", | |
], | |
"Oral Health & Prevention": [ | |
"What's the proper brushing technique for optimal plaque removal?", | |
"How does fluoride protect teeth and is it safe for children?", | |
"What foods should I avoid to maintain healthy teeth?", | |
], | |
"Paediatric Dentistry": [ | |
"When should a child have their first dental visit?", | |
"Explain the tooth eruption timeline in children", | |
"How can parents help prevent early childhood cavities?", | |
], | |
"Emergency & Post-Care": [ | |
"What should I do if I knock out a permanent tooth?", | |
"How should I care for my mouth after wisdom tooth extraction?", | |
"What are signs of a dental infection that needs immediate attention?", | |
] | |
} | |
# Flatten examples for the Examples component | |
EXAMPLES = [] | |
for category, questions in EXAMPLE_CATEGORIES.items(): | |
for question in questions: | |
EXAMPLES.append([question]) | |
# Custom CSS for improved styling with proper dark mode support | |
custom_css = """ | |
/* Improved disclaimer box with proper dark mode support */ | |
.disclaimer-box { | |
background: linear-gradient(135deg, #fff9e6 0%, #fff3cd 100%); | |
border: 2px solid #f0ad4e; | |
border-radius: 10px; | |
padding: 16px 20px; | |
margin: 20px 0; | |
font-size: 14px; | |
line-height: 1.6; | |
position: relative; | |
overflow: hidden; | |
} | |
/* Dark mode disclaimer */ | |
.dark .disclaimer-box { | |
background: linear-gradient(135deg, #3d2f1f 0%, #4a3a28 100%); | |
border: 2px solid #d4a574; | |
color: #ffd9b3; | |
} | |
.disclaimer-box::before { | |
content: ''; | |
position: absolute; | |
left: 0; | |
top: 0; | |
bottom: 0; | |
width: 4px; | |
background: #f0ad4e; | |
} | |
.dark .disclaimer-box::before { | |
background: #d4a574; | |
} | |
.disclaimer-title { | |
font-weight: 600; | |
color: #d58512; | |
margin-bottom: 8px; | |
display: flex; | |
align-items: center; | |
gap: 8px; | |
} | |
.dark .disclaimer-title { | |
color: #ffa500; | |
} | |
.disclaimer-text { | |
color: #856404; | |
} | |
.dark .disclaimer-text { | |
color: #ffd9b3; | |
} | |
/* Model capabilities cards */ | |
.capability-cards { | |
display: grid; | |
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); | |
gap: 16px; | |
margin: 20px 0; | |
} | |
.capability-card { | |
background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%); | |
border: 1px solid #dee2e6; | |
border-radius: 8px; | |
padding: 16px; | |
transition: transform 0.2s, box-shadow 0.2s; | |
} | |
.dark .capability-card { | |
background: linear-gradient(135deg, #2b2b2b 0%, #1f1f1f 100%); | |
border: 1px solid #404040; | |
} | |
.capability-card:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 4px 12px rgba(0,0,0,0.1); | |
} | |
.dark .capability-card:hover { | |
box-shadow: 0 4px 12px rgba(255,255,255,0.1); | |
} | |
.capability-title { | |
font-weight: 600; | |
color: #495057; | |
margin-bottom: 8px; | |
font-size: 16px; | |
} | |
.dark .capability-title { | |
color: #e9ecef; | |
} | |
.capability-description { | |
color: #6c757d; | |
font-size: 14px; | |
line-height: 1.5; | |
} | |
.dark .capability-description { | |
color: #adb5bd; | |
} | |
/* Stats badges */ | |
.stats-container { | |
display: flex; | |
gap: 16px; | |
flex-wrap: wrap; | |
margin: 16px 0; | |
} | |
.stat-badge { | |
background: linear-gradient(135deg, #e7f3ff 0%, #cfe2ff 100%); | |
border: 1px solid #b6d4fe; | |
border-radius: 20px; | |
padding: 8px 16px; | |
display: flex; | |
align-items: center; | |
gap: 8px; | |
} | |
.dark .stat-badge { | |
background: linear-gradient(135deg, #1a3a52 0%, #0f2940 100%); | |
border: 1px solid #2563eb; | |
} | |
.stat-label { | |
color: #0066cc; | |
font-weight: 500; | |
font-size: 12px; | |
text-transform: uppercase; | |
letter-spacing: 0.5px; | |
} | |
.dark .stat-label { | |
color: #60a5fa; | |
} | |
.stat-value { | |
color: #004099; | |
font-weight: 700; | |
font-size: 14px; | |
} | |
.dark .stat-value { | |
color: #93bbfc; | |
} | |
/* Improved button styling */ | |
.gr-button-primary { | |
background: linear-gradient(135deg, #0066cc 0%, #0052a3 100%) !important; | |
border: none !important; | |
color: white !important; | |
font-weight: 600 !important; | |
transition: all 0.3s ease !important; | |
} | |
.gr-button-primary:hover { | |
background: linear-gradient(135deg, #0052a3 0%, #003d7a 100%) !important; | |
transform: translateY(-1px); | |
box-shadow: 0 4px 12px rgba(0, 102, 204, 0.3); | |
} | |
/* Chat improvements */ | |
.gr-chatbot { | |
border-radius: 12px !important; | |
border: 1px solid #dee2e6 !important; | |
} | |
.dark .gr-chatbot { | |
border: 1px solid #404040 !important; | |
} | |
/* Example section styling */ | |
.example-category { | |
margin-bottom: 12px; | |
padding: 12px; | |
background: #f8f9fa; | |
border-radius: 8px; | |
} | |
.dark .example-category { | |
background: #1f1f1f; | |
} | |
.example-category-title { | |
font-weight: 600; | |
color: #495057; | |
margin-bottom: 8px; | |
font-size: 14px; | |
text-transform: uppercase; | |
letter-spacing: 0.5px; | |
} | |
.dark .example-category-title { | |
color: #e9ecef; | |
} | |
/* Header styling */ | |
.main-header { | |
background: linear-gradient(135deg, #0066cc 0%, #0052a3 100%); | |
color: white; | |
padding: 32px; | |
border-radius: 12px; | |
margin-bottom: 24px; | |
text-align: center; | |
} | |
.dark .main-header { | |
background: linear-gradient(135deg, #1e3a8a 0%, #1e40af 100%); | |
} | |
.header-title { | |
font-size: 36px; | |
font-weight: 700; | |
margin-bottom: 12px; | |
} | |
.header-subtitle { | |
font-size: 18px; | |
opacity: 0.95; | |
font-weight: 400; | |
} | |
/* Mobile responsiveness */ | |
@media (max-width: 768px) { | |
.capability-cards { | |
grid-template-columns: 1fr; | |
} | |
.stats-container { | |
flex-direction: column; | |
} | |
.stat-badge { | |
width: 100%; | |
justify-content: center; | |
} | |
.header-title { | |
font-size: 28px; | |
} | |
.header-subtitle { | |
font-size: 16px; | |
} | |
} | |
""" | |
# Create Gradio interface with improved design | |
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
# Professional header with gradient | |
gr.HTML( | |
""" | |
<div class="main-header"> | |
<h1 class="header-title">π¦· DentaInstruct-1.2B Demo</h1> | |
<p class="header-subtitle">Advanced AI assistant for dental education and oral health information</p> | |
</div> | |
""" | |
) | |
# Model statistics and capabilities | |
gr.HTML( | |
""" | |
<div class="stats-container"> | |
<div class="stat-badge"> | |
<span class="stat-label">Model Size</span> | |
<span class="stat-value">1.17B params</span> | |
</div> | |
<div class="stat-badge"> | |
<span class="stat-label">Base Model</span> | |
<span class="stat-value">LFM2-1.2B</span> | |
</div> | |
<div class="stat-badge"> | |
<span class="stat-label">Training Data</span> | |
<span class="stat-value">MIRIAD Dental</span> | |
</div> | |
<div class="stat-badge"> | |
<span class="stat-label">Response Time</span> | |
<span class="stat-value">< 2 seconds</span> | |
</div> | |
</div> | |
""" | |
) | |
# Improved disclaimer with better visibility | |
gr.HTML( | |
""" | |
<div class="disclaimer-box"> | |
<div class="disclaimer-title"> | |
β οΈ Educational Use Only - Important Medical Disclaimer | |
</div> | |
<div class="disclaimer-text"> | |
This AI model provides educational information about dental topics and is designed for learning purposes only. | |
It is <strong>NOT</strong> a substitute for professional dental or medical advice, diagnosis, or treatment. | |
Always seek the advice of your dentist or qualified healthcare provider with any questions about a medical condition or treatment. | |
</div> | |
</div> | |
""" | |
) | |
# Model capabilities showcase | |
gr.HTML( | |
""" | |
<h2 style="margin-top: 24px; margin-bottom: 16px;">What can DentaInstruct help you with?</h2> | |
<div class="capability-cards"> | |
<div class="capability-card"> | |
<div class="capability-title">π Patient Education</div> | |
<div class="capability-description">Clear explanations of dental conditions, treatments, and procedures in patient-friendly language</div> | |
</div> | |
<div class="capability-card"> | |
<div class="capability-title">π Procedure Details</div> | |
<div class="capability-description">Step-by-step breakdowns of common dental procedures from cleanings to complex treatments</div> | |
</div> | |
<div class="capability-card"> | |
<div class="capability-title">π‘οΈ Prevention Tips</div> | |
<div class="capability-description">Evidence-based oral hygiene guidance and preventive care recommendations</div> | |
</div> | |
<div class="capability-card"> | |
<div class="capability-title">πΆ Paediatric Dentistry</div> | |
<div class="capability-description">Specialised information about children's dental development and care</div> | |
</div> | |
<div class="capability-card"> | |
<div class="capability-title">π¨ Emergency Guidance</div> | |
<div class="capability-description">Educational information about dental emergencies and post-treatment care</div> | |
</div> | |
<div class="capability-card"> | |
<div class="capability-title">π¦· Anatomy & Terms</div> | |
<div class="capability-description">Detailed explanations of dental anatomy and professional terminology</div> | |
</div> | |
</div> | |
""" | |
) | |
# Main chat interface | |
with gr.Row(): | |
with gr.Column(scale=1): | |
chatbot = gr.Chatbot( | |
height=500, | |
label="Dental Education Assistant", | |
show_label=True, | |
avatar_images=None, | |
bubble_full_width=False, | |
render_markdown=True, | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
label="Your dental question", | |
placeholder="Ask about dental procedures, oral health, treatment options, or any dental topic...", | |
lines=3, | |
scale=4, | |
container=False, | |
) | |
with gr.Row(): | |
submit = gr.Button("Send Question", variant="primary", scale=1) | |
clear = gr.Button("Clear Chat", scale=1) | |
# Advanced settings in a collapsible section | |
with gr.Accordion("βοΈ Advanced Settings", open=False): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.3, | |
step=0.1, | |
label="Temperature", | |
info="Lower values (0.1-0.3) for factual responses, higher (0.7-1.0) for creative explanations" | |
) | |
max_new_tokens = gr.Slider( | |
minimum=64, | |
maximum=1024, | |
value=512, | |
step=64, | |
label="Response Length", | |
info="Maximum number of tokens in the response" | |
) | |
with gr.Column(scale=1): | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (Nucleus Sampling)", | |
info="Controls diversity of word choices" | |
) | |
repetition_penalty = gr.Slider( | |
minimum=1.0, | |
maximum=1.5, | |
value=1.05, | |
step=0.05, | |
label="Repetition Penalty", | |
info="Reduces repetitive phrases in responses" | |
) | |
# Example questions organised by category | |
with gr.Accordion("π‘ Example Questions by Category", open=True): | |
gr.Examples( | |
examples=EXAMPLES[:8], # Show first 8 examples | |
inputs=msg, | |
label="Quick Start Examples", | |
) | |
# Additional categorised examples | |
gr.Markdown( | |
""" | |
### More Example Categories: | |
- **Patient Education**: Understanding conditions, prevention, and treatment basics | |
- **Treatment Procedures**: Detailed explanations of dental procedures | |
- **Oral Health & Prevention**: Daily care and preventive measures | |
- **Paediatric Dentistry**: Children's dental health and development | |
- **Emergency & Post-Care**: Urgent situations and aftercare instructions | |
""" | |
) | |
# About section with professional information | |
gr.Markdown( | |
""" | |
--- | |
## About DentaInstruct-1.2B | |
DentaInstruct-1.2B is a specialised language model fine-tuned specifically for dental education and oral health information. | |
Built on the efficient LFM2-1.2B architecture, it combines compact size with domain expertise to provide accurate, | |
educational content about dentistry. | |
### Key Features: | |
- **Specialised Training**: Fine-tuned on comprehensive dental educational content from the MIRIAD dataset | |
- **Efficient Architecture**: 1.17B parameters optimised for fast response times | |
- **Broad Coverage**: Knowledgeable about general dentistry, orthodontics, periodontics, endodontics, and more | |
- **Educational Focus**: Designed to explain complex dental concepts in accessible language | |
- **Multi-context Support**: Can handle patient education, professional discussions, and academic queries | |
### Technical Specifications: | |
- **Architecture**: Transformer-based language model | |
- **Base Model**: LiquidAI LFM2-1.2B | |
- **Training Method**: Supervised fine-tuning on dental domain data | |
- **Context Length**: 2048 tokens | |
- **Inference**: Optimised for GPU acceleration with bfloat16 precision | |
### Use Cases: | |
- Patient education materials and explanations | |
- Dental student study assistance | |
- Quick reference for dental terminology | |
- Understanding treatment options and procedures | |
- Oral health and hygiene guidance | |
### Important Considerations: | |
- This model is for educational purposes only | |
- Not intended for clinical decision-making | |
- Information should be verified with professional sources | |
- Always consult qualified dental professionals for personal health concerns | |
--- | |
**Model Creator**: [@yasserrmd](https://huggingface.co/yasserrmd) | | |
**Space Developer**: [@chrisvoncsefalvay](https://huggingface.co/chrisvoncsefalvay) | | |
**License**: Apache 2.0 | |
π [Model Card](https://huggingface.co/yasserrmd/DentaInstruct-1.2B) | | |
π [MIRIAD Dataset](https://huggingface.co/datasets/miriad) | | |
π¬ [Report Issues](https://huggingface.co/spaces/chrisvoncsefalvay/dental-vqa-comparison/discussions) | |
""" | |
) | |
# Event handlers | |
def respond(message, chat_history, temperature, max_new_tokens, top_p, repetition_penalty): | |
"""Handle user messages and generate responses""" | |
if not message.strip(): | |
gr.Warning("Please enter a question") | |
return "", chat_history | |
try: | |
response = generate_response( | |
message, | |
chat_history, | |
temperature, | |
max_new_tokens, | |
top_p, | |
repetition_penalty | |
) | |
chat_history.append((message, response)) | |
return "", chat_history | |
except Exception as e: | |
gr.Error(f"An error occurred: {str(e)}") | |
return message, chat_history | |
# Connect event handlers | |
msg.submit( | |
respond, | |
[msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty], | |
[msg, chatbot], | |
queue=True | |
) | |
submit.click( | |
respond, | |
[msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty], | |
[msg, chatbot], | |
queue=True | |
) | |
clear.click( | |
lambda: (None, ""), | |
None, | |
[chatbot, msg], | |
queue=False | |
) | |
# Launch configuration | |
if __name__ == "__main__": | |
demo.queue(max_size=10) | |
demo.launch( | |
share=False, | |
show_error=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |