gemma / app.py
lyimo's picture
Create app.py
fa3f706 verified
raw
history blame
5.84 kB
# app.py
import gradio as gr
from unsloth import FastLanguageModel
import torch
from PIL import Image
from transformers import TextStreamer
import os
# --- Configuration ---
# 1. Base Model Name (must match the one used for training)
BASE_MODEL_NAME = "unsloth/gemma-3n-E4B-it"
# 2. Your PEFT (LoRA) Model Name on Hugging Face Hub
# Replace 'your-username' and 'your-model-repo-name' with your actual details
PEFT_MODEL_NAME = "lyimo/mosquito-breeding-detection" # Or your Hugging Face repo path
# 3. Max sequence length (should match or exceed training setting)
MAX_SEQ_LENGTH = 2048
# --- Load Model and Tokenizer ---
print("Loading base model...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=BASE_MODEL_NAME,
max_seq_length=MAX_SEQ_LENGTH,
dtype=None, # Auto-detect
load_in_4bit=True, # Match training setting
)
print("Loading LoRA adapters...")
model = FastLanguageModel.get_peft_model(model, peft_model_name=PEFT_MODEL_NAME)
print("Setting up chat template...")
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")
print("Model and tokenizer loaded successfully!")
# --- Inference Function ---
def analyze_image(image, prompt):
"""
Analyzes the image using the fine-tuned model.
"""
if image is None:
return "Please upload an image."
# Save the uploaded image temporarily (or pass the PIL object, see notes)
# Unsloth's tokenizer often expects the image path during apply_chat_template
# for multimodal inputs.
temp_image_path = "temp_uploaded_image.jpg"
try:
image.save(temp_image_path) # Save PIL image from Gradio
# Construct messages
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": temp_image_path}, # Pass the temporary path
{"type": "text", "text": prompt}
]
}
]
# Apply chat template
full_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize inputs
inputs = tokenizer(
full_prompt,
return_tensors="pt",
).to(model.device)
# --- Generation ---
# Collect the output text
output_text = ""
def text_collector(text):
nonlocal output_text
output_text += text
# Create a custom streamer to capture text
class GradioTextStreamer:
def __init__(self, tokenizer, callback=None):
self.tokenizer = tokenizer
self.callback = callback
self.token_cache = []
self.print_len = 0
def put(self, value):
if self.callback:
# Decode the current token(s)
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True)
# Call the callback with the new text
self.callback(text[len(output_text):]) # Send only the new part
# Update output_text locally to track progress
nonlocal output_text
output_text = text
def end(self):
if self.callback:
# Ensure any remaining text is sent
self.callback("") # Signal end, or send final text if needed differently
self.token_cache = []
self.print_len = 0
streamer = GradioTextStreamer(tokenizer, callback=text_collector)
# Start generation in a separate thread to allow streaming
import threading
def generate_text():
_ = model.generate(
**inputs,
max_new_tokens=1024,
streamer=streamer,
# You can add other generation parameters here
# temperature=0.7,
# top_p=0.95,
# do_sample=True
)
# Signal completion after generation finishes
yield output_text # Final yield to ensure completeness
# Yield initial output and then stream updates
yield output_text # Initial empty or partial output
for _ in generate_text(): # This loop runs the generation
yield output_text # Yield updated text as it's generated
except Exception as e:
error_msg = f"An error occurred during processing: {str(e)}"
print(error_msg)
yield error_msg
finally:
# Clean up the temporary image file
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("# 🦟 Mosquito Breeding Site Detector")
gr.Markdown("Upload an image and ask the AI to analyze it for potential mosquito breeding sites.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
prompt_input = gr.Textbox(
label="Your Question",
value="Can you analyze this image for mosquito breeding sites and recommend what to do?",
lines=2
)
submit_btn = gr.Button("Analyze")
with gr.Column():
output_text = gr.Textbox(label="Analysis Result", interactive=False, lines=15)
# Connect the button to the function
submit_btn.click(
fn=analyze_image,
inputs=[image_input, prompt_input],
outputs=output_text, # Stream to the textbox
streaming=True # Enable streaming output
)
# Launch the app
if __name__ == "__main__":
demo.launch()