Spaces:
Sleeping
Sleeping
| import spaces | |
| import torch | |
| from datetime import datetime | |
| from transformers import AutoModel, AutoTokenizer | |
| import gradio as gr | |
| from PIL import Image | |
| from decord import VideoReader, cpu | |
| import os | |
| import gc | |
| import tempfile | |
| from ultralytics import YOLO | |
| import numpy as np | |
| import cv2 | |
| from modelscope.hub.snapshot_download import snapshot_download | |
| # Fix GLIBCXX dependency | |
| os.environ['LD_LIBRARY_PATH'] = '/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH' | |
| # Initialize GPU | |
| def initialize_gpu(): | |
| if torch.cuda.is_available(): | |
| torch.randn(10).cuda() | |
| initialize_gpu() | |
| # Load YOLO model with error handling | |
| try: | |
| YOLO_MODEL = YOLO('best_yolov11.pt') | |
| except Exception as e: | |
| raise RuntimeError(f"YOLO model loading failed: {str(e)}") | |
| # Model configuration | |
| MODEL_NAME = 'iic/mPLUG-Owl3-7B-240728' | |
| try: | |
| model_dir = snapshot_download(MODEL_NAME, | |
| cache_dir='./models', | |
| revision='v1.0.0') # Verified working revision | |
| except Exception as e: | |
| raise RuntimeError(f"Model download failed: {str(e)}") | |
| # Device setup | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # File validation | |
| IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} | |
| VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'} | |
| def get_file_extension(filename): | |
| return os.path.splitext(filename)[1].lower() | |
| def is_image(filename): | |
| return get_file_extension(filename) in IMAGE_EXTENSIONS | |
| def is_video(filename): | |
| return get_file_extension(filename) in VIDEO_EXTENSIONS | |
| def load_model_and_tokenizer(): | |
| """Load 8-bit quantized model with memory optimizations""" | |
| try: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| model = AutoModel.from_pretrained( | |
| model_dir, | |
| attn_implementation='sdpa', | |
| trust_remote_code=True, | |
| load_in_8bit=True, | |
| device_map="auto", | |
| torch_dtype=torch.float16 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_dir, | |
| trust_remote_code=True | |
| ) | |
| processor = model.init_processor(tokenizer) | |
| model.eval() | |
| return model, tokenizer, processor | |
| except Exception as e: | |
| print(f"Model loading error: {str(e)}") | |
| raise | |
| def process_yolo_results(results): | |
| """Process YOLO detection results with safety checks""" | |
| machinery_mapping = { | |
| 'tower_crane': "Tower Crane", | |
| 'mobile_crane': "Mobile Crane", | |
| 'compactor': "Compactor/Roller", | |
| 'roller': "Compactor/Roller", | |
| 'bulldozer': "Bulldozer", | |
| 'dozer': "Bulldozer", | |
| 'excavator': "Excavator", | |
| 'dump_truck': "Dump Truck", | |
| 'truck': "Dump Truck", | |
| 'concrete_mixer_truck': "Concrete Mixer", | |
| 'loader': "Loader", | |
| 'pump_truck': "Pump Truck", | |
| 'pile_driver': "Pile Driver", | |
| 'grader': "Grader", | |
| 'other_vehicle': "Other Vehicle" | |
| } | |
| counts = {"Worker": 0, **{v: 0 for v in machinery_mapping.values()}} | |
| try: | |
| for r in results: | |
| for box in r.boxes: | |
| if box.conf.item() < 0.5: | |
| continue | |
| cls_name = YOLO_MODEL.names[int(box.cls.item())].lower() | |
| if cls_name == 'worker': | |
| counts["Worker"] += 1 | |
| continue | |
| for key, value in machinery_mapping.items(): | |
| if key in cls_name: | |
| counts[value] += 1 | |
| break | |
| except Exception as e: | |
| print(f"YOLO processing error: {str(e)}") | |
| return counts["Worker"], sum(counts.values()) - counts["Worker"], counts | |
| def detect_people_and_machinery(media_path): | |
| """GPU-accelerated detection with memory management""" | |
| try: | |
| max_people = 0 | |
| max_machines = {k: 0 for k in [ | |
| "Tower Crane", "Mobile Crane", "Compactor/Roller", "Bulldozer", | |
| "Excavator", "Dump Truck", "Concrete Mixer", "Loader", | |
| "Pump Truck", "Pile Driver", "Grader", "Other Vehicle" | |
| ]} | |
| if isinstance(media_path, str) and is_video(media_path): | |
| cap = cv2.VideoCapture(media_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| sample_rate = max(1, int(fps)) | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if cap.get(cv2.CAP_PROP_POS_FRAMES) % sample_rate == 0: | |
| results = YOLO_MODEL(frame) | |
| people, machines, types = process_yolo_results(results) | |
| max_people = max(max_people, people) | |
| for k in max_machines: | |
| max_machines[k] = max(max_machines[k], types.get(k, 0)) | |
| cap.release() | |
| else: | |
| img = cv2.imread(media_path) if isinstance(media_path, str) else cv2.cvtColor(np.array(media_path), cv2.COLOR_RGB2BGR) | |
| results = YOLO_MODEL(img) | |
| max_people, _, types = process_yolo_results(results) | |
| for k in max_machines: | |
| max_machines[k] = types.get(k, 0) | |
| filtered = {k: v for k, v in max_machines.items() if v > 0} | |
| return max_people, sum(filtered.values()), filtered | |
| except Exception as e: | |
| print(f"Detection error: {str(e)}") | |
| return 0, 0, {} | |
| def analyze_video_activities(video_path): | |
| """Video analysis with chunk processing and memory cleanup""" | |
| try: | |
| model, tokenizer, processor = load_model_and_tokenizer() | |
| responses = [] | |
| vr = VideoReader(video_path, ctx=cpu(0)) | |
| frame_step = max(1, int(vr.get_avg_fps())) | |
| total_frames = len(vr) | |
| # Process in 16-frame chunks | |
| for i in range(0, total_frames, 16): | |
| end_idx = min(i+16, total_frames) | |
| frames = [Image.fromarray(vr[j].asnumpy()) for j in range(i, end_idx)] | |
| inputs = processor( | |
| [{"role": "user", "content": "Analyze construction activities", "video_frames": frames}], | |
| videos=[frames] | |
| ).to(DEVICE) | |
| response = model.generate(**inputs, max_new_tokens=200) | |
| responses.append(response[0]) | |
| del frames, inputs | |
| torch.cuda.empty_cache() | |
| del model, tokenizer, processor | |
| return "\n".join(responses) | |
| except Exception as e: | |
| print(f"Video analysis error: {str(e)}") | |
| return "Activity analysis unavailable" | |
| def analyze_image_activities(image_path): | |
| """Image analysis with memory cleanup""" | |
| try: | |
| model, tokenizer, processor = load_model_and_tokenizer() | |
| image = Image.open(image_path).convert("RGB") | |
| inputs = processor( | |
| [{"role": "user", "content": "Analyze construction site", "images": [image]}], | |
| images=[image] | |
| ).to(DEVICE) | |
| response = model.generate(**inputs, max_new_tokens=200) | |
| del model, tokenizer, processor, image, inputs | |
| torch.cuda.empty_cache() | |
| return response[0] | |
| except Exception as e: | |
| print(f"Image analysis error: {str(e)}") | |
| return "Activity analysis unavailable" | |
| def annotate_video_with_bboxes(video_path): | |
| """Video annotation with efficient frame processing""" | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
| writer = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) | |
| frame_count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Process every 5th frame to reduce load | |
| if frame_count % 5 == 0: | |
| results = YOLO_MODEL(frame) | |
| counts = {} | |
| for r in results: | |
| for box in r.boxes: | |
| if box.conf.item() < 0.5: | |
| continue | |
| cls_id = int(box.cls.item()) | |
| class_name = YOLO_MODEL.names[cls_id] | |
| counts[class_name] = counts.get(class_name, 0) + 1 | |
| x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), (0,255,0), 2) | |
| cv2.putText(frame, f"{class_name} {box.conf.item():.2f}", | |
| (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1) | |
| summary = ", ".join([f"{k}:{v}" for k,v in counts.items()]) | |
| cv2.putText(frame, summary, (10,30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,255), 2) | |
| writer.write(frame) | |
| frame_count += 1 | |
| cap.release() | |
| writer.release() | |
| return temp_file.name | |
| except Exception as e: | |
| print(f"Video annotation error: {str(e)}") | |
| return None | |
| def process_diary(day, date, media): | |
| """Main processing pipeline with error handling""" | |
| try: | |
| if not media: | |
| return [day, date, "No data", "No data", "No data", "No data", None] | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
| tmp.write(media.read()) | |
| media_path = tmp.name | |
| detected_people, detected_machinery, machine_types = detect_people_and_machinery(media_path) | |
| annotated_video = None | |
| try: | |
| if is_image(media.name): | |
| activities = analyze_image_activities(media_path) | |
| else: | |
| activities = analyze_video_activities(media_path) | |
| annotated_video = annotate_video_with_bboxes(media_path) | |
| except Exception as e: | |
| activities = f"Analysis error: {str(e)}" | |
| os.remove(media_path) | |
| return [ | |
| day, | |
| date, | |
| str(detected_people), | |
| str(detected_machinery), | |
| ", ".join([f"{k}:{v}" for k,v in machine_types.items()]), | |
| activities, | |
| annotated_video | |
| ] | |
| except Exception as e: | |
| print(f"Processing error: {str(e)}") | |
| return [day, date, "Error", "Error", "Error", "Error", None] | |
| # Gradio Interface | |
| with gr.Blocks(title="Digital Site Diary", css="video {height: auto !important;}") as demo: | |
| gr.Markdown("# 🏗️ Digital Construction Diary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Site Details") | |
| day = gr.Textbox(label="Day Number", value="1") | |
| date = gr.Textbox(label="Date", value=datetime.now().strftime("%Y-%m-%d")) | |
| media = gr.File(label="Upload Media", file_types=["image", "video"]) | |
| submit_btn = gr.Button("Generate Report", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown("### Safety Report") | |
| model_day = gr.Textbox(label="Day") | |
| model_date = gr.Textbox(label="Date") | |
| model_people = gr.Textbox(label="Worker Count") | |
| model_machinery = gr.Textbox(label="Machinery Count") | |
| model_machinery_types = gr.Textbox(label="Machinery Breakdown") | |
| model_activities = gr.Textbox(label="Activity Analysis", lines=4) | |
| model_video = gr.Video(label="Safety Annotations") | |
| submit_btn.click( | |
| process_diary, | |
| inputs=[day, date, media], | |
| outputs=[ | |
| model_day, | |
| model_date, | |
| model_people, | |
| model_machinery, | |
| model_machinery_types, | |
| model_activities, | |
| model_video | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |