Spaces:
Runtime error
Runtime error
| import spaces | |
| import torch | |
| import argparse | |
| import os | |
| import sys | |
| import pickle # For serializing frames | |
| import gc | |
| import tempfile | |
| import subprocess | |
| from datetime import datetime | |
| from transformers import AutoModel, AutoTokenizer | |
| from modelscope.hub.snapshot_download import snapshot_download | |
| from PIL import Image | |
| from decord import VideoReader, cpu | |
| import cv2 | |
| import gradio as gr | |
| from ultralytics import YOLO | |
| import numpy as np | |
| import io | |
| from azure.storage.blob import BlobServiceClient | |
| # Install flash-attn (using prebuilt wheel mode if needed) | |
| subprocess.run( | |
| 'pip install flash-attn --no-build-isolation', | |
| env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': 'TRUE'}, shell=True | |
| ) | |
| # -------------------------------------------------------------------- | |
| # Command-line arguments | |
| # -------------------------------------------------------------------- | |
| parser = argparse.ArgumentParser(description='demo') | |
| parser.add_argument('--device', type=str, default='cuda', help='cuda or mps') | |
| parser.add_argument("--host", type=str, default="0.0.0.0") | |
| parser.add_argument("--port", type=int) | |
| # New arguments for subprocess inference (unused in this version) | |
| parser.add_argument("--chunk_inference", action="store_true", help="Run inference on a chunk (subprocess mode).") | |
| parser.add_argument("--input_file", type=str, help="Path to serialized input chunk frames.") | |
| parser.add_argument("--output_file", type=str, help="Path to file where inference result is written.") | |
| parser.add_argument("--inference_prompt", type=str, help="Inference prompt for the chunk.") | |
| parser.add_argument("--model_path_arg", type=str, help="Model path for the subprocess.") | |
| args = parser.parse_args() | |
| device = args.device | |
| assert device in ['cuda', 'mps'] | |
| # Global model configuration | |
| MODEL_NAME = 'iic/mPLUG-Owl3-7B-240728' | |
| MODEL_CACHE_DIR = '/data/models' | |
| os.makedirs(MODEL_CACHE_DIR, exist_ok=True) | |
| # Download and cache the model (only in the main process) | |
| if not args.chunk_inference: | |
| try: | |
| model_path = snapshot_download(MODEL_NAME, cache_dir=MODEL_CACHE_DIR) | |
| except Exception as e: | |
| print(f"Error downloading model: {str(e)}") | |
| model_path = os.path.join(MODEL_CACHE_DIR, MODEL_NAME) | |
| else: | |
| # In worker mode, use the passed model path (unused here) | |
| model_path = args.model_path_arg | |
| MAX_NUM_FRAMES = 32 | |
| # Initialize YOLO model (assumed to be lightweight) | |
| YOLO_MODEL = YOLO('./best_yolov11.pt') # Load YOLOv11 model | |
| # File type validation | |
| IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} | |
| VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'} | |
| def get_file_extension(filename): | |
| return os.path.splitext(filename)[1].lower() | |
| def is_image(filename): | |
| return get_file_extension(filename) in IMAGE_EXTENSIONS | |
| def is_video(filename): | |
| return get_file_extension(filename) in VIDEO_EXTENSIONS | |
| # -------------------------------------------------------------------- | |
| # Azure Blob Storage Setup | |
| CONTAINER_NAME = "logs" # Replace with your actual container name | |
| connection_string ="BlobEndpoint=https://assentian.blob.core.windows.net/;QueueEndpoint=https://assentian.queue.core.windows.net/;FileEndpoint=https://assentian.file.core.windows.net/;TableEndpoint=https://assentian.table.core.windows.net/;SharedAccessSignature=sv=2024-11-04&ss=bfqt&srt=sco&sp=rwdlacupiytfx&se=2025-04-30T17:16:18Z&st=2025-04-22T09:16:18Z&spr=https&sig=AkJb79C%2FJ0G1HqfotIYuSfm%2Fb%2BQ2E%2FjvxV3ZG7ejVQo%3D" | |
| if not connection_string: | |
| raise ValueError("Azure storage connection string not found in environment variables. Please set AZURE_STORAGE_CONNECTION_STRING.") | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| def list_blobs(): | |
| """List video blobs in the specified container.""" | |
| try: | |
| container_client = blob_service_client.get_container_client(CONTAINER_NAME) | |
| blobs = container_client.list_blobs() | |
| video_extensions = ['.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'] | |
| return [blob.name for blob in blobs if any(blob.name.lower().endswith(ext) for ext in video_extensions)] | |
| except Exception as e: | |
| print(f"Error listing blobs: {str(e)}") | |
| return [] | |
| # Fetch blob names at startup | |
| blob_names = list_blobs() | |
| # -------------------------------------------------------------------- | |
| # Model Loading and Inference Functions | |
| # -------------------------------------------------------------------- | |
| def load_model_and_tokenizer(): | |
| """Load a fresh instance of the model and tokenizer.""" | |
| try: | |
| # Clear GPU memory if using CUDA | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| model = AutoModel.from_pretrained( | |
| model_path, | |
| attn_implementation='sdpa', | |
| trust_remote_code=True, | |
| torch_dtype=torch.half, | |
| device_map='auto' | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
| model.eval() | |
| processor = model.init_processor(tokenizer) | |
| return model, tokenizer, processor | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| raise | |
| def process_video_chunk(video_frames, model, tokenizer, processor, prompt): | |
| """Process a chunk of video frames with mPLUG model.""" | |
| messages = [{ | |
| "role": "user", | |
| "content": prompt, | |
| "video_frames": video_frames | |
| }] | |
| model_messages = [] | |
| videos = [] | |
| for msg in messages: | |
| content_str = msg["content"] | |
| if "video_frames" in msg and msg["video_frames"]: | |
| content_str += "<|video|>" | |
| videos.append(msg["video_frames"]) | |
| model_messages.append({"role": msg["role"], "content": content_str}) | |
| model_messages.append({"role": "assistant", "content": ""}) | |
| inputs = processor( | |
| model_messages, | |
| images=None, | |
| videos=videos if videos else None | |
| ) | |
| inputs.to('cuda') | |
| inputs.update({ | |
| 'tokenizer': tokenizer, | |
| 'max_new_tokens': 100, | |
| 'decode_text': True, | |
| }) | |
| response = model.generate(**inputs) | |
| return response[0] | |
| # -------------------------------------------------------------------- | |
| # Video and YOLO functions | |
| # -------------------------------------------------------------------- | |
| def encode_video_in_chunks(video_path): | |
| """Extract frames from a video in chunks.""" | |
| vr = VideoReader(video_path, ctx=cpu(0)) | |
| sample_fps = round(vr.get_avg_fps() / 1) # 1 FPS | |
| frame_idx = [i for i in range(0, len(vr), sample_fps)] | |
| chunks = [frame_idx[i:i + MAX_NUM_FRAMES] for i in range(0, len(frame_idx), MAX_NUM_FRAMES)] | |
| for chunk_idx, chunk in enumerate(chunks): | |
| frames = vr.get_batch(chunk).asnumpy() | |
| frames = [Image.fromarray(v.astype('uint8')) for v in frames] | |
| yield chunk_idx, frames | |
| def process_yolo_results(results): | |
| """Process YOLO detection results and count people and machinery.""" | |
| people_count = 0 | |
| machine_types = { | |
| "Tower Crane": 0, "Mobile Crane": 0, "Compactor/Roller": 0, "Bulldozer": 0, | |
| "Excavator": 0, "Dump Truck": 0, "Concrete Mixer": 0, "Loader": 0, | |
| "Pump Truck": 0, "Pile Driver": 0, "Grader": 0, "Other Vehicle": 0 | |
| } | |
| for r in results: | |
| boxes = r.boxes | |
| for box in boxes: | |
| cls = int(box.cls[0]) | |
| conf = float(box.conf[0]) | |
| class_name = YOLO_MODEL.names[cls] | |
| if class_name.lower() == 'worker' and conf > 0.5: | |
| people_count += 1 | |
| machinery_mapping = { | |
| 'tower_crane': "Tower Crane", | |
| 'mobile_crane': "Mobile Crane", | |
| 'compactor': "Compactor/Roller", | |
| 'roller': "Compactor/Roller", | |
| 'bulldozer': "Bulldozer", | |
| 'dozer': "Bulldozer", | |
| 'excavator': "Excavator", | |
| 'dump_truck': "Dump Truck", | |
| 'truck': "Dump Truck", | |
| 'concrete_mixer_truck': "Concrete Mixer", | |
| 'loader': "Loader", | |
| 'pump_truck': "Pump Truck", | |
| 'pile_driver': "Pile Driver", | |
| 'grader': "Grader", | |
| 'other_vehicle': "Other Vehicle" | |
| } | |
| if conf > 0.5: | |
| class_lower = class_name.lower() | |
| for key, value in machinery_mapping.items(): | |
| if key in class_lower: | |
| machine_types[value] += 1 | |
| break | |
| total_machinery = sum(machine_types.values()) | |
| return people_count, total_machinery, machine_types | |
| def detect_people_and_machinery(media_path): | |
| """Detect people and machinery using YOLOv11 for both images and videos.""" | |
| try: | |
| max_people_count = 0 | |
| max_machine_types = { | |
| "Tower Crane": 0, "Mobile Crane": 0, "Compactor/Roller": 0, "Bulldozer": 0, | |
| "Excavator": 0, "Dump Truck": 0, "Concrete Mixer": 0, "Loader": 0, | |
| "Pump Truck": 0, "Pile Driver": 0, "Grader": 0, "Other Vehicle": 0 | |
| } | |
| if isinstance(media_path, str) and is_video(media_path): | |
| cap = cv2.VideoCapture(media_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| sample_rate = max(1, int(fps)) | |
| frame_count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if frame_count % sample_rate == 0: | |
| results = YOLO_MODEL(frame) | |
| people, _, machine_types = process_yolo_results(results) | |
| max_people_count = max(max_people_count, people) | |
| for k, v in machine_types.items(): | |
| max_machine_types[k] = max(max_machine_types[k], v) | |
| frame_count += 1 | |
| cap.release() | |
| else: | |
| if isinstance(media_path, str): | |
| img = cv2.imread(media_path) | |
| else: | |
| img = cv2.cvtColor(np.array(media_path), cv2.COLOR_RGB2BGR) | |
| results = YOLO_MODEL(img) | |
| max_people_count, _, max_machine_types = process_yolo_results(results) | |
| max_machine_types = {k: v for k, v in max_machine_types.items() if v > 0} | |
| total_machinery_count = sum(max_machine_types.values()) | |
| return max_people_count, total_machinery_count, max_machine_types | |
| except Exception as e: | |
| print(f"Error in YOLO detection: {str(e)}") | |
| return 0, 0, {} | |
| def process_image(image_path, model, tokenizer, processor, prompt): | |
| """Process single image with mPLUG model.""" | |
| try: | |
| image = Image.open(image_path) | |
| messages = [{ | |
| "role": "user", | |
| "content": prompt, | |
| "images": [image] | |
| }] | |
| model_messages = [] | |
| images = [] | |
| for msg in messages: | |
| content_str = msg["content"] | |
| if "images" in msg and msg["images"]: | |
| content_str += "<|image|>" | |
| images.extend(msg["images"]) | |
| model_messages.append({"role": msg["role"], "content": content_str}) | |
| model_messages.append({"role": "assistant", "content": ""}) | |
| inputs = processor(model_messages, images=images, videos=None) | |
| inputs.to('cuda') | |
| inputs.update({ | |
| 'tokenizer': tokenizer, | |
| 'max_new_tokens': 100, | |
| 'decode_text': True, | |
| }) | |
| response = model.generate(**inputs) | |
| return response[0] | |
| except Exception as e: | |
| print(f"Error processing image: {str(e)}") | |
| return "Error processing image" | |
| def analyze_image_activities(image_path): | |
| """Analyze image using mPLUG model.""" | |
| try: | |
| model, tokenizer, processor = load_model_and_tokenizer() | |
| prompt = ("Analyze this construction site image and describe the activities happening. " | |
| "Focus on construction activities, machinery usage, and worker actions.") | |
| response = process_image(image_path, model, tokenizer, processor, prompt) | |
| del model, tokenizer, processor | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return response | |
| except Exception as e: | |
| print(f"Error analyzing image: {str(e)}") | |
| return "Error analyzing image activities" | |
| def annotate_video_with_bboxes(video_path): | |
| """ | |
| Reads the video frame-by-frame, runs YOLO, draws bounding boxes, | |
| writes a per-frame summary of detected classes on the frame, and saves | |
| the annotated video. Returns the annotated video path. | |
| """ | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
| annotated_video_path = out_file.name | |
| out_file.close() | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| writer = cv2.VideoWriter(annotated_video_path, fourcc, fps, (w, h)) | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| results = YOLO_MODEL(frame) | |
| frame_counts = {} | |
| for r in results: | |
| boxes = r.boxes | |
| for box in boxes: | |
| cls_id = int(box.cls[0]) | |
| conf = float(box.conf[0]) | |
| if conf < 0.5: | |
| continue | |
| x1, y1, x2, y2 = box.xyxy[0] | |
| class_name = YOLO_MODEL.names[cls_id] | |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
| color = (0, 255, 0) | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
| label_text = f"{class_name} {conf:.2f}" | |
| cv2.putText(frame, label_text, (x1, y1 - 6), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1) | |
| frame_counts[class_name] = frame_counts.get(class_name, 0) + 1 | |
| summary_str = ", ".join(f"{cls_name}: {count}" for cls_name, count in frame_counts.items()) | |
| cv2.putText(frame, summary_str, (15, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2) | |
| writer.write(frame) | |
| cap.release() | |
| writer.release() | |
| return annotated_video_path | |
| # -------------------------------------------------------------------- | |
| # Adjusted Video Analysis with Shared mPLUG Instance and VRAM Monitoring | |
| # -------------------------------------------------------------------- | |
| def analyze_video_activities(video_path): | |
| """Analyze video using mPLUG model with chunking. | |
| Reuse the same model instance across chunks and reload only if VRAM exceeds 50%.""" | |
| try: | |
| print(f"Analyzing video: {video_path}") | |
| all_responses = [] | |
| chunk_generator = encode_video_in_chunks(video_path) | |
| print("Video chunk generator initialized.") | |
| # Load model instance once | |
| model, tokenizer, processor = load_model_and_tokenizer() | |
| print("Model loaded.") | |
| for chunk_idx, video_frames in chunk_generator: | |
| prompt = ("Analyze this construction site video chunk and describe the activities happening. " | |
| "Focus on construction activities, machinery usage, and worker actions.") | |
| response = process_video_chunk(video_frames, model, tokenizer, processor, prompt) | |
| all_responses.append(f"Time period {chunk_idx + 1}:\n{response}") | |
| print(f"Processed chunk {chunk_idx + 1}: {response}") | |
| return "\n\n".join(all_responses) | |
| except Exception as e: | |
| print(f"Error analyzing video: {str(e)}") | |
| return "Error analyzing video activities" | |
| # -------------------------------------------------------------------- | |
| # Gradio Interface and Main Launch (only executed in main process) | |
| # -------------------------------------------------------------------- | |
| def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media_source, local_file, azure_blob): | |
| """Process the site diary entry with media from local file or Azure Blob Storage.""" | |
| import torch | |
| import gc | |
| # Initialize media_path | |
| media_path = None | |
| try: | |
| if media_source == "Local File": | |
| if local_file is None: | |
| return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", "No media uploaded", None] | |
| media_path = local_file # local_file is a string path in Gradio | |
| else: # Azure Blob | |
| if not azure_blob: | |
| return [day, date, "No blob selected", "No blob selected", "No blob selected", "No blob selected", None] | |
| try: | |
| blob_client = blob_service_client.get_blob_client(container=CONTAINER_NAME, blob=azure_blob) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(azure_blob)[1]) as temp_file: | |
| temp_path = temp_file.name | |
| blob_data = blob_client.download_blob() | |
| blob_data.readinto(temp_file) | |
| media_path = temp_path | |
| except Exception as e: | |
| print(f"Error downloading blob: {str(e)}") | |
| return [day, date, "Error downloading blob", "Error downloading blob", "Error downloading blob", "Error downloading blob", None] | |
| # Process the media file | |
| file_ext = get_file_extension(media_path) | |
| if not (is_image(media_path) or is_video(media_path)): | |
| raise ValueError(f"Unsupported file type: {file_ext}") | |
| detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(media_path) | |
| print(f"Detected people: {detected_people}, Detected machinery: {detected_machinery}, Machinery types: {detected_machinery_types}") | |
| print("Analyzing activities...") | |
| print(f"Media type: {'Image' if is_image(media_path) else 'Video'}") | |
| print(f"Media path: {media_path}") | |
| annotated_video_path = None | |
| if is_image(media_path): | |
| detected_activities = analyze_image_activities(media_path) | |
| else: | |
| detected_activities = analyze_video_activities(media_path) | |
| annotated_video_path = annotate_video_with_bboxes(media_path) | |
| detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()]) | |
| return [day, date, str(detected_people), str(detected_machinery), detected_types_str, detected_activities, annotated_video_path] | |
| except Exception as e: | |
| print(f"Error processing media: {str(e)}") | |
| return [day, date, "Error processing media", "Error processing media", "Error processing media", "Error processing media", None] | |
| finally: | |
| # Clean up GPU memory | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Remove temporary file if downloaded from Azure | |
| if media_source == "Azure Blob" and media_path and os.path.exists(media_path): | |
| try: | |
| os.remove(media_path) | |
| print(f"Removed temporary file: {media_path}") | |
| except Exception as e: | |
| print(f"Error removing temporary file: {str(e)}") | |
| with gr.Blocks(title="Digital Site Diary") as demo: | |
| gr.Markdown("# 📝 Digital Site Diary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### User Input") | |
| day = gr.Textbox(label="Day", value='9') | |
| date = gr.Textbox(label="Date", placeholder="YYYY-MM-DD", value=datetime.now().strftime("%Y-%m-%d")) | |
| total_people = gr.Number(label="Total Number of People", precision=0, value=10) | |
| total_machinery = gr.Number(label="Total Number of Machinery", precision=0, value=3) | |
| machinery_types = gr.Textbox(label="Number of Machinery Per Type", | |
| placeholder="e.g., Excavator: 2, Roller: 1", | |
| value="Excavator: 2, Roller: 1") | |
| activities = gr.Textbox(label="Activity", | |
| placeholder="e.g., 9 AM: Excavation, 10 AM: Concreting", | |
| value="9 AM: Excavation, 10 AM: Concreting", lines=3) | |
| media_source = gr.Radio(["Local File", "Azure Blob"], label="Media Source", value="Local File") | |
| local_file = gr.File(label="Upload Image/Video", file_types=["image", "video"], visible=True) | |
| azure_blob = gr.Dropdown(label="Select Video from Azure", choices=blob_names, visible=False) | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown("### Model Detection") | |
| model_day = gr.Textbox(label="Day") | |
| model_date = gr.Textbox(label="Date") | |
| model_people = gr.Textbox(label="Total Number of People") | |
| model_machinery = gr.Textbox(label="Total Executable") | |
| model_machinery_types = gr.Textbox(label="Number of Machinery Per Type") | |
| model_activities = gr.Textbox(label="Activity", lines=5) | |
| model_annotated_video = gr.Video(label="Annotated Video") | |
| # Update visibility based on media source | |
| def update_visibility(source): | |
| if source == "Local File": | |
| return gr.update(visible=True), gr.update(visible=False) | |
| else: | |
| return gr.update(visible=False), gr.update(visible=True) | |
| media_source.change(fn=update_visibility, inputs=media_source, outputs=[local_file, azure_blob]) | |
| submit_btn.click( | |
| fn=process_diary, | |
| inputs=[day, date, total_people, total_machinery, machinery_types, activities, media_source, local_file, azure_blob], | |
| outputs=[model_day, model_date, model_people, model_machinery, model_machinery_types, model_activities, model_annotated_video] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False, debug=True, show_api=False, server_port=args.port, server_name=args.host) |