|
|
|
import gradio as gr |
|
from unsloth import FastLanguageModel |
|
import torch |
|
from PIL import Image |
|
from transformers import TextStreamer |
|
import os |
|
|
|
|
|
|
|
BASE_MODEL_NAME = "unsloth/gemma-3n-E4B-it" |
|
|
|
|
|
|
|
PEFT_MODEL_NAME = "lyimo/mosquito-breeding-detection" |
|
|
|
|
|
MAX_SEQ_LENGTH = 2048 |
|
|
|
|
|
print("Loading base model...") |
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=BASE_MODEL_NAME, |
|
max_seq_length=MAX_SEQ_LENGTH, |
|
dtype=None, |
|
load_in_4bit=True, |
|
) |
|
|
|
print("Loading LoRA adapters...") |
|
model = FastLanguageModel.get_peft_model(model, peft_model_name=PEFT_MODEL_NAME) |
|
|
|
print("Setting up chat template...") |
|
from unsloth.chat_templates import get_chat_template |
|
tokenizer = get_chat_template(tokenizer, chat_template="gemma-3") |
|
|
|
print("Model and tokenizer loaded successfully!") |
|
|
|
|
|
def analyze_image(image, prompt): |
|
""" |
|
Analyzes the image using the fine-tuned model. |
|
""" |
|
if image is None: |
|
return "Please upload an image." |
|
|
|
|
|
|
|
|
|
temp_image_path = "temp_uploaded_image.jpg" |
|
try: |
|
image.save(temp_image_path) |
|
|
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image", "image": temp_image_path}, |
|
{"type": "text", "text": prompt} |
|
] |
|
} |
|
] |
|
|
|
|
|
full_prompt = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
|
|
inputs = tokenizer( |
|
full_prompt, |
|
return_tensors="pt", |
|
).to(model.device) |
|
|
|
|
|
|
|
output_text = "" |
|
def text_collector(text): |
|
nonlocal output_text |
|
output_text += text |
|
|
|
|
|
class GradioTextStreamer: |
|
def __init__(self, tokenizer, callback=None): |
|
self.tokenizer = tokenizer |
|
self.callback = callback |
|
self.token_cache = [] |
|
self.print_len = 0 |
|
|
|
def put(self, value): |
|
if self.callback: |
|
|
|
self.token_cache.extend(value.tolist()) |
|
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True) |
|
|
|
self.callback(text[len(output_text):]) |
|
|
|
nonlocal output_text |
|
output_text = text |
|
|
|
def end(self): |
|
if self.callback: |
|
|
|
self.callback("") |
|
self.token_cache = [] |
|
self.print_len = 0 |
|
|
|
streamer = GradioTextStreamer(tokenizer, callback=text_collector) |
|
|
|
|
|
import threading |
|
def generate_text(): |
|
_ = model.generate( |
|
**inputs, |
|
max_new_tokens=1024, |
|
streamer=streamer, |
|
|
|
|
|
|
|
|
|
) |
|
|
|
yield output_text |
|
|
|
|
|
yield output_text |
|
for _ in generate_text(): |
|
yield output_text |
|
|
|
except Exception as e: |
|
error_msg = f"An error occurred during processing: {str(e)}" |
|
print(error_msg) |
|
yield error_msg |
|
finally: |
|
|
|
if os.path.exists(temp_image_path): |
|
os.remove(temp_image_path) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# 🦟 Mosquito Breeding Site Detector") |
|
gr.Markdown("Upload an image and ask the AI to analyze it for potential mosquito breeding sites.") |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
prompt_input = gr.Textbox( |
|
label="Your Question", |
|
value="Can you analyze this image for mosquito breeding sites and recommend what to do?", |
|
lines=2 |
|
) |
|
submit_btn = gr.Button("Analyze") |
|
with gr.Column(): |
|
output_text = gr.Textbox(label="Analysis Result", interactive=False, lines=15) |
|
|
|
|
|
submit_btn.click( |
|
fn=analyze_image, |
|
inputs=[image_input, prompt_input], |
|
outputs=output_text, |
|
streaming=True |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|