gemma / app.py
lyimo's picture
Update app.py
35283c1 verified
# app.py
import gradio as gr
from unsloth import FastLanguageModel
import torch
from PIL import Image
from transformers import TextIteratorStreamer
from threading import Thread
import os
# --- Configuration ---
# 1. Base Model Name (must match the one used for training)
BASE_MODEL_NAME = "unsloth/gemma-3n-E4B-it"
# 2. Your PEFT (LoRA) Model Name on Hugging Face Hub
PEFT_MODEL_NAME = "lyimo/mosquito-breeding-detection"
# 3. Max sequence length (should match or exceed training setting)
MAX_SEQ_LENGTH = 2048
# --- Load Model and Tokenizer ---
print("Loading base model...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=BASE_MODEL_NAME,
max_seq_length=MAX_SEQ_LENGTH,
dtype=None, # Auto-detect
load_in_4bit=True, # Match training setting
)
print("Loading LoRA adapters...")
model = FastLanguageModel.get_peft_model(model, peft_model_name=PEFT_MODEL_NAME)
print("Setting up chat template...")
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")
print("Model and tokenizer loaded successfully!")
# --- Inference Function ---
def analyze_image(image, prompt):
"""
Analyzes the image using the fine-tuned model and streams the output.
"""
if image is None:
return "Please upload an image."
temp_image_path = "temp_uploaded_image.jpg"
try:
image.save(temp_image_path)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": temp_image_path},
{"type": "text", "text": prompt}
]
}
]
full_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(
full_prompt,
return_tensors="pt",
).to(model.device)
# Use TextIteratorStreamer for simpler, more robust streaming
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Define generation arguments
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=1024,
# You can add other generation parameters here
# temperature=0.7,
# top_p=0.95,
# do_sample=True
)
# Run generation in a separate thread to avoid blocking the UI
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield the generated text as it becomes available
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
except Exception as e:
error_msg = f"An error occurred during processing: {str(e)}"
print(error_msg)
yield error_msg
finally:
# Clean up the temporary image file
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("# 🦟 Mosquito Breeding Site Detector")
gr.Markdown("Upload an image and ask the AI to analyze it for potential mosquito breeding sites.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
prompt_input = gr.Textbox(
label="Your Question",
value="Can you analyze this image for mosquito breeding sites and recommend what to do?",
lines=2
)
submit_btn = gr.Button("Analyze")
with gr.Column():
output_text = gr.Textbox(label="Analysis Result", interactive=False, lines=15)
# Connect the button to the function
# The 'streaming=True' flag in Gradio 3 is deprecated. The streaming behavior
# is now automatically handled by using a generator function (with 'yield').
submit_btn.click(
fn=analyze_image,
inputs=[image_input, prompt_input],
outputs=output_text
)
# Launch the app
if __name__ == "__main__":
demo.launch()