File size: 4,178 Bytes
fa3f706
 
 
 
 
35283c1
 
fa3f706
 
 
 
 
 
 
35283c1
fa3f706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35283c1
fa3f706
 
 
35283c1
fa3f706
 
 
 
 
 
35283c1
fa3f706
 
 
 
 
35283c1
fa3f706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35283c1
 
 
 
 
 
 
 
 
 
 
 
 
fa3f706
35283c1
 
 
 
 
 
 
 
 
fa3f706
 
 
 
 
 
 
 
 
 
35283c1
fa3f706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35283c1
 
fa3f706
 
 
35283c1
fa3f706
 
 
 
35283c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# app.py
import gradio as gr
from unsloth import FastLanguageModel
import torch
from PIL import Image
from transformers import TextIteratorStreamer
from threading import Thread
import os

# --- Configuration ---
# 1. Base Model Name (must match the one used for training)
BASE_MODEL_NAME = "unsloth/gemma-3n-E4B-it"

# 2. Your PEFT (LoRA) Model Name on Hugging Face Hub
PEFT_MODEL_NAME = "lyimo/mosquito-breeding-detection"

# 3. Max sequence length (should match or exceed training setting)
MAX_SEQ_LENGTH = 2048

# --- Load Model and Tokenizer ---
print("Loading base model...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE_MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    dtype=None, # Auto-detect
    load_in_4bit=True, # Match training setting
)

print("Loading LoRA adapters...")
model = FastLanguageModel.get_peft_model(model, peft_model_name=PEFT_MODEL_NAME)

print("Setting up chat template...")
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")

print("Model and tokenizer loaded successfully!")


# --- Inference Function ---
def analyze_image(image, prompt):
    """
    Analyzes the image using the fine-tuned model and streams the output.
    """
    if image is None:
        return "Please upload an image."

    temp_image_path = "temp_uploaded_image.jpg"
    try:
        image.save(temp_image_path)

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": temp_image_path},
                    {"type": "text", "text": prompt}
                ]
            }
        ]

        full_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )

        inputs = tokenizer(
            full_prompt,
            return_tensors="pt",
        ).to(model.device)

        # Use TextIteratorStreamer for simpler, more robust streaming
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

        # Define generation arguments
        generation_kwargs = dict(
            **inputs,
            streamer=streamer,
            max_new_tokens=1024,
            # You can add other generation parameters here
            # temperature=0.7,
            # top_p=0.95,
            # do_sample=True
        )

        # Run generation in a separate thread to avoid blocking the UI
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        # Yield the generated text as it becomes available
        generated_text = ""
        for new_text in streamer:
            generated_text += new_text
            yield generated_text

    except Exception as e:
        error_msg = f"An error occurred during processing: {str(e)}"
        print(error_msg)
        yield error_msg
    finally:
        # Clean up the temporary image file
        if os.path.exists(temp_image_path):
            os.remove(temp_image_path)


# --- Gradio Interface ---
with gr.Blocks() as demo:
    gr.Markdown("# 🦟 Mosquito Breeding Site Detector")
    gr.Markdown("Upload an image and ask the AI to analyze it for potential mosquito breeding sites.")
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image")
            prompt_input = gr.Textbox(
                label="Your Question",
                value="Can you analyze this image for mosquito breeding sites and recommend what to do?",
                lines=2
            )
            submit_btn = gr.Button("Analyze")
        with gr.Column():
            output_text = gr.Textbox(label="Analysis Result", interactive=False, lines=15)

    # Connect the button to the function
    # The 'streaming=True' flag in Gradio 3 is deprecated. The streaming behavior
    # is now automatically handled by using a generator function (with 'yield').
    submit_btn.click(
        fn=analyze_image,
        inputs=[image_input, prompt_input],
        outputs=output_text
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()