assentian1970 commited on
Commit
83b0d3a
·
verified ·
1 Parent(s): ecd2075

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -197
app.py CHANGED
@@ -1,227 +1,200 @@
1
- #!/usr/bin/env python
2
- # encoding: utf-8
3
  import spaces
4
  import torch
5
- @spaces.GPU
6
- def debug():
7
- torch.randn(10).cuda()
8
- debug()
9
  import argparse
10
- from transformers import AutoModel, AutoTokenizer
 
 
 
 
11
  import gradio as gr
 
12
  from PIL import Image
13
  from decord import VideoReader, cpu
14
- import io
15
- import os
16
- os.system("nvidia-smi")
17
- import copy
18
- import requests
19
- import base64
20
- import json
21
- import traceback
22
- import re
23
- import modelscope_studio as mgr
24
  from modelscope.hub.snapshot_download import snapshot_download
 
25
 
26
- # Configuration
27
- model_dir = snapshot_download('iic/mPLUG-Owl3-7B-240728', cache_dir='./')
28
- device_map = "auto"
29
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
-
31
- # Argparser
32
- parser = argparse.ArgumentParser(description='demo')
33
- parser.add_argument('--device', type=str, default='cuda', help='cuda, mps or cpu')
34
- parser.add_argument("--host", type=str, default="0.0.0.0")
35
- parser.add_argument("--port", type=int, default=7860)
36
- args = parser.parse_args()
37
- device = args.device
38
-
39
- # Before model loading, add:
40
- torch.set_num_threads(4) # Limit CPU threads
41
- torch._C._jit_set_texpr_fuser_enabled(False)
42
 
43
- # Replace the model loading section with:
44
- model = AutoModel.from_pretrained(
45
- model_path,
46
- trust_remote_code=True,
47
- torch_dtype=torch.bfloat16 if 'int4' not in model_path else torch.float32,
48
- attn_implementation="sdpa" # Use scaled dot-product attention instead of flash-attn
49
- ).to(device)
50
 
51
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
52
- model.eval()
 
 
53
 
54
- # Constants
55
- ERROR_MSG = "Error occurred, please check inputs and try again"
56
- MAX_NUM_FRAMES = 64
57
  IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
58
- VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
59
 
60
  def get_file_extension(filename):
61
- return os.path.splitext(filename)[1].lower()
 
 
 
62
 
63
  def is_image(filename):
64
  return get_file_extension(filename) in IMAGE_EXTENSIONS
65
 
66
- def is_video(filename):
67
- return get_file_extension(filename) in VIDEO_EXTENSIONS
 
 
 
 
 
68
 
69
- def create_multimodal_input(upload_image_disabled=False, upload_video_disabled=False):
70
- return mgr.MultimodalInput(
71
- upload_image_button_props={'label': 'Upload Image', 'disabled': upload_image_disabled, 'file_count': 'multiple'},
72
- upload_video_button_props={'label': 'Upload Video', 'disabled': upload_video_disabled, 'file_count': 'single'},
73
- submit_button_props={'label': 'Submit'}
74
- )
75
 
76
- @spaces.GPU
77
- def chat(images, messages, params):
78
- try:
79
- response = model.chat(
80
- images=images,
81
- messages=messages,
82
- tokenizer=tokenizer,
83
- **params
84
- )
85
- return 0, response, None
86
- except Exception as e:
87
- print(f"Error in chat: {str(e)}")
88
- traceback.print_exc()
89
- return -1, ERROR_MSG, None
90
 
91
- def encode_image(image):
92
- try:
93
- if not isinstance(image, Image.Image):
94
- image = Image.open(image.file.path).convert("RGB")
95
-
96
- max_size = 448 * 16
97
- if max(image.size) > max_size:
98
- ratio = max_size / max(image.size)
99
- new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
100
- image = image.resize(new_size, Image.BICUBIC)
101
-
102
- return image
103
- except Exception as e:
104
- raise gr.Error(f"Image processing error: {str(e)}")
105
 
106
- def encode_video(video):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  try:
108
- vr = VideoReader(video.file.path, ctx=cpu(0))
109
- sample_fps = round(vr.get_avg_fps() / 1)
110
- frame_idx = [i for i in range(0, len(vr), sample_fps)]
111
-
112
- if len(frame_idx) > MAX_NUM_FRAMES:
113
- frame_idx = frame_idx[:MAX_NUM_FRAMES]
114
-
115
- frames = vr.get_batch(frame_idx).asnumpy()
116
- return [Image.fromarray(frame.astype('uint8')) for frame in frames]
117
- except Exception as e:
118
- raise gr.Error(f"Video processing error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- def process_inputs(_question, _app_cfg):
121
- try:
122
- files = _question.files
123
- text = _question.text
124
- pattern = r"\[mm_media\]\d+\[/mm_media\]"
125
- matches = re.split(pattern, text)
126
-
127
- if len(matches) != len(files) + 1:
128
- raise gr.Error("Media placeholders don't match uploaded files count")
129
-
130
- message = []
131
- media_count = 0
132
-
133
- for i, match in enumerate(matches):
134
- if match.strip():
135
- message.append({"type": "text", "content": match.strip()})
136
-
137
- if i < len(files):
138
- file = files[i]
139
- if is_image(file.file.path):
140
- message.append({"type": "image", "content": encode_image(file)})
141
- elif is_video(file.file.path):
142
- message.append({"type": "video", "content": encode_video(file)})
143
- media_count += 1
144
-
145
- return message, media_count
146
  except Exception as e:
147
- traceback.print_exc()
148
- raise gr.Error(f"Input processing failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- def generate_response(_question, _chat_history, _app_cfg, params_form):
 
 
 
151
  try:
152
- params = {
153
- 'max_new_tokens': 2048,
154
- 'temperature': 0.7 if params_form == 'Sampling' else 1.0,
155
- 'top_p': 0.8 if params_form == 'Sampling' else None,
156
- 'num_beams': 3 if params_form == 'Beam Search' else 1,
157
- 'repetition_penalty': 1.1
158
- }
159
-
160
- processed_input, media_count = process_inputs(_question, _app_cfg)
161
- _app_cfg['media_count'] += media_count
162
-
163
- code, response, _ = chat(
164
- images=[item['content'] for item in processed_input if item['type'] == 'image'],
165
- messages=[{"role": "user", "content": processed_input}],
166
- params=params
167
- )
168
-
169
- if code != 0:
170
- raise gr.Error("Model response generation failed")
171
-
172
- _chat_history.append((_question, response))
173
- return _chat_history, _app_cfg
174
-
175
  except Exception as e:
176
- traceback.print_exc()
177
- raise gr.Error(f"Generation failed: {str(e)}")
178
-
179
- def reset_chat():
180
- return [], {'media_count': 0, 'ctx': []}
181
-
182
- with gr.Blocks(css="video {height: auto !important;}") as demo:
183
- with gr.Tab("mPLUG-Owl3"):
184
- gr.Markdown("## mPLUG-Owl3 Multi-Modal Chat Interface")
185
-
186
- # State management
187
- app_state = gr.State({'media_count': 0, 'ctx': []})
188
-
189
- # Chat interface
190
- chatbot = mgr.Chatbot(height=600)
191
- input_interface = create_multimodal_input()
192
-
193
- # Controls
194
- with gr.Row():
195
- decode_type = gr.Radio(
196
- choices=['Beam Search', 'Sampling'],
197
- value='Sampling',
198
- label="Decoding Strategy"
199
- )
200
- clear_btn = gr.Button("Clear History")
201
- regenerate_btn = gr.Button("Regenerate")
202
-
203
- # Event handlers
204
- input_interface.submit(
205
- generate_response,
206
- [input_interface, chatbot, app_state, decode_type],
207
- [chatbot, app_state]
208
- )
209
-
210
- clear_btn.click(
211
- reset_chat,
212
- outputs=[chatbot, app_state]
213
- )
214
-
215
- regenerate_btn.click(
216
- lambda history: history[:-1] if history else [],
217
- inputs=[chatbot],
218
- outputs=[chatbot]
219
- )
220
 
221
  if __name__ == "__main__":
222
- demo.launch(
223
- server_name=args.host,
224
- server_port=args.port,
225
- share=False,
226
- debug=True
227
- )
 
 
 
1
  import spaces
2
  import torch
 
 
 
 
3
  import argparse
4
+ import os
5
+ import gc
6
+ import tempfile
7
+ import cv2
8
+ import numpy as np
9
  import gradio as gr
10
+ from datetime import datetime
11
  from PIL import Image
12
  from decord import VideoReader, cpu
13
+ from transformers import AutoModel, AutoTokenizer
 
 
 
 
 
 
 
 
 
14
  from modelscope.hub.snapshot_download import snapshot_download
15
+ from ultralytics import YOLO
16
 
17
+ os.system("nvidia-smi")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
20
 
21
+ if DEVICE == "cuda":
22
+ def debug():
23
+ torch.randn(10).cuda()
24
+ debug()
25
 
 
 
 
26
  IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
27
+ VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv'} # Example, define properly
28
 
29
  def get_file_extension(filename):
30
+ return os.path.splitext(filename)[-1].lower()
31
+
32
+ def is_video(filename):
33
+ return get_file_extension(filename) in VIDEO_EXTENSIONS
34
 
35
  def is_image(filename):
36
  return get_file_extension(filename) in IMAGE_EXTENSIONS
37
 
38
+ parser = argparse.ArgumentParser(description='demo')
39
+ parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
40
+ parser.add_argument("--host", type=str, default="0.0.0.0")
41
+ parser.add_argument("--port", type=int)
42
+ args = parser.parse_args()
43
+ device = args.device
44
+ assert device in ['cuda', 'mps']
45
 
46
+ MODEL_NAME = 'iic/mPLUG-Owl3-7B-240728'
47
+ MODEL_CACHE_DIR = os.getenv('TRANSFORMERS_CACHE', './models')
48
+ os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
 
 
 
49
 
50
+ try:
51
+ model_path = snapshot_download(MODEL_NAME, cache_dir=MODEL_CACHE_DIR)
52
+ except Exception as e:
53
+ print(f"Error downloading model: {str(e)}")
54
+ model_path = os.path.join(MODEL_CACHE_DIR, MODEL_NAME)
 
 
 
 
 
 
 
 
 
55
 
56
+ YOLO_MODEL = YOLO('./best_yolov11.pt')
57
+ MAX_NUM_FRAMES = 64
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ def load_model_and_tokenizer():
60
+ if DEVICE == "cuda":
61
+ torch.cuda.empty_cache()
62
+ gc.collect()
63
+ model = AutoModel.from_pretrained(
64
+ model_path,
65
+ attn_implementation='flash_attention_2',
66
+ trust_remote_code=True,
67
+ torch_dtype=torch.half,
68
+ device_map='auto'
69
+ )
70
+ tokenizer = AutoTokenizer.from_pretrained(
71
+ model_path,
72
+ trust_remote_code=True
73
+ )
74
+ return model, tokenizer, None # Assuming processor is missing
75
+
76
+ def encode_video_in_chunks(video_path):
77
+ vr = VideoReader(video_path, ctx=cpu(0))
78
+ sample_fps = round(vr.get_avg_fps() / 1)
79
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
80
+ chunks = [frame_idx[i:i + MAX_NUM_FRAMES] for i in range(0, len(frame_idx), MAX_NUM_FRAMES)]
81
+ for chunk_idx, chunk in enumerate(chunks):
82
+ frames = vr.get_batch(chunk).asnumpy()
83
+ frames = [Image.fromarray(v.astype('uint8')) for v in frames]
84
+ yield chunk_idx, frames
85
+
86
+ def detect_people_and_machinery(media_path):
87
  try:
88
+ max_people_count = 0
89
+ max_machine_types = {key: 0 for key in [
90
+ "Tower Crane", "Mobile Crane", "Compactor/Roller", "Bulldozer",
91
+ "Excavator", "Dump Truck", "Concrete Mixer", "Loader",
92
+ "Pump Truck", "Pile Driver", "Grader", "Other Vehicle"
93
+ ]}
94
+
95
+ if is_video(media_path):
96
+ cap = cv2.VideoCapture(media_path)
97
+ fps = cap.get(cv2.CAP_PROP_FPS)
98
+ sample_rate = max(1, int(fps))
99
+ frame_count = 0
100
+ while cap.isOpened():
101
+ ret, frame = cap.read()
102
+ if not ret:
103
+ break
104
+ if frame_count % sample_rate == 0:
105
+ results = YOLO_MODEL(frame)
106
+ people, _, machine_types = process_yolo_results(results)
107
+ max_people_count = max(max_people_count, people)
108
+ for k, v in machine_types.items():
109
+ max_machine_types[k] = max(max_machine_types[k], v)
110
+ frame_count += 1
111
+ cap.release()
112
+ else:
113
+ img = cv2.imread(media_path)
114
+ results = YOLO_MODEL(img)
115
+ max_people_count, _, max_machine_types = process_yolo_results(results)
116
+
117
+ max_machine_types = {k: v for k, v in max_machine_types.items() if v > 0}
118
+ total_machinery_count = sum(max_machine_types.values())
119
+ return max_people_count, total_machinery_count, max_machine_types
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  except Exception as e:
122
+ print(f"Error in YOLO detection: {str(e)}")
123
+ return 0, 0, {}
124
+
125
+ def process_yolo_results(results):
126
+ people_count = 0
127
+ machine_types = {key: 0 for key in [
128
+ "Tower Crane", "Mobile Crane", "Compactor/Roller", "Bulldozer",
129
+ "Excavator", "Dump Truck", "Concrete Mixer", "Loader",
130
+ "Pump Truck", "Pile Driver", "Grader", "Other Vehicle"
131
+ ]}
132
+ for r in results:
133
+ for box in r.boxes:
134
+ cls = int(box.cls[0])
135
+ conf = float(box.conf[0])
136
+ class_name = YOLO_MODEL.names[cls]
137
+ if class_name.lower() == 'worker' and conf > 0.5:
138
+ people_count += 1
139
+ machinery_mapping = {
140
+ 'tower_crane': "Tower Crane",
141
+ 'mobile_crane': "Mobile Crane",
142
+ 'grader': "Grader",
143
+ 'other_vehicle': "Other Vehicle"
144
+ }
145
+ if conf > 0.5:
146
+ for key, value in machinery_mapping.items():
147
+ if key in class_name.lower():
148
+ machine_types[value] += 1
149
+ break
150
+ return people_count, sum(machine_types.values()), machine_types
151
 
152
+ @spaces.GPU
153
+ def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media):
154
+ if media is None:
155
+ return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", "No media uploaded", None]
156
  try:
157
+ file_ext = get_file_extension(media.name)
158
+ if not (is_image(media.name) or is_video(media.name)):
159
+ raise ValueError(f"Unsupported file type: {file_ext}")
160
+
161
+ with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file:
162
+ temp_path = temp_file.name
163
+ temp_file.write(media.read())
164
+
165
+ detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(temp_path)
166
+ detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()])
167
+ detected_activities = "Sample activity analysis." # Placeholder
168
+
169
+ os.remove(temp_path)
170
+ return [day, date, str(detected_people), str(detected_machinery), detected_types_str, detected_activities, None]
171
+
 
 
 
 
 
 
 
 
172
  except Exception as e:
173
+ print(f"Error processing media: {str(e)}")
174
+ return [day, date, "Error", "Error", "Error", "Error", None]
175
+
176
+ with gr.Blocks(title="Digital Site Diary") as demo:
177
+ gr.Markdown("# 📝 Digital Site Diary")
178
+ with gr.Row():
179
+ with gr.Column():
180
+ day = gr.Textbox(label="Day", value='9')
181
+ date = gr.Textbox(label="Date", value=datetime.now().strftime("%Y-%m-%d"))
182
+ total_people = gr.Number(label="Total Number of People", value=10)
183
+ total_machinery = gr.Number(label="Total Number of Machinery", value=3)
184
+ media = gr.File(label="Upload Image/Video", file_types=["image", "video"])
185
+ submit_btn = gr.Button("Submit")
186
+ with gr.Column():
187
+ model_day = gr.Textbox(label="Day")
188
+ model_date = gr.Textbox(label="Date")
189
+ model_people = gr.Textbox(label="Total Number of People")
190
+ model_machinery = gr.Textbox(label="Total Machinery")
191
+ model_machinery_types = gr.Textbox(label="Machinery Types")
192
+ model_activities = gr.Textbox(label="Activities")
193
+ submit_btn.click(
194
+ fn=process_diary,
195
+ inputs=[day, date, total_people, total_machinery, None, None, media],
196
+ outputs=[model_day, model_date, model_people, model_machinery, model_machinery_types, model_activities, None]
197
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  if __name__ == "__main__":
200
+ demo.launch(share=False, debug=True, show_api=False, server_port=args.port, server_name=args.host)