# app.py import gradio as gr from unsloth import FastLanguageModel import torch from PIL import Image from transformers import TextStreamer import os # --- Configuration --- # 1. Base Model Name (must match the one used for training) BASE_MODEL_NAME = "unsloth/gemma-3n-E4B-it" # 2. Your PEFT (LoRA) Model Name on Hugging Face Hub # Replace 'your-username' and 'your-model-repo-name' with your actual details PEFT_MODEL_NAME = "lyimo/mosquito-breeding-detection" # Or your Hugging Face repo path # 3. Max sequence length (should match or exceed training setting) MAX_SEQ_LENGTH = 2048 # --- Load Model and Tokenizer --- print("Loading base model...") model, tokenizer = FastLanguageModel.from_pretrained( model_name=BASE_MODEL_NAME, max_seq_length=MAX_SEQ_LENGTH, dtype=None, # Auto-detect load_in_4bit=True, # Match training setting ) print("Loading LoRA adapters...") model = FastLanguageModel.get_peft_model(model, peft_model_name=PEFT_MODEL_NAME) print("Setting up chat template...") from unsloth.chat_templates import get_chat_template tokenizer = get_chat_template(tokenizer, chat_template="gemma-3") print("Model and tokenizer loaded successfully!") # --- Inference Function --- def analyze_image(image, prompt): """ Analyzes the image using the fine-tuned model. """ if image is None: return "Please upload an image." # Save the uploaded image temporarily (or pass the PIL object, see notes) # Unsloth's tokenizer often expects the image path during apply_chat_template # for multimodal inputs. temp_image_path = "temp_uploaded_image.jpg" try: image.save(temp_image_path) # Save PIL image from Gradio # Construct messages messages = [ { "role": "user", "content": [ {"type": "image", "image": temp_image_path}, # Pass the temporary path {"type": "text", "text": prompt} ] } ] # Apply chat template full_prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize inputs inputs = tokenizer( full_prompt, return_tensors="pt", ).to(model.device) # --- Generation --- # Collect the output text output_text = "" def text_collector(text): nonlocal output_text output_text += text # Create a custom streamer to capture text class GradioTextStreamer: def __init__(self, tokenizer, callback=None): self.tokenizer = tokenizer self.callback = callback self.token_cache = [] self.print_len = 0 def put(self, value): if self.callback: # Decode the current token(s) self.token_cache.extend(value.tolist()) text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True) # Call the callback with the new text self.callback(text[len(output_text):]) # Send only the new part # Update output_text locally to track progress nonlocal output_text output_text = text def end(self): if self.callback: # Ensure any remaining text is sent self.callback("") # Signal end, or send final text if needed differently self.token_cache = [] self.print_len = 0 streamer = GradioTextStreamer(tokenizer, callback=text_collector) # Start generation in a separate thread to allow streaming import threading def generate_text(): _ = model.generate( **inputs, max_new_tokens=1024, streamer=streamer, # You can add other generation parameters here # temperature=0.7, # top_p=0.95, # do_sample=True ) # Signal completion after generation finishes yield output_text # Final yield to ensure completeness # Yield initial output and then stream updates yield output_text # Initial empty or partial output for _ in generate_text(): # This loop runs the generation yield output_text # Yield updated text as it's generated except Exception as e: error_msg = f"An error occurred during processing: {str(e)}" print(error_msg) yield error_msg finally: # Clean up the temporary image file if os.path.exists(temp_image_path): os.remove(temp_image_path) # --- Gradio Interface --- with gr.Blocks() as demo: gr.Markdown("# 🦟 Mosquito Breeding Site Detector") gr.Markdown("Upload an image and ask the AI to analyze it for potential mosquito breeding sites.") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") prompt_input = gr.Textbox( label="Your Question", value="Can you analyze this image for mosquito breeding sites and recommend what to do?", lines=2 ) submit_btn = gr.Button("Analyze") with gr.Column(): output_text = gr.Textbox(label="Analysis Result", interactive=False, lines=15) # Connect the button to the function submit_btn.click( fn=analyze_image, inputs=[image_input, prompt_input], outputs=output_text, # Stream to the textbox streaming=True # Enable streaming output ) # Launch the app if __name__ == "__main__": demo.launch()