chrisvoncsefalvay's picture
Use TinyLlama tokenizer as fallback for corrupted tokenizer
0015607
raw
history blame
7.29 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Model configuration
MODEL_ID = "yasserrmd/DentaInstruct-1.2B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize model and tokenizer
print(f"Loading model {MODEL_ID}...")
try:
# Try loading the tokenizer normally
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
except Exception as e:
print(f"Failed to load tokenizer from {MODEL_ID}: {e}")
print("Using fallback tokenizer from base model...")
# Use the base model's tokenizer as fallback
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
if not torch.cuda.is_available():
model = model.to(DEVICE)
# Set padding token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def format_prompt(message, history):
"""Format the prompt for the model"""
messages = []
# Add conversation history
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Add current message
messages.append({"role": "user", "content": message})
# Apply chat template
if hasattr(tokenizer, 'apply_chat_template'):
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
# Fallback formatting
prompt = ""
for msg in messages:
if msg["role"] == "user":
prompt += f"User: {msg['content']}\n"
else:
prompt += f"Assistant: {msg['content']}\n"
prompt += "Assistant: "
return prompt
def generate_response(
message,
history,
temperature=0.7,
max_new_tokens=512,
top_p=0.95,
repetition_penalty=1.1,
):
"""Generate response from the model"""
# Format the prompt
prompt = format_prompt(message, history)
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode response
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return response
# Example questions
EXAMPLES = [
["What are the main types of dental cavities?"],
["Explain the process of root canal treatment"],
["What is the difference between gingivitis and periodontitis?"],
["How should I care for my teeth after a dental extraction?"],
["What are the benefits of fluoride in dental care?"],
["Explain the stages of tooth development in children"],
["What causes tooth sensitivity and how can it be treated?"],
["Describe the different types of dental fillings available"],
]
# Custom CSS for styling
custom_css = """
.disclaimer {
background-color: #fff3cd;
border: 1px solid #ffc107;
border-radius: 5px;
padding: 10px;
margin-bottom: 15px;
}
"""
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
gr.Markdown(
"""
# Dental VQA Model Comparison
Interactive comparison of dental visual question answering models. Currently featuring DentaInstruct-1.2B for dental education and oral health information.
"""
)
gr.HTML(
"""
<div class="disclaimer">
<strong>⚠️ Important Disclaimer:</strong><br>
This model is for educational purposes only. It is NOT a substitute for professional dental care.
Do not use this model for clinical diagnosis or treatment advice. Always consult a qualified dental professional.
</div>
"""
)
chatbot = gr.Chatbot(
height=400,
label="Conversation"
)
msg = gr.Textbox(
label="Your dental question",
placeholder="Ask a question about dental health, procedures, or oral care...",
lines=2
)
with gr.Row():
submit = gr.Button("Send", variant="primary")
clear = gr.Button("Clear")
with gr.Accordion("Advanced Settings", open=False):
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature",
info="Controls randomness in responses"
)
max_new_tokens = gr.Slider(
minimum=64,
maximum=1024,
value=512,
step=64,
label="Max New Tokens",
info="Maximum length of the response"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p",
info="Nucleus sampling parameter"
)
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=1.5,
value=1.1,
step=0.05,
label="Repetition Penalty",
info="Reduces repetition in responses"
)
gr.Examples(
examples=EXAMPLES,
inputs=msg,
label="Example Questions"
)
gr.Markdown(
"""
## About This Model
DentaInstruct-1.2B is a specialised language model fine-tuned on dental educational content.
It's designed to provide educational information about dental health, procedures, and oral care.
**Model Details:**
- Base Model: LFM2-1.2B
- Parameters: 1.17B
- Training Data: Dental subset of MIRIAD dataset
- Purpose: Educational dental information
**Created by:** @yasserrmd | **Space by:** @chrisvoncsefalvay
"""
)
# Event handlers
def respond(message, chat_history, temperature, max_new_tokens, top_p, repetition_penalty):
response = generate_response(
message,
chat_history,
temperature,
max_new_tokens,
top_p,
repetition_penalty
)
chat_history.append((message, response))
return "", chat_history
msg.submit(
respond,
[msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty],
[msg, chatbot]
)
submit.click(
respond,
[msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty],
[msg, chatbot]
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch()