assentian1970's picture
Update app.py
83b0d3a verified
raw
history blame
7.87 kB
import spaces
import torch
import argparse
import os
import gc
import tempfile
import cv2
import numpy as np
import gradio as gr
from datetime import datetime
from PIL import Image
from decord import VideoReader, cpu
from transformers import AutoModel, AutoTokenizer
from modelscope.hub.snapshot_download import snapshot_download
from ultralytics import YOLO
os.system("nvidia-smi")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE == "cuda":
def debug():
torch.randn(10).cuda()
debug()
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv'} # Example, define properly
def get_file_extension(filename):
return os.path.splitext(filename)[-1].lower()
def is_video(filename):
return get_file_extension(filename) in VIDEO_EXTENSIONS
def is_image(filename):
return get_file_extension(filename) in IMAGE_EXTENSIONS
parser = argparse.ArgumentParser(description='demo')
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
args = parser.parse_args()
device = args.device
assert device in ['cuda', 'mps']
MODEL_NAME = 'iic/mPLUG-Owl3-7B-240728'
MODEL_CACHE_DIR = os.getenv('TRANSFORMERS_CACHE', './models')
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
try:
model_path = snapshot_download(MODEL_NAME, cache_dir=MODEL_CACHE_DIR)
except Exception as e:
print(f"Error downloading model: {str(e)}")
model_path = os.path.join(MODEL_CACHE_DIR, MODEL_NAME)
YOLO_MODEL = YOLO('./best_yolov11.pt')
MAX_NUM_FRAMES = 64
def load_model_and_tokenizer():
if DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
model = AutoModel.from_pretrained(
model_path,
attn_implementation='flash_attention_2',
trust_remote_code=True,
torch_dtype=torch.half,
device_map='auto'
)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True
)
return model, tokenizer, None # Assuming processor is missing
def encode_video_in_chunks(video_path):
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = [i for i in range(0, len(vr), sample_fps)]
chunks = [frame_idx[i:i + MAX_NUM_FRAMES] for i in range(0, len(frame_idx), MAX_NUM_FRAMES)]
for chunk_idx, chunk in enumerate(chunks):
frames = vr.get_batch(chunk).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
yield chunk_idx, frames
def detect_people_and_machinery(media_path):
try:
max_people_count = 0
max_machine_types = {key: 0 for key in [
"Tower Crane", "Mobile Crane", "Compactor/Roller", "Bulldozer",
"Excavator", "Dump Truck", "Concrete Mixer", "Loader",
"Pump Truck", "Pile Driver", "Grader", "Other Vehicle"
]}
if is_video(media_path):
cap = cv2.VideoCapture(media_path)
fps = cap.get(cv2.CAP_PROP_FPS)
sample_rate = max(1, int(fps))
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % sample_rate == 0:
results = YOLO_MODEL(frame)
people, _, machine_types = process_yolo_results(results)
max_people_count = max(max_people_count, people)
for k, v in machine_types.items():
max_machine_types[k] = max(max_machine_types[k], v)
frame_count += 1
cap.release()
else:
img = cv2.imread(media_path)
results = YOLO_MODEL(img)
max_people_count, _, max_machine_types = process_yolo_results(results)
max_machine_types = {k: v for k, v in max_machine_types.items() if v > 0}
total_machinery_count = sum(max_machine_types.values())
return max_people_count, total_machinery_count, max_machine_types
except Exception as e:
print(f"Error in YOLO detection: {str(e)}")
return 0, 0, {}
def process_yolo_results(results):
people_count = 0
machine_types = {key: 0 for key in [
"Tower Crane", "Mobile Crane", "Compactor/Roller", "Bulldozer",
"Excavator", "Dump Truck", "Concrete Mixer", "Loader",
"Pump Truck", "Pile Driver", "Grader", "Other Vehicle"
]}
for r in results:
for box in r.boxes:
cls = int(box.cls[0])
conf = float(box.conf[0])
class_name = YOLO_MODEL.names[cls]
if class_name.lower() == 'worker' and conf > 0.5:
people_count += 1
machinery_mapping = {
'tower_crane': "Tower Crane",
'mobile_crane': "Mobile Crane",
'grader': "Grader",
'other_vehicle': "Other Vehicle"
}
if conf > 0.5:
for key, value in machinery_mapping.items():
if key in class_name.lower():
machine_types[value] += 1
break
return people_count, sum(machine_types.values()), machine_types
@spaces.GPU
def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media):
if media is None:
return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", "No media uploaded", None]
try:
file_ext = get_file_extension(media.name)
if not (is_image(media.name) or is_video(media.name)):
raise ValueError(f"Unsupported file type: {file_ext}")
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file:
temp_path = temp_file.name
temp_file.write(media.read())
detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(temp_path)
detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()])
detected_activities = "Sample activity analysis." # Placeholder
os.remove(temp_path)
return [day, date, str(detected_people), str(detected_machinery), detected_types_str, detected_activities, None]
except Exception as e:
print(f"Error processing media: {str(e)}")
return [day, date, "Error", "Error", "Error", "Error", None]
with gr.Blocks(title="Digital Site Diary") as demo:
gr.Markdown("# 📝 Digital Site Diary")
with gr.Row():
with gr.Column():
day = gr.Textbox(label="Day", value='9')
date = gr.Textbox(label="Date", value=datetime.now().strftime("%Y-%m-%d"))
total_people = gr.Number(label="Total Number of People", value=10)
total_machinery = gr.Number(label="Total Number of Machinery", value=3)
media = gr.File(label="Upload Image/Video", file_types=["image", "video"])
submit_btn = gr.Button("Submit")
with gr.Column():
model_day = gr.Textbox(label="Day")
model_date = gr.Textbox(label="Date")
model_people = gr.Textbox(label="Total Number of People")
model_machinery = gr.Textbox(label="Total Machinery")
model_machinery_types = gr.Textbox(label="Machinery Types")
model_activities = gr.Textbox(label="Activities")
submit_btn.click(
fn=process_diary,
inputs=[day, date, total_people, total_machinery, None, None, media],
outputs=[model_day, model_date, model_people, model_machinery, model_machinery_types, model_activities, None]
)
if __name__ == "__main__":
demo.launch(share=False, debug=True, show_api=False, server_port=args.port, server_name=args.host)