import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread import spaces # Model configuration MODEL_ID = "yasserrmd/DentaInstruct-1.2B" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Initialize model and tokenizer print(f"Loading model {MODEL_ID}...") # Load tokenizer - try the fine-tuned model first, then base model try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) print(f"Loaded tokenizer from {MODEL_ID}") except Exception as e: print(f"Failed to load tokenizer from {MODEL_ID}: {e}") print("Using tokenizer from base LFM2 model...") try: tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-1.2B") except Exception as e2: print(f"Failed to load LFM2 tokenizer: {e2}") print("Using fallback TinyLlama tokenizer...") tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") # Load model with proper dtype for efficiency model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None ) if not torch.cuda.is_available(): model = model.to(DEVICE) # Set padding token if not set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token def format_prompt(message, history): """Format the prompt for the model""" messages = [] # Add conversation history for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) # Add current message messages.append({"role": "user", "content": message}) # Apply chat template if hasattr(tokenizer, 'apply_chat_template'): prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) else: # Fallback formatting prompt = "" for msg in messages: if msg["role"] == "user": prompt += f"User: {msg['content']}\n" else: prompt += f"Assistant: {msg['content']}\n" prompt += "Assistant: " return prompt @spaces.GPU(duration=60) def generate_response_streaming( message, history, temperature=0.3, max_new_tokens=512, top_p=0.95, repetition_penalty=1.05, ): """Generate response from the model with streaming""" # Format the prompt prompt = format_prompt(message, history) # Tokenize input inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) # Move to device and filter out token_type_ids if present model_inputs = {} for k, v in inputs.items(): if k != 'token_type_ids': # Filter out token_type_ids model_inputs[k] = v.to(model.device) # Set up the streamer streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=30.0 ) # Generation parameters generation_kwargs = dict( **model_inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, streamer=streamer, ) # Start generation in a separate thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Stream the response partial_response = "" for new_text in streamer: partial_response += new_text yield partial_response thread.join() # Question categories for the carousel QUESTION_CATEGORIES = { "Patient education": [ "What are the main types of dental cavities and how can I prevent them?", "Explain the stages of gum disease from gingivitis to periodontitis", "What should I expect during my first dental cleaning appointment?" ], "Procedures": [ "Walk me through the steps of a root canal treatment", "What's the difference between a crown and a veneer?", "How does the dental implant process work from start to finish?" ], "Preventative care advice": [ "What's the proper brushing technique for optimal plaque removal?", "How does fluoride protect teeth and is it safe for children?", "What foods should I avoid to maintain healthy teeth?" ], "Anatomy and terms": [ "Explain the anatomy of a tooth from crown to root", "What are the different types of teeth and their functions?", "What is the difference between enamel, dentin, and pulp?" ] } # Custom CSS for the redesigned interface custom_css = """ /* Reset and base styles */ * { box-sizing: border-box; } /* Header with credits */ .header-container { background: linear-gradient(135deg, #1e40af 0%, #3b82f6 50%, #60a5fa 100%); border-radius: 16px; padding: 32px; margin-bottom: 24px; color: white; text-align: center; box-shadow: 0 8px 32px rgba(30, 64, 175, 0.3); } .dark .header-container { background: linear-gradient(135deg, #1e3a8a 0%, #3730a3 50%, #4338ca 100%); } .header-title { font-size: 40px; font-weight: 800; margin-bottom: 8px; text-shadow: 0 2px 4px rgba(0,0,0,0.1); } .header-subtitle { font-size: 18px; opacity: 0.95; margin-bottom: 16px; } .header-credits { font-size: 14px; opacity: 0.9; margin-bottom: 12px; } .header-credits a { color: #fef3c7; text-decoration: none; font-weight: 500; } .header-credits a:hover { color: #fde68a; text-decoration: underline; } .social-icon { display: inline-block; margin-left: 8px; text-decoration: none; font-size: 18px; opacity: 0.85; transition: all 0.2s ease; vertical-align: middle; } .social-icon:hover { opacity: 1; transform: translateY(-2px); } /* Mini model card - skeuomorphic design */ .model-card { background: linear-gradient(145deg, #f8fafc 0%, #e2e8f0 100%); border: 1px solid #cbd5e1; border-radius: 16px; padding: 20px; margin-bottom: 24px; box-shadow: 0 10px 25px rgba(0,0,0,0.1), inset 0 1px 0 rgba(255,255,255,0.6); position: relative; overflow: hidden; } .dark .model-card { background: linear-gradient(145deg, #374151 0%, #1f2937 100%); border: 1px solid #4b5563; box-shadow: 0 10px 25px rgba(0,0,0,0.3), inset 0 1px 0 rgba(255,255,255,0.1); } .model-card::before { content: ''; position: absolute; top: 0; left: 0; right: 0; height: 2px; background: linear-gradient(90deg, #3b82f6, #8b5cf6, #ef4444, #f59e0b); } .model-card-title { font-size: 20px; font-weight: 700; color: #1e293b; margin-bottom: 12px; display: flex; align-items: center; gap: 8px; } .dark .model-card-title { color: #f1f5f9; } .model-stats { display: grid; grid-template-columns: repeat(auto-fit, minmax(120px, 1fr)); gap: 12px; margin-bottom: 16px; } .model-stat { background: rgba(59, 130, 246, 0.1); border: 1px solid rgba(59, 130, 246, 0.2); border-radius: 8px; padding: 8px 12px; text-align: center; } .dark .model-stat { background: rgba(59, 130, 246, 0.15); border: 1px solid rgba(59, 130, 246, 0.3); } .stat-value { font-weight: 700; font-size: 14px; color: #3b82f6; } .dark .stat-value { color: #60a5fa; } .stat-label { font-size: 11px; color: #64748b; text-transform: uppercase; letter-spacing: 0.5px; margin-top: 2px; } .dark .stat-label { color: #94a3b8; } .model-description { color: #475569; font-size: 14px; line-height: 1.6; } .dark .model-description { color: #cbd5e1; } /* Question carousel - right side */ .question-carousel { background: linear-gradient(145deg, #ffffff 0%, #f1f5f9 100%); border: 1px solid #e2e8f0; border-radius: 16px; padding: 20px; box-shadow: 0 4px 16px rgba(0,0,0,0.08), inset 0 1px 0 rgba(255,255,255,0.8); height: fit-content; position: sticky; top: 20px; } .dark .question-carousel { background: linear-gradient(145deg, #1f2937 0%, #111827 100%); border: 1px solid #374151; box-shadow: 0 4px 16px rgba(0,0,0,0.2), inset 0 1px 0 rgba(255,255,255,0.05); } .carousel-title { font-size: 18px; font-weight: 700; color: #1e293b; margin-bottom: 16px; text-align: center; } .dark .carousel-title { color: #f1f5f9; } .carousel-card { background: linear-gradient(135deg, #fafafa 0%, #f4f4f5 100%); border: 1px solid #e4e4e7; border-radius: 12px; padding: 16px; margin-bottom: 16px; box-shadow: 0 2px 8px rgba(0,0,0,0.06), inset 0 1px 0 rgba(255,255,255,0.7); transition: transform 0.2s, box-shadow 0.2s; } .dark .carousel-card { background: linear-gradient(135deg, #374151 0%, #2d3748 100%); border: 1px solid #4b5563; box-shadow: 0 2px 8px rgba(0,0,0,0.15), inset 0 1px 0 rgba(255,255,255,0.05); } .carousel-card:hover { transform: translateY(-2px); box-shadow: 0 4px 16px rgba(0,0,0,0.12), inset 0 1px 0 rgba(255,255,255,0.7); } .dark .carousel-card:hover { box-shadow: 0 4px 16px rgba(0,0,0,0.25), inset 0 1px 0 rgba(255,255,255,0.05); } .carousel-card-title { font-weight: 600; color: #3b82f6; margin-bottom: 12px; font-size: 15px; } .dark .carousel-card-title { color: #60a5fa; } .question-button { display: block; width: 100%; background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%); border: 1px solid #cbd5e1; border-radius: 8px; padding: 8px 12px; margin-bottom: 8px; font-size: 13px; color: #475569; text-align: left; cursor: pointer; transition: all 0.2s; box-shadow: 0 1px 3px rgba(0,0,0,0.05); } .dark .question-button { background: linear-gradient(135deg, #4b5563 0%, #374151 100%); border: 1px solid #6b7280; color: #d1d5db; } .question-button:hover { background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%); color: white; border-color: #3b82f6; transform: translateY(-1px); box-shadow: 0 2px 8px rgba(59, 130, 246, 0.3); } /* Loading animation */ @keyframes pulse { 0% { opacity: 1; } 50% { opacity: 0.5; } 100% { opacity: 1; } } .processing { animation: pulse 1.5s ease-in-out infinite; } /* Typing indicator */ @keyframes typing { 0%, 60%, 100% { opacity: 0.3; } 30% { opacity: 1; } } .typing-indicator { display: inline-block; animation: typing 1.4s infinite; } .question-button:last-child { margin-bottom: 0; } /* Main layout */ .main-layout { display: grid; grid-template-columns: 2fr 1fr; gap: 24px; margin-bottom: 24px; } @media (max-width: 1024px) { .main-layout { grid-template-columns: 1fr; } .question-carousel { position: static; } } /* Chat interface improvements */ .chat-container { background: linear-gradient(145deg, #ffffff 0%, #f8fafc 100%); border: 1px solid #e2e8f0; border-radius: 16px; padding: 20px; box-shadow: 0 4px 16px rgba(0,0,0,0.08), inset 0 1px 0 rgba(255,255,255,0.8); } .dark .chat-container { background: linear-gradient(145deg, #1f2937 0%, #111827 100%); border: 1px solid #374151; box-shadow: 0 4px 16px rgba(0,0,0,0.2), inset 0 1px 0 rgba(255,255,255,0.05); } /* Citation boxes */ .citation-section { margin-top: 32px; padding-top: 24px; border-top: 2px solid #e2e8f0; } .dark .citation-section { border-top: 2px solid #374151; } .citation-title { font-size: 20px; font-weight: 700; color: #1e293b; margin-bottom: 16px; text-align: center; } .dark .citation-title { color: #f1f5f9; } .citation-boxes { display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 16px; } .citation-box { background: linear-gradient(145deg, #f8fafc 0%, #e2e8f0 100%); border: 1px solid #cbd5e1; border-radius: 12px; padding: 16px; box-shadow: 0 4px 12px rgba(0,0,0,0.08), inset 0 1px 0 rgba(255,255,255,0.6); } .dark .citation-box { background: linear-gradient(145deg, #374151 0%, #1f2937 100%); border: 1px solid #4b5563; box-shadow: 0 4px 12px rgba(0,0,0,0.2), inset 0 1px 0 rgba(255,255,255,0.1); } .citation-box h4 { color: #3b82f6; font-weight: 600; margin-bottom: 8px; font-size: 16px; } .dark .citation-box h4 { color: #60a5fa; } .citation-content { background: #1f2937; color: #e5e7eb; padding: 12px; border-radius: 8px; font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; font-size: 12px; line-height: 1.4; overflow-x: auto; white-space: pre-wrap; word-break: break-all; margin-top: 8px; } .dark .citation-content { background: #111827; border: 1px solid #374151; } /* Advanced settings styling */ .advanced-settings { background: linear-gradient(145deg, #f1f5f9 0%, #e2e8f0 100%); border: 1px solid #cbd5e1; border-radius: 12px; margin: 16px 0; box-shadow: 0 2px 8px rgba(0,0,0,0.06), inset 0 1px 0 rgba(255,255,255,0.7); } .dark .advanced-settings { background: linear-gradient(145deg, #374151 0%, #1f2937 100%); border: 1px solid #4b5563; box-shadow: 0 2px 8px rgba(0,0,0,0.15), inset 0 1px 0 rgba(255,255,255,0.05); } /* Disclaimer styling */ .disclaimer-box { background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%); border: 2px solid #f59e0b; border-radius: 12px; padding: 16px 20px; margin: 20px 0; box-shadow: 0 4px 12px rgba(245, 158, 11, 0.2), inset 0 1px 0 rgba(255,255,255,0.5); position: relative; } .dark .disclaimer-box { background: linear-gradient(135deg, #92400e 0%, #78350f 100%); border: 2px solid #f59e0b; color: #fef3c7; box-shadow: 0 4px 12px rgba(245, 158, 11, 0.3), inset 0 1px 0 rgba(255,255,255,0.1); } .disclaimer-box::before { content: ''; position: absolute; left: 0; top: 0; bottom: 0; width: 4px; background: #f59e0b; border-radius: 2px 0 0 2px; } .disclaimer-title { font-weight: 600; color: #92400e; margin-bottom: 8px; display: flex; align-items: center; gap: 8px; font-size: 15px; } .dark .disclaimer-title { color: #fbbf24; } .disclaimer-text { color: #78350f; font-size: 14px; line-height: 1.5; } .dark .disclaimer-text { color: #fef3c7; } /* Button improvements */ .gr-button { border-radius: 8px !important; font-weight: 600 !important; transition: all 0.2s ease !important; } .gr-button-primary { background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%) !important; border: none !important; color: white !important; box-shadow: 0 2px 4px rgba(59, 130, 246, 0.3) !important; } .gr-button-primary:hover { background: linear-gradient(135deg, #2563eb 0%, #1d4ed8 100%) !important; transform: translateY(-1px) !important; box-shadow: 0 4px 12px rgba(59, 130, 246, 0.4) !important; } /* Responsive design */ @media (max-width: 768px) { .header-title { font-size: 28px; } .model-stats { grid-template-columns: repeat(2, 1fr); } .social-icon { font-size: 16px; } .citation-boxes { grid-template-columns: 1fr; } } """ # Create the Gradio interface with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: # Header with credits and social links gr.HTML( """
Advanced AI assistant for dental education and oral health information
Training dataset used for fine-tuning the dental knowledge base.
The fine-tuned model used in this demonstration.