lyimo commited on
Commit
35283c1
·
verified ·
1 Parent(s): fa3f706

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -68
app.py CHANGED
@@ -3,7 +3,8 @@ import gradio as gr
3
  from unsloth import FastLanguageModel
4
  import torch
5
  from PIL import Image
6
- from transformers import TextStreamer
 
7
  import os
8
 
9
  # --- Configuration ---
@@ -11,8 +12,7 @@ import os
11
  BASE_MODEL_NAME = "unsloth/gemma-3n-E4B-it"
12
 
13
  # 2. Your PEFT (LoRA) Model Name on Hugging Face Hub
14
- # Replace 'your-username' and 'your-model-repo-name' with your actual details
15
- PEFT_MODEL_NAME = "lyimo/mosquito-breeding-detection" # Or your Hugging Face repo path
16
 
17
  # 3. Max sequence length (should match or exceed training setting)
18
  MAX_SEQ_LENGTH = 2048
@@ -35,99 +35,63 @@ tokenizer = get_chat_template(tokenizer, chat_template="gemma-3")
35
 
36
  print("Model and tokenizer loaded successfully!")
37
 
 
38
  # --- Inference Function ---
39
  def analyze_image(image, prompt):
40
  """
41
- Analyzes the image using the fine-tuned model.
42
  """
43
  if image is None:
44
  return "Please upload an image."
45
 
46
- # Save the uploaded image temporarily (or pass the PIL object, see notes)
47
- # Unsloth's tokenizer often expects the image path during apply_chat_template
48
- # for multimodal inputs.
49
  temp_image_path = "temp_uploaded_image.jpg"
50
  try:
51
- image.save(temp_image_path) # Save PIL image from Gradio
52
 
53
- # Construct messages
54
  messages = [
55
  {
56
  "role": "user",
57
  "content": [
58
- {"type": "image", "image": temp_image_path}, # Pass the temporary path
59
  {"type": "text", "text": prompt}
60
  ]
61
  }
62
  ]
63
 
64
- # Apply chat template
65
  full_prompt = tokenizer.apply_chat_template(
66
  messages,
67
  tokenize=False,
68
  add_generation_prompt=True
69
  )
70
 
71
- # Tokenize inputs
72
  inputs = tokenizer(
73
  full_prompt,
74
  return_tensors="pt",
75
  ).to(model.device)
76
 
77
- # --- Generation ---
78
- # Collect the output text
79
- output_text = ""
80
- def text_collector(text):
81
- nonlocal output_text
82
- output_text += text
83
-
84
- # Create a custom streamer to capture text
85
- class GradioTextStreamer:
86
- def __init__(self, tokenizer, callback=None):
87
- self.tokenizer = tokenizer
88
- self.callback = callback
89
- self.token_cache = []
90
- self.print_len = 0
91
-
92
- def put(self, value):
93
- if self.callback:
94
- # Decode the current token(s)
95
- self.token_cache.extend(value.tolist())
96
- text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True)
97
- # Call the callback with the new text
98
- self.callback(text[len(output_text):]) # Send only the new part
99
- # Update output_text locally to track progress
100
- nonlocal output_text
101
- output_text = text
102
-
103
- def end(self):
104
- if self.callback:
105
- # Ensure any remaining text is sent
106
- self.callback("") # Signal end, or send final text if needed differently
107
- self.token_cache = []
108
- self.print_len = 0
109
-
110
- streamer = GradioTextStreamer(tokenizer, callback=text_collector)
111
-
112
- # Start generation in a separate thread to allow streaming
113
- import threading
114
- def generate_text():
115
- _ = model.generate(
116
- **inputs,
117
- max_new_tokens=1024,
118
- streamer=streamer,
119
- # You can add other generation parameters here
120
- # temperature=0.7,
121
- # top_p=0.95,
122
- # do_sample=True
123
- )
124
- # Signal completion after generation finishes
125
- yield output_text # Final yield to ensure completeness
126
 
127
- # Yield initial output and then stream updates
128
- yield output_text # Initial empty or partial output
129
- for _ in generate_text(): # This loop runs the generation
130
- yield output_text # Yield updated text as it's generated
 
 
 
 
 
131
 
132
  except Exception as e:
133
  error_msg = f"An error occurred during processing: {str(e)}"
@@ -138,6 +102,7 @@ def analyze_image(image, prompt):
138
  if os.path.exists(temp_image_path):
139
  os.remove(temp_image_path)
140
 
 
141
  # --- Gradio Interface ---
142
  with gr.Blocks() as demo:
143
  gr.Markdown("# 🦟 Mosquito Breeding Site Detector")
@@ -155,13 +120,14 @@ with gr.Blocks() as demo:
155
  output_text = gr.Textbox(label="Analysis Result", interactive=False, lines=15)
156
 
157
  # Connect the button to the function
 
 
158
  submit_btn.click(
159
  fn=analyze_image,
160
  inputs=[image_input, prompt_input],
161
- outputs=output_text, # Stream to the textbox
162
- streaming=True # Enable streaming output
163
  )
164
 
165
  # Launch the app
166
  if __name__ == "__main__":
167
- demo.launch()
 
3
  from unsloth import FastLanguageModel
4
  import torch
5
  from PIL import Image
6
+ from transformers import TextIteratorStreamer
7
+ from threading import Thread
8
  import os
9
 
10
  # --- Configuration ---
 
12
  BASE_MODEL_NAME = "unsloth/gemma-3n-E4B-it"
13
 
14
  # 2. Your PEFT (LoRA) Model Name on Hugging Face Hub
15
+ PEFT_MODEL_NAME = "lyimo/mosquito-breeding-detection"
 
16
 
17
  # 3. Max sequence length (should match or exceed training setting)
18
  MAX_SEQ_LENGTH = 2048
 
35
 
36
  print("Model and tokenizer loaded successfully!")
37
 
38
+
39
  # --- Inference Function ---
40
  def analyze_image(image, prompt):
41
  """
42
+ Analyzes the image using the fine-tuned model and streams the output.
43
  """
44
  if image is None:
45
  return "Please upload an image."
46
 
 
 
 
47
  temp_image_path = "temp_uploaded_image.jpg"
48
  try:
49
+ image.save(temp_image_path)
50
 
 
51
  messages = [
52
  {
53
  "role": "user",
54
  "content": [
55
+ {"type": "image", "image": temp_image_path},
56
  {"type": "text", "text": prompt}
57
  ]
58
  }
59
  ]
60
 
 
61
  full_prompt = tokenizer.apply_chat_template(
62
  messages,
63
  tokenize=False,
64
  add_generation_prompt=True
65
  )
66
 
 
67
  inputs = tokenizer(
68
  full_prompt,
69
  return_tensors="pt",
70
  ).to(model.device)
71
 
72
+ # Use TextIteratorStreamer for simpler, more robust streaming
73
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
74
+
75
+ # Define generation arguments
76
+ generation_kwargs = dict(
77
+ **inputs,
78
+ streamer=streamer,
79
+ max_new_tokens=1024,
80
+ # You can add other generation parameters here
81
+ # temperature=0.7,
82
+ # top_p=0.95,
83
+ # do_sample=True
84
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ # Run generation in a separate thread to avoid blocking the UI
87
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
88
+ thread.start()
89
+
90
+ # Yield the generated text as it becomes available
91
+ generated_text = ""
92
+ for new_text in streamer:
93
+ generated_text += new_text
94
+ yield generated_text
95
 
96
  except Exception as e:
97
  error_msg = f"An error occurred during processing: {str(e)}"
 
102
  if os.path.exists(temp_image_path):
103
  os.remove(temp_image_path)
104
 
105
+
106
  # --- Gradio Interface ---
107
  with gr.Blocks() as demo:
108
  gr.Markdown("# 🦟 Mosquito Breeding Site Detector")
 
120
  output_text = gr.Textbox(label="Analysis Result", interactive=False, lines=15)
121
 
122
  # Connect the button to the function
123
+ # The 'streaming=True' flag in Gradio 3 is deprecated. The streaming behavior
124
+ # is now automatically handled by using a generator function (with 'yield').
125
  submit_btn.click(
126
  fn=analyze_image,
127
  inputs=[image_input, prompt_input],
128
+ outputs=output_text
 
129
  )
130
 
131
  # Launch the app
132
  if __name__ == "__main__":
133
+ demo.launch()