import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import spaces # Model configuration MODEL_ID = "yasserrmd/DentaInstruct-1.2B" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Initialize model and tokenizer print(f"Loading model {MODEL_ID}...") # Load tokenizer - try the fine-tuned model first, then base model try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) print(f"Loaded tokenizer from {MODEL_ID}") except Exception as e: print(f"Failed to load tokenizer from {MODEL_ID}: {e}") print("Using tokenizer from base LFM2 model...") try: tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-1.2B") except Exception as e2: print(f"Failed to load LFM2 tokenizer: {e2}") print("Using fallback TinyLlama tokenizer...") tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") # Load model with proper dtype for efficiency model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None ) if not torch.cuda.is_available(): model = model.to(DEVICE) # Set padding token if not set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token def format_prompt(message, history): """Format the prompt for the model""" messages = [] # Add conversation history for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) # Add current message messages.append({"role": "user", "content": message}) # Apply chat template if hasattr(tokenizer, 'apply_chat_template'): prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) else: # Fallback formatting prompt = "" for msg in messages: if msg["role"] == "user": prompt += f"User: {msg['content']}\n" else: prompt += f"Assistant: {msg['content']}\n" prompt += "Assistant: " return prompt @spaces.GPU(duration=60) def generate_response( message, history, temperature=0.3, max_new_tokens=512, top_p=0.95, repetition_penalty=1.05, ): """Generate response from the model""" # Format the prompt prompt = format_prompt(message, history) # Tokenize input inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) inputs = {k: v.to(model.device) for k, v in inputs.items()} # Generate response with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) # Decode response response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) return response # Example questions EXAMPLES = [ ["What are the main types of dental cavities?"], ["Explain the process of root canal treatment"], ["What is the difference between gingivitis and periodontitis?"], ["How should I care for my teeth after a dental extraction?"], ["What are the benefits of fluoride in dental care?"], ["Explain the stages of tooth development in children"], ["What causes tooth sensitivity and how can it be treated?"], ["Describe the different types of dental fillings available"], ] # Custom CSS for styling custom_css = """ .disclaimer { background-color: #fff3cd; border: 1px solid #ffc107; border-radius: 5px; padding: 10px; margin-bottom: 15px; } """ # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: gr.Markdown( """ # Dental VQA Model Comparison Interactive comparison of dental visual question answering models. Currently featuring DentaInstruct-1.2B for dental education and oral health information. """ ) gr.HTML( """