File size: 5,842 Bytes
fa3f706 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# app.py
import gradio as gr
from unsloth import FastLanguageModel
import torch
from PIL import Image
from transformers import TextStreamer
import os
# --- Configuration ---
# 1. Base Model Name (must match the one used for training)
BASE_MODEL_NAME = "unsloth/gemma-3n-E4B-it"
# 2. Your PEFT (LoRA) Model Name on Hugging Face Hub
# Replace 'your-username' and 'your-model-repo-name' with your actual details
PEFT_MODEL_NAME = "lyimo/mosquito-breeding-detection" # Or your Hugging Face repo path
# 3. Max sequence length (should match or exceed training setting)
MAX_SEQ_LENGTH = 2048
# --- Load Model and Tokenizer ---
print("Loading base model...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=BASE_MODEL_NAME,
max_seq_length=MAX_SEQ_LENGTH,
dtype=None, # Auto-detect
load_in_4bit=True, # Match training setting
)
print("Loading LoRA adapters...")
model = FastLanguageModel.get_peft_model(model, peft_model_name=PEFT_MODEL_NAME)
print("Setting up chat template...")
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")
print("Model and tokenizer loaded successfully!")
# --- Inference Function ---
def analyze_image(image, prompt):
"""
Analyzes the image using the fine-tuned model.
"""
if image is None:
return "Please upload an image."
# Save the uploaded image temporarily (or pass the PIL object, see notes)
# Unsloth's tokenizer often expects the image path during apply_chat_template
# for multimodal inputs.
temp_image_path = "temp_uploaded_image.jpg"
try:
image.save(temp_image_path) # Save PIL image from Gradio
# Construct messages
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": temp_image_path}, # Pass the temporary path
{"type": "text", "text": prompt}
]
}
]
# Apply chat template
full_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize inputs
inputs = tokenizer(
full_prompt,
return_tensors="pt",
).to(model.device)
# --- Generation ---
# Collect the output text
output_text = ""
def text_collector(text):
nonlocal output_text
output_text += text
# Create a custom streamer to capture text
class GradioTextStreamer:
def __init__(self, tokenizer, callback=None):
self.tokenizer = tokenizer
self.callback = callback
self.token_cache = []
self.print_len = 0
def put(self, value):
if self.callback:
# Decode the current token(s)
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True)
# Call the callback with the new text
self.callback(text[len(output_text):]) # Send only the new part
# Update output_text locally to track progress
nonlocal output_text
output_text = text
def end(self):
if self.callback:
# Ensure any remaining text is sent
self.callback("") # Signal end, or send final text if needed differently
self.token_cache = []
self.print_len = 0
streamer = GradioTextStreamer(tokenizer, callback=text_collector)
# Start generation in a separate thread to allow streaming
import threading
def generate_text():
_ = model.generate(
**inputs,
max_new_tokens=1024,
streamer=streamer,
# You can add other generation parameters here
# temperature=0.7,
# top_p=0.95,
# do_sample=True
)
# Signal completion after generation finishes
yield output_text # Final yield to ensure completeness
# Yield initial output and then stream updates
yield output_text # Initial empty or partial output
for _ in generate_text(): # This loop runs the generation
yield output_text # Yield updated text as it's generated
except Exception as e:
error_msg = f"An error occurred during processing: {str(e)}"
print(error_msg)
yield error_msg
finally:
# Clean up the temporary image file
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("# 🦟 Mosquito Breeding Site Detector")
gr.Markdown("Upload an image and ask the AI to analyze it for potential mosquito breeding sites.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
prompt_input = gr.Textbox(
label="Your Question",
value="Can you analyze this image for mosquito breeding sites and recommend what to do?",
lines=2
)
submit_btn = gr.Button("Analyze")
with gr.Column():
output_text = gr.Textbox(label="Analysis Result", interactive=False, lines=15)
# Connect the button to the function
submit_btn.click(
fn=analyze_image,
inputs=[image_input, prompt_input],
outputs=output_text, # Stream to the textbox
streaming=True # Enable streaming output
)
# Launch the app
if __name__ == "__main__":
demo.launch()
|