# app.py import gradio as gr from unsloth import FastLanguageModel import torch from PIL import Image from transformers import TextIteratorStreamer from threading import Thread import os # --- Configuration --- # 1. Base Model Name (must match the one used for training) BASE_MODEL_NAME = "unsloth/gemma-3n-E4B-it" # 2. Your PEFT (LoRA) Model Name on Hugging Face Hub PEFT_MODEL_NAME = "lyimo/mosquito-breeding-detection" # 3. Max sequence length (should match or exceed training setting) MAX_SEQ_LENGTH = 2048 # --- Load Model and Tokenizer --- print("Loading base model...") model, tokenizer = FastLanguageModel.from_pretrained( model_name=BASE_MODEL_NAME, max_seq_length=MAX_SEQ_LENGTH, dtype=None, # Auto-detect load_in_4bit=True, # Match training setting ) print("Loading LoRA adapters...") model = FastLanguageModel.get_peft_model(model, peft_model_name=PEFT_MODEL_NAME) print("Setting up chat template...") from unsloth.chat_templates import get_chat_template tokenizer = get_chat_template(tokenizer, chat_template="gemma-3") print("Model and tokenizer loaded successfully!") # --- Inference Function --- def analyze_image(image, prompt): """ Analyzes the image using the fine-tuned model and streams the output. """ if image is None: return "Please upload an image." temp_image_path = "temp_uploaded_image.jpg" try: image.save(temp_image_path) messages = [ { "role": "user", "content": [ {"type": "image", "image": temp_image_path}, {"type": "text", "text": prompt} ] } ] full_prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer( full_prompt, return_tensors="pt", ).to(model.device) # Use TextIteratorStreamer for simpler, more robust streaming streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # Define generation arguments generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=1024, # You can add other generation parameters here # temperature=0.7, # top_p=0.95, # do_sample=True ) # Run generation in a separate thread to avoid blocking the UI thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Yield the generated text as it becomes available generated_text = "" for new_text in streamer: generated_text += new_text yield generated_text except Exception as e: error_msg = f"An error occurred during processing: {str(e)}" print(error_msg) yield error_msg finally: # Clean up the temporary image file if os.path.exists(temp_image_path): os.remove(temp_image_path) # --- Gradio Interface --- with gr.Blocks() as demo: gr.Markdown("# 🦟 Mosquito Breeding Site Detector") gr.Markdown("Upload an image and ask the AI to analyze it for potential mosquito breeding sites.") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") prompt_input = gr.Textbox( label="Your Question", value="Can you analyze this image for mosquito breeding sites and recommend what to do?", lines=2 ) submit_btn = gr.Button("Analyze") with gr.Column(): output_text = gr.Textbox(label="Analysis Result", interactive=False, lines=15) # Connect the button to the function # The 'streaming=True' flag in Gradio 3 is deprecated. The streaming behavior # is now automatically handled by using a generator function (with 'yield'). submit_btn.click( fn=analyze_image, inputs=[image_input, prompt_input], outputs=output_text ) # Launch the app if __name__ == "__main__": demo.launch()