Spaces:
Runtime error
Runtime error
import spaces | |
import torch | |
import argparse | |
import os | |
import gc | |
import tempfile | |
import cv2 | |
import numpy as np | |
import gradio as gr | |
from datetime import datetime | |
from PIL import Image | |
from decord import VideoReader, cpu | |
from transformers import AutoModel, AutoTokenizer | |
from modelscope.hub.snapshot_download import snapshot_download | |
from ultralytics import YOLO | |
os.system("nvidia-smi") | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
if DEVICE == "cuda": | |
def debug(): | |
torch.randn(10).cuda() | |
debug() | |
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} | |
VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv'} # Example, define properly | |
def get_file_extension(filename): | |
return os.path.splitext(filename)[-1].lower() | |
def is_video(filename): | |
return get_file_extension(filename) in VIDEO_EXTENSIONS | |
def is_image(filename): | |
return get_file_extension(filename) in IMAGE_EXTENSIONS | |
parser = argparse.ArgumentParser(description='demo') | |
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps') | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int) | |
args = parser.parse_args() | |
device = args.device | |
assert device in ['cuda', 'mps'] | |
MODEL_NAME = 'iic/mPLUG-Owl3-7B-240728' | |
MODEL_CACHE_DIR = os.getenv('TRANSFORMERS_CACHE', './models') | |
os.makedirs(MODEL_CACHE_DIR, exist_ok=True) | |
try: | |
model_path = snapshot_download(MODEL_NAME, cache_dir=MODEL_CACHE_DIR) | |
except Exception as e: | |
print(f"Error downloading model: {str(e)}") | |
model_path = os.path.join(MODEL_CACHE_DIR, MODEL_NAME) | |
YOLO_MODEL = YOLO('./best_yolov11.pt') | |
MAX_NUM_FRAMES = 64 | |
def load_model_and_tokenizer(): | |
if DEVICE == "cuda": | |
torch.cuda.empty_cache() | |
gc.collect() | |
model = AutoModel.from_pretrained( | |
model_path, | |
attn_implementation='flash_attention_2', | |
trust_remote_code=True, | |
torch_dtype=torch.half, | |
device_map='auto' | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
trust_remote_code=True | |
) | |
return model, tokenizer, None # Assuming processor is missing | |
def encode_video_in_chunks(video_path): | |
vr = VideoReader(video_path, ctx=cpu(0)) | |
sample_fps = round(vr.get_avg_fps() / 1) | |
frame_idx = [i for i in range(0, len(vr), sample_fps)] | |
chunks = [frame_idx[i:i + MAX_NUM_FRAMES] for i in range(0, len(frame_idx), MAX_NUM_FRAMES)] | |
for chunk_idx, chunk in enumerate(chunks): | |
frames = vr.get_batch(chunk).asnumpy() | |
frames = [Image.fromarray(v.astype('uint8')) for v in frames] | |
yield chunk_idx, frames | |
def detect_people_and_machinery(media_path): | |
try: | |
max_people_count = 0 | |
max_machine_types = {key: 0 for key in [ | |
"Tower Crane", "Mobile Crane", "Compactor/Roller", "Bulldozer", | |
"Excavator", "Dump Truck", "Concrete Mixer", "Loader", | |
"Pump Truck", "Pile Driver", "Grader", "Other Vehicle" | |
]} | |
if is_video(media_path): | |
cap = cv2.VideoCapture(media_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
sample_rate = max(1, int(fps)) | |
frame_count = 0 | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if frame_count % sample_rate == 0: | |
results = YOLO_MODEL(frame) | |
people, _, machine_types = process_yolo_results(results) | |
max_people_count = max(max_people_count, people) | |
for k, v in machine_types.items(): | |
max_machine_types[k] = max(max_machine_types[k], v) | |
frame_count += 1 | |
cap.release() | |
else: | |
img = cv2.imread(media_path) | |
results = YOLO_MODEL(img) | |
max_people_count, _, max_machine_types = process_yolo_results(results) | |
max_machine_types = {k: v for k, v in max_machine_types.items() if v > 0} | |
total_machinery_count = sum(max_machine_types.values()) | |
return max_people_count, total_machinery_count, max_machine_types | |
except Exception as e: | |
print(f"Error in YOLO detection: {str(e)}") | |
return 0, 0, {} | |
def process_yolo_results(results): | |
people_count = 0 | |
machine_types = {key: 0 for key in [ | |
"Tower Crane", "Mobile Crane", "Compactor/Roller", "Bulldozer", | |
"Excavator", "Dump Truck", "Concrete Mixer", "Loader", | |
"Pump Truck", "Pile Driver", "Grader", "Other Vehicle" | |
]} | |
for r in results: | |
for box in r.boxes: | |
cls = int(box.cls[0]) | |
conf = float(box.conf[0]) | |
class_name = YOLO_MODEL.names[cls] | |
if class_name.lower() == 'worker' and conf > 0.5: | |
people_count += 1 | |
machinery_mapping = { | |
'tower_crane': "Tower Crane", | |
'mobile_crane': "Mobile Crane", | |
'grader': "Grader", | |
'other_vehicle': "Other Vehicle" | |
} | |
if conf > 0.5: | |
for key, value in machinery_mapping.items(): | |
if key in class_name.lower(): | |
machine_types[value] += 1 | |
break | |
return people_count, sum(machine_types.values()), machine_types | |
def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media): | |
if media is None: | |
return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", "No media uploaded", None] | |
try: | |
file_ext = get_file_extension(media.name) | |
if not (is_image(media.name) or is_video(media.name)): | |
raise ValueError(f"Unsupported file type: {file_ext}") | |
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file: | |
temp_path = temp_file.name | |
temp_file.write(media.read()) | |
detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(temp_path) | |
detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()]) | |
detected_activities = "Sample activity analysis." # Placeholder | |
os.remove(temp_path) | |
return [day, date, str(detected_people), str(detected_machinery), detected_types_str, detected_activities, None] | |
except Exception as e: | |
print(f"Error processing media: {str(e)}") | |
return [day, date, "Error", "Error", "Error", "Error", None] | |
with gr.Blocks(title="Digital Site Diary") as demo: | |
gr.Markdown("# 📝 Digital Site Diary") | |
with gr.Row(): | |
with gr.Column(): | |
day = gr.Textbox(label="Day", value='9') | |
date = gr.Textbox(label="Date", value=datetime.now().strftime("%Y-%m-%d")) | |
total_people = gr.Number(label="Total Number of People", value=10) | |
total_machinery = gr.Number(label="Total Number of Machinery", value=3) | |
media = gr.File(label="Upload Image/Video", file_types=["image", "video"]) | |
submit_btn = gr.Button("Submit") | |
with gr.Column(): | |
model_day = gr.Textbox(label="Day") | |
model_date = gr.Textbox(label="Date") | |
model_people = gr.Textbox(label="Total Number of People") | |
model_machinery = gr.Textbox(label="Total Machinery") | |
model_machinery_types = gr.Textbox(label="Machinery Types") | |
model_activities = gr.Textbox(label="Activities") | |
submit_btn.click( | |
fn=process_diary, | |
inputs=[day, date, total_people, total_machinery, None, None, media], | |
outputs=[model_day, model_date, model_people, model_machinery, model_machinery_types, model_activities, None] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=False, debug=True, show_api=False, server_port=args.port, server_name=args.host) | |