Spaces:
Sleeping
Sleeping
import spaces | |
import torch | |
from datetime import datetime | |
from transformers import AutoModel, AutoTokenizer | |
import gradio as gr | |
from PIL import Image | |
from decord import VideoReader, cpu | |
import os | |
import gc | |
import tempfile | |
from ultralytics import YOLO | |
import numpy as np | |
import cv2 | |
from modelscope.hub.snapshot_download import snapshot_download | |
from ultralytics.nn.modules import Conv, C2f | |
from torch import nn | |
import ultralytics.nn.modules as modules | |
# Add custom C3k2 module definition | |
class C3k2(nn.Module): | |
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): | |
super().__init__() | |
c_ = int(c2 * e) | |
self.cv1 = Conv(c1, c_, 1, 1) | |
self.cv2 = Conv(c1, c_, 1, 1) | |
self.cv3 = Conv(2 * c_, c2, 1) | |
self.m = nn.Sequential(*(C2f(c_, c_, shortcut, g, e=1.0) for _ in range(n))) | |
def forward(self, x): | |
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) | |
# Patch the Ultralytics module | |
modules.C3k2 = C3k2 | |
# Fix GLIBCXX dependency | |
os.environ['LD_LIBRARY_PATH'] = '/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH' | |
# Initialize GPU | |
def initialize_gpu(): | |
if torch.cuda.is_available(): | |
torch.randn(10).cuda() | |
initialize_gpu() | |
# Load YOLO model with error handling | |
try: | |
YOLO_MODEL = YOLO('best_yolov11.pt') | |
except Exception as e: | |
raise RuntimeError(f"YOLO model loading failed: {str(e)}") | |
# Model configuration | |
MODEL_NAME = 'iic/mPLUG-Owl3-7B-240728' | |
try: | |
model_dir = snapshot_download(MODEL_NAME, | |
cache_dir='./models', | |
revision='main') | |
except Exception as e: | |
raise RuntimeError(f"Model download failed: {str(e)}") | |
# Device setup | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# File validation | |
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} | |
VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'} | |
def get_file_extension(filename): | |
return os.path.splitext(filename)[1].lower() | |
def is_image(filename): | |
return get_file_extension(filename) in IMAGE_EXTENSIONS | |
def is_video(filename): | |
return get_file_extension(filename) in VIDEO_EXTENSIONS | |
def load_model_and_tokenizer(): | |
"""Load 8-bit quantized model""" | |
try: | |
torch.cuda.empty_cache() | |
gc.collect() | |
model = AutoModel.from_pretrained( | |
model_dir, | |
attn_implementation='sdpa', | |
trust_remote_code=True, | |
load_in_8bit=True, | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_dir, | |
trust_remote_code=True | |
) | |
processor = model.init_processor(tokenizer) | |
model.eval() | |
return model, tokenizer, processor | |
except Exception as e: | |
print(f"Model loading error: {str(e)}") | |
raise | |
def process_yolo_results(results): | |
"""Process YOLO detection results""" | |
machinery_mapping = { | |
'tower_crane': "Tower Crane", | |
'mobile_crane': "Mobile Crane", | |
'compactor': "Compactor/Roller", | |
'roller': "Compactor/Roller", | |
'bulldozer': "Bulldozer", | |
'dozer': "Bulldozer", | |
'excavator': "Excavator", | |
'dump_truck': "Dump Truck", | |
'truck': "Dump Truck", | |
'concrete_mixer_truck': "Concrete Mixer", | |
'loader': "Loader", | |
'pump_truck': "Pump Truck", | |
'pile_driver': "Pile Driver", | |
'grader': "Grader", | |
'other_vehicle': "Other Vehicle" | |
} | |
counts = {"Worker": 0, **{v: 0 for v in machinery_mapping.values()}} | |
for r in results: | |
for box in r.boxes: | |
if box.conf.item() < 0.5: | |
continue | |
cls_name = YOLO_MODEL.names[int(box.cls.item())].lower() | |
if cls_name == 'worker': | |
counts["Worker"] += 1 | |
continue | |
for key, value in machinery_mapping.items(): | |
if key in cls_name: | |
counts[value] += 1 | |
break | |
return counts["Worker"], sum(counts.values()) - counts["Worker"], counts | |
def detect_people_and_machinery(media_path): | |
"""GPU-accelerated detection""" | |
try: | |
max_people = 0 | |
max_machines = {k: 0 for k in [ | |
"Tower Crane", "Mobile Crane", "Compactor/Roller", "Bulldozer", | |
"Excavator", "Dump Truck", "Concrete Mixer", "Loader", | |
"Pump Truck", "Pile Driver", "Grader", "Other Vehicle" | |
]} | |
if isinstance(media_path, str) and is_video(media_path): | |
cap = cv2.VideoCapture(media_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
sample_rate = max(1, int(fps)) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if cap.get(cv2.CAP_PROP_POS_FRAMES) % sample_rate == 0: | |
results = YOLO_MODEL(frame) | |
people, machines, types = process_yolo_results(results) | |
max_people = max(max_people, people) | |
for k in max_machines: | |
max_machines[k] = max(max_machines[k], types.get(k, 0)) | |
cap.release() | |
else: | |
img = cv2.imread(media_path) if isinstance(media_path, str) else cv2.cvtColor(np.array(media_path), cv2.COLOR_RGB2BGR) | |
results = YOLO_MODEL(img) | |
max_people, _, types = process_yolo_results(results) | |
for k in max_machines: | |
max_machines[k] = types.get(k, 0) | |
filtered = {k: v for k, v in max_machines.items() if v > 0} | |
return max_people, sum(filtered.values()), filtered | |
except Exception as e: | |
print(f"Detection error: {str(e)}") | |
return 0, 0, {} | |
def analyze_video_activities(video_path): | |
"""Video analysis with chunk processing""" | |
try: | |
model, tokenizer, processor = load_model_and_tokenizer() | |
responses = [] | |
vr = VideoReader(video_path, ctx=cpu(0)) | |
frame_step = max(1, int(vr.get_avg_fps())) | |
total_frames = len(vr) | |
for i in range(0, total_frames, 16): | |
end_idx = min(i+16, total_frames) | |
frames = [Image.fromarray(vr[j].asnumpy()) for j in range(i, end_idx)] | |
inputs = processor( | |
[{"role": "user", "content": "Analyze construction activities", "video_frames": frames}], | |
videos=[frames] | |
).to(DEVICE) | |
response = model.generate(**inputs, max_new_tokens=200) | |
responses.append(response[0]) | |
del frames, inputs | |
torch.cuda.empty_cache() | |
del model, tokenizer, processor | |
return "\n".join(responses) | |
except Exception as e: | |
print(f"Video analysis error: {str(e)}") | |
return "Activity analysis unavailable" | |
def analyze_image_activities(image_path): | |
"""Image analysis pipeline""" | |
try: | |
model, tokenizer, processor = load_model_and_tokenizer() | |
image = Image.open(image_path).convert("RGB") | |
inputs = processor( | |
[{"role": "user", "content": "Analyze construction site", "images": [image]}], | |
images=[image] | |
).to(DEVICE) | |
response = model.generate(**inputs, max_new_tokens=200) | |
del model, tokenizer, processor, image, inputs | |
torch.cuda.empty_cache() | |
return response[0] | |
except Exception as e: | |
print(f"Image analysis error: {str(e)}") | |
return "Activity analysis unavailable" | |
def annotate_video_with_bboxes(video_path): | |
"""Video annotation with detection overlay""" | |
try: | |
cap = cv2.VideoCapture(video_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
writer = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) | |
frame_count = 0 | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if frame_count % 5 == 0: | |
results = YOLO_MODEL(frame) | |
counts = {} | |
for r in results: | |
for box in r.boxes: | |
if box.conf.item() < 0.5: | |
continue | |
cls_id = int(box.cls.item()) | |
class_name = YOLO_MODEL.names[cls_id] | |
counts[class_name] = counts.get(class_name, 0) + 1 | |
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) | |
cv2.rectangle(frame, (x1, y1), (x2, y2), (0,255,0), 2) | |
cv2.putText(frame, f"{class_name} {box.conf.item():.2f}", | |
(x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1) | |
summary = ", ".join([f"{k}:{v}" for k,v in counts.items()]) | |
cv2.putText(frame, summary, (10,30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,255), 2) | |
writer.write(frame) | |
frame_count += 1 | |
cap.release() | |
writer.release() | |
return temp_file.name | |
except Exception as e: | |
print(f"Video annotation error: {str(e)}") | |
return None | |
def process_diary(day, date, media): | |
"""Main processing pipeline""" | |
try: | |
if not media: | |
return [day, date, "No data", "No data", "No data", "No data", None] | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
tmp.write(media.read()) | |
media_path = tmp.name | |
detected_people, detected_machinery, machine_types = detect_people_and_machinery(media_path) | |
annotated_video = None | |
try: | |
if is_image(media.name): | |
activities = analyze_image_activities(media_path) | |
else: | |
activities = analyze_video_activities(media_path) | |
annotated_video = annotate_video_with_bboxes(media_path) | |
except Exception as e: | |
activities = f"Analysis error: {str(e)}" | |
os.remove(media_path) | |
return [ | |
day, | |
date, | |
str(detected_people), | |
str(detected_machinery), | |
", ".join([f"{k}:{v}" for k,v in machine_types.items()]), | |
activities, | |
annotated_video | |
] | |
except Exception as e: | |
print(f"Processing error: {str(e)}") | |
return [day, date, "Error", "Error", "Error", "Error", None] | |
# Gradio Interface | |
with gr.Blocks(title="Digital Site Diary", css="video {height: auto !important;}") as demo: | |
gr.Markdown("# 🏗️ Digital Construction Diary") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Site Details") | |
day = gr.Textbox(label="Day Number", value="1") | |
date = gr.Textbox(label="Date", value=datetime.now().strftime("%Y-%m-%d")) | |
media = gr.File(label="Upload Media", file_types=["image", "video"]) | |
submit_btn = gr.Button("Generate Report", variant="primary") | |
with gr.Column(): | |
gr.Markdown("### Safety Report") | |
model_day = gr.Textbox(label="Day") | |
model_date = gr.Textbox(label="Date") | |
model_people = gr.Textbox(label="Worker Count") | |
model_machinery = gr.Textbox(label="Machinery Count") | |
model_machinery_types = gr.Textbox(label="Machinery Breakdown") | |
model_activities = gr.Textbox(label="Activity Analysis", lines=4) | |
model_video = gr.Video(label="Safety Annotations") | |
submit_btn.click( | |
process_diary, | |
inputs=[day, date, media], | |
outputs=[ | |
model_day, | |
model_date, | |
model_people, | |
model_machinery, | |
model_machinery_types, | |
model_activities, | |
model_video | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |