File size: 4,178 Bytes
fa3f706 35283c1 fa3f706 35283c1 fa3f706 35283c1 fa3f706 35283c1 fa3f706 35283c1 fa3f706 35283c1 fa3f706 35283c1 fa3f706 35283c1 fa3f706 35283c1 fa3f706 35283c1 fa3f706 35283c1 fa3f706 35283c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# app.py
import gradio as gr
from unsloth import FastLanguageModel
import torch
from PIL import Image
from transformers import TextIteratorStreamer
from threading import Thread
import os
# --- Configuration ---
# 1. Base Model Name (must match the one used for training)
BASE_MODEL_NAME = "unsloth/gemma-3n-E4B-it"
# 2. Your PEFT (LoRA) Model Name on Hugging Face Hub
PEFT_MODEL_NAME = "lyimo/mosquito-breeding-detection"
# 3. Max sequence length (should match or exceed training setting)
MAX_SEQ_LENGTH = 2048
# --- Load Model and Tokenizer ---
print("Loading base model...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=BASE_MODEL_NAME,
max_seq_length=MAX_SEQ_LENGTH,
dtype=None, # Auto-detect
load_in_4bit=True, # Match training setting
)
print("Loading LoRA adapters...")
model = FastLanguageModel.get_peft_model(model, peft_model_name=PEFT_MODEL_NAME)
print("Setting up chat template...")
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")
print("Model and tokenizer loaded successfully!")
# --- Inference Function ---
def analyze_image(image, prompt):
"""
Analyzes the image using the fine-tuned model and streams the output.
"""
if image is None:
return "Please upload an image."
temp_image_path = "temp_uploaded_image.jpg"
try:
image.save(temp_image_path)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": temp_image_path},
{"type": "text", "text": prompt}
]
}
]
full_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(
full_prompt,
return_tensors="pt",
).to(model.device)
# Use TextIteratorStreamer for simpler, more robust streaming
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Define generation arguments
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=1024,
# You can add other generation parameters here
# temperature=0.7,
# top_p=0.95,
# do_sample=True
)
# Run generation in a separate thread to avoid blocking the UI
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield the generated text as it becomes available
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
except Exception as e:
error_msg = f"An error occurred during processing: {str(e)}"
print(error_msg)
yield error_msg
finally:
# Clean up the temporary image file
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("# 🦟 Mosquito Breeding Site Detector")
gr.Markdown("Upload an image and ask the AI to analyze it for potential mosquito breeding sites.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
prompt_input = gr.Textbox(
label="Your Question",
value="Can you analyze this image for mosquito breeding sites and recommend what to do?",
lines=2
)
submit_btn = gr.Button("Analyze")
with gr.Column():
output_text = gr.Textbox(label="Analysis Result", interactive=False, lines=15)
# Connect the button to the function
# The 'streaming=True' flag in Gradio 3 is deprecated. The streaming behavior
# is now automatically handled by using a generator function (with 'yield').
submit_btn.click(
fn=analyze_image,
inputs=[image_input, prompt_input],
outputs=output_text
)
# Launch the app
if __name__ == "__main__":
demo.launch() |