assentian1970 commited on
Commit
efd5a3f
·
verified ·
1 Parent(s): 6d1a54e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +426 -436
app.py CHANGED
@@ -1,418 +1,49 @@
1
- import spaces
2
- import torch
3
- import argparse
4
  import os
5
- import sys
6
- import pickle # For serializing frames
7
  import gc
8
- import tempfile
9
- import subprocess
10
- import time
11
- from datetime import datetime
12
- from transformers import AutoModel, AutoTokenizer
13
- from modelscope.hub.snapshot_download import snapshot_download
14
  from PIL import Image
15
  from decord import VideoReader, cpu
16
- import cv2
17
- import gradio as gr
18
- from ultralytics import YOLO
19
- import numpy as np
20
- import io
21
-
22
- # Install flash-attn (using prebuilt wheel mode if needed)
23
- subprocess.run(
24
- 'pip install flash-attn --no-build-isolation',
25
- env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': 'TRUE'},
26
- shell=True
 
27
  )
28
 
29
- # --------------------------------------------------------------------
30
- # Command-line arguments
31
- # --------------------------------------------------------------------
32
- parser = argparse.ArgumentParser(description='demo')
33
- parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
34
- parser.add_argument("--host", type=str, default="0.0.0.0")
35
- parser.add_argument("--port", type=int)
36
- # Arguments for subprocess inference mode
37
- parser.add_argument("--chunk_inference", action="store_true", help="Run inference on a chunk (subprocess mode).")
38
- parser.add_argument("--input_file", type=str, help="Path to serialized input chunk frames.")
39
- parser.add_argument("--output_file", type=str, help="Path to file where inference result is written.")
40
- parser.add_argument("--inference_prompt", type=str, help="Inference prompt for the chunk.")
41
- parser.add_argument("--model_path_arg", type=str, help="Model path for the subprocess.")
42
- args = parser.parse_args()
43
- device = args.device
44
- assert device in ['cuda', 'mps']
45
-
46
- # Global model configuration
47
- MODEL_NAME = 'iic/mPLUG-Owl3-7B-240728'
48
- MODEL_CACHE_DIR = os.getenv('TRANSFORMERS_CACHE', './models')
49
- os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
50
-
51
- # Download and cache the model (only in the main process)
52
- if not args.chunk_inference:
53
- try:
54
- model_path = snapshot_download(MODEL_NAME, cache_dir=MODEL_CACHE_DIR)
55
- except Exception as e:
56
- print(f"Error downloading model: {str(e)}")
57
- model_path = os.path.join(MODEL_CACHE_DIR, MODEL_NAME)
58
- else:
59
- model_path = args.model_path_arg
60
-
61
- MAX_NUM_FRAMES = 64
62
-
63
- # Initialize YOLO model (assumed to be lightweight)
64
- YOLO_MODEL = YOLO('./best_yolov11.pt') # Load YOLOv11 model
65
-
66
- # File type validation
67
- IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
68
- VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
69
-
70
- def get_file_extension(filename):
71
- return os.path.splitext(filename)[1].lower()
72
-
73
- def is_image(filename):
74
- return get_file_extension(filename) in IMAGE_EXTENSIONS
75
-
76
- def is_video(filename):
77
- return get_file_extension(filename) in VIDEO_EXTENSIONS
78
-
79
- # --------------------------------------------------------------------
80
- # Model Loading and Inference Functions
81
- # --------------------------------------------------------------------
82
- def load_model_and_tokenizer():
83
- """Load a fresh instance of the model and tokenizer."""
84
- try:
85
- # Clear GPU memory if using CUDA (only at initial load)
86
- if device == "cuda":
87
- torch.cuda.empty_cache()
88
- gc.collect()
89
- model = AutoModel.from_pretrained(
90
- model_path,
91
- attn_implementation='sdpa',
92
- trust_remote_code=True,
93
- torch_dtype=torch.half,
94
- device_map='auto'
95
- )
96
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
97
- model.eval()
98
- processor = model.init_processor(tokenizer)
99
- return model, tokenizer, processor
100
- except Exception as e:
101
- print(f"Error loading model: {str(e)}")
102
- raise
103
-
104
- def process_video_chunk(video_frames, model, tokenizer, processor, prompt):
105
- """Process a chunk of video frames with mPLUG model."""
106
- messages = [{
107
- "role": "user",
108
- "content": prompt,
109
- "video_frames": video_frames
110
- }]
111
- model_messages = []
112
- videos = []
113
- for msg in messages:
114
- content_str = msg["content"]
115
- if "video_frames" in msg and msg["video_frames"]:
116
- content_str += "<|video|>"
117
- videos.append(msg["video_frames"])
118
- model_messages.append({"role": msg["role"], "content": content_str})
119
- model_messages.append({"role": "assistant", "content": ""})
120
- inputs = processor(
121
- model_messages,
122
- images=None,
123
- videos=videos if videos else None
124
- )
125
- inputs.to('cuda')
126
- inputs.update({
127
- 'tokenizer': tokenizer,
128
- 'max_new_tokens': 100,
129
- 'decode_text': True,
130
- 'use_cache': False # disable caching to reduce memory buildup
131
- })
132
- with torch.no_grad():
133
- response = model.generate(**inputs)
134
- del inputs # Free temporary memory
135
- return response[0]
136
-
137
- # --------------------------------------------------------------------
138
- # Video and YOLO Functions (Unchanged)
139
- # --------------------------------------------------------------------
140
- def encode_video_in_chunks(video_path):
141
- """Extract frames from a video in chunks."""
142
- vr = VideoReader(video_path, ctx=cpu(0))
143
- sample_fps = round(vr.get_avg_fps() / 1) # 1 FPS
144
- frame_idx = [i for i in range(0, len(vr), sample_fps)]
145
- chunks = [frame_idx[i:i + MAX_NUM_FRAMES] for i in range(0, len(frame_idx), MAX_NUM_FRAMES)]
146
- for chunk_idx, chunk in enumerate(chunks):
147
- frames = vr.get_batch(chunk).asnumpy()
148
- frames = [Image.fromarray(v.astype('uint8')) for v in frames]
149
- yield chunk_idx, frames
150
-
151
- def process_yolo_results(results):
152
- """Process YOLO detection results and count people and machinery."""
153
- people_count = 0
154
- machine_types = {
155
- "Tower Crane": 0, "Mobile Crane": 0, "Compactor/Roller": 0, "Bulldozer": 0,
156
- "Excavator": 0, "Dump Truck": 0, "Concrete Mixer": 0, "Loader": 0,
157
- "Pump Truck": 0, "Pile Driver": 0, "Grader": 0, "Other Vehicle": 0
158
- }
159
- for r in results:
160
- boxes = r.boxes
161
- for box in boxes:
162
- cls = int(box.cls[0])
163
- conf = float(box.conf[0])
164
- class_name = YOLO_MODEL.names[cls]
165
- if class_name.lower() == 'worker' and conf > 0.5:
166
- people_count += 1
167
- machinery_mapping = {
168
- 'tower_crane': "Tower Crane",
169
- 'mobile_crane': "Mobile Crane",
170
- 'compactor': "Compactor/Roller",
171
- 'roller': "Compactor/Roller",
172
- 'bulldozer': "Bulldozer",
173
- 'dozer': "Bulldozer",
174
- 'excavator': "Excavator",
175
- 'dump_truck': "Dump Truck",
176
- 'truck': "Dump Truck",
177
- 'concrete_mixer_truck': "Concrete Mixer",
178
- 'loader': "Loader",
179
- 'pump_truck': "Pump Truck",
180
- 'pile_driver': "Pile Driver",
181
- 'grader': "Grader",
182
- 'other_vehicle': "Other Vehicle"
183
- }
184
- if conf > 0.5:
185
- class_lower = class_name.lower()
186
- for key, value in machinery_mapping.items():
187
- if key in class_lower:
188
- machine_types[value] += 1
189
- break
190
- total_machinery = sum(machine_types.values())
191
- return people_count, total_machinery, machine_types
192
-
193
- def detect_people_and_machinery(media_path):
194
- """Detect people and machinery using YOLOv11 for both images and videos."""
195
- try:
196
- max_people_count = 0
197
- max_machine_types = {
198
- "Tower Crane": 0, "Mobile Crane": 0, "Compactor/Roller": 0, "Bulldozer": 0,
199
- "Excavator": 0, "Dump Truck": 0, "Concrete Mixer": 0, "Loader": 0,
200
- "Pump Truck": 0, "Pile Driver": 0, "Grader": 0, "Other Vehicle": 0
201
- }
202
- if isinstance(media_path, str) and is_video(media_path):
203
- cap = cv2.VideoCapture(media_path)
204
- fps = cap.get(cv2.CAP_PROP_FPS)
205
- sample_rate = max(1, int(fps))
206
- frame_count = 0
207
- while cap.isOpened():
208
- ret, frame = cap.read()
209
- if not ret:
210
- break
211
- if frame_count % sample_rate == 0:
212
- results = YOLO_MODEL(frame)
213
- people, _, machine_types = process_yolo_results(results)
214
- max_people_count = max(max_people_count, people)
215
- for k, v in machine_types.items():
216
- max_machine_types[k] = max(max_machine_types[k], v)
217
- frame_count += 1
218
- cap.release()
219
- else:
220
- if isinstance(media_path, str):
221
- img = cv2.imread(media_path)
222
- else:
223
- img = cv2.cvtColor(np.array(media_path), cv2.COLOR_RGB2BGR)
224
- results = YOLO_MODEL(img)
225
- max_people_count, _, max_machine_types = process_yolo_results(results)
226
- max_machine_types = {k: v for k, v in max_machine_types.items() if v > 0}
227
- total_machinery_count = sum(max_machine_types.values())
228
- return max_people_count, total_machinery_count, max_machine_types
229
- except Exception as e:
230
- print(f"Error in YOLO detection: {str(e)}")
231
- return 0, 0, {}
232
-
233
- def process_image(image_path, model, tokenizer, processor, prompt):
234
- """Process single image with mPLUG model."""
235
- try:
236
- image = Image.open(image_path)
237
- messages = [{
238
- "role": "user",
239
- "content": prompt,
240
- "images": [image]
241
- }]
242
- model_messages = []
243
- images = []
244
- for msg in messages:
245
- content_str = msg["content"]
246
- if "images" in msg and msg["images"]:
247
- content_str += "<|image|>"
248
- images.extend(msg["images"])
249
- model_messages.append({"role": msg["role"], "content": content_str})
250
- model_messages.append({"role": "assistant", "content": ""})
251
- inputs = processor(model_messages, images=images, videos=None)
252
- inputs.to('cuda')
253
- inputs.update({
254
- 'tokenizer': tokenizer,
255
- 'max_new_tokens': 100,
256
- 'decode_text': True,
257
- 'use_cache': False
258
- })
259
- with torch.no_grad():
260
- response = model.generate(**inputs)
261
- del inputs
262
- return response[0]
263
- except Exception as e:
264
- print(f"Error processing image: {str(e)}")
265
- return "Error processing image"
266
-
267
- def analyze_image_activities(image_path):
268
- """Analyze image using mPLUG model."""
269
- try:
270
- model, tokenizer, processor = load_model_and_tokenizer()
271
- prompt = ("Analyze this construction site image and describe the activities happening. "
272
- "Focus on construction activities, machinery usage, and worker actions.")
273
- response = process_image(image_path, model, tokenizer, processor, prompt)
274
- del model, tokenizer, processor
275
- torch.cuda.empty_cache() # Final cleanup after image processing
276
- gc.collect()
277
- return response
278
- except Exception as e:
279
- print(f"Error analyzing image: {str(e)}")
280
- return "Error analyzing image activities"
281
 
282
- def annotate_video_with_bboxes(video_path):
283
- """
284
- Reads the video frame-by-frame, runs YOLO, draws bounding boxes,
285
- writes a per-frame summary of detected classes on the frame, and saves
286
- the annotated video. Returns the annotated video path.
287
- """
288
- cap = cv2.VideoCapture(video_path)
289
- fps = cap.get(cv2.CAP_PROP_FPS)
290
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
291
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
292
- out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
293
- annotated_video_path = out_file.name
294
- out_file.close()
295
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
296
- writer = cv2.VideoWriter(annotated_video_path, fourcc, fps, (w, h))
297
- while True:
298
- ret, frame = cap.read()
299
- if not ret:
300
- break
301
- results = YOLO_MODEL(frame)
302
- frame_counts = {}
303
- for r in results:
304
- boxes = r.boxes
305
- for box in boxes:
306
- cls_id = int(box.cls[0])
307
- conf = float(box.conf[0])
308
- if conf < 0.5:
309
- continue
310
- x1, y1, x2, y2 = box.xyxy[0]
311
- class_name = YOLO_MODEL.names[cls_id]
312
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
313
- color = (0, 255, 0)
314
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
315
- label_text = f"{class_name} {conf:.2f}"
316
- cv2.putText(frame, label_text, (x1, y1 - 6),
317
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)
318
- frame_counts[class_name] = frame_counts.get(class_name, 0) + 1
319
- summary_str = ", ".join(f"{cls_name}: {count}" for cls_name, count in frame_counts.items())
320
- cv2.putText(frame, summary_str, (15, 30),
321
- cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2)
322
- writer.write(frame)
323
- cap.release()
324
- writer.release()
325
- return annotated_video_path
326
 
327
- # --------------------------------------------------------------------
328
- # Subprocess Worker: Executed when --chunk_inference flag is provided
329
- # --------------------------------------------------------------------
330
- if args.chunk_inference:
331
- # In worker mode, load the serialized frames from the input file
332
- try:
333
- with open(args.input_file, "rb") as f:
334
- frames_serialized = pickle.load(f)
335
- video_frames = []
336
- for img_bytes in frames_serialized:
337
- video_frames.append(Image.open(io.BytesIO(img_bytes)))
338
- except Exception as e:
339
- print(f"Error reading input frames: {str(e)}")
340
- sys.exit(1)
341
- try:
342
- model, tokenizer, processor = load_model_and_tokenizer()
343
- response = process_video_chunk(video_frames, model, tokenizer, processor, args.inference_prompt)
344
- with open(args.output_file, "w") as f:
345
- f.write(response)
346
- del model, tokenizer, processor
347
- torch.cuda.empty_cache()
348
- gc.collect()
349
- except Exception as e:
350
- with open(args.output_file, "w") as f:
351
- f.write(f"Error in chunk inference: {str(e)}")
352
- sys.exit(0)
353
-
354
- # --------------------------------------------------------------------
355
- # Main Video Analysis Function Using Subprocess Isolation
356
- # --------------------------------------------------------------------
357
- @spaces.GPU
358
- def analyze_video_activities_subprocess(video_path):
359
- """Analyze video by processing each chunk in a separate subprocess.
360
- Each subprocess loads a fresh model instance to avoid GPU memory buildup."""
361
- try:
362
- all_responses = []
363
- chunk_generator = encode_video_in_chunks(video_path)
364
- for chunk_idx, video_frames in chunk_generator:
365
- # Serialize each frame in the chunk to bytes
366
- temp_input = tempfile.NamedTemporaryFile(suffix=".pkl", delete=False)
367
- frames_serializable = []
368
- for img in video_frames:
369
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tf:
370
- img.save(tf, format="PNG")
371
- tf.seek(0)
372
- frames_serializable.append(tf.read())
373
- os.remove(tf.name)
374
- with open(temp_input.name, "wb") as f:
375
- pickle.dump(frames_serializable, f)
376
- # Create a temporary file for subprocess output
377
- temp_output = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
378
- temp_output.close()
379
- prompt = ("Analyze this construction site video chunk and describe the activities happening. "
380
- "Focus on construction activities, machinery usage, and worker actions.")
381
- # Launch subprocess for this chunk
382
- subprocess.run([
383
- sys.executable, __file__,
384
- "--chunk_inference",
385
- "--input_file", temp_input.name,
386
- "--output_file", temp_output.name,
387
- "--inference_prompt", prompt,
388
- "--model_path_arg", model_path,
389
- "--device", device
390
- ], check=True)
391
- with open(temp_output.name, "r") as f:
392
- response = f.read().strip()
393
- all_responses.append(f"Time period {chunk_idx + 1}:\n{response}")
394
- os.remove(temp_input.name)
395
- os.remove(temp_output.name)
396
- time.sleep(2) # Allow time for GPU memory to fully clear before next chunk
397
- return "\n\n".join(all_responses)
398
- except Exception as e:
399
- print(f"Error in subprocess chunk inference: {str(e)}")
400
- return "Error analyzing video activities"
401
-
402
- # --------------------------------------------------------------------
403
- # Gradio Interface and Main Launch (only executed in main process)
404
- # --------------------------------------------------------------------
405
- @spaces.GPU
406
  def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media):
407
- """Process the site diary entry."""
 
 
408
  if media is None:
409
- return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", "No media uploaded", None]
 
410
  try:
411
  if not hasattr(media, 'name'):
412
  raise ValueError("Invalid file upload")
 
413
  file_ext = get_file_extension(media.name)
414
  if not (is_image(media.name) or is_video(media.name)):
415
  raise ValueError(f"Unsupported file type: {file_ext}")
 
416
  with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file:
417
  temp_path = temp_file.name
418
  if hasattr(media, 'name') and os.path.exists(media.name):
@@ -421,53 +52,412 @@ def process_diary(day, date, total_people, total_machinery, machinery_types, act
421
  else:
422
  file_content = media.read() if hasattr(media, 'read') else media
423
  temp_file.write(file_content if isinstance(file_content, bytes) else file_content.read())
 
424
  detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(temp_path)
 
425
  annotated_video_path = None
426
- if is_image(media.name):
427
- detected_activities = analyze_image_activities(temp_path)
428
- else:
429
- # Use the subprocess-based video analysis for each chunk
430
- detected_activities = analyze_video_activities_subprocess(temp_path)
431
- annotated_video_path = annotate_video_with_bboxes(temp_path)
432
- if os.path.exists(temp_path):
433
- os.remove(temp_path)
 
 
 
 
434
  detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()])
435
- return [day, date, str(detected_people), str(detected_machinery), detected_types_str, detected_activities, annotated_video_path]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  except Exception as e:
437
  print(f"Error processing media: {str(e)}")
438
- return [day, date, "Error processing media", "Error processing media", "Error processing media", "Error processing media", None]
 
439
 
440
- with gr.Blocks(title="Digital Site Diary") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  gr.Markdown("# 📝 Digital Site Diary")
442
- with gr.Row():
443
- with gr.Column():
444
- gr.Markdown("### User Input")
445
- day = gr.Textbox(label="Day", value='9')
446
- date = gr.Textbox(label="Date", placeholder="YYYY-MM-DD", value=datetime.now().strftime("%Y-%m-%d"))
447
- total_people = gr.Number(label="Total Number of People", precision=0, value=10)
448
- total_machinery = gr.Number(label="Total Number of Machinery", precision=0, value=3)
449
- machinery_types = gr.Textbox(label="Number of Machinery Per Type",
450
- placeholder="e.g., Excavator: 2, Roller: 1",
451
- value="Excavator: 2, Roller: 1")
452
- activities = gr.Textbox(label="Activity",
453
- placeholder="e.g., 9 AM: Excavation, 10 AM: Concreting",
454
- value="9 AM: Excavation, 10 AM: Concreting", lines=3)
455
- media = gr.File(label="Upload Image/Video", file_types=["image", "video"])
456
- submit_btn = gr.Button("Submit", variant="primary")
457
- with gr.Column():
458
- gr.Markdown("### Model Detection")
459
- model_day = gr.Textbox(label="Day")
460
- model_date = gr.Textbox(label="Date")
461
- model_people = gr.Textbox(label="Total Number of People")
462
- model_machinery = gr.Textbox(label="Total Number of Machinery")
463
- model_machinery_types = gr.Textbox(label="Number of Machinery Per Type")
464
- model_activities = gr.Textbox(label="Activity", lines=5)
465
- model_annotated_video = gr.Video(label="Annotated Video")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  submit_btn.click(
467
  fn=process_diary,
468
  inputs=[day, date, total_people, total_machinery, machinery_types, activities, media],
469
- outputs=[model_day, model_date, model_people, model_machinery, model_machinery_types, model_activities, model_annotated_video]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
  if __name__ == "__main__":
473
- demo.launch(share=False, debug=True, show_api=False, server_port=args.port, server_name=args.host)
 
1
+ import gradio as gr
2
+ from datetime import datetime
3
+ import tempfile
4
  import os
5
+ import json
6
+ import torch
7
  import gc
 
 
 
 
 
 
8
  from PIL import Image
9
  from decord import VideoReader, cpu
10
+ from yolo_detection import (
11
+ detect_people_and_machinery,
12
+ annotate_video_with_bboxes,
13
+ is_image,
14
+ is_video
15
+ )
16
+ from image_captioning import (
17
+ analyze_image_activities,
18
+ analyze_video_activities,
19
+ process_video_chunk,
20
+ load_model_and_tokenizer,
21
+ MAX_NUM_FRAMES
22
  )
23
 
24
+ # Global storage for activities and media paths
25
+ global_activities = []
26
+ global_media_path = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # Create tmp directory for storing frames
29
+ tmp_dir = os.path.join('.', 'tmp')
30
+ os.makedirs(tmp_dir, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media):
33
+ """Process the site diary entry"""
34
+ global global_activities, global_media_path
35
+
36
  if media is None:
37
+ return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", None, None, [], None]
38
+
39
  try:
40
  if not hasattr(media, 'name'):
41
  raise ValueError("Invalid file upload")
42
+
43
  file_ext = get_file_extension(media.name)
44
  if not (is_image(media.name) or is_video(media.name)):
45
  raise ValueError(f"Unsupported file type: {file_ext}")
46
+
47
  with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file:
48
  temp_path = temp_file.name
49
  if hasattr(media, 'name') and os.path.exists(media.name):
 
52
  else:
53
  file_content = media.read() if hasattr(media, 'read') else media
54
  temp_file.write(file_content if isinstance(file_content, bytes) else file_content.read())
55
+
56
  detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(temp_path)
57
+ print(f"Detected people: {detected_people}, machinery: {detected_machinery}, types: {detected_machinery_types}")
58
  annotated_video_path = None
59
+
60
+ detected_activities = analyze_image_activities(temp_path) if is_image(media.name) else analyze_video_activities(temp_path)
61
+
62
+ print(f"Detected activities: {detected_activities}")
63
+
64
+ # Store activities and media path globally for chat mode
65
+ global_activities = detected_activities
66
+ global_media_path = temp_path
67
+
68
+ if is_video(media.name):
69
+ annotated_video_path = temp_path # Or use annotate_video_with_bboxes(temp_path) if implemented
70
+
71
  detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()])
72
+
73
+ # We'll return the activities as a list for the card display
74
+ # Clear the chat history when loading new media
75
+ chat_history = []
76
+
77
+ # Extract data for the activity table
78
+ activity_rows = []
79
+ for activity in detected_activities:
80
+ time = activity.get('time', 'Unknown')
81
+ summary = activity.get('summary', 'No description available')
82
+ activity_rows.append([time, summary])
83
+
84
+ return [day, date, str(detected_people), str(detected_machinery),
85
+ detected_types_str, gr.update(visible=True), annotated_video_path,
86
+ detected_activities, chat_history, activity_rows]
87
+
88
  except Exception as e:
89
  print(f"Error processing media: {str(e)}")
90
+ return [day, date, "Error processing media", "Error processing media",
91
+ "Error processing media", None, None, [], None, []]
92
 
93
+ def get_file_extension(filename):
94
+ return os.path.splitext(filename)[1].lower()
95
+
96
+ def on_card_click(activity_indices, history, evt: gr.SelectData):
97
+ """Handle clicking on an activity card in the gallery"""
98
+ global global_activities, global_media_path
99
+
100
+ # Get the index of the selected activity from the SelectData event
101
+ selected_idx = evt.index
102
+
103
+ # Map the gallery index to the actual activity index
104
+ if selected_idx < 0 or selected_idx >= len(activity_indices):
105
+ return [gr.update(visible=True), gr.update(visible=False), [], None]
106
+
107
+ card_idx = activity_indices[selected_idx]
108
+ print(f"Gallery item {selected_idx} clicked, corresponds to activity index: {card_idx}")
109
+
110
+ if card_idx < 0 or card_idx >= len(global_activities):
111
+ return [gr.update(visible=True), gr.update(visible=False), [], None]
112
+
113
+ selected_activity = global_activities[card_idx]
114
+ chunk_video_path = None
115
+
116
+ # Use the pre-saved chunk video if available
117
+ if 'chunk_path' in selected_activity and os.path.exists(selected_activity['chunk_path']):
118
+ chunk_video_path = selected_activity['chunk_path']
119
+ print(f"Using pre-saved chunk video: {chunk_video_path}")
120
+ else:
121
+ # Fallback to full video if chunk not available
122
+ chunk_video_path = global_media_path
123
+ print(f"Chunk video not available, using full video: {chunk_video_path}")
124
+
125
+ # Add the selected activity to chat history
126
+ history = []
127
+ history.append((None, f"🎬 Selected video at timestamp {selected_activity['time']}"))
128
+
129
+ # Add the thumbnail to the chat as a visual element
130
+ if 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']):
131
+ # Use the tuple format for images in chatbot
132
+ thumbnail_path = selected_activity['thumbnail']
133
+ history.append((None, f"📷 Video frame at {selected_activity['time']}"))
134
+ history.append((None, thumbnail_path))
135
+
136
+ # Format message about the detected activity
137
+ activity_info = f"I detected the following activity:\n\n{selected_activity['summary']}"
138
+ if selected_activity['objects']:
139
+ activity_info += f"\n\nIdentified objects: {', '.join(selected_activity['objects'])}"
140
+
141
+ history.append(("Tell me about this video segment", activity_info))
142
+
143
+ return [gr.update(visible=False), gr.update(visible=True), history, chunk_video_path]
144
+
145
+ def chat_with_video(message, history):
146
+ """Chat with the mPLUG model about the selected video segment"""
147
+ global global_activities, global_media_path
148
+
149
+ try:
150
+ # Get the selected activity from the history to identify which chunk we're discussing
151
+ selected_chunk_idx = None
152
+ selected_time = None
153
+ selected_activity = None
154
+
155
+ for entry in history:
156
+ if entry[0] is None and "Selected video at timestamp" in entry[1]:
157
+ time_str = entry[1].split("Selected video at timestamp ")[1]
158
+ selected_time = time_str.strip()
159
+ break
160
+
161
+ # Find the corresponding chunk
162
+ if selected_time:
163
+ for i, activity in enumerate(global_activities):
164
+ if activity.get('time') == selected_time:
165
+ selected_chunk_idx = activity.get('chunk_id')
166
+ selected_activity = activity
167
+ break
168
+
169
+ # If we found the chunk, use the model to analyze it
170
+ if selected_chunk_idx is not None and global_media_path and selected_activity:
171
+ # Load model
172
+ model, tokenizer, processor = load_model_and_tokenizer()
173
+
174
+ # Generate prompt based on user question and add context about what's in the video
175
+ context = f"This video shows construction site activities at timestamp {selected_time}."
176
+ if selected_activity.get('objects'):
177
+ context += f" The scene contains {', '.join(selected_activity.get('objects'))}."
178
+
179
+ prompt = f"{context} Analyze this segment of construction site video and answer this question: {message}"
180
+
181
+ # This would ideally use the specific chunk, but for simplicity we'll use the global path
182
+ # In a production system, you'd extract just that chunk of the video
183
+ vr = VideoReader(global_media_path, ctx=cpu(0))
184
+
185
+ # Get the frames for this chunk
186
+ sample_fps = round(vr.get_avg_fps() / 1)
187
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
188
+
189
+ # Extract frames for the specific chunk
190
+ chunk_size = MAX_NUM_FRAMES # From the constants in image_captioning.py
191
+ start_idx = selected_chunk_idx * chunk_size
192
+ end_idx = min(start_idx + chunk_size, len(frame_idx))
193
+
194
+ chunk_frames = frame_idx[start_idx:end_idx]
195
+ if chunk_frames:
196
+ frames = vr.get_batch(chunk_frames).asnumpy()
197
+ frames_pil = [Image.fromarray(v.astype('uint8')) for v in frames]
198
+
199
+ # Process frames with model
200
+ response = process_video_chunk(frames_pil, model, tokenizer, processor, prompt)
201
+
202
+
203
+ # If we couldn't save a frame, just return the text response
204
+ # Clean up
205
+ del model, tokenizer, processor
206
+ torch.cuda.empty_cache()
207
+ gc.collect()
208
+
209
+ return history + [(message, response)]
210
+ else:
211
+ return history + [(message, "Could not extract frames for this segment.")]
212
+ else:
213
+ # Fallback response if we can't identify the chunk
214
+ thumbnail = None
215
+ response_text = f"I'm analyzing your question about the video segment: {message}\n\nBased on what I can see in this segment, it appears to show construction activity with various machinery and workers on site. The specific details would depend on the exact timestamp you're referring to."
216
+
217
+ # Try to get a thumbnail from the selected activity if available
218
+ if selected_activity and 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']):
219
+ thumbnail = selected_activity['thumbnail']
220
+ new_history = history + [(message, response_text)]
221
+ new_history.append((None, f"📷 Video frame at {selected_time}"))
222
+ new_history.append((None, thumbnail))
223
+ return new_history
224
+
225
+ return history + [(message, response_text)]
226
+
227
+ except Exception as e:
228
+ print(f"Error in chat_with_video: {str(e)}")
229
+ return history + [(message, f"I encountered an error while processing your question. Let me try to answer based on what I can see: {message}\n\nThe video appears to show construction site activities, but I'm having trouble with the detailed analysis at the moment.")]
230
+
231
+ # Native Gradio activity cards
232
+ def create_activity_cards_ui(activities):
233
+ """Create activity cards using native Gradio components"""
234
+ if not activities:
235
+ return gr.HTML("<div class='activity-timeline'><h3>No activities detected</h3></div>"), []
236
+
237
+ # Prepare data for gallery
238
+ thumbnails = []
239
+ captions = []
240
+ activity_indices = []
241
+
242
+ for i, activity in enumerate(activities):
243
+ thumbnail = activity.get('thumbnail', '')
244
+ time = activity.get('time', 'Unknown')
245
+ summary = activity.get('summary', 'No description available')
246
+ objects_list = activity.get('objects', [])
247
+ objects_text = f"Objects: {', '.join(objects_list)}" if objects_list else ""
248
+
249
+ # Truncate summary if too long
250
+ if len(summary) > 150:
251
+ summary = summary[:147] + "..."
252
+
253
+ thumbnails.append(thumbnail)
254
+ captions.append(f"Timestamp: {time} | {summary}")
255
+ activity_indices.append(i)
256
+
257
+ # Create a gallery for the thumbnails
258
+ gallery = gr.Gallery(
259
+ value=[(path, caption) for path, caption in zip(thumbnails, captions)],
260
+ columns=5,
261
+ rows=None,
262
+ height="auto",
263
+ object_fit="contain",
264
+ label="Activity Timeline"
265
+ )
266
+
267
+ return gallery, activity_indices
268
+
269
+ # Create the Gradio interface
270
+ with gr.Blocks(title="Digital Site Diary", css="") as demo:
271
+
272
  gr.Markdown("# 📝 Digital Site Diary")
273
+
274
+ # Activity data and indices storage
275
+ activity_data = gr.State([])
276
+ activity_indices = gr.State([])
277
+
278
+ # Create tabs for different views
279
+ with gr.Tabs() as tabs:
280
+ with gr.Tab("Site Diary"):
281
+ with gr.Row():
282
+ # User Input Column
283
+ with gr.Column():
284
+ gr.Markdown("### User Input")
285
+ day = gr.Textbox(label="Day",value='9')
286
+ date = gr.Textbox(label="Date", placeholder="YYYY-MM-DD", value=datetime.now().strftime("%Y-%m-%d"))
287
+ total_people = gr.Number(label="Total Number of People", precision=0, value=10)
288
+ total_machinery = gr.Number(label="Total Number of Machinery", precision=0, value=3)
289
+ machinery_types = gr.Textbox(
290
+ label="Number of Machinery Per Type",
291
+ placeholder="e.g., Excavator: 2, Roller: 1",
292
+ value="Excavator: 2, Roller: 1"
293
+ )
294
+ activities = gr.Textbox(
295
+ label="Activity",
296
+ placeholder="e.g., 9 AM: Excavation, 10 AM: Concreting",
297
+ value="9 AM: Excavation, 10 AM: Concreting",
298
+ lines=3
299
+ )
300
+ media = gr.File(label="Upload Image/Video", file_types=["image", "video"])
301
+ submit_btn = gr.Button("Submit", variant="primary")
302
+
303
+ # Model Detection Column
304
+ with gr.Column():
305
+ gr.Markdown("### Model Detection")
306
+ model_day = gr.Textbox(label="Day")
307
+ model_date = gr.Textbox(label="Date")
308
+ model_people = gr.Textbox(label="Total Number of People")
309
+ model_machinery = gr.Textbox(label="Total Number of Machinery")
310
+ model_machinery_types = gr.Textbox(label="Number of Machinery Per Type")
311
+ # Activity Row with Timestamps
312
+ with gr.Row():
313
+ gr.Markdown("#### Activities with Timestamps")
314
+ model_activities = gr.Dataframe(
315
+ headers=["Time", "Activity Description"],
316
+ datatype=["str", "str"],
317
+ label="Detected Activities",
318
+ interactive=False,
319
+ wrap=True
320
+ )
321
+
322
+ # Activity timeline section
323
+ with gr.Row():
324
+ # Timeline View (default visible)
325
+ with gr.Column(visible=True) as timeline_view:
326
+ activity_gallery = gr.Gallery(label="Activity Timeline")
327
+ model_annotated_video = gr.Video(label="Full Video")
328
+
329
+ # Chat View (initially hidden)
330
+ with gr.Column(visible=False) as chat_view:
331
+ chunk_video = gr.Video(label="Chunk video")
332
+ chatbot = gr.Chatbot(height=400)
333
+ chat_input = gr.Textbox(
334
+ placeholder="Ask about this video segment...",
335
+ show_label=False
336
+ )
337
+ back_btn = gr.Button("← Back to Timeline")
338
+
339
+ # Connect the submit button to the processing function
340
  submit_btn.click(
341
  fn=process_diary,
342
  inputs=[day, date, total_people, total_machinery, machinery_types, activities, media],
343
+ outputs=[
344
+ model_day,
345
+ model_date,
346
+ model_people,
347
+ model_machinery,
348
+ model_machinery_types,
349
+ timeline_view,
350
+ model_annotated_video,
351
+ activity_data,
352
+ chatbot,
353
+ model_activities
354
+ ]
355
+ )
356
+
357
+ # Process activity data into gallery
358
+ activity_data.change(
359
+ fn=create_activity_cards_ui,
360
+ inputs=[activity_data],
361
+ outputs=[activity_gallery, activity_indices]
362
+ )
363
+
364
+ # Handle gallery selection
365
+ activity_gallery.select(
366
+ fn=on_card_click,
367
+ inputs=[activity_indices, chatbot],
368
+ outputs=[timeline_view, chat_view, chatbot, chunk_video]
369
+ )
370
+
371
+ # Chat submission
372
+ chat_input.submit(
373
+ fn=chat_with_video,
374
+ inputs=[chat_input, chatbot],
375
+ outputs=[chatbot]
376
+ )
377
+
378
+ # Back button
379
+ back_btn.click(
380
+ fn=lambda: [gr.update(visible=True), gr.update(visible=False)],
381
+ inputs=None,
382
+ outputs=[timeline_view, chat_view]
383
  )
384
+
385
+ # Add enhanced CSS styling
386
+ gr.HTML("""
387
+ <style>
388
+ /* Gallery customizations */
389
+ .gradio-container .gallery-item {
390
+ border: 1px solid #444444 !important;
391
+ border-radius: 8px !important;
392
+ padding: 8px !important;
393
+ margin: 10px !important;
394
+ cursor: pointer !important;
395
+ transition: all 0.3s !important;
396
+ background: #18181b !important;
397
+ box-shadow: 0 2px 5px rgba(0,0,0,0.2) !important;
398
+ }
399
+
400
+ .gradio-container .gallery-item:hover {
401
+ transform: translateY(-2px) !important;
402
+ box-shadow: 0 4px 12px rgba(0,0,0,0.25) !important;
403
+ border-color: #007bff !important;
404
+ background: #202025 !important;
405
+ }
406
+
407
+ .gradio-container .gallery-item.selected {
408
+ border: 2px solid #007bff !important;
409
+ background: #202030 !important;
410
+ }
411
+
412
+ /* Improved image display */
413
+ .gradio-container .gallery-item img {
414
+ height: 180px !important;
415
+ object-fit: cover !important;
416
+ border-radius: 4px !important;
417
+ border: 1px solid #444444 !important;
418
+ margin-bottom: 8px !important;
419
+ }
420
+
421
+ /* Caption styling */
422
+ .gradio-container .caption {
423
+ color: #e0e0e0 !important;
424
+ font-size: 0.9em !important;
425
+ margin-top: 8px !important;
426
+ line-height: 1.4 !important;
427
+ padding: 0 4px !important;
428
+ }
429
+
430
+ /* Gallery container */
431
+ .gradio-container [id*='gallery'] > div:first-child {
432
+ background-color: #27272a !important;
433
+ padding: 15px !important;
434
+ border-radius: 10px !important;
435
+ }
436
+
437
+ /* Chatbot styling */
438
+ .gradio-container .chatbot {
439
+ background-color: #27272a !important;
440
+ border-radius: 10px !important;
441
+ border: 1px solid #444444 !important;
442
+ }
443
+
444
+ .gradio-container .chatbot .message.user {
445
+ background-color: #18181b !important;
446
+ border-radius: 8px !important;
447
+ }
448
+
449
+ .gradio-container .chatbot .message.bot {
450
+ background-color: #202030 !important;
451
+ border-radius: 8px !important;
452
+ }
453
+
454
+ /* Button styling */
455
+ .gradio-container button.secondary {
456
+ background-color: #3d4452 !important;
457
+ color: white !important;
458
+ }
459
+ </style>
460
+ """)
461
 
462
  if __name__ == "__main__":
463
+ demo.launch(share=True, allowed_paths=["./tmp"])