Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,227 +1,200 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
# encoding: utf-8
|
3 |
import spaces
|
4 |
import torch
|
5 |
-
@spaces.GPU
|
6 |
-
def debug():
|
7 |
-
torch.randn(10).cuda()
|
8 |
-
debug()
|
9 |
import argparse
|
10 |
-
|
|
|
|
|
|
|
|
|
11 |
import gradio as gr
|
|
|
12 |
from PIL import Image
|
13 |
from decord import VideoReader, cpu
|
14 |
-
import
|
15 |
-
import os
|
16 |
-
os.system("nvidia-smi")
|
17 |
-
import copy
|
18 |
-
import requests
|
19 |
-
import base64
|
20 |
-
import json
|
21 |
-
import traceback
|
22 |
-
import re
|
23 |
-
import modelscope_studio as mgr
|
24 |
from modelscope.hub.snapshot_download import snapshot_download
|
|
|
25 |
|
26 |
-
|
27 |
-
model_dir = snapshot_download('iic/mPLUG-Owl3-7B-240728', cache_dir='./')
|
28 |
-
device_map = "auto"
|
29 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
30 |
-
|
31 |
-
# Argparser
|
32 |
-
parser = argparse.ArgumentParser(description='demo')
|
33 |
-
parser.add_argument('--device', type=str, default='cuda', help='cuda, mps or cpu')
|
34 |
-
parser.add_argument("--host", type=str, default="0.0.0.0")
|
35 |
-
parser.add_argument("--port", type=int, default=7860)
|
36 |
-
args = parser.parse_args()
|
37 |
-
device = args.device
|
38 |
-
|
39 |
-
# Before model loading, add:
|
40 |
-
torch.set_num_threads(4) # Limit CPU threads
|
41 |
-
torch._C._jit_set_texpr_fuser_enabled(False)
|
42 |
|
43 |
-
|
44 |
-
model = AutoModel.from_pretrained(
|
45 |
-
model_path,
|
46 |
-
trust_remote_code=True,
|
47 |
-
torch_dtype=torch.bfloat16 if 'int4' not in model_path else torch.float32,
|
48 |
-
attn_implementation="sdpa" # Use scaled dot-product attention instead of flash-attn
|
49 |
-
).to(device)
|
50 |
|
51 |
-
|
52 |
-
|
|
|
|
|
53 |
|
54 |
-
# Constants
|
55 |
-
ERROR_MSG = "Error occurred, please check inputs and try again"
|
56 |
-
MAX_NUM_FRAMES = 64
|
57 |
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
58 |
-
VIDEO_EXTENSIONS = {'.mp4', '.
|
59 |
|
60 |
def get_file_extension(filename):
|
61 |
-
return os.path.splitext(filename)[1].lower()
|
|
|
|
|
|
|
62 |
|
63 |
def is_image(filename):
|
64 |
return get_file_extension(filename) in IMAGE_EXTENSIONS
|
65 |
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
upload_video_button_props={'label': 'Upload Video', 'disabled': upload_video_disabled, 'file_count': 'single'},
|
73 |
-
submit_button_props={'label': 'Submit'}
|
74 |
-
)
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
messages=messages,
|
82 |
-
tokenizer=tokenizer,
|
83 |
-
**params
|
84 |
-
)
|
85 |
-
return 0, response, None
|
86 |
-
except Exception as e:
|
87 |
-
print(f"Error in chat: {str(e)}")
|
88 |
-
traceback.print_exc()
|
89 |
-
return -1, ERROR_MSG, None
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
if not isinstance(image, Image.Image):
|
94 |
-
image = Image.open(image.file.path).convert("RGB")
|
95 |
-
|
96 |
-
max_size = 448 * 16
|
97 |
-
if max(image.size) > max_size:
|
98 |
-
ratio = max_size / max(image.size)
|
99 |
-
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
|
100 |
-
image = image.resize(new_size, Image.BICUBIC)
|
101 |
-
|
102 |
-
return image
|
103 |
-
except Exception as e:
|
104 |
-
raise gr.Error(f"Image processing error: {str(e)}")
|
105 |
|
106 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
try:
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
-
def process_inputs(_question, _app_cfg):
|
121 |
-
try:
|
122 |
-
files = _question.files
|
123 |
-
text = _question.text
|
124 |
-
pattern = r"\[mm_media\]\d+\[/mm_media\]"
|
125 |
-
matches = re.split(pattern, text)
|
126 |
-
|
127 |
-
if len(matches) != len(files) + 1:
|
128 |
-
raise gr.Error("Media placeholders don't match uploaded files count")
|
129 |
-
|
130 |
-
message = []
|
131 |
-
media_count = 0
|
132 |
-
|
133 |
-
for i, match in enumerate(matches):
|
134 |
-
if match.strip():
|
135 |
-
message.append({"type": "text", "content": match.strip()})
|
136 |
-
|
137 |
-
if i < len(files):
|
138 |
-
file = files[i]
|
139 |
-
if is_image(file.file.path):
|
140 |
-
message.append({"type": "image", "content": encode_image(file)})
|
141 |
-
elif is_video(file.file.path):
|
142 |
-
message.append({"type": "video", "content": encode_video(file)})
|
143 |
-
media_count += 1
|
144 |
-
|
145 |
-
return message, media_count
|
146 |
except Exception as e:
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
-
|
|
|
|
|
|
|
151 |
try:
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
)
|
168 |
-
|
169 |
-
if code != 0:
|
170 |
-
raise gr.Error("Model response generation failed")
|
171 |
-
|
172 |
-
_chat_history.append((_question, response))
|
173 |
-
return _chat_history, _app_cfg
|
174 |
-
|
175 |
except Exception as e:
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
with gr.
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
regenerate_btn = gr.Button("Regenerate")
|
202 |
-
|
203 |
-
# Event handlers
|
204 |
-
input_interface.submit(
|
205 |
-
generate_response,
|
206 |
-
[input_interface, chatbot, app_state, decode_type],
|
207 |
-
[chatbot, app_state]
|
208 |
-
)
|
209 |
-
|
210 |
-
clear_btn.click(
|
211 |
-
reset_chat,
|
212 |
-
outputs=[chatbot, app_state]
|
213 |
-
)
|
214 |
-
|
215 |
-
regenerate_btn.click(
|
216 |
-
lambda history: history[:-1] if history else [],
|
217 |
-
inputs=[chatbot],
|
218 |
-
outputs=[chatbot]
|
219 |
-
)
|
220 |
|
221 |
if __name__ == "__main__":
|
222 |
-
demo.launch(
|
223 |
-
server_name=args.host,
|
224 |
-
server_port=args.port,
|
225 |
-
share=False,
|
226 |
-
debug=True
|
227 |
-
)
|
|
|
|
|
|
|
1 |
import spaces
|
2 |
import torch
|
|
|
|
|
|
|
|
|
3 |
import argparse
|
4 |
+
import os
|
5 |
+
import gc
|
6 |
+
import tempfile
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
import gradio as gr
|
10 |
+
from datetime import datetime
|
11 |
from PIL import Image
|
12 |
from decord import VideoReader, cpu
|
13 |
+
from transformers import AutoModel, AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
from modelscope.hub.snapshot_download import snapshot_download
|
15 |
+
from ultralytics import YOLO
|
16 |
|
17 |
+
os.system("nvidia-smi")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
if DEVICE == "cuda":
|
22 |
+
def debug():
|
23 |
+
torch.randn(10).cuda()
|
24 |
+
debug()
|
25 |
|
|
|
|
|
|
|
26 |
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
27 |
+
VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv'} # Example, define properly
|
28 |
|
29 |
def get_file_extension(filename):
|
30 |
+
return os.path.splitext(filename)[-1].lower()
|
31 |
+
|
32 |
+
def is_video(filename):
|
33 |
+
return get_file_extension(filename) in VIDEO_EXTENSIONS
|
34 |
|
35 |
def is_image(filename):
|
36 |
return get_file_extension(filename) in IMAGE_EXTENSIONS
|
37 |
|
38 |
+
parser = argparse.ArgumentParser(description='demo')
|
39 |
+
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
|
40 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
41 |
+
parser.add_argument("--port", type=int)
|
42 |
+
args = parser.parse_args()
|
43 |
+
device = args.device
|
44 |
+
assert device in ['cuda', 'mps']
|
45 |
|
46 |
+
MODEL_NAME = 'iic/mPLUG-Owl3-7B-240728'
|
47 |
+
MODEL_CACHE_DIR = os.getenv('TRANSFORMERS_CACHE', './models')
|
48 |
+
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
|
|
|
|
|
|
|
49 |
|
50 |
+
try:
|
51 |
+
model_path = snapshot_download(MODEL_NAME, cache_dir=MODEL_CACHE_DIR)
|
52 |
+
except Exception as e:
|
53 |
+
print(f"Error downloading model: {str(e)}")
|
54 |
+
model_path = os.path.join(MODEL_CACHE_DIR, MODEL_NAME)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
YOLO_MODEL = YOLO('./best_yolov11.pt')
|
57 |
+
MAX_NUM_FRAMES = 64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
def load_model_and_tokenizer():
|
60 |
+
if DEVICE == "cuda":
|
61 |
+
torch.cuda.empty_cache()
|
62 |
+
gc.collect()
|
63 |
+
model = AutoModel.from_pretrained(
|
64 |
+
model_path,
|
65 |
+
attn_implementation='flash_attention_2',
|
66 |
+
trust_remote_code=True,
|
67 |
+
torch_dtype=torch.half,
|
68 |
+
device_map='auto'
|
69 |
+
)
|
70 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
71 |
+
model_path,
|
72 |
+
trust_remote_code=True
|
73 |
+
)
|
74 |
+
return model, tokenizer, None # Assuming processor is missing
|
75 |
+
|
76 |
+
def encode_video_in_chunks(video_path):
|
77 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
78 |
+
sample_fps = round(vr.get_avg_fps() / 1)
|
79 |
+
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
80 |
+
chunks = [frame_idx[i:i + MAX_NUM_FRAMES] for i in range(0, len(frame_idx), MAX_NUM_FRAMES)]
|
81 |
+
for chunk_idx, chunk in enumerate(chunks):
|
82 |
+
frames = vr.get_batch(chunk).asnumpy()
|
83 |
+
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
84 |
+
yield chunk_idx, frames
|
85 |
+
|
86 |
+
def detect_people_and_machinery(media_path):
|
87 |
try:
|
88 |
+
max_people_count = 0
|
89 |
+
max_machine_types = {key: 0 for key in [
|
90 |
+
"Tower Crane", "Mobile Crane", "Compactor/Roller", "Bulldozer",
|
91 |
+
"Excavator", "Dump Truck", "Concrete Mixer", "Loader",
|
92 |
+
"Pump Truck", "Pile Driver", "Grader", "Other Vehicle"
|
93 |
+
]}
|
94 |
+
|
95 |
+
if is_video(media_path):
|
96 |
+
cap = cv2.VideoCapture(media_path)
|
97 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
98 |
+
sample_rate = max(1, int(fps))
|
99 |
+
frame_count = 0
|
100 |
+
while cap.isOpened():
|
101 |
+
ret, frame = cap.read()
|
102 |
+
if not ret:
|
103 |
+
break
|
104 |
+
if frame_count % sample_rate == 0:
|
105 |
+
results = YOLO_MODEL(frame)
|
106 |
+
people, _, machine_types = process_yolo_results(results)
|
107 |
+
max_people_count = max(max_people_count, people)
|
108 |
+
for k, v in machine_types.items():
|
109 |
+
max_machine_types[k] = max(max_machine_types[k], v)
|
110 |
+
frame_count += 1
|
111 |
+
cap.release()
|
112 |
+
else:
|
113 |
+
img = cv2.imread(media_path)
|
114 |
+
results = YOLO_MODEL(img)
|
115 |
+
max_people_count, _, max_machine_types = process_yolo_results(results)
|
116 |
+
|
117 |
+
max_machine_types = {k: v for k, v in max_machine_types.items() if v > 0}
|
118 |
+
total_machinery_count = sum(max_machine_types.values())
|
119 |
+
return max_people_count, total_machinery_count, max_machine_types
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
except Exception as e:
|
122 |
+
print(f"Error in YOLO detection: {str(e)}")
|
123 |
+
return 0, 0, {}
|
124 |
+
|
125 |
+
def process_yolo_results(results):
|
126 |
+
people_count = 0
|
127 |
+
machine_types = {key: 0 for key in [
|
128 |
+
"Tower Crane", "Mobile Crane", "Compactor/Roller", "Bulldozer",
|
129 |
+
"Excavator", "Dump Truck", "Concrete Mixer", "Loader",
|
130 |
+
"Pump Truck", "Pile Driver", "Grader", "Other Vehicle"
|
131 |
+
]}
|
132 |
+
for r in results:
|
133 |
+
for box in r.boxes:
|
134 |
+
cls = int(box.cls[0])
|
135 |
+
conf = float(box.conf[0])
|
136 |
+
class_name = YOLO_MODEL.names[cls]
|
137 |
+
if class_name.lower() == 'worker' and conf > 0.5:
|
138 |
+
people_count += 1
|
139 |
+
machinery_mapping = {
|
140 |
+
'tower_crane': "Tower Crane",
|
141 |
+
'mobile_crane': "Mobile Crane",
|
142 |
+
'grader': "Grader",
|
143 |
+
'other_vehicle': "Other Vehicle"
|
144 |
+
}
|
145 |
+
if conf > 0.5:
|
146 |
+
for key, value in machinery_mapping.items():
|
147 |
+
if key in class_name.lower():
|
148 |
+
machine_types[value] += 1
|
149 |
+
break
|
150 |
+
return people_count, sum(machine_types.values()), machine_types
|
151 |
|
152 |
+
@spaces.GPU
|
153 |
+
def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media):
|
154 |
+
if media is None:
|
155 |
+
return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", "No media uploaded", None]
|
156 |
try:
|
157 |
+
file_ext = get_file_extension(media.name)
|
158 |
+
if not (is_image(media.name) or is_video(media.name)):
|
159 |
+
raise ValueError(f"Unsupported file type: {file_ext}")
|
160 |
+
|
161 |
+
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file:
|
162 |
+
temp_path = temp_file.name
|
163 |
+
temp_file.write(media.read())
|
164 |
+
|
165 |
+
detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(temp_path)
|
166 |
+
detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()])
|
167 |
+
detected_activities = "Sample activity analysis." # Placeholder
|
168 |
+
|
169 |
+
os.remove(temp_path)
|
170 |
+
return [day, date, str(detected_people), str(detected_machinery), detected_types_str, detected_activities, None]
|
171 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
except Exception as e:
|
173 |
+
print(f"Error processing media: {str(e)}")
|
174 |
+
return [day, date, "Error", "Error", "Error", "Error", None]
|
175 |
+
|
176 |
+
with gr.Blocks(title="Digital Site Diary") as demo:
|
177 |
+
gr.Markdown("# 📝 Digital Site Diary")
|
178 |
+
with gr.Row():
|
179 |
+
with gr.Column():
|
180 |
+
day = gr.Textbox(label="Day", value='9')
|
181 |
+
date = gr.Textbox(label="Date", value=datetime.now().strftime("%Y-%m-%d"))
|
182 |
+
total_people = gr.Number(label="Total Number of People", value=10)
|
183 |
+
total_machinery = gr.Number(label="Total Number of Machinery", value=3)
|
184 |
+
media = gr.File(label="Upload Image/Video", file_types=["image", "video"])
|
185 |
+
submit_btn = gr.Button("Submit")
|
186 |
+
with gr.Column():
|
187 |
+
model_day = gr.Textbox(label="Day")
|
188 |
+
model_date = gr.Textbox(label="Date")
|
189 |
+
model_people = gr.Textbox(label="Total Number of People")
|
190 |
+
model_machinery = gr.Textbox(label="Total Machinery")
|
191 |
+
model_machinery_types = gr.Textbox(label="Machinery Types")
|
192 |
+
model_activities = gr.Textbox(label="Activities")
|
193 |
+
submit_btn.click(
|
194 |
+
fn=process_diary,
|
195 |
+
inputs=[day, date, total_people, total_machinery, None, None, media],
|
196 |
+
outputs=[model_day, model_date, model_people, model_machinery, model_machinery_types, model_activities, None]
|
197 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
if __name__ == "__main__":
|
200 |
+
demo.launch(share=False, debug=True, show_api=False, server_port=args.port, server_name=args.host)
|
|
|
|
|
|
|
|
|
|