Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Model configuration | |
MODEL_ID = "yasserrmd/DentaInstruct-1.2B" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Initialize model and tokenizer | |
print(f"Loading model {MODEL_ID}...") | |
try: | |
# Try loading the tokenizer normally | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
except Exception as e: | |
print(f"Failed to load tokenizer from {MODEL_ID}: {e}") | |
print("Using fallback tokenizer from base model...") | |
# Use the base model's tokenizer as fallback | |
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" if torch.cuda.is_available() else None, | |
trust_remote_code=True | |
) | |
if not torch.cuda.is_available(): | |
model = model.to(DEVICE) | |
# Set padding token if not set | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
def format_prompt(message, history): | |
"""Format the prompt for the model""" | |
messages = [] | |
# Add conversation history | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Add current message | |
messages.append({"role": "user", "content": message}) | |
# Apply chat template | |
if hasattr(tokenizer, 'apply_chat_template'): | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
else: | |
# Fallback formatting | |
prompt = "" | |
for msg in messages: | |
if msg["role"] == "user": | |
prompt += f"User: {msg['content']}\n" | |
else: | |
prompt += f"Assistant: {msg['content']}\n" | |
prompt += "Assistant: " | |
return prompt | |
def generate_response( | |
message, | |
history, | |
temperature=0.7, | |
max_new_tokens=512, | |
top_p=0.95, | |
repetition_penalty=1.1, | |
): | |
"""Generate response from the model""" | |
# Format the prompt | |
prompt = format_prompt(message, history) | |
# Tokenize input | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
# Generate response | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
# Decode response | |
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
return response | |
# Example questions | |
EXAMPLES = [ | |
["What are the main types of dental cavities?"], | |
["Explain the process of root canal treatment"], | |
["What is the difference between gingivitis and periodontitis?"], | |
["How should I care for my teeth after a dental extraction?"], | |
["What are the benefits of fluoride in dental care?"], | |
["Explain the stages of tooth development in children"], | |
["What causes tooth sensitivity and how can it be treated?"], | |
["Describe the different types of dental fillings available"], | |
] | |
# Custom CSS for styling | |
custom_css = """ | |
.disclaimer { | |
background-color: #fff3cd; | |
border: 1px solid #ffc107; | |
border-radius: 5px; | |
padding: 10px; | |
margin-bottom: 15px; | |
} | |
""" | |
# Create Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
gr.Markdown( | |
""" | |
# Dental VQA Model Comparison | |
Interactive comparison of dental visual question answering models. Currently featuring DentaInstruct-1.2B for dental education and oral health information. | |
""" | |
) | |
gr.HTML( | |
""" | |
<div class="disclaimer"> | |
<strong>⚠️ Important Disclaimer:</strong><br> | |
This model is for educational purposes only. It is NOT a substitute for professional dental care. | |
Do not use this model for clinical diagnosis or treatment advice. Always consult a qualified dental professional. | |
</div> | |
""" | |
) | |
chatbot = gr.Chatbot( | |
height=400, | |
label="Conversation" | |
) | |
msg = gr.Textbox( | |
label="Your dental question", | |
placeholder="Ask a question about dental health, procedures, or oral care...", | |
lines=2 | |
) | |
with gr.Row(): | |
submit = gr.Button("Send", variant="primary") | |
clear = gr.Button("Clear") | |
with gr.Accordion("Advanced Settings", open=False): | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature", | |
info="Controls randomness in responses" | |
) | |
max_new_tokens = gr.Slider( | |
minimum=64, | |
maximum=1024, | |
value=512, | |
step=64, | |
label="Max New Tokens", | |
info="Maximum length of the response" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p", | |
info="Nucleus sampling parameter" | |
) | |
repetition_penalty = gr.Slider( | |
minimum=1.0, | |
maximum=1.5, | |
value=1.1, | |
step=0.05, | |
label="Repetition Penalty", | |
info="Reduces repetition in responses" | |
) | |
gr.Examples( | |
examples=EXAMPLES, | |
inputs=msg, | |
label="Example Questions" | |
) | |
gr.Markdown( | |
""" | |
## About This Model | |
DentaInstruct-1.2B is a specialised language model fine-tuned on dental educational content. | |
It's designed to provide educational information about dental health, procedures, and oral care. | |
**Model Details:** | |
- Base Model: LFM2-1.2B | |
- Parameters: 1.17B | |
- Training Data: Dental subset of MIRIAD dataset | |
- Purpose: Educational dental information | |
**Created by:** @yasserrmd | **Space by:** @chrisvoncsefalvay | |
""" | |
) | |
# Event handlers | |
def respond(message, chat_history, temperature, max_new_tokens, top_p, repetition_penalty): | |
response = generate_response( | |
message, | |
chat_history, | |
temperature, | |
max_new_tokens, | |
top_p, | |
repetition_penalty | |
) | |
chat_history.append((message, response)) | |
return "", chat_history | |
msg.submit( | |
respond, | |
[msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty], | |
[msg, chatbot] | |
) | |
submit.click( | |
respond, | |
[msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty], | |
[msg, chatbot] | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
if __name__ == "__main__": | |
demo.launch() |