Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,418 +1,49 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
import
|
4 |
import os
|
5 |
-
import
|
6 |
-
import
|
7 |
import gc
|
8 |
-
import tempfile
|
9 |
-
import subprocess
|
10 |
-
import time
|
11 |
-
from datetime import datetime
|
12 |
-
from transformers import AutoModel, AutoTokenizer
|
13 |
-
from modelscope.hub.snapshot_download import snapshot_download
|
14 |
from PIL import Image
|
15 |
from decord import VideoReader, cpu
|
16 |
-
import
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
27 |
)
|
28 |
|
29 |
-
#
|
30 |
-
|
31 |
-
|
32 |
-
parser = argparse.ArgumentParser(description='demo')
|
33 |
-
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
|
34 |
-
parser.add_argument("--host", type=str, default="0.0.0.0")
|
35 |
-
parser.add_argument("--port", type=int)
|
36 |
-
# Arguments for subprocess inference mode
|
37 |
-
parser.add_argument("--chunk_inference", action="store_true", help="Run inference on a chunk (subprocess mode).")
|
38 |
-
parser.add_argument("--input_file", type=str, help="Path to serialized input chunk frames.")
|
39 |
-
parser.add_argument("--output_file", type=str, help="Path to file where inference result is written.")
|
40 |
-
parser.add_argument("--inference_prompt", type=str, help="Inference prompt for the chunk.")
|
41 |
-
parser.add_argument("--model_path_arg", type=str, help="Model path for the subprocess.")
|
42 |
-
args = parser.parse_args()
|
43 |
-
device = args.device
|
44 |
-
assert device in ['cuda', 'mps']
|
45 |
-
|
46 |
-
# Global model configuration
|
47 |
-
MODEL_NAME = 'iic/mPLUG-Owl3-7B-240728'
|
48 |
-
MODEL_CACHE_DIR = os.getenv('TRANSFORMERS_CACHE', './models')
|
49 |
-
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
|
50 |
-
|
51 |
-
# Download and cache the model (only in the main process)
|
52 |
-
if not args.chunk_inference:
|
53 |
-
try:
|
54 |
-
model_path = snapshot_download(MODEL_NAME, cache_dir=MODEL_CACHE_DIR)
|
55 |
-
except Exception as e:
|
56 |
-
print(f"Error downloading model: {str(e)}")
|
57 |
-
model_path = os.path.join(MODEL_CACHE_DIR, MODEL_NAME)
|
58 |
-
else:
|
59 |
-
model_path = args.model_path_arg
|
60 |
-
|
61 |
-
MAX_NUM_FRAMES = 64
|
62 |
-
|
63 |
-
# Initialize YOLO model (assumed to be lightweight)
|
64 |
-
YOLO_MODEL = YOLO('./best_yolov11.pt') # Load YOLOv11 model
|
65 |
-
|
66 |
-
# File type validation
|
67 |
-
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
68 |
-
VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
|
69 |
-
|
70 |
-
def get_file_extension(filename):
|
71 |
-
return os.path.splitext(filename)[1].lower()
|
72 |
-
|
73 |
-
def is_image(filename):
|
74 |
-
return get_file_extension(filename) in IMAGE_EXTENSIONS
|
75 |
-
|
76 |
-
def is_video(filename):
|
77 |
-
return get_file_extension(filename) in VIDEO_EXTENSIONS
|
78 |
-
|
79 |
-
# --------------------------------------------------------------------
|
80 |
-
# Model Loading and Inference Functions
|
81 |
-
# --------------------------------------------------------------------
|
82 |
-
def load_model_and_tokenizer():
|
83 |
-
"""Load a fresh instance of the model and tokenizer."""
|
84 |
-
try:
|
85 |
-
# Clear GPU memory if using CUDA (only at initial load)
|
86 |
-
if device == "cuda":
|
87 |
-
torch.cuda.empty_cache()
|
88 |
-
gc.collect()
|
89 |
-
model = AutoModel.from_pretrained(
|
90 |
-
model_path,
|
91 |
-
attn_implementation='sdpa',
|
92 |
-
trust_remote_code=True,
|
93 |
-
torch_dtype=torch.half,
|
94 |
-
device_map='auto'
|
95 |
-
)
|
96 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
97 |
-
model.eval()
|
98 |
-
processor = model.init_processor(tokenizer)
|
99 |
-
return model, tokenizer, processor
|
100 |
-
except Exception as e:
|
101 |
-
print(f"Error loading model: {str(e)}")
|
102 |
-
raise
|
103 |
-
|
104 |
-
def process_video_chunk(video_frames, model, tokenizer, processor, prompt):
|
105 |
-
"""Process a chunk of video frames with mPLUG model."""
|
106 |
-
messages = [{
|
107 |
-
"role": "user",
|
108 |
-
"content": prompt,
|
109 |
-
"video_frames": video_frames
|
110 |
-
}]
|
111 |
-
model_messages = []
|
112 |
-
videos = []
|
113 |
-
for msg in messages:
|
114 |
-
content_str = msg["content"]
|
115 |
-
if "video_frames" in msg and msg["video_frames"]:
|
116 |
-
content_str += "<|video|>"
|
117 |
-
videos.append(msg["video_frames"])
|
118 |
-
model_messages.append({"role": msg["role"], "content": content_str})
|
119 |
-
model_messages.append({"role": "assistant", "content": ""})
|
120 |
-
inputs = processor(
|
121 |
-
model_messages,
|
122 |
-
images=None,
|
123 |
-
videos=videos if videos else None
|
124 |
-
)
|
125 |
-
inputs.to('cuda')
|
126 |
-
inputs.update({
|
127 |
-
'tokenizer': tokenizer,
|
128 |
-
'max_new_tokens': 100,
|
129 |
-
'decode_text': True,
|
130 |
-
'use_cache': False # disable caching to reduce memory buildup
|
131 |
-
})
|
132 |
-
with torch.no_grad():
|
133 |
-
response = model.generate(**inputs)
|
134 |
-
del inputs # Free temporary memory
|
135 |
-
return response[0]
|
136 |
-
|
137 |
-
# --------------------------------------------------------------------
|
138 |
-
# Video and YOLO Functions (Unchanged)
|
139 |
-
# --------------------------------------------------------------------
|
140 |
-
def encode_video_in_chunks(video_path):
|
141 |
-
"""Extract frames from a video in chunks."""
|
142 |
-
vr = VideoReader(video_path, ctx=cpu(0))
|
143 |
-
sample_fps = round(vr.get_avg_fps() / 1) # 1 FPS
|
144 |
-
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
145 |
-
chunks = [frame_idx[i:i + MAX_NUM_FRAMES] for i in range(0, len(frame_idx), MAX_NUM_FRAMES)]
|
146 |
-
for chunk_idx, chunk in enumerate(chunks):
|
147 |
-
frames = vr.get_batch(chunk).asnumpy()
|
148 |
-
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
149 |
-
yield chunk_idx, frames
|
150 |
-
|
151 |
-
def process_yolo_results(results):
|
152 |
-
"""Process YOLO detection results and count people and machinery."""
|
153 |
-
people_count = 0
|
154 |
-
machine_types = {
|
155 |
-
"Tower Crane": 0, "Mobile Crane": 0, "Compactor/Roller": 0, "Bulldozer": 0,
|
156 |
-
"Excavator": 0, "Dump Truck": 0, "Concrete Mixer": 0, "Loader": 0,
|
157 |
-
"Pump Truck": 0, "Pile Driver": 0, "Grader": 0, "Other Vehicle": 0
|
158 |
-
}
|
159 |
-
for r in results:
|
160 |
-
boxes = r.boxes
|
161 |
-
for box in boxes:
|
162 |
-
cls = int(box.cls[0])
|
163 |
-
conf = float(box.conf[0])
|
164 |
-
class_name = YOLO_MODEL.names[cls]
|
165 |
-
if class_name.lower() == 'worker' and conf > 0.5:
|
166 |
-
people_count += 1
|
167 |
-
machinery_mapping = {
|
168 |
-
'tower_crane': "Tower Crane",
|
169 |
-
'mobile_crane': "Mobile Crane",
|
170 |
-
'compactor': "Compactor/Roller",
|
171 |
-
'roller': "Compactor/Roller",
|
172 |
-
'bulldozer': "Bulldozer",
|
173 |
-
'dozer': "Bulldozer",
|
174 |
-
'excavator': "Excavator",
|
175 |
-
'dump_truck': "Dump Truck",
|
176 |
-
'truck': "Dump Truck",
|
177 |
-
'concrete_mixer_truck': "Concrete Mixer",
|
178 |
-
'loader': "Loader",
|
179 |
-
'pump_truck': "Pump Truck",
|
180 |
-
'pile_driver': "Pile Driver",
|
181 |
-
'grader': "Grader",
|
182 |
-
'other_vehicle': "Other Vehicle"
|
183 |
-
}
|
184 |
-
if conf > 0.5:
|
185 |
-
class_lower = class_name.lower()
|
186 |
-
for key, value in machinery_mapping.items():
|
187 |
-
if key in class_lower:
|
188 |
-
machine_types[value] += 1
|
189 |
-
break
|
190 |
-
total_machinery = sum(machine_types.values())
|
191 |
-
return people_count, total_machinery, machine_types
|
192 |
-
|
193 |
-
def detect_people_and_machinery(media_path):
|
194 |
-
"""Detect people and machinery using YOLOv11 for both images and videos."""
|
195 |
-
try:
|
196 |
-
max_people_count = 0
|
197 |
-
max_machine_types = {
|
198 |
-
"Tower Crane": 0, "Mobile Crane": 0, "Compactor/Roller": 0, "Bulldozer": 0,
|
199 |
-
"Excavator": 0, "Dump Truck": 0, "Concrete Mixer": 0, "Loader": 0,
|
200 |
-
"Pump Truck": 0, "Pile Driver": 0, "Grader": 0, "Other Vehicle": 0
|
201 |
-
}
|
202 |
-
if isinstance(media_path, str) and is_video(media_path):
|
203 |
-
cap = cv2.VideoCapture(media_path)
|
204 |
-
fps = cap.get(cv2.CAP_PROP_FPS)
|
205 |
-
sample_rate = max(1, int(fps))
|
206 |
-
frame_count = 0
|
207 |
-
while cap.isOpened():
|
208 |
-
ret, frame = cap.read()
|
209 |
-
if not ret:
|
210 |
-
break
|
211 |
-
if frame_count % sample_rate == 0:
|
212 |
-
results = YOLO_MODEL(frame)
|
213 |
-
people, _, machine_types = process_yolo_results(results)
|
214 |
-
max_people_count = max(max_people_count, people)
|
215 |
-
for k, v in machine_types.items():
|
216 |
-
max_machine_types[k] = max(max_machine_types[k], v)
|
217 |
-
frame_count += 1
|
218 |
-
cap.release()
|
219 |
-
else:
|
220 |
-
if isinstance(media_path, str):
|
221 |
-
img = cv2.imread(media_path)
|
222 |
-
else:
|
223 |
-
img = cv2.cvtColor(np.array(media_path), cv2.COLOR_RGB2BGR)
|
224 |
-
results = YOLO_MODEL(img)
|
225 |
-
max_people_count, _, max_machine_types = process_yolo_results(results)
|
226 |
-
max_machine_types = {k: v for k, v in max_machine_types.items() if v > 0}
|
227 |
-
total_machinery_count = sum(max_machine_types.values())
|
228 |
-
return max_people_count, total_machinery_count, max_machine_types
|
229 |
-
except Exception as e:
|
230 |
-
print(f"Error in YOLO detection: {str(e)}")
|
231 |
-
return 0, 0, {}
|
232 |
-
|
233 |
-
def process_image(image_path, model, tokenizer, processor, prompt):
|
234 |
-
"""Process single image with mPLUG model."""
|
235 |
-
try:
|
236 |
-
image = Image.open(image_path)
|
237 |
-
messages = [{
|
238 |
-
"role": "user",
|
239 |
-
"content": prompt,
|
240 |
-
"images": [image]
|
241 |
-
}]
|
242 |
-
model_messages = []
|
243 |
-
images = []
|
244 |
-
for msg in messages:
|
245 |
-
content_str = msg["content"]
|
246 |
-
if "images" in msg and msg["images"]:
|
247 |
-
content_str += "<|image|>"
|
248 |
-
images.extend(msg["images"])
|
249 |
-
model_messages.append({"role": msg["role"], "content": content_str})
|
250 |
-
model_messages.append({"role": "assistant", "content": ""})
|
251 |
-
inputs = processor(model_messages, images=images, videos=None)
|
252 |
-
inputs.to('cuda')
|
253 |
-
inputs.update({
|
254 |
-
'tokenizer': tokenizer,
|
255 |
-
'max_new_tokens': 100,
|
256 |
-
'decode_text': True,
|
257 |
-
'use_cache': False
|
258 |
-
})
|
259 |
-
with torch.no_grad():
|
260 |
-
response = model.generate(**inputs)
|
261 |
-
del inputs
|
262 |
-
return response[0]
|
263 |
-
except Exception as e:
|
264 |
-
print(f"Error processing image: {str(e)}")
|
265 |
-
return "Error processing image"
|
266 |
-
|
267 |
-
def analyze_image_activities(image_path):
|
268 |
-
"""Analyze image using mPLUG model."""
|
269 |
-
try:
|
270 |
-
model, tokenizer, processor = load_model_and_tokenizer()
|
271 |
-
prompt = ("Analyze this construction site image and describe the activities happening. "
|
272 |
-
"Focus on construction activities, machinery usage, and worker actions.")
|
273 |
-
response = process_image(image_path, model, tokenizer, processor, prompt)
|
274 |
-
del model, tokenizer, processor
|
275 |
-
torch.cuda.empty_cache() # Final cleanup after image processing
|
276 |
-
gc.collect()
|
277 |
-
return response
|
278 |
-
except Exception as e:
|
279 |
-
print(f"Error analyzing image: {str(e)}")
|
280 |
-
return "Error analyzing image activities"
|
281 |
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
writes a per-frame summary of detected classes on the frame, and saves
|
286 |
-
the annotated video. Returns the annotated video path.
|
287 |
-
"""
|
288 |
-
cap = cv2.VideoCapture(video_path)
|
289 |
-
fps = cap.get(cv2.CAP_PROP_FPS)
|
290 |
-
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
291 |
-
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
292 |
-
out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
293 |
-
annotated_video_path = out_file.name
|
294 |
-
out_file.close()
|
295 |
-
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
296 |
-
writer = cv2.VideoWriter(annotated_video_path, fourcc, fps, (w, h))
|
297 |
-
while True:
|
298 |
-
ret, frame = cap.read()
|
299 |
-
if not ret:
|
300 |
-
break
|
301 |
-
results = YOLO_MODEL(frame)
|
302 |
-
frame_counts = {}
|
303 |
-
for r in results:
|
304 |
-
boxes = r.boxes
|
305 |
-
for box in boxes:
|
306 |
-
cls_id = int(box.cls[0])
|
307 |
-
conf = float(box.conf[0])
|
308 |
-
if conf < 0.5:
|
309 |
-
continue
|
310 |
-
x1, y1, x2, y2 = box.xyxy[0]
|
311 |
-
class_name = YOLO_MODEL.names[cls_id]
|
312 |
-
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
313 |
-
color = (0, 255, 0)
|
314 |
-
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
315 |
-
label_text = f"{class_name} {conf:.2f}"
|
316 |
-
cv2.putText(frame, label_text, (x1, y1 - 6),
|
317 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)
|
318 |
-
frame_counts[class_name] = frame_counts.get(class_name, 0) + 1
|
319 |
-
summary_str = ", ".join(f"{cls_name}: {count}" for cls_name, count in frame_counts.items())
|
320 |
-
cv2.putText(frame, summary_str, (15, 30),
|
321 |
-
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2)
|
322 |
-
writer.write(frame)
|
323 |
-
cap.release()
|
324 |
-
writer.release()
|
325 |
-
return annotated_video_path
|
326 |
|
327 |
-
# --------------------------------------------------------------------
|
328 |
-
# Subprocess Worker: Executed when --chunk_inference flag is provided
|
329 |
-
# --------------------------------------------------------------------
|
330 |
-
if args.chunk_inference:
|
331 |
-
# In worker mode, load the serialized frames from the input file
|
332 |
-
try:
|
333 |
-
with open(args.input_file, "rb") as f:
|
334 |
-
frames_serialized = pickle.load(f)
|
335 |
-
video_frames = []
|
336 |
-
for img_bytes in frames_serialized:
|
337 |
-
video_frames.append(Image.open(io.BytesIO(img_bytes)))
|
338 |
-
except Exception as e:
|
339 |
-
print(f"Error reading input frames: {str(e)}")
|
340 |
-
sys.exit(1)
|
341 |
-
try:
|
342 |
-
model, tokenizer, processor = load_model_and_tokenizer()
|
343 |
-
response = process_video_chunk(video_frames, model, tokenizer, processor, args.inference_prompt)
|
344 |
-
with open(args.output_file, "w") as f:
|
345 |
-
f.write(response)
|
346 |
-
del model, tokenizer, processor
|
347 |
-
torch.cuda.empty_cache()
|
348 |
-
gc.collect()
|
349 |
-
except Exception as e:
|
350 |
-
with open(args.output_file, "w") as f:
|
351 |
-
f.write(f"Error in chunk inference: {str(e)}")
|
352 |
-
sys.exit(0)
|
353 |
-
|
354 |
-
# --------------------------------------------------------------------
|
355 |
-
# Main Video Analysis Function Using Subprocess Isolation
|
356 |
-
# --------------------------------------------------------------------
|
357 |
-
@spaces.GPU
|
358 |
-
def analyze_video_activities_subprocess(video_path):
|
359 |
-
"""Analyze video by processing each chunk in a separate subprocess.
|
360 |
-
Each subprocess loads a fresh model instance to avoid GPU memory buildup."""
|
361 |
-
try:
|
362 |
-
all_responses = []
|
363 |
-
chunk_generator = encode_video_in_chunks(video_path)
|
364 |
-
for chunk_idx, video_frames in chunk_generator:
|
365 |
-
# Serialize each frame in the chunk to bytes
|
366 |
-
temp_input = tempfile.NamedTemporaryFile(suffix=".pkl", delete=False)
|
367 |
-
frames_serializable = []
|
368 |
-
for img in video_frames:
|
369 |
-
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tf:
|
370 |
-
img.save(tf, format="PNG")
|
371 |
-
tf.seek(0)
|
372 |
-
frames_serializable.append(tf.read())
|
373 |
-
os.remove(tf.name)
|
374 |
-
with open(temp_input.name, "wb") as f:
|
375 |
-
pickle.dump(frames_serializable, f)
|
376 |
-
# Create a temporary file for subprocess output
|
377 |
-
temp_output = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
|
378 |
-
temp_output.close()
|
379 |
-
prompt = ("Analyze this construction site video chunk and describe the activities happening. "
|
380 |
-
"Focus on construction activities, machinery usage, and worker actions.")
|
381 |
-
# Launch subprocess for this chunk
|
382 |
-
subprocess.run([
|
383 |
-
sys.executable, __file__,
|
384 |
-
"--chunk_inference",
|
385 |
-
"--input_file", temp_input.name,
|
386 |
-
"--output_file", temp_output.name,
|
387 |
-
"--inference_prompt", prompt,
|
388 |
-
"--model_path_arg", model_path,
|
389 |
-
"--device", device
|
390 |
-
], check=True)
|
391 |
-
with open(temp_output.name, "r") as f:
|
392 |
-
response = f.read().strip()
|
393 |
-
all_responses.append(f"Time period {chunk_idx + 1}:\n{response}")
|
394 |
-
os.remove(temp_input.name)
|
395 |
-
os.remove(temp_output.name)
|
396 |
-
time.sleep(2) # Allow time for GPU memory to fully clear before next chunk
|
397 |
-
return "\n\n".join(all_responses)
|
398 |
-
except Exception as e:
|
399 |
-
print(f"Error in subprocess chunk inference: {str(e)}")
|
400 |
-
return "Error analyzing video activities"
|
401 |
-
|
402 |
-
# --------------------------------------------------------------------
|
403 |
-
# Gradio Interface and Main Launch (only executed in main process)
|
404 |
-
# --------------------------------------------------------------------
|
405 |
-
@spaces.GPU
|
406 |
def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media):
|
407 |
-
"""Process the site diary entry
|
|
|
|
|
408 |
if media is None:
|
409 |
-
return [day, date, "No media uploaded", "No media uploaded", "No media uploaded",
|
|
|
410 |
try:
|
411 |
if not hasattr(media, 'name'):
|
412 |
raise ValueError("Invalid file upload")
|
|
|
413 |
file_ext = get_file_extension(media.name)
|
414 |
if not (is_image(media.name) or is_video(media.name)):
|
415 |
raise ValueError(f"Unsupported file type: {file_ext}")
|
|
|
416 |
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file:
|
417 |
temp_path = temp_file.name
|
418 |
if hasattr(media, 'name') and os.path.exists(media.name):
|
@@ -421,53 +52,412 @@ def process_diary(day, date, total_people, total_machinery, machinery_types, act
|
|
421 |
else:
|
422 |
file_content = media.read() if hasattr(media, 'read') else media
|
423 |
temp_file.write(file_content if isinstance(file_content, bytes) else file_content.read())
|
|
|
424 |
detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(temp_path)
|
|
|
425 |
annotated_video_path = None
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
|
|
|
|
|
|
|
|
434 |
detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()])
|
435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
except Exception as e:
|
437 |
print(f"Error processing media: {str(e)}")
|
438 |
-
return [day, date, "Error processing media", "Error processing media",
|
|
|
439 |
|
440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
gr.Markdown("# 📝 Digital Site Diary")
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
submit_btn.click(
|
467 |
fn=process_diary,
|
468 |
inputs=[day, date, total_people, total_machinery, machinery_types, activities, media],
|
469 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
|
472 |
if __name__ == "__main__":
|
473 |
-
demo.launch(share=
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from datetime import datetime
|
3 |
+
import tempfile
|
4 |
import os
|
5 |
+
import json
|
6 |
+
import torch
|
7 |
import gc
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from PIL import Image
|
9 |
from decord import VideoReader, cpu
|
10 |
+
from yolo_detection import (
|
11 |
+
detect_people_and_machinery,
|
12 |
+
annotate_video_with_bboxes,
|
13 |
+
is_image,
|
14 |
+
is_video
|
15 |
+
)
|
16 |
+
from image_captioning import (
|
17 |
+
analyze_image_activities,
|
18 |
+
analyze_video_activities,
|
19 |
+
process_video_chunk,
|
20 |
+
load_model_and_tokenizer,
|
21 |
+
MAX_NUM_FRAMES
|
22 |
)
|
23 |
|
24 |
+
# Global storage for activities and media paths
|
25 |
+
global_activities = []
|
26 |
+
global_media_path = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
+
# Create tmp directory for storing frames
|
29 |
+
tmp_dir = os.path.join('.', 'tmp')
|
30 |
+
os.makedirs(tmp_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media):
|
33 |
+
"""Process the site diary entry"""
|
34 |
+
global global_activities, global_media_path
|
35 |
+
|
36 |
if media is None:
|
37 |
+
return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", None, None, [], None]
|
38 |
+
|
39 |
try:
|
40 |
if not hasattr(media, 'name'):
|
41 |
raise ValueError("Invalid file upload")
|
42 |
+
|
43 |
file_ext = get_file_extension(media.name)
|
44 |
if not (is_image(media.name) or is_video(media.name)):
|
45 |
raise ValueError(f"Unsupported file type: {file_ext}")
|
46 |
+
|
47 |
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file:
|
48 |
temp_path = temp_file.name
|
49 |
if hasattr(media, 'name') and os.path.exists(media.name):
|
|
|
52 |
else:
|
53 |
file_content = media.read() if hasattr(media, 'read') else media
|
54 |
temp_file.write(file_content if isinstance(file_content, bytes) else file_content.read())
|
55 |
+
|
56 |
detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(temp_path)
|
57 |
+
print(f"Detected people: {detected_people}, machinery: {detected_machinery}, types: {detected_machinery_types}")
|
58 |
annotated_video_path = None
|
59 |
+
|
60 |
+
detected_activities = analyze_image_activities(temp_path) if is_image(media.name) else analyze_video_activities(temp_path)
|
61 |
+
|
62 |
+
print(f"Detected activities: {detected_activities}")
|
63 |
+
|
64 |
+
# Store activities and media path globally for chat mode
|
65 |
+
global_activities = detected_activities
|
66 |
+
global_media_path = temp_path
|
67 |
+
|
68 |
+
if is_video(media.name):
|
69 |
+
annotated_video_path = temp_path # Or use annotate_video_with_bboxes(temp_path) if implemented
|
70 |
+
|
71 |
detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()])
|
72 |
+
|
73 |
+
# We'll return the activities as a list for the card display
|
74 |
+
# Clear the chat history when loading new media
|
75 |
+
chat_history = []
|
76 |
+
|
77 |
+
# Extract data for the activity table
|
78 |
+
activity_rows = []
|
79 |
+
for activity in detected_activities:
|
80 |
+
time = activity.get('time', 'Unknown')
|
81 |
+
summary = activity.get('summary', 'No description available')
|
82 |
+
activity_rows.append([time, summary])
|
83 |
+
|
84 |
+
return [day, date, str(detected_people), str(detected_machinery),
|
85 |
+
detected_types_str, gr.update(visible=True), annotated_video_path,
|
86 |
+
detected_activities, chat_history, activity_rows]
|
87 |
+
|
88 |
except Exception as e:
|
89 |
print(f"Error processing media: {str(e)}")
|
90 |
+
return [day, date, "Error processing media", "Error processing media",
|
91 |
+
"Error processing media", None, None, [], None, []]
|
92 |
|
93 |
+
def get_file_extension(filename):
|
94 |
+
return os.path.splitext(filename)[1].lower()
|
95 |
+
|
96 |
+
def on_card_click(activity_indices, history, evt: gr.SelectData):
|
97 |
+
"""Handle clicking on an activity card in the gallery"""
|
98 |
+
global global_activities, global_media_path
|
99 |
+
|
100 |
+
# Get the index of the selected activity from the SelectData event
|
101 |
+
selected_idx = evt.index
|
102 |
+
|
103 |
+
# Map the gallery index to the actual activity index
|
104 |
+
if selected_idx < 0 or selected_idx >= len(activity_indices):
|
105 |
+
return [gr.update(visible=True), gr.update(visible=False), [], None]
|
106 |
+
|
107 |
+
card_idx = activity_indices[selected_idx]
|
108 |
+
print(f"Gallery item {selected_idx} clicked, corresponds to activity index: {card_idx}")
|
109 |
+
|
110 |
+
if card_idx < 0 or card_idx >= len(global_activities):
|
111 |
+
return [gr.update(visible=True), gr.update(visible=False), [], None]
|
112 |
+
|
113 |
+
selected_activity = global_activities[card_idx]
|
114 |
+
chunk_video_path = None
|
115 |
+
|
116 |
+
# Use the pre-saved chunk video if available
|
117 |
+
if 'chunk_path' in selected_activity and os.path.exists(selected_activity['chunk_path']):
|
118 |
+
chunk_video_path = selected_activity['chunk_path']
|
119 |
+
print(f"Using pre-saved chunk video: {chunk_video_path}")
|
120 |
+
else:
|
121 |
+
# Fallback to full video if chunk not available
|
122 |
+
chunk_video_path = global_media_path
|
123 |
+
print(f"Chunk video not available, using full video: {chunk_video_path}")
|
124 |
+
|
125 |
+
# Add the selected activity to chat history
|
126 |
+
history = []
|
127 |
+
history.append((None, f"🎬 Selected video at timestamp {selected_activity['time']}"))
|
128 |
+
|
129 |
+
# Add the thumbnail to the chat as a visual element
|
130 |
+
if 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']):
|
131 |
+
# Use the tuple format for images in chatbot
|
132 |
+
thumbnail_path = selected_activity['thumbnail']
|
133 |
+
history.append((None, f"📷 Video frame at {selected_activity['time']}"))
|
134 |
+
history.append((None, thumbnail_path))
|
135 |
+
|
136 |
+
# Format message about the detected activity
|
137 |
+
activity_info = f"I detected the following activity:\n\n{selected_activity['summary']}"
|
138 |
+
if selected_activity['objects']:
|
139 |
+
activity_info += f"\n\nIdentified objects: {', '.join(selected_activity['objects'])}"
|
140 |
+
|
141 |
+
history.append(("Tell me about this video segment", activity_info))
|
142 |
+
|
143 |
+
return [gr.update(visible=False), gr.update(visible=True), history, chunk_video_path]
|
144 |
+
|
145 |
+
def chat_with_video(message, history):
|
146 |
+
"""Chat with the mPLUG model about the selected video segment"""
|
147 |
+
global global_activities, global_media_path
|
148 |
+
|
149 |
+
try:
|
150 |
+
# Get the selected activity from the history to identify which chunk we're discussing
|
151 |
+
selected_chunk_idx = None
|
152 |
+
selected_time = None
|
153 |
+
selected_activity = None
|
154 |
+
|
155 |
+
for entry in history:
|
156 |
+
if entry[0] is None and "Selected video at timestamp" in entry[1]:
|
157 |
+
time_str = entry[1].split("Selected video at timestamp ")[1]
|
158 |
+
selected_time = time_str.strip()
|
159 |
+
break
|
160 |
+
|
161 |
+
# Find the corresponding chunk
|
162 |
+
if selected_time:
|
163 |
+
for i, activity in enumerate(global_activities):
|
164 |
+
if activity.get('time') == selected_time:
|
165 |
+
selected_chunk_idx = activity.get('chunk_id')
|
166 |
+
selected_activity = activity
|
167 |
+
break
|
168 |
+
|
169 |
+
# If we found the chunk, use the model to analyze it
|
170 |
+
if selected_chunk_idx is not None and global_media_path and selected_activity:
|
171 |
+
# Load model
|
172 |
+
model, tokenizer, processor = load_model_and_tokenizer()
|
173 |
+
|
174 |
+
# Generate prompt based on user question and add context about what's in the video
|
175 |
+
context = f"This video shows construction site activities at timestamp {selected_time}."
|
176 |
+
if selected_activity.get('objects'):
|
177 |
+
context += f" The scene contains {', '.join(selected_activity.get('objects'))}."
|
178 |
+
|
179 |
+
prompt = f"{context} Analyze this segment of construction site video and answer this question: {message}"
|
180 |
+
|
181 |
+
# This would ideally use the specific chunk, but for simplicity we'll use the global path
|
182 |
+
# In a production system, you'd extract just that chunk of the video
|
183 |
+
vr = VideoReader(global_media_path, ctx=cpu(0))
|
184 |
+
|
185 |
+
# Get the frames for this chunk
|
186 |
+
sample_fps = round(vr.get_avg_fps() / 1)
|
187 |
+
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
188 |
+
|
189 |
+
# Extract frames for the specific chunk
|
190 |
+
chunk_size = MAX_NUM_FRAMES # From the constants in image_captioning.py
|
191 |
+
start_idx = selected_chunk_idx * chunk_size
|
192 |
+
end_idx = min(start_idx + chunk_size, len(frame_idx))
|
193 |
+
|
194 |
+
chunk_frames = frame_idx[start_idx:end_idx]
|
195 |
+
if chunk_frames:
|
196 |
+
frames = vr.get_batch(chunk_frames).asnumpy()
|
197 |
+
frames_pil = [Image.fromarray(v.astype('uint8')) for v in frames]
|
198 |
+
|
199 |
+
# Process frames with model
|
200 |
+
response = process_video_chunk(frames_pil, model, tokenizer, processor, prompt)
|
201 |
+
|
202 |
+
|
203 |
+
# If we couldn't save a frame, just return the text response
|
204 |
+
# Clean up
|
205 |
+
del model, tokenizer, processor
|
206 |
+
torch.cuda.empty_cache()
|
207 |
+
gc.collect()
|
208 |
+
|
209 |
+
return history + [(message, response)]
|
210 |
+
else:
|
211 |
+
return history + [(message, "Could not extract frames for this segment.")]
|
212 |
+
else:
|
213 |
+
# Fallback response if we can't identify the chunk
|
214 |
+
thumbnail = None
|
215 |
+
response_text = f"I'm analyzing your question about the video segment: {message}\n\nBased on what I can see in this segment, it appears to show construction activity with various machinery and workers on site. The specific details would depend on the exact timestamp you're referring to."
|
216 |
+
|
217 |
+
# Try to get a thumbnail from the selected activity if available
|
218 |
+
if selected_activity and 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']):
|
219 |
+
thumbnail = selected_activity['thumbnail']
|
220 |
+
new_history = history + [(message, response_text)]
|
221 |
+
new_history.append((None, f"📷 Video frame at {selected_time}"))
|
222 |
+
new_history.append((None, thumbnail))
|
223 |
+
return new_history
|
224 |
+
|
225 |
+
return history + [(message, response_text)]
|
226 |
+
|
227 |
+
except Exception as e:
|
228 |
+
print(f"Error in chat_with_video: {str(e)}")
|
229 |
+
return history + [(message, f"I encountered an error while processing your question. Let me try to answer based on what I can see: {message}\n\nThe video appears to show construction site activities, but I'm having trouble with the detailed analysis at the moment.")]
|
230 |
+
|
231 |
+
# Native Gradio activity cards
|
232 |
+
def create_activity_cards_ui(activities):
|
233 |
+
"""Create activity cards using native Gradio components"""
|
234 |
+
if not activities:
|
235 |
+
return gr.HTML("<div class='activity-timeline'><h3>No activities detected</h3></div>"), []
|
236 |
+
|
237 |
+
# Prepare data for gallery
|
238 |
+
thumbnails = []
|
239 |
+
captions = []
|
240 |
+
activity_indices = []
|
241 |
+
|
242 |
+
for i, activity in enumerate(activities):
|
243 |
+
thumbnail = activity.get('thumbnail', '')
|
244 |
+
time = activity.get('time', 'Unknown')
|
245 |
+
summary = activity.get('summary', 'No description available')
|
246 |
+
objects_list = activity.get('objects', [])
|
247 |
+
objects_text = f"Objects: {', '.join(objects_list)}" if objects_list else ""
|
248 |
+
|
249 |
+
# Truncate summary if too long
|
250 |
+
if len(summary) > 150:
|
251 |
+
summary = summary[:147] + "..."
|
252 |
+
|
253 |
+
thumbnails.append(thumbnail)
|
254 |
+
captions.append(f"Timestamp: {time} | {summary}")
|
255 |
+
activity_indices.append(i)
|
256 |
+
|
257 |
+
# Create a gallery for the thumbnails
|
258 |
+
gallery = gr.Gallery(
|
259 |
+
value=[(path, caption) for path, caption in zip(thumbnails, captions)],
|
260 |
+
columns=5,
|
261 |
+
rows=None,
|
262 |
+
height="auto",
|
263 |
+
object_fit="contain",
|
264 |
+
label="Activity Timeline"
|
265 |
+
)
|
266 |
+
|
267 |
+
return gallery, activity_indices
|
268 |
+
|
269 |
+
# Create the Gradio interface
|
270 |
+
with gr.Blocks(title="Digital Site Diary", css="") as demo:
|
271 |
+
|
272 |
gr.Markdown("# 📝 Digital Site Diary")
|
273 |
+
|
274 |
+
# Activity data and indices storage
|
275 |
+
activity_data = gr.State([])
|
276 |
+
activity_indices = gr.State([])
|
277 |
+
|
278 |
+
# Create tabs for different views
|
279 |
+
with gr.Tabs() as tabs:
|
280 |
+
with gr.Tab("Site Diary"):
|
281 |
+
with gr.Row():
|
282 |
+
# User Input Column
|
283 |
+
with gr.Column():
|
284 |
+
gr.Markdown("### User Input")
|
285 |
+
day = gr.Textbox(label="Day",value='9')
|
286 |
+
date = gr.Textbox(label="Date", placeholder="YYYY-MM-DD", value=datetime.now().strftime("%Y-%m-%d"))
|
287 |
+
total_people = gr.Number(label="Total Number of People", precision=0, value=10)
|
288 |
+
total_machinery = gr.Number(label="Total Number of Machinery", precision=0, value=3)
|
289 |
+
machinery_types = gr.Textbox(
|
290 |
+
label="Number of Machinery Per Type",
|
291 |
+
placeholder="e.g., Excavator: 2, Roller: 1",
|
292 |
+
value="Excavator: 2, Roller: 1"
|
293 |
+
)
|
294 |
+
activities = gr.Textbox(
|
295 |
+
label="Activity",
|
296 |
+
placeholder="e.g., 9 AM: Excavation, 10 AM: Concreting",
|
297 |
+
value="9 AM: Excavation, 10 AM: Concreting",
|
298 |
+
lines=3
|
299 |
+
)
|
300 |
+
media = gr.File(label="Upload Image/Video", file_types=["image", "video"])
|
301 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
302 |
+
|
303 |
+
# Model Detection Column
|
304 |
+
with gr.Column():
|
305 |
+
gr.Markdown("### Model Detection")
|
306 |
+
model_day = gr.Textbox(label="Day")
|
307 |
+
model_date = gr.Textbox(label="Date")
|
308 |
+
model_people = gr.Textbox(label="Total Number of People")
|
309 |
+
model_machinery = gr.Textbox(label="Total Number of Machinery")
|
310 |
+
model_machinery_types = gr.Textbox(label="Number of Machinery Per Type")
|
311 |
+
# Activity Row with Timestamps
|
312 |
+
with gr.Row():
|
313 |
+
gr.Markdown("#### Activities with Timestamps")
|
314 |
+
model_activities = gr.Dataframe(
|
315 |
+
headers=["Time", "Activity Description"],
|
316 |
+
datatype=["str", "str"],
|
317 |
+
label="Detected Activities",
|
318 |
+
interactive=False,
|
319 |
+
wrap=True
|
320 |
+
)
|
321 |
+
|
322 |
+
# Activity timeline section
|
323 |
+
with gr.Row():
|
324 |
+
# Timeline View (default visible)
|
325 |
+
with gr.Column(visible=True) as timeline_view:
|
326 |
+
activity_gallery = gr.Gallery(label="Activity Timeline")
|
327 |
+
model_annotated_video = gr.Video(label="Full Video")
|
328 |
+
|
329 |
+
# Chat View (initially hidden)
|
330 |
+
with gr.Column(visible=False) as chat_view:
|
331 |
+
chunk_video = gr.Video(label="Chunk video")
|
332 |
+
chatbot = gr.Chatbot(height=400)
|
333 |
+
chat_input = gr.Textbox(
|
334 |
+
placeholder="Ask about this video segment...",
|
335 |
+
show_label=False
|
336 |
+
)
|
337 |
+
back_btn = gr.Button("← Back to Timeline")
|
338 |
+
|
339 |
+
# Connect the submit button to the processing function
|
340 |
submit_btn.click(
|
341 |
fn=process_diary,
|
342 |
inputs=[day, date, total_people, total_machinery, machinery_types, activities, media],
|
343 |
+
outputs=[
|
344 |
+
model_day,
|
345 |
+
model_date,
|
346 |
+
model_people,
|
347 |
+
model_machinery,
|
348 |
+
model_machinery_types,
|
349 |
+
timeline_view,
|
350 |
+
model_annotated_video,
|
351 |
+
activity_data,
|
352 |
+
chatbot,
|
353 |
+
model_activities
|
354 |
+
]
|
355 |
+
)
|
356 |
+
|
357 |
+
# Process activity data into gallery
|
358 |
+
activity_data.change(
|
359 |
+
fn=create_activity_cards_ui,
|
360 |
+
inputs=[activity_data],
|
361 |
+
outputs=[activity_gallery, activity_indices]
|
362 |
+
)
|
363 |
+
|
364 |
+
# Handle gallery selection
|
365 |
+
activity_gallery.select(
|
366 |
+
fn=on_card_click,
|
367 |
+
inputs=[activity_indices, chatbot],
|
368 |
+
outputs=[timeline_view, chat_view, chatbot, chunk_video]
|
369 |
+
)
|
370 |
+
|
371 |
+
# Chat submission
|
372 |
+
chat_input.submit(
|
373 |
+
fn=chat_with_video,
|
374 |
+
inputs=[chat_input, chatbot],
|
375 |
+
outputs=[chatbot]
|
376 |
+
)
|
377 |
+
|
378 |
+
# Back button
|
379 |
+
back_btn.click(
|
380 |
+
fn=lambda: [gr.update(visible=True), gr.update(visible=False)],
|
381 |
+
inputs=None,
|
382 |
+
outputs=[timeline_view, chat_view]
|
383 |
)
|
384 |
+
|
385 |
+
# Add enhanced CSS styling
|
386 |
+
gr.HTML("""
|
387 |
+
<style>
|
388 |
+
/* Gallery customizations */
|
389 |
+
.gradio-container .gallery-item {
|
390 |
+
border: 1px solid #444444 !important;
|
391 |
+
border-radius: 8px !important;
|
392 |
+
padding: 8px !important;
|
393 |
+
margin: 10px !important;
|
394 |
+
cursor: pointer !important;
|
395 |
+
transition: all 0.3s !important;
|
396 |
+
background: #18181b !important;
|
397 |
+
box-shadow: 0 2px 5px rgba(0,0,0,0.2) !important;
|
398 |
+
}
|
399 |
+
|
400 |
+
.gradio-container .gallery-item:hover {
|
401 |
+
transform: translateY(-2px) !important;
|
402 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.25) !important;
|
403 |
+
border-color: #007bff !important;
|
404 |
+
background: #202025 !important;
|
405 |
+
}
|
406 |
+
|
407 |
+
.gradio-container .gallery-item.selected {
|
408 |
+
border: 2px solid #007bff !important;
|
409 |
+
background: #202030 !important;
|
410 |
+
}
|
411 |
+
|
412 |
+
/* Improved image display */
|
413 |
+
.gradio-container .gallery-item img {
|
414 |
+
height: 180px !important;
|
415 |
+
object-fit: cover !important;
|
416 |
+
border-radius: 4px !important;
|
417 |
+
border: 1px solid #444444 !important;
|
418 |
+
margin-bottom: 8px !important;
|
419 |
+
}
|
420 |
+
|
421 |
+
/* Caption styling */
|
422 |
+
.gradio-container .caption {
|
423 |
+
color: #e0e0e0 !important;
|
424 |
+
font-size: 0.9em !important;
|
425 |
+
margin-top: 8px !important;
|
426 |
+
line-height: 1.4 !important;
|
427 |
+
padding: 0 4px !important;
|
428 |
+
}
|
429 |
+
|
430 |
+
/* Gallery container */
|
431 |
+
.gradio-container [id*='gallery'] > div:first-child {
|
432 |
+
background-color: #27272a !important;
|
433 |
+
padding: 15px !important;
|
434 |
+
border-radius: 10px !important;
|
435 |
+
}
|
436 |
+
|
437 |
+
/* Chatbot styling */
|
438 |
+
.gradio-container .chatbot {
|
439 |
+
background-color: #27272a !important;
|
440 |
+
border-radius: 10px !important;
|
441 |
+
border: 1px solid #444444 !important;
|
442 |
+
}
|
443 |
+
|
444 |
+
.gradio-container .chatbot .message.user {
|
445 |
+
background-color: #18181b !important;
|
446 |
+
border-radius: 8px !important;
|
447 |
+
}
|
448 |
+
|
449 |
+
.gradio-container .chatbot .message.bot {
|
450 |
+
background-color: #202030 !important;
|
451 |
+
border-radius: 8px !important;
|
452 |
+
}
|
453 |
+
|
454 |
+
/* Button styling */
|
455 |
+
.gradio-container button.secondary {
|
456 |
+
background-color: #3d4452 !important;
|
457 |
+
color: white !important;
|
458 |
+
}
|
459 |
+
</style>
|
460 |
+
""")
|
461 |
|
462 |
if __name__ == "__main__":
|
463 |
+
demo.launch(share=True, allowed_paths=["./tmp"])
|