File size: 7,591 Bytes
26dc4f5
 
 
62ede92
26dc4f5
 
 
 
 
 
 
d55c559
 
810bbae
0015607
d55c559
810bbae
0015607
d55c559
 
 
 
 
 
 
810bbae
d55c559
26dc4f5
 
d55c559
 
26dc4f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62ede92
26dc4f5
 
 
d55c559
26dc4f5
 
d55c559
26dc4f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d55c559
26dc4f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d55c559
26dc4f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces

# Model configuration
MODEL_ID = "yasserrmd/DentaInstruct-1.2B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize model and tokenizer
print(f"Loading model {MODEL_ID}...")

# Load tokenizer - try the fine-tuned model first, then base model
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    print(f"Loaded tokenizer from {MODEL_ID}")
except Exception as e:
    print(f"Failed to load tokenizer from {MODEL_ID}: {e}")
    print("Using tokenizer from base LFM2 model...")
    try:
        tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-1.2B")
    except Exception as e2:
        print(f"Failed to load LFM2 tokenizer: {e2}")
        print("Using fallback TinyLlama tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

# Load model with proper dtype for efficiency
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None
)

if not torch.cuda.is_available():
    model = model.to(DEVICE)

# Set padding token if not set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def format_prompt(message, history):
    """Format the prompt for the model"""
    messages = []
    
    # Add conversation history
    for user_msg, assistant_msg in history:
        messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    
    # Add current message
    messages.append({"role": "user", "content": message})
    
    # Apply chat template
    if hasattr(tokenizer, 'apply_chat_template'):
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    else:
        # Fallback formatting
        prompt = ""
        for msg in messages:
            if msg["role"] == "user":
                prompt += f"User: {msg['content']}\n"
            else:
                prompt += f"Assistant: {msg['content']}\n"
        prompt += "Assistant: "
    
    return prompt

@spaces.GPU(duration=60)
def generate_response(
    message,
    history,
    temperature=0.3,
    max_new_tokens=512,
    top_p=0.95,
    repetition_penalty=1.05,
):
    """Generate response from the model"""
    
    # Format the prompt
    prompt = format_prompt(message, history)
    
    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    # Decode response
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    
    return response

# Example questions
EXAMPLES = [
    ["What are the main types of dental cavities?"],
    ["Explain the process of root canal treatment"],
    ["What is the difference between gingivitis and periodontitis?"],
    ["How should I care for my teeth after a dental extraction?"],
    ["What are the benefits of fluoride in dental care?"],
    ["Explain the stages of tooth development in children"],
    ["What causes tooth sensitivity and how can it be treated?"],
    ["Describe the different types of dental fillings available"],
]

# Custom CSS for styling
custom_css = """
.disclaimer {
    background-color: #fff3cd;
    border: 1px solid #ffc107;
    border-radius: 5px;
    padding: 10px;
    margin-bottom: 15px;
}
"""

# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
    gr.Markdown(
        """
        # Dental VQA Model Comparison
        
        Interactive comparison of dental visual question answering models. Currently featuring DentaInstruct-1.2B for dental education and oral health information.
        """
    )
    
    gr.HTML(
        """
        <div class="disclaimer">
        <strong>⚠️ Important Disclaimer:</strong><br>
        This model is for educational purposes only. It is NOT a substitute for professional dental care.
        Do not use this model for clinical diagnosis or treatment advice. Always consult a qualified dental professional.
        </div>
        """
    )
    
    chatbot = gr.Chatbot(
        height=400,
        label="Conversation"
    )
    
    msg = gr.Textbox(
        label="Your dental question",
        placeholder="Ask a question about dental health, procedures, or oral care...",
        lines=2
    )
    
    with gr.Row():
        submit = gr.Button("Send", variant="primary")
        clear = gr.Button("Clear")
    
    with gr.Accordion("Advanced Settings", open=False):
        temperature = gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.3,
            step=0.1,
            label="Temperature",
            info="Controls randomness in responses"
        )
        
        max_new_tokens = gr.Slider(
            minimum=64,
            maximum=1024,
            value=512,
            step=64,
            label="Max New Tokens",
            info="Maximum length of the response"
        )
        
        top_p = gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p",
            info="Nucleus sampling parameter"
        )
        
        repetition_penalty = gr.Slider(
            minimum=1.0,
            maximum=1.5,
            value=1.05,
            step=0.05,
            label="Repetition Penalty",
            info="Reduces repetition in responses"
        )
    
    gr.Examples(
        examples=EXAMPLES,
        inputs=msg,
        label="Example Questions"
    )
    
    gr.Markdown(
        """
        ## About This Model
        
        DentaInstruct-1.2B is a specialised language model fine-tuned on dental educational content.
        It's designed to provide educational information about dental health, procedures, and oral care.
        
        **Model Details:**
        - Base Model: LFM2-1.2B
        - Parameters: 1.17B
        - Training Data: Dental subset of MIRIAD dataset
        - Purpose: Educational dental information
        
        **Created by:** @yasserrmd | **Space by:** @chrisvoncsefalvay
        """
    )
    
    # Event handlers
    def respond(message, chat_history, temperature, max_new_tokens, top_p, repetition_penalty):
        response = generate_response(
            message, 
            chat_history,
            temperature,
            max_new_tokens,
            top_p,
            repetition_penalty
        )
        chat_history.append((message, response))
        return "", chat_history
    
    msg.submit(
        respond,
        [msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty],
        [msg, chatbot]
    )
    
    submit.click(
        respond,
        [msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty],
        [msg, chatbot]
    )
    
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.launch()