assentian1970 commited on
Commit
7e990b0
·
verified ·
1 Parent(s): 722e2d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -224
app.py CHANGED
@@ -1,53 +1,59 @@
1
  #!/usr/bin/env python
2
  # encoding: utf-8
3
-
4
-
5
-
6
  import spaces
7
  import torch
8
- import os
9
- import gc
10
- import tempfile
11
- import numpy as np
12
- import cv2
13
- from datetime import datetime
14
- from PIL import Image
15
- from decord import VideoReader, cpu
16
  from transformers import AutoModel, AutoTokenizer
17
  import gradio as gr
18
- from ultralytics import YOLO
 
 
 
 
 
 
 
 
 
 
 
19
  from modelscope.hub.snapshot_download import snapshot_download
20
 
21
- # Initialize GPU first
22
- @spaces.GPU
23
- def initialize_gpu():
24
- return torch.randn(10).cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- initialize_gpu()
 
27
 
28
- # Configuration
29
- MODEL_NAME = 'mPLUG-Owl3'
30
- YOLO_MODEL = YOLO('best_yolov11.pt')
31
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
  MAX_NUM_FRAMES = 64
33
  IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
34
  VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
35
 
36
- # Download models
37
- model_dir = snapshot_download('iic/mPLUG-Owl3-7B-240728', cache_dir='./')
38
-
39
-
40
-
41
- # Replace the model loading section with:
42
- # Load models with ZeroGPU optimization
43
- model = AutoModel.from_pretrained(
44
- model_dir,
45
- attn_implementation='sdpa',
46
- trust_remote_code=True,
47
- torch_dtype=torch.float16, # Use float16 instead of bfloat16
48
- device_map="auto"
49
- )
50
-
51
  def get_file_extension(filename):
52
  return os.path.splitext(filename)[1].lower()
53
 
@@ -57,219 +63,162 @@ def is_image(filename):
57
  def is_video(filename):
58
  return get_file_extension(filename) in VIDEO_EXTENSIONS
59
 
60
- def process_yolo_results(results):
61
- counts = {
62
- "people": 0,
63
- "machinery": {
64
- "Tower Crane": 0, "Mobile Crane": 0, "Compactor/Roller": 0,
65
- "Bulldozer": 0, "Excavator": 0, "Dump Truck": 0,
66
- "Concrete Mixer": 0, "Loader": 0, "Pump Truck": 0,
67
- "Pile Driver": 0, "Grader": 0, "Other Vehicle": 0
68
- }
69
- }
70
-
71
- for r in results:
72
- for box in r.boxes:
73
- cls_id = int(box.cls[0])
74
- conf = float(box.conf[0])
75
- if conf < 0.5:
76
- continue
77
-
78
- class_name = YOLO_MODEL.names[cls_id].lower()
79
-
80
- if 'worker' in class_name:
81
- counts["people"] += 1
82
- else:
83
- machinery_mapping = {
84
- 'tower_crane': "Tower Crane",
85
- 'mobile_crane': "Mobile Crane",
86
- 'compactor': "Compactor/Roller",
87
- 'roller': "Compactor/Roller",
88
- 'bulldozer': "Bulldozer",
89
- 'excavator': "Excavator",
90
- 'dump_truck': "Dump Truck",
91
- 'concrete_mixer': "Concrete Mixer",
92
- 'loader': "Loader",
93
- 'pump_truck': "Pump Truck",
94
- 'pile_driver': "Pile Driver",
95
- 'grader': "Grader"
96
- }
97
- counts["machinery"][machinery_mapping.get(class_name, "Other Vehicle")] += 1
98
-
99
- return counts
100
 
101
- def detect_objects(media_path):
102
- if is_video(media_path):
103
- cap = cv2.VideoCapture(media_path)
104
- max_counts = {"people": 0, "machinery": {}}
105
- frame_count = 0
106
 
107
- while cap.isOpened():
108
- ret, frame = cap.read()
109
- if not ret:
110
- break
111
-
112
- if frame_count % 30 == 0: # Process every 30th frame
113
- results = YOLO_MODEL(frame)
114
- counts = process_yolo_results(results)
115
-
116
- max_counts["people"] = max(max_counts["people"], counts["people"])
117
- for key, value in counts["machinery"].items():
118
- max_counts["machinery"][key] = max(max_counts["machinery"].get(key, 0), value)
119
-
120
- frame_count += 1
121
 
122
- cap.release()
123
- return max_counts
124
-
125
- else:
126
- img = cv2.imread(media_path)
127
- results = YOLO_MODEL(img)
128
- return process_yolo_results(results)
129
 
130
- def analyze_media(media_path):
131
  try:
132
- if is_image(media_path):
133
- return analyze_image(media_path)
134
- return analyze_video(media_path)
 
 
 
 
 
 
135
  except Exception as e:
136
- print(f"Analysis error: {str(e)}")
137
- return "Analysis unavailable"
138
 
139
- def analyze_image(image_path):
140
  try:
141
- image = Image.open(image_path).convert("RGB")
142
- messages = [{
143
- "role": "user",
144
- "content": "Analyze this construction site image. Describe visible activities, equipment, and safety observations.",
145
- "images": [image]
146
- }]
147
-
148
- inputs = model.build_inputs(
149
- messages=messages,
150
- tokenizer=tokenizer,
151
- max_new_tokens=1000,
152
- padding=True
153
- )
154
- inputs = inputs.to(DEVICE)
155
 
156
- with torch.no_grad():
157
- outputs = model.generate(**inputs)
158
 
159
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  except Exception as e:
161
- print(f"Image analysis error: {str(e)}")
162
- return "Image analysis failed"
163
 
164
- def analyze_video(video_path):
165
  try:
166
- vr = VideoReader(video_path, ctx=cpu(0))
167
- frame_step = max(1, len(vr) // MAX_NUM_FRAMES)
168
- frames = [Image.fromarray(vr[i].asnumpy()) for i in range(0, len(vr), frame_step)]
 
 
 
 
169
 
170
- messages = [{
171
- "role": "user",
172
- "content": "Analyze this construction site video. Describe ongoing activities, equipment usage, and safety observations.",
173
- "videos": frames[:MAX_NUM_FRAMES]
174
- }]
175
 
176
- inputs = model.build_inputs(
177
- messages=messages,
178
- tokenizer=tokenizer,
179
- max_new_tokens=1000,
180
- padding=True
181
  )
182
- inputs = inputs.to(DEVICE)
183
 
184
- with torch.no_grad():
185
- outputs = model.generate(**inputs)
186
 
187
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
188
  except Exception as e:
189
- print(f"Video analysis error: {str(e)}")
190
- return "Video analysis failed"
191
 
192
- def annotate_video(input_path):
193
- cap = cv2.VideoCapture(input_path)
194
- fps = cap.get(cv2.CAP_PROP_FPS)
195
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
196
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
197
-
198
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
199
- output_path = temp_file.name
200
 
201
- writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
202
-
203
- while cap.isOpened():
204
- ret, frame = cap.read()
205
- if not ret:
206
- break
207
-
208
- results = YOLO_MODEL(frame)
209
- annotated_frame = results[0].plot()
210
- writer.write(annotated_frame)
211
 
212
- cap.release()
213
- writer.release()
214
- return output_path
215
-
216
- @spaces.GPU
217
- def process_entry(day, date, media):
218
- try:
219
- if not media:
220
- return [day, date, "No media", "No media", "No media", None]
221
-
222
- with tempfile.NamedTemporaryFile(delete=False) as tmp:
223
- tmp.write(media.read())
224
- tmp_path = tmp.name
225
-
226
- detection = detect_objects(tmp_path)
227
- analysis = analyze_media(tmp_path)
228
- annotated_video = annotate_video(tmp_path) if is_video(tmp_path) else None
229
 
230
- machinery_str = ", ".join(
231
- f"{k}: {v}" for k, v in detection['machinery'].items() if v > 0
232
- ) if isinstance(detection, dict) else "Detection failed"
 
 
 
 
 
 
233
 
234
- return [
235
- day,
236
- date,
237
- str(detection.get('people', 0)),
238
- machinery_str,
239
- analysis,
240
- annotated_video
241
- ]
242
- except Exception as e:
243
- print(f"Processing error: {str(e)}")
244
- return [day, date, "Error", "Error", "Error", None]
245
-
246
- with gr.Blocks(title="Construction Site Diary", css="footer {visibility: hidden}") as demo:
247
- gr.Markdown("# 🏗️ Digital Construction Site Diary")
248
-
249
- with gr.Row():
250
- with gr.Column(scale=1):
251
- day_input = gr.Number(label="Day Number", value=1)
252
- date_input = gr.Textbox(label="Date", value=datetime.now().strftime("%Y-%m-%d"))
253
- media_input = gr.File(label="Upload Site Photo/Video", file_types=["image", "video"])
254
- submit_btn = gr.Button("Analyze Site", variant="primary")
255
-
256
- with gr.Column(scale=2):
257
- day_output = gr.Textbox(label="Day")
258
- date_output = gr.Textbox(label="Date")
259
- people_output = gr.Textbox(label="People Detected")
260
- machinery_output = gr.Textbox(label="Equipment Detected")
261
- analysis_output = gr.Textbox(label="Activity Analysis", lines=4)
262
- video_output = gr.Video(label="Annotated Video Preview")
263
-
264
- submit_btn.click(
265
- fn=process_entry,
266
- inputs=[day_input, date_input, media_input],
267
- outputs=[day_output, date_output, people_output, machinery_output, analysis_output, video_output]
268
- )
269
 
270
  if __name__ == "__main__":
271
  demo.launch(
272
- server_name="0.0.0.0",
273
- server_port=7860,
274
- share=False
 
275
  )
 
1
  #!/usr/bin/env python
2
  # encoding: utf-8
 
 
 
3
  import spaces
4
  import torch
5
+ @spaces.GPU
6
+ def debug():
7
+ torch.randn(10).cuda()
8
+ debug()
9
+ import argparse
 
 
 
10
  from transformers import AutoModel, AutoTokenizer
11
  import gradio as gr
12
+ from PIL import Image
13
+ from decord import VideoReader, cpu
14
+ import io
15
+ import os
16
+ os.system("nvidia-smi")
17
+ import copy
18
+ import requests
19
+ import base64
20
+ import json
21
+ import traceback
22
+ import re
23
+ import modelscope_studio as mgr
24
  from modelscope.hub.snapshot_download import snapshot_download
25
 
26
+ # Configuration
27
+ model_dir = snapshot_download('iic/mPLUG-Owl3-7B-240728', cache_dir='./')
28
+ device_map = "auto"
29
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
+
31
+ # Argparser
32
+ parser = argparse.ArgumentParser(description='demo')
33
+ parser.add_argument('--device', type=str, default='cuda', help='cuda, mps or cpu')
34
+ parser.add_argument("--host", type=str, default="0.0.0.0")
35
+ parser.add_argument("--port", type=int, default=7860)
36
+ args = parser.parse_args()
37
+ device = args.device
38
+
39
+ # Load model and tokenizer
40
+ model_path = './iic/mPLUG-Owl3-7B-240728'
41
+ model = AutoModel.from_pretrained(
42
+ model_path,
43
+ trust_remote_code=True,
44
+ torch_dtype=torch.bfloat16 if 'int4' not in model_path else torch.float32,
45
+ attn_implementation="flash_attention_2" if device == 'cuda' else None
46
+ ).to(device)
47
 
48
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
49
+ model.eval()
50
 
51
+ # Constants
52
+ ERROR_MSG = "Error occurred, please check inputs and try again"
 
 
53
  MAX_NUM_FRAMES = 64
54
  IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
55
  VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def get_file_extension(filename):
58
  return os.path.splitext(filename)[1].lower()
59
 
 
63
  def is_video(filename):
64
  return get_file_extension(filename) in VIDEO_EXTENSIONS
65
 
66
+ def create_multimodal_input(upload_image_disabled=False, upload_video_disabled=False):
67
+ return mgr.MultimodalInput(
68
+ upload_image_button_props={'label': 'Upload Image', 'disabled': upload_image_disabled, 'file_count': 'multiple'},
69
+ upload_video_button_props={'label': 'Upload Video', 'disabled': upload_video_disabled, 'file_count': 'single'},
70
+ submit_button_props={'label': 'Submit'}
71
+ )
72
+
73
+ @spaces.GPU
74
+ def chat(images, messages, params):
75
+ try:
76
+ response = model.chat(
77
+ images=images,
78
+ messages=messages,
79
+ tokenizer=tokenizer,
80
+ **params
81
+ )
82
+ return 0, response, None
83
+ except Exception as e:
84
+ print(f"Error in chat: {str(e)}")
85
+ traceback.print_exc()
86
+ return -1, ERROR_MSG, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ def encode_image(image):
89
+ try:
90
+ if not isinstance(image, Image.Image):
91
+ image = Image.open(image.file.path).convert("RGB")
 
92
 
93
+ max_size = 448 * 16
94
+ if max(image.size) > max_size:
95
+ ratio = max_size / max(image.size)
96
+ new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
97
+ image = image.resize(new_size, Image.BICUBIC)
 
 
 
 
 
 
 
 
 
98
 
99
+ return image
100
+ except Exception as e:
101
+ raise gr.Error(f"Image processing error: {str(e)}")
 
 
 
 
102
 
103
+ def encode_video(video):
104
  try:
105
+ vr = VideoReader(video.file.path, ctx=cpu(0))
106
+ sample_fps = round(vr.get_avg_fps() / 1)
107
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
108
+
109
+ if len(frame_idx) > MAX_NUM_FRAMES:
110
+ frame_idx = frame_idx[:MAX_NUM_FRAMES]
111
+
112
+ frames = vr.get_batch(frame_idx).asnumpy()
113
+ return [Image.fromarray(frame.astype('uint8')) for frame in frames]
114
  except Exception as e:
115
+ raise gr.Error(f"Video processing error: {str(e)}")
 
116
 
117
+ def process_inputs(_question, _app_cfg):
118
  try:
119
+ files = _question.files
120
+ text = _question.text
121
+ pattern = r"\[mm_media\]\d+\[/mm_media\]"
122
+ matches = re.split(pattern, text)
 
 
 
 
 
 
 
 
 
 
123
 
124
+ if len(matches) != len(files) + 1:
125
+ raise gr.Error("Media placeholders don't match uploaded files count")
126
 
127
+ message = []
128
+ media_count = 0
129
+
130
+ for i, match in enumerate(matches):
131
+ if match.strip():
132
+ message.append({"type": "text", "content": match.strip()})
133
+
134
+ if i < len(files):
135
+ file = files[i]
136
+ if is_image(file.file.path):
137
+ message.append({"type": "image", "content": encode_image(file)})
138
+ elif is_video(file.file.path):
139
+ message.append({"type": "video", "content": encode_video(file)})
140
+ media_count += 1
141
+
142
+ return message, media_count
143
  except Exception as e:
144
+ traceback.print_exc()
145
+ raise gr.Error(f"Input processing failed: {str(e)}")
146
 
147
+ def generate_response(_question, _chat_history, _app_cfg, params_form):
148
  try:
149
+ params = {
150
+ 'max_new_tokens': 2048,
151
+ 'temperature': 0.7 if params_form == 'Sampling' else 1.0,
152
+ 'top_p': 0.8 if params_form == 'Sampling' else None,
153
+ 'num_beams': 3 if params_form == 'Beam Search' else 1,
154
+ 'repetition_penalty': 1.1
155
+ }
156
 
157
+ processed_input, media_count = process_inputs(_question, _app_cfg)
158
+ _app_cfg['media_count'] += media_count
 
 
 
159
 
160
+ code, response, _ = chat(
161
+ images=[item['content'] for item in processed_input if item['type'] == 'image'],
162
+ messages=[{"role": "user", "content": processed_input}],
163
+ params=params
 
164
  )
 
165
 
166
+ if code != 0:
167
+ raise gr.Error("Model response generation failed")
168
 
169
+ _chat_history.append((_question, response))
170
+ return _chat_history, _app_cfg
171
+
172
  except Exception as e:
173
+ traceback.print_exc()
174
+ raise gr.Error(f"Generation failed: {str(e)}")
175
 
176
+ def reset_chat():
177
+ return [], {'media_count': 0, 'ctx': []}
178
+
179
+ with gr.Blocks(css="video {height: auto !important;}") as demo:
180
+ with gr.Tab("mPLUG-Owl3"):
181
+ gr.Markdown("## mPLUG-Owl3 Multi-Modal Chat Interface")
 
 
182
 
183
+ # State management
184
+ app_state = gr.State({'media_count': 0, 'ctx': []})
 
 
 
 
 
 
 
 
185
 
186
+ # Chat interface
187
+ chatbot = mgr.Chatbot(height=600)
188
+ input_interface = create_multimodal_input()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ # Controls
191
+ with gr.Row():
192
+ decode_type = gr.Radio(
193
+ choices=['Beam Search', 'Sampling'],
194
+ value='Sampling',
195
+ label="Decoding Strategy"
196
+ )
197
+ clear_btn = gr.Button("Clear History")
198
+ regenerate_btn = gr.Button("Regenerate")
199
 
200
+ # Event handlers
201
+ input_interface.submit(
202
+ generate_response,
203
+ [input_interface, chatbot, app_state, decode_type],
204
+ [chatbot, app_state]
205
+ )
206
+
207
+ clear_btn.click(
208
+ reset_chat,
209
+ outputs=[chatbot, app_state]
210
+ )
211
+
212
+ regenerate_btn.click(
213
+ lambda history: history[:-1] if history else [],
214
+ inputs=[chatbot],
215
+ outputs=[chatbot]
216
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  if __name__ == "__main__":
219
  demo.launch(
220
+ server_name=args.host,
221
+ server_port=args.port,
222
+ share=False,
223
+ debug=True
224
  )