lyimo commited on
Commit
fa3f706
·
verified ·
1 Parent(s): 51372bb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ from unsloth import FastLanguageModel
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import TextStreamer
7
+ import os
8
+
9
+ # --- Configuration ---
10
+ # 1. Base Model Name (must match the one used for training)
11
+ BASE_MODEL_NAME = "unsloth/gemma-3n-E4B-it"
12
+
13
+ # 2. Your PEFT (LoRA) Model Name on Hugging Face Hub
14
+ # Replace 'your-username' and 'your-model-repo-name' with your actual details
15
+ PEFT_MODEL_NAME = "lyimo/mosquito-breeding-detection" # Or your Hugging Face repo path
16
+
17
+ # 3. Max sequence length (should match or exceed training setting)
18
+ MAX_SEQ_LENGTH = 2048
19
+
20
+ # --- Load Model and Tokenizer ---
21
+ print("Loading base model...")
22
+ model, tokenizer = FastLanguageModel.from_pretrained(
23
+ model_name=BASE_MODEL_NAME,
24
+ max_seq_length=MAX_SEQ_LENGTH,
25
+ dtype=None, # Auto-detect
26
+ load_in_4bit=True, # Match training setting
27
+ )
28
+
29
+ print("Loading LoRA adapters...")
30
+ model = FastLanguageModel.get_peft_model(model, peft_model_name=PEFT_MODEL_NAME)
31
+
32
+ print("Setting up chat template...")
33
+ from unsloth.chat_templates import get_chat_template
34
+ tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")
35
+
36
+ print("Model and tokenizer loaded successfully!")
37
+
38
+ # --- Inference Function ---
39
+ def analyze_image(image, prompt):
40
+ """
41
+ Analyzes the image using the fine-tuned model.
42
+ """
43
+ if image is None:
44
+ return "Please upload an image."
45
+
46
+ # Save the uploaded image temporarily (or pass the PIL object, see notes)
47
+ # Unsloth's tokenizer often expects the image path during apply_chat_template
48
+ # for multimodal inputs.
49
+ temp_image_path = "temp_uploaded_image.jpg"
50
+ try:
51
+ image.save(temp_image_path) # Save PIL image from Gradio
52
+
53
+ # Construct messages
54
+ messages = [
55
+ {
56
+ "role": "user",
57
+ "content": [
58
+ {"type": "image", "image": temp_image_path}, # Pass the temporary path
59
+ {"type": "text", "text": prompt}
60
+ ]
61
+ }
62
+ ]
63
+
64
+ # Apply chat template
65
+ full_prompt = tokenizer.apply_chat_template(
66
+ messages,
67
+ tokenize=False,
68
+ add_generation_prompt=True
69
+ )
70
+
71
+ # Tokenize inputs
72
+ inputs = tokenizer(
73
+ full_prompt,
74
+ return_tensors="pt",
75
+ ).to(model.device)
76
+
77
+ # --- Generation ---
78
+ # Collect the output text
79
+ output_text = ""
80
+ def text_collector(text):
81
+ nonlocal output_text
82
+ output_text += text
83
+
84
+ # Create a custom streamer to capture text
85
+ class GradioTextStreamer:
86
+ def __init__(self, tokenizer, callback=None):
87
+ self.tokenizer = tokenizer
88
+ self.callback = callback
89
+ self.token_cache = []
90
+ self.print_len = 0
91
+
92
+ def put(self, value):
93
+ if self.callback:
94
+ # Decode the current token(s)
95
+ self.token_cache.extend(value.tolist())
96
+ text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True)
97
+ # Call the callback with the new text
98
+ self.callback(text[len(output_text):]) # Send only the new part
99
+ # Update output_text locally to track progress
100
+ nonlocal output_text
101
+ output_text = text
102
+
103
+ def end(self):
104
+ if self.callback:
105
+ # Ensure any remaining text is sent
106
+ self.callback("") # Signal end, or send final text if needed differently
107
+ self.token_cache = []
108
+ self.print_len = 0
109
+
110
+ streamer = GradioTextStreamer(tokenizer, callback=text_collector)
111
+
112
+ # Start generation in a separate thread to allow streaming
113
+ import threading
114
+ def generate_text():
115
+ _ = model.generate(
116
+ **inputs,
117
+ max_new_tokens=1024,
118
+ streamer=streamer,
119
+ # You can add other generation parameters here
120
+ # temperature=0.7,
121
+ # top_p=0.95,
122
+ # do_sample=True
123
+ )
124
+ # Signal completion after generation finishes
125
+ yield output_text # Final yield to ensure completeness
126
+
127
+ # Yield initial output and then stream updates
128
+ yield output_text # Initial empty or partial output
129
+ for _ in generate_text(): # This loop runs the generation
130
+ yield output_text # Yield updated text as it's generated
131
+
132
+ except Exception as e:
133
+ error_msg = f"An error occurred during processing: {str(e)}"
134
+ print(error_msg)
135
+ yield error_msg
136
+ finally:
137
+ # Clean up the temporary image file
138
+ if os.path.exists(temp_image_path):
139
+ os.remove(temp_image_path)
140
+
141
+ # --- Gradio Interface ---
142
+ with gr.Blocks() as demo:
143
+ gr.Markdown("# 🦟 Mosquito Breeding Site Detector")
144
+ gr.Markdown("Upload an image and ask the AI to analyze it for potential mosquito breeding sites.")
145
+ with gr.Row():
146
+ with gr.Column():
147
+ image_input = gr.Image(type="pil", label="Upload Image")
148
+ prompt_input = gr.Textbox(
149
+ label="Your Question",
150
+ value="Can you analyze this image for mosquito breeding sites and recommend what to do?",
151
+ lines=2
152
+ )
153
+ submit_btn = gr.Button("Analyze")
154
+ with gr.Column():
155
+ output_text = gr.Textbox(label="Analysis Result", interactive=False, lines=15)
156
+
157
+ # Connect the button to the function
158
+ submit_btn.click(
159
+ fn=analyze_image,
160
+ inputs=[image_input, prompt_input],
161
+ outputs=output_text, # Stream to the textbox
162
+ streaming=True # Enable streaming output
163
+ )
164
+
165
+ # Launch the app
166
+ if __name__ == "__main__":
167
+ demo.launch()