V3Test / app.py
assentian1970's picture
Update app.py
aa25b19 verified
raw
history blame
18.9 kB
import spaces
import torch
import argparse
import os
import sys
import pickle # For serializing frames
import gc
import tempfile
import subprocess
from datetime import datetime
from transformers import AutoModel, AutoTokenizer
from modelscope.hub.snapshot_download import snapshot_download
from PIL import Image
from decord import VideoReader, cpu
import cv2
import gradio as gr
from ultralytics import YOLO
import numpy as np
import io
# Install flash-attn (using prebuilt wheel mode if needed)
subprocess.run(
'pip install flash-attn --no-build-isolation',
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': 'TRUE'},
shell=True
)
# --------------------------------------------------------------------
# Command-line arguments
# --------------------------------------------------------------------
parser = argparse.ArgumentParser(description='demo')
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
# New arguments for subprocess inference (unused in this version)
parser.add_argument("--chunk_inference", action="store_true", help="Run inference on a chunk (subprocess mode).")
parser.add_argument("--input_file", type=str, help="Path to serialized input chunk frames.")
parser.add_argument("--output_file", type=str, help="Path to file where inference result is written.")
parser.add_argument("--inference_prompt", type=str, help="Inference prompt for the chunk.")
parser.add_argument("--model_path_arg", type=str, help="Model path for the subprocess.")
args = parser.parse_args()
device = args.device
assert device in ['cuda', 'mps']
# Global model configuration
MODEL_NAME = 'iic/mPLUG-Owl3-7B-240728'
MODEL_CACHE_DIR = os.getenv('TRANSFORMERS_CACHE', './models')
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
# Download and cache the model (only in the main process)
if not args.chunk_inference:
try:
model_path = snapshot_download(MODEL_NAME, cache_dir=MODEL_CACHE_DIR)
except Exception as e:
print(f"Error downloading model: {str(e)}")
model_path = os.path.join(MODEL_CACHE_DIR, MODEL_NAME)
else:
model_path = args.model_path_arg
MAX_NUM_FRAMES = 64
# Initialize YOLO model (assumed to be lightweight)
YOLO_MODEL = YOLO('./best_yolov11.pt') # Load YOLOv11 model
# File type validation
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'}
def get_file_extension(filename):
return os.path.splitext(filename)[1].lower()
def is_image(filename):
return get_file_extension(filename) in IMAGE_EXTENSIONS
def is_video(filename):
return get_file_extension(filename) in VIDEO_EXTENSIONS
# --------------------------------------------------------------------
# Model Loading and Inference Functions
# --------------------------------------------------------------------
def load_model_and_tokenizer():
"""Load a fresh instance of the model and tokenizer."""
try:
# Clear GPU memory if using CUDA (only at initial load)
if device == "cuda":
torch.cuda.empty_cache()
gc.collect()
model = AutoModel.from_pretrained(
model_path,
attn_implementation='sdpa',
trust_remote_code=True,
torch_dtype=torch.half,
device_map='auto'
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.eval()
processor = model.init_processor(tokenizer)
return model, tokenizer, processor
except Exception as e:
print(f"Error loading model: {str(e)}")
raise
def process_video_chunk(video_frames, model, tokenizer, processor, prompt):
"""Process a chunk of video frames with mPLUG model."""
messages = [{
"role": "user",
"content": prompt,
"video_frames": video_frames
}]
model_messages = []
videos = []
for msg in messages:
content_str = msg["content"]
if "video_frames" in msg and msg["video_frames"]:
content_str += "<|video|>"
videos.append(msg["video_frames"])
model_messages.append({"role": msg["role"], "content": content_str})
model_messages.append({"role": "assistant", "content": ""})
inputs = processor(
model_messages,
images=None,
videos=videos if videos else None
)
inputs.to('cuda')
inputs.update({
'tokenizer': tokenizer,
'max_new_tokens': 100,
'decode_text': True,
'use_cache': False # disable caching to reduce memory buildup
})
with torch.no_grad():
response = model.generate(**inputs)
del inputs # delete inputs to free temporary memory
return response[0]
# --------------------------------------------------------------------
# Video and YOLO functions (unchanged)
# --------------------------------------------------------------------
def encode_video_in_chunks(video_path):
"""Extract frames from a video in chunks."""
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # 1 FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
chunks = [frame_idx[i:i + MAX_NUM_FRAMES] for i in range(0, len(frame_idx), MAX_NUM_FRAMES)]
for chunk_idx, chunk in enumerate(chunks):
frames = vr.get_batch(chunk).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
yield chunk_idx, frames
def process_yolo_results(results):
"""Process YOLO detection results and count people and machinery."""
people_count = 0
machine_types = {
"Tower Crane": 0, "Mobile Crane": 0, "Compactor/Roller": 0, "Bulldozer": 0,
"Excavator": 0, "Dump Truck": 0, "Concrete Mixer": 0, "Loader": 0,
"Pump Truck": 0, "Pile Driver": 0, "Grader": 0, "Other Vehicle": 0
}
for r in results:
boxes = r.boxes
for box in boxes:
cls = int(box.cls[0])
conf = float(box.conf[0])
class_name = YOLO_MODEL.names[cls]
if class_name.lower() == 'worker' and conf > 0.5:
people_count += 1
machinery_mapping = {
'tower_crane': "Tower Crane",
'mobile_crane': "Mobile Crane",
'compactor': "Compactor/Roller",
'roller': "Compactor/Roller",
'bulldozer': "Bulldozer",
'dozer': "Bulldozer",
'excavator': "Excavator",
'dump_truck': "Dump Truck",
'truck': "Dump Truck",
'concrete_mixer_truck': "Concrete Mixer",
'loader': "Loader",
'pump_truck': "Pump Truck",
'pile_driver': "Pile Driver",
'grader': "Grader",
'other_vehicle': "Other Vehicle"
}
if conf > 0.5:
class_lower = class_name.lower()
for key, value in machinery_mapping.items():
if key in class_lower:
machine_types[value] += 1
break
total_machinery = sum(machine_types.values())
return people_count, total_machinery, machine_types
def detect_people_and_machinery(media_path):
"""Detect people and machinery using YOLOv11 for both images and videos."""
try:
max_people_count = 0
max_machine_types = {
"Tower Crane": 0, "Mobile Crane": 0, "Compactor/Roller": 0, "Bulldozer": 0,
"Excavator": 0, "Dump Truck": 0, "Concrete Mixer": 0, "Loader": 0,
"Pump Truck": 0, "Pile Driver": 0, "Grader": 0, "Other Vehicle": 0
}
if isinstance(media_path, str) and is_video(media_path):
cap = cv2.VideoCapture(media_path)
fps = cap.get(cv2.CAP_PROP_FPS)
sample_rate = max(1, int(fps))
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % sample_rate == 0:
results = YOLO_MODEL(frame)
people, _, machine_types = process_yolo_results(results)
max_people_count = max(max_people_count, people)
for k, v in machine_types.items():
max_machine_types[k] = max(max_machine_types[k], v)
frame_count += 1
cap.release()
else:
if isinstance(media_path, str):
img = cv2.imread(media_path)
else:
img = cv2.cvtColor(np.array(media_path), cv2.COLOR_RGB2BGR)
results = YOLO_MODEL(img)
max_people_count, _, max_machine_types = process_yolo_results(results)
max_machine_types = {k: v for k, v in max_machine_types.items() if v > 0}
total_machinery_count = sum(max_machine_types.values())
return max_people_count, total_machinery_count, max_machine_types
except Exception as e:
print(f"Error in YOLO detection: {str(e)}")
return 0, 0, {}
def process_image(image_path, model, tokenizer, processor, prompt):
"""Process single image with mPLUG model."""
try:
image = Image.open(image_path)
messages = [{
"role": "user",
"content": prompt,
"images": [image]
}]
model_messages = []
images = []
for msg in messages:
content_str = msg["content"]
if "images" in msg and msg["images"]:
content_str += "<|image|>"
images.extend(msg["images"])
model_messages.append({"role": msg["role"], "content": content_str})
model_messages.append({"role": "assistant", "content": ""})
inputs = processor(model_messages, images=images, videos=None)
inputs.to('cuda')
inputs.update({
'tokenizer': tokenizer,
'max_new_tokens': 100,
'decode_text': True,
'use_cache': False
})
with torch.no_grad():
response = model.generate(**inputs)
del inputs
return response[0]
except Exception as e:
print(f"Error processing image: {str(e)}")
return "Error processing image"
def analyze_image_activities(image_path):
"""Analyze image using mPLUG model."""
try:
model, tokenizer, processor = load_model_and_tokenizer()
prompt = ("Analyze this construction site image and describe the activities happening. "
"Focus on construction activities, machinery usage, and worker actions.")
response = process_image(image_path, model, tokenizer, processor, prompt)
del model, tokenizer, processor
torch.cuda.empty_cache() # Final cleanup after image processing
gc.collect()
return response
except Exception as e:
print(f"Error analyzing image: {str(e)}")
return "Error analyzing image activities"
def annotate_video_with_bboxes(video_path):
"""
Reads the video frame-by-frame, runs YOLO, draws bounding boxes,
writes a per-frame summary of detected classes on the frame, and saves
the annotated video. Returns the annotated video path.
"""
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
annotated_video_path = out_file.name
out_file.close()
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(annotated_video_path, fourcc, fps, (w, h))
while True:
ret, frame = cap.read()
if not ret:
break
results = YOLO_MODEL(frame)
frame_counts = {}
for r in results:
boxes = r.boxes
for box in boxes:
cls_id = int(box.cls[0])
conf = float(box.conf[0])
if conf < 0.5:
continue
x1, y1, x2, y2 = box.xyxy[0]
class_name = YOLO_MODEL.names[cls_id]
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
color = (0, 255, 0)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
label_text = f"{class_name} {conf:.2f}"
cv2.putText(frame, label_text, (x1, y1 - 6),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)
frame_counts[class_name] = frame_counts.get(class_name, 0) + 1
summary_str = ", ".join(f"{cls_name}: {count}" for cls_name, count in frame_counts.items())
cv2.putText(frame, summary_str, (15, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 0), 2)
writer.write(frame)
cap.release()
writer.release()
return annotated_video_path
# --------------------------------------------------------------------
# Adjusted Video Analysis with Single mPLUG Instance (No Reload)
# --------------------------------------------------------------------
@spaces.GPU
def analyze_video_activities_single_instance(video_path):
"""Analyze video using mPLUG model with chunking.
Use a single mPLUG model instance for all chunks without any per-chunk cleanup."""
try:
all_responses = []
chunk_generator = encode_video_in_chunks(video_path)
# Load model instance once
model, tokenizer, processor = load_model_and_tokenizer()
for chunk_idx, video_frames in chunk_generator:
prompt = (
"Analyze this construction site video chunk and describe the activities happening. "
"Focus on construction activities, machinery usage, and worker actions."
)
with torch.no_grad():
response = process_video_chunk(video_frames, model, tokenizer, processor, prompt)
all_responses.append(f"Time period {chunk_idx + 1}:\n{response}")
# No per-chunk cache clearing is performed here
# Final cleanup after processing all chunks
del model, tokenizer, processor
torch.cuda.empty_cache()
gc.collect()
return "\n\n".join(all_responses)
except Exception as e:
print(f"Error analyzing video: {str(e)}")
return "Error analyzing video activities"
# --------------------------------------------------------------------
# Gradio Interface and Main Launch (only executed in main process)
# --------------------------------------------------------------------
@spaces.GPU
def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media):
"""Process the site diary entry."""
if media is None:
return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", "No media uploaded", None]
try:
if not hasattr(media, 'name'):
raise ValueError("Invalid file upload")
file_ext = get_file_extension(media.name)
if not (is_image(media.name) or is_video(media.name)):
raise ValueError(f"Unsupported file type: {file_ext}")
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file:
temp_path = temp_file.name
if hasattr(media, 'name') and os.path.exists(media.name):
with open(media.name, 'rb') as f:
temp_file.write(f.read())
else:
file_content = media.read() if hasattr(media, 'read') else media
temp_file.write(file_content if isinstance(file_content, bytes) else file_content.read())
detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(temp_path)
annotated_video_path = None
if is_image(media.name):
detected_activities = analyze_image_activities(temp_path)
else:
detected_activities = analyze_video_activities_single_instance(temp_path)
annotated_video_path = annotate_video_with_bboxes(temp_path)
if os.path.exists(temp_path):
os.remove(temp_path)
detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()])
return [day, date, str(detected_people), str(detected_machinery), detected_types_str, detected_activities, annotated_video_path]
except Exception as e:
print(f"Error processing media: {str(e)}")
return [day, date, "Error processing media", "Error processing media", "Error processing media", "Error processing media", None]
with gr.Blocks(title="Digital Site Diary") as demo:
gr.Markdown("# 📝 Digital Site Diary")
with gr.Row():
with gr.Column():
gr.Markdown("### User Input")
day = gr.Textbox(label="Day", value='9')
date = gr.Textbox(label="Date", placeholder="YYYY-MM-DD", value=datetime.now().strftime("%Y-%m-%d"))
total_people = gr.Number(label="Total Number of People", precision=0, value=10)
total_machinery = gr.Number(label="Total Number of Machinery", precision=0, value=3)
machinery_types = gr.Textbox(label="Number of Machinery Per Type",
placeholder="e.g., Excavator: 2, Roller: 1",
value="Excavator: 2, Roller: 1")
activities = gr.Textbox(label="Activity",
placeholder="e.g., 9 AM: Excavation, 10 AM: Concreting",
value="9 AM: Excavation, 10 AM: Concreting", lines=3)
media = gr.File(label="Upload Image/Video", file_types=["image", "video"])
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column():
gr.Markdown("### Model Detection")
model_day = gr.Textbox(label="Day")
model_date = gr.Textbox(label="Date")
model_people = gr.Textbox(label="Total Number of People")
model_machinery = gr.Textbox(label="Total Number of Machinery")
model_machinery_types = gr.Textbox(label="Number of Machinery Per Type")
model_activities = gr.Textbox(label="Activity", lines=5)
model_annotated_video = gr.Video(label="Annotated Video")
submit_btn.click(
fn=process_diary,
inputs=[day, date, total_people, total_machinery, machinery_types, activities, media],
outputs=[model_day, model_date, model_people, model_machinery, model_machinery_types, model_activities, model_annotated_video]
)
if __name__ == "__main__":
demo.launch(share=False, debug=True, show_api=False, server_port=args.port, server_name=args.host)