Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,591 Bytes
26dc4f5 62ede92 26dc4f5 d55c559 810bbae 0015607 d55c559 810bbae 0015607 d55c559 810bbae d55c559 26dc4f5 d55c559 26dc4f5 62ede92 26dc4f5 d55c559 26dc4f5 d55c559 26dc4f5 d55c559 26dc4f5 d55c559 26dc4f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
# Model configuration
MODEL_ID = "yasserrmd/DentaInstruct-1.2B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize model and tokenizer
print(f"Loading model {MODEL_ID}...")
# Load tokenizer - try the fine-tuned model first, then base model
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print(f"Loaded tokenizer from {MODEL_ID}")
except Exception as e:
print(f"Failed to load tokenizer from {MODEL_ID}: {e}")
print("Using tokenizer from base LFM2 model...")
try:
tokenizer = AutoTokenizer.from_pretrained("LiquidAI/LFM2-1.2B")
except Exception as e2:
print(f"Failed to load LFM2 tokenizer: {e2}")
print("Using fallback TinyLlama tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
# Load model with proper dtype for efficiency
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
if not torch.cuda.is_available():
model = model.to(DEVICE)
# Set padding token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def format_prompt(message, history):
"""Format the prompt for the model"""
messages = []
# Add conversation history
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Add current message
messages.append({"role": "user", "content": message})
# Apply chat template
if hasattr(tokenizer, 'apply_chat_template'):
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
# Fallback formatting
prompt = ""
for msg in messages:
if msg["role"] == "user":
prompt += f"User: {msg['content']}\n"
else:
prompt += f"Assistant: {msg['content']}\n"
prompt += "Assistant: "
return prompt
@spaces.GPU(duration=60)
def generate_response(
message,
history,
temperature=0.3,
max_new_tokens=512,
top_p=0.95,
repetition_penalty=1.05,
):
"""Generate response from the model"""
# Format the prompt
prompt = format_prompt(message, history)
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode response
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return response
# Example questions
EXAMPLES = [
["What are the main types of dental cavities?"],
["Explain the process of root canal treatment"],
["What is the difference between gingivitis and periodontitis?"],
["How should I care for my teeth after a dental extraction?"],
["What are the benefits of fluoride in dental care?"],
["Explain the stages of tooth development in children"],
["What causes tooth sensitivity and how can it be treated?"],
["Describe the different types of dental fillings available"],
]
# Custom CSS for styling
custom_css = """
.disclaimer {
background-color: #fff3cd;
border: 1px solid #ffc107;
border-radius: 5px;
padding: 10px;
margin-bottom: 15px;
}
"""
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
gr.Markdown(
"""
# Dental VQA Model Comparison
Interactive comparison of dental visual question answering models. Currently featuring DentaInstruct-1.2B for dental education and oral health information.
"""
)
gr.HTML(
"""
<div class="disclaimer">
<strong>⚠️ Important Disclaimer:</strong><br>
This model is for educational purposes only. It is NOT a substitute for professional dental care.
Do not use this model for clinical diagnosis or treatment advice. Always consult a qualified dental professional.
</div>
"""
)
chatbot = gr.Chatbot(
height=400,
label="Conversation"
)
msg = gr.Textbox(
label="Your dental question",
placeholder="Ask a question about dental health, procedures, or oral care...",
lines=2
)
with gr.Row():
submit = gr.Button("Send", variant="primary")
clear = gr.Button("Clear")
with gr.Accordion("Advanced Settings", open=False):
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.3,
step=0.1,
label="Temperature",
info="Controls randomness in responses"
)
max_new_tokens = gr.Slider(
minimum=64,
maximum=1024,
value=512,
step=64,
label="Max New Tokens",
info="Maximum length of the response"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p",
info="Nucleus sampling parameter"
)
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=1.5,
value=1.05,
step=0.05,
label="Repetition Penalty",
info="Reduces repetition in responses"
)
gr.Examples(
examples=EXAMPLES,
inputs=msg,
label="Example Questions"
)
gr.Markdown(
"""
## About This Model
DentaInstruct-1.2B is a specialised language model fine-tuned on dental educational content.
It's designed to provide educational information about dental health, procedures, and oral care.
**Model Details:**
- Base Model: LFM2-1.2B
- Parameters: 1.17B
- Training Data: Dental subset of MIRIAD dataset
- Purpose: Educational dental information
**Created by:** @yasserrmd | **Space by:** @chrisvoncsefalvay
"""
)
# Event handlers
def respond(message, chat_history, temperature, max_new_tokens, top_p, repetition_penalty):
response = generate_response(
message,
chat_history,
temperature,
max_new_tokens,
top_p,
repetition_penalty
)
chat_history.append((message, response))
return "", chat_history
msg.submit(
respond,
[msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty],
[msg, chatbot]
)
submit.click(
respond,
[msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty],
[msg, chatbot]
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch() |