import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import spaces # Model configuration MODEL_ID = "yasserrmd/DentaInstruct-1.2B" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Initialize model and tokenizer print(f"Loading model {MODEL_ID}...") # Load tokenizer - try the fine-tuned model first, then base model try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) print(f"Loaded tokenizer from {MODEL_ID}") except Exception as e: print(f"Failed to load tokenizer from {MODEL_ID}: {e}") print("Using tokenizer from base LFM2 model...") try: tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-1.2B") except Exception as e2: print(f"Failed to load LFM2 tokenizer: {e2}") print("Using fallback TinyLlama tokenizer...") tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") # Load model with proper dtype for efficiency model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None ) if not torch.cuda.is_available(): model = model.to(DEVICE) # Set padding token if not set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token def format_prompt(message, history): """Format the prompt for the model""" messages = [] # Add conversation history for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) # Add current message messages.append({"role": "user", "content": message}) # Apply chat template if hasattr(tokenizer, 'apply_chat_template'): prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) else: # Fallback formatting prompt = "" for msg in messages: if msg["role"] == "user": prompt += f"User: {msg['content']}\n" else: prompt += f"Assistant: {msg['content']}\n" prompt += "Assistant: " return prompt @spaces.GPU(duration=60) def generate_response( message, history, temperature=0.3, max_new_tokens=512, top_p=0.95, repetition_penalty=1.05, ): """Generate response from the model""" # Format the prompt prompt = format_prompt(message, history) # Tokenize input inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) inputs = {k: v.to(model.device) for k, v in inputs.items()} # Generate response with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) # Decode response response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) return response # Example questions EXAMPLES = [ ["What are the main types of dental cavities?"], ["Explain the process of root canal treatment"], ["What is the difference between gingivitis and periodontitis?"], ["How should I care for my teeth after a dental extraction?"], ["What are the benefits of fluoride in dental care?"], ["Explain the stages of tooth development in children"], ["What causes tooth sensitivity and how can it be treated?"], ["Describe the different types of dental fillings available"], ] # Custom CSS for styling custom_css = """ .disclaimer { background-color: #fff3cd; border: 1px solid #ffc107; border-radius: 5px; padding: 10px; margin-bottom: 15px; } """ # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: gr.Markdown( """ # Dental VQA Model Comparison Interactive comparison of dental visual question answering models. Currently featuring DentaInstruct-1.2B for dental education and oral health information. """ ) gr.HTML( """
⚠️ Important Disclaimer:
This model is for educational purposes only. It is NOT a substitute for professional dental care. Do not use this model for clinical diagnosis or treatment advice. Always consult a qualified dental professional.
""" ) chatbot = gr.Chatbot( height=400, label="Conversation" ) msg = gr.Textbox( label="Your dental question", placeholder="Ask a question about dental health, procedures, or oral care...", lines=2 ) with gr.Row(): submit = gr.Button("Send", variant="primary") clear = gr.Button("Clear") with gr.Accordion("Advanced Settings", open=False): temperature = gr.Slider( minimum=0.1, maximum=1.0, value=0.3, step=0.1, label="Temperature", info="Controls randomness in responses" ) max_new_tokens = gr.Slider( minimum=64, maximum=1024, value=512, step=64, label="Max New Tokens", info="Maximum length of the response" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p", info="Nucleus sampling parameter" ) repetition_penalty = gr.Slider( minimum=1.0, maximum=1.5, value=1.05, step=0.05, label="Repetition Penalty", info="Reduces repetition in responses" ) gr.Examples( examples=EXAMPLES, inputs=msg, label="Example Questions" ) gr.Markdown( """ ## About This Model DentaInstruct-1.2B is a specialised language model fine-tuned on dental educational content. It's designed to provide educational information about dental health, procedures, and oral care. **Model Details:** - Base Model: LFM2-1.2B - Parameters: 1.17B - Training Data: Dental subset of MIRIAD dataset - Purpose: Educational dental information **Created by:** @yasserrmd | **Space by:** @chrisvoncsefalvay """ ) # Event handlers def respond(message, chat_history, temperature, max_new_tokens, top_p, repetition_penalty): response = generate_response( message, chat_history, temperature, max_new_tokens, top_p, repetition_penalty ) chat_history.append((message, response)) return "", chat_history msg.submit( respond, [msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty], [msg, chatbot] ) submit.click( respond, [msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty], [msg, chatbot] ) clear.click(lambda: None, None, chatbot, queue=False) if __name__ == "__main__": demo.launch()