Spaces:
Sleeping
Sleeping
import spaces | |
import torch | |
import argparse | |
import os | |
import sys | |
import pickle # For serializing frames | |
import gc | |
import tempfile | |
import subprocess | |
from datetime import datetime | |
from transformers import AutoModel, AutoTokenizer | |
from modelscope.hub.snapshot_download import snapshot_download | |
from PIL import Image | |
from decord import VideoReader, cpu | |
import cv2 | |
import gradio as gr | |
from ultralytics import YOLO | |
import numpy as np | |
import io | |
# Install flash-attn (using prebuilt wheel mode if needed) | |
subprocess.run( | |
'pip install flash-attn --no-build-isolation', | |
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': 'TRUE'}, | |
shell=True | |
) | |
# -------------------------------------------------------------------- | |
# Command-line arguments | |
# -------------------------------------------------------------------- | |
parser = argparse.ArgumentParser(description='demo') | |
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps') | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int) | |
# New arguments for subprocess inference (unused in this version) | |
parser.add_argument("--chunk_inference", action="store_true", help="Run inference on a chunk (subprocess mode).") | |
parser.add_argument("--input_file", type=str, help="Path to serialized input chunk frames.") | |
parser.add_argument("--output_file", type=str, help="Path to file where inference result is written.") | |
parser.add_argument("--inference_prompt", type=str, help="Inference prompt for the chunk.") | |
parser.add_argument("--model_path_arg", type=str, help="Model path for the subprocess.") | |
args = parser.parse_args() | |
device = args.device | |
assert device in ['cuda', 'mps'] | |
# Global model configuration | |
MODEL_NAME = 'iic/mPLUG-Owl3-7B-240728' | |
MODEL_CACHE_DIR = os.getenv('TRANSFORMERS_CACHE', './models') | |
os.makedirs(MODEL_CACHE_DIR, exist_ok=True) | |
# Download and cache the model (only in the main process) | |
if not args.chunk_inference: | |
try: | |
model_path = snapshot_download(MODEL_NAME, cache_dir=MODEL_CACHE_DIR) | |
except Exception as e: | |
print(f"Error downloading model: {str(e)}") | |
model_path = os.path.join(MODEL_CACHE_DIR, MODEL_NAME) | |
else: | |
model_path = args.model_path_arg | |
MAX_NUM_FRAMES = 64 | |
# Initialize YOLO model (assumed to be lightweight) | |
YOLO_MODEL = YOLO('./best_yolov11.pt') # Load YOLOv11 model | |
# File type validation | |
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} | |
VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'} | |
def get_file_extension(filename): | |
return os.path.splitext(filename)[1].lower() | |
def is_image(filename): | |
return get_file_extension(filename) in IMAGE_EXTENSIONS | |
def is_video(filename): | |
return get_file_extension(filename) in VIDEO_EXTENSIONS | |
# -------------------------------------------------------------------- | |
# Model Loading and Inference Functions | |
# -------------------------------------------------------------------- | |
def load_model_and_tokenizer(): | |
"""Load a fresh instance of the model and tokenizer.""" | |
try: | |
# Clear GPU memory if using CUDA (only at initial load) | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
gc.collect() | |
model = AutoModel.from_pretrained( | |
model_path, | |
attn_implementation='sdpa', | |
trust_remote_code=True, | |
torch_dtype=torch.half, | |
device_map='auto' | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
model.eval() | |
processor = model.init_processor(tokenizer) | |
return model, tokenizer, processor | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
raise | |
def process_video_chunk(video_frames, model, tokenizer, processor, prompt): | |
"""Process a chunk of video frames with mPLUG model.""" | |
messages = [{ | |
"role": "user", | |
"content": prompt, | |
"video_frames": video_frames | |
}] | |
model_messages = [] | |
videos = [] | |
for msg in messages: | |
content_str = msg["content"] | |
if "video_frames" in msg and msg["video_frames"]: | |
content_str += "<|video|>" | |
videos.append(msg["video_frames"]) | |
model_messages.append({"role": msg["role"], "content": content_str}) | |
model_messages.append({"role": "assistant", "content": ""}) | |
inputs = processor( | |
model_messages, | |
images=None, | |
videos=videos if videos else None | |
) | |
inputs.to('cuda') | |
inputs.update({ | |
'tokenizer': tokenizer, | |
'max_new_tokens': 100, | |
'decode_text': True, | |
'use_cache': False # disable caching to reduce memory buildup | |
}) | |
with torch.no_grad(): | |
response = model.generate(**inputs) | |
del inputs # delete inputs to free temporary memory | |
return response[0] | |
# -------------------------------------------------------------------- | |
# Video and YOLO functions (unchanged) | |
# -------------------------------------------------------------------- | |
def encode_video_in_chunks(video_path): | |
"""Extract frames from a video in chunks.""" | |
vr = VideoReader(video_path, ctx=cpu(0)) | |
sample_fps = round(vr.get_avg_fps() / 1) # 1 FPS | |
frame_idx = [i for i in range(0, len(vr), sample_fps)] | |
chunks = [frame_idx[i:i + MAX_NUM_FRAMES] for i in range(0, len(frame_idx), MAX_NUM_FRAMES)] | |
for chunk_idx, chunk in enumerate(chunks): | |
frames = vr.get_batch(chunk).asnumpy() | |
frames = [Image.fromarray(v.astype('uint8')) for v in frames] | |
yield chunk_idx, frames | |
def process_yolo_results(results): | |
"""Process YOLO detection results and count people and machinery.""" | |
people_count = 0 | |
machine_types = { | |
"Tower Crane": 0, "Mobile Crane": 0, "Compactor/Roller": 0, "Bulldozer": 0, | |
"Excavator": 0, "Dump Truck": 0, "Concrete Mixer": 0, "Loader": 0, | |
"Pump Truck": 0, "Pile Driver": 0, "Grader": 0, "Other Vehicle": 0 | |
} | |
for r in results: | |
boxes = r.boxes | |
for box in boxes: | |
cls = int(box.cls[0]) | |
conf = float(box.conf[0]) | |
class_name = YOLO_MODEL.names[cls] | |
if class_name.lower() == 'worker' and conf > 0.5: | |
people_count += 1 | |
machinery_mapping = { | |
'tower_crane': "Tower Crane", | |
'mobile_crane': "Mobile Crane", | |
'compactor': "Compactor/Roller", | |
'roller': "Compactor/Roller", | |
'bulldozer': "Bulldozer", | |
'dozer': "Bulldozer", | |
'excavator': "Excavator", | |
'dump_truck': "Dump Truck", | |
'truck': "Dump Truck", | |
'concrete_mixer_truck': "Concrete Mixer", | |
'loader': "Loader", | |
'pump_truck': "Pump Truck", | |
'pile_driver': "Pile Driver", | |
'grader': "Grader", | |
'other_vehicle': "Other Vehicle" | |
} | |
if conf > 0.5: | |
class_lower = class_name.lower() | |
for key, value in machinery_mapping.items(): | |
if key in class_lower: | |
machine_types[value] += 1 | |
break | |
total_machinery = sum(machine_types.values()) | |
return people_count, total_machinery, machine_types | |
def detect_people_and_machinery(media_path): | |
"""Detect people and machinery using YOLOv11 for both images and videos.""" | |
try: | |
max_people_count = 0 | |
max_machine_types = { | |
"Tower Crane": 0, "Mobile Crane": 0, "Compactor/Roller": 0, "Bulldozer": 0, | |
"Excavator": 0, "Dump Truck": 0, "Concrete Mixer": 0, "Loader": 0, | |
"Pump Truck": 0, "Pile Driver": 0, "Grader": 0, "Other Vehicle": 0 | |
} | |
if isinstance(media_path, str) and is_video(media_path): | |
cap = cv2.VideoCapture(media_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
sample_rate = max(1, int(fps)) | |
frame_count = 0 | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if frame_count % sample_rate == 0: | |
results = YOLO_MODEL(frame) | |
people, _, machine_types = process_yolo_results(results) | |
max_people_count = max(max_people_count, people) | |
for k, v in machine_types.items(): | |
max_machine_types[k] = max(max_machine_types[k], v) | |
frame_count += 1 | |
cap.release() | |
else: | |
if isinstance(media_path, str): | |
img = cv2.imread(media_path) | |
else: | |
img = cv2.cvtColor(np.array(media_path), cv2.COLOR_RGB2BGR) | |
results = YOLO_MODEL(img) | |
max_people_count, _, max_machine_types = process_yolo_results(results) | |
max_machine_types = {k: v for k, v in max_machine_types.items() if v > 0} | |
total_machinery_count = sum(max_machine_types.values()) | |
return max_people_count, total_machinery_count, max_machine_types | |
except Exception as e: | |
print(f"Error in YOLO detection: {str(e)}") | |
return 0, 0, {} | |
def process_image(image_path, model, tokenizer, processor, prompt): | |
"""Process single image with mPLUG model.""" | |
try: | |
image = Image.open(image_path) | |
messages = [{ | |
"role": "user", | |
"content": prompt, | |
"images": [image] | |
}] | |
model_messages = [] | |
images = [] | |
for msg in messages: | |
content_str = msg["content"] | |
if "images" in msg and msg["images"]: | |
content_str += "<|image|>" | |
images.extend(msg["images"]) | |
model_messages.append({"role": msg["role"], "content": content_str}) | |
model_messages.append({"role": "assistant", "content": ""}) | |
inputs = processor(model_messages, images=images, videos=None) | |
inputs.to('cuda') | |
inputs.update({ | |
'tokenizer': tokenizer, | |
'max_new_tokens': 100, | |
'decode_text': True, | |
'use_cache': False | |
}) | |
with torch.no_grad(): | |
response = model.generate(**inputs) | |
del inputs | |
return response[0] | |
except Exception as e: | |
print(f"Error processing image: {str(e)}") | |
return "Error processing image" | |
def analyze_image_activities(image_path): | |
"""Analyze image using mPLUG model.""" | |
try: | |
model, tokenizer, processor = load_model_and_tokenizer() | |
prompt = ("Analyze this construction site image and describe the activities happening. " | |
"Focus on construction activities, machinery usage, and worker actions.") | |
response = process_image(image_path, model, tokenizer, processor, prompt) | |
del model, tokenizer, processor | |
torch.cuda.empty_cache() # Final cleanup after image processing | |
gc.collect() | |
return response | |
except Exception as e: | |
print(f"Error analyzing image: {str(e)}") | |
return "Error analyzing image activities" | |
def annotate_video_with_bboxes(video_path): | |
""" | |
Reads the video frame-by-frame, runs YOLO, draws bounding boxes, | |
writes a per-frame summary of detected classes on the frame, and saves | |
the annotated video. Returns the annotated video path. | |
""" | |
cap = cv2.VideoCapture(video_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
annotated_video_path = out_file.name | |
out_file.close() | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
writer = cv2.VideoWriter(annotated_video_path, fourcc, fps, (w, h)) | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
results = YOLO_MODEL(frame) | |
frame_counts = {} | |
for r in results: | |
boxes = r.boxes | |
for box in boxes: | |
cls_id = int(box.cls[0]) | |
conf = float(box.conf[0]) | |
if conf < 0.5: | |
continue | |
x1, y1, x2, y2 = box.xyxy[0] | |
class_name = YOLO_MODEL.names[cls_id] | |
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
color = (0, 255, 0) | |
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
label_text = f"{class_name} {conf:.2f}" | |
cv2.putText(frame, label_text, (x1, y1 - 6), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1) | |
frame_counts[class_name] = frame_counts.get(class_name, 0) + 1 | |
summary_str = ", ".join(f"{cls_name}: {count}" for cls_name, count in frame_counts.items()) | |
cv2.putText(frame, summary_str, (15, 30), | |
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2) | |
writer.write(frame) | |
cap.release() | |
writer.release() | |
return annotated_video_path | |
# -------------------------------------------------------------------- | |
# Adjusted Video Analysis with Single mPLUG Instance (No Reload) | |
# -------------------------------------------------------------------- | |
def analyze_video_activities_single_instance(video_path): | |
"""Analyze video using mPLUG model with chunking. | |
Use a single mPLUG model instance for all chunks without any per-chunk cleanup.""" | |
try: | |
all_responses = [] | |
chunk_generator = encode_video_in_chunks(video_path) | |
# Load model instance once | |
model, tokenizer, processor = load_model_and_tokenizer() | |
for chunk_idx, video_frames in chunk_generator: | |
prompt = ( | |
"Analyze this construction site video chunk and describe the activities happening. " | |
"Focus on construction activities, machinery usage, and worker actions." | |
) | |
with torch.no_grad(): | |
response = process_video_chunk(video_frames, model, tokenizer, processor, prompt) | |
all_responses.append(f"Time period {chunk_idx + 1}:\n{response}") | |
# No per-chunk cache clearing is performed here | |
# Final cleanup after processing all chunks | |
del model, tokenizer, processor | |
torch.cuda.empty_cache() | |
gc.collect() | |
return "\n\n".join(all_responses) | |
except Exception as e: | |
print(f"Error analyzing video: {str(e)}") | |
return "Error analyzing video activities" | |
# -------------------------------------------------------------------- | |
# Gradio Interface and Main Launch (only executed in main process) | |
# -------------------------------------------------------------------- | |
def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media): | |
"""Process the site diary entry.""" | |
if media is None: | |
return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", "No media uploaded", None] | |
try: | |
if not hasattr(media, 'name'): | |
raise ValueError("Invalid file upload") | |
file_ext = get_file_extension(media.name) | |
if not (is_image(media.name) or is_video(media.name)): | |
raise ValueError(f"Unsupported file type: {file_ext}") | |
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file: | |
temp_path = temp_file.name | |
if hasattr(media, 'name') and os.path.exists(media.name): | |
with open(media.name, 'rb') as f: | |
temp_file.write(f.read()) | |
else: | |
file_content = media.read() if hasattr(media, 'read') else media | |
temp_file.write(file_content if isinstance(file_content, bytes) else file_content.read()) | |
detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(temp_path) | |
annotated_video_path = None | |
if is_image(media.name): | |
detected_activities = analyze_image_activities(temp_path) | |
else: | |
detected_activities = analyze_video_activities_single_instance(temp_path) | |
annotated_video_path = annotate_video_with_bboxes(temp_path) | |
if os.path.exists(temp_path): | |
os.remove(temp_path) | |
detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()]) | |
return [day, date, str(detected_people), str(detected_machinery), detected_types_str, detected_activities, annotated_video_path] | |
except Exception as e: | |
print(f"Error processing media: {str(e)}") | |
return [day, date, "Error processing media", "Error processing media", "Error processing media", "Error processing media", None] | |
with gr.Blocks(title="Digital Site Diary") as demo: | |
gr.Markdown("# 📝 Digital Site Diary") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### User Input") | |
day = gr.Textbox(label="Day", value='9') | |
date = gr.Textbox(label="Date", placeholder="YYYY-MM-DD", value=datetime.now().strftime("%Y-%m-%d")) | |
total_people = gr.Number(label="Total Number of People", precision=0, value=10) | |
total_machinery = gr.Number(label="Total Number of Machinery", precision=0, value=3) | |
machinery_types = gr.Textbox(label="Number of Machinery Per Type", | |
placeholder="e.g., Excavator: 2, Roller: 1", | |
value="Excavator: 2, Roller: 1") | |
activities = gr.Textbox(label="Activity", | |
placeholder="e.g., 9 AM: Excavation, 10 AM: Concreting", | |
value="9 AM: Excavation, 10 AM: Concreting", lines=3) | |
media = gr.File(label="Upload Image/Video", file_types=["image", "video"]) | |
submit_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
gr.Markdown("### Model Detection") | |
model_day = gr.Textbox(label="Day") | |
model_date = gr.Textbox(label="Date") | |
model_people = gr.Textbox(label="Total Number of People") | |
model_machinery = gr.Textbox(label="Total Number of Machinery") | |
model_machinery_types = gr.Textbox(label="Number of Machinery Per Type") | |
model_activities = gr.Textbox(label="Activity", lines=5) | |
model_annotated_video = gr.Video(label="Annotated Video") | |
submit_btn.click( | |
fn=process_diary, | |
inputs=[day, date, total_people, total_machinery, machinery_types, activities, media], | |
outputs=[model_day, model_date, model_people, model_machinery, model_machinery_types, model_activities, model_annotated_video] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=False, debug=True, show_api=False, server_port=args.port, server_name=args.host) | |