Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
from datetime import datetime | |
import tempfile | |
import os | |
import json | |
import torch | |
import gc | |
import shutil # Added for directory cleanup | |
from azure.storage.blob import BlobServiceClient # Added for Azure integration | |
def debug(): | |
torch.randn(10).cuda() | |
debug() | |
from PIL import Image | |
from decord import VideoReader, cpu | |
from yolo_detection import ( | |
detect_people_and_machinery, # Keep for images | |
# annotate_video_with_bboxes, # Replaced by unified function | |
process_video_unified, # Import the new unified function | |
is_image, | |
is_video | |
) | |
from image_captioning import ( | |
analyze_image_activities, | |
analyze_video_activities, | |
process_video_chunk, | |
load_model_and_tokenizer, | |
MAX_NUM_FRAMES | |
) | |
# Load model instance once | |
gc.collect() | |
torch.cuda.empty_cache() | |
model, tokenizer, processor = load_model_and_tokenizer() | |
print("Model loaded.") | |
# Azure Blob Storage Setup | |
CONTAINER_NAME = "logs" # Replace with your actual container name | |
connection_string = "BlobEndpoint=https://assentian.blob.core.windows.net/;QueueEndpoint=https://assentian.queue.core.windows.net/;FileEndpoint=https://assentian.file.core.windows.net/;TableEndpoint=https://assentian.table.core.windows.net/;SharedAccessSignature=sv=2024-11-04&ss=bfqt&srt=sco&sp=rwdlacupiytfx&se=2025-04-30T17:16:18Z&st=2025-04-22T09:16:18Z&spr=https&sig=AkJb79C%2FJ0G1HqfotIYuSfm%2Fb%2BQ2E%2FjvxV3ZG7ejVQo%3D" | |
if not connection_string: | |
print("Warning: AZURE_STORAGE_CONNECTION_STRING not found. Azure Blob functionality will be disabled.") | |
# Initialize Azure Blob Service Client if connection string is available | |
blob_service_client = None | |
if connection_string: | |
try: | |
blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
print("Azure Blob Service Client initialized successfully.") | |
except Exception as e: | |
print(f"Error initializing BlobServiceClient: {str(e)}") | |
blob_service_client = None | |
def list_blobs(): | |
"""List video blobs in the specified Azure container.""" | |
if not blob_service_client: | |
print("Cannot list blobs: BlobServiceClient is not initialized.") | |
return [] | |
try: | |
container_client = blob_service_client.get_container_client(CONTAINER_NAME) | |
blobs = container_client.list_blobs() | |
video_extensions = ['.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'] | |
blob_list = [blob.name for blob in blobs if any(blob.name.lower().endswith(ext) for ext in video_extensions)] | |
print(f"Found {len(blob_list)} video blobs in container '{CONTAINER_NAME}': {blob_list}") | |
return blob_list | |
except Exception as e: | |
print(f"Error listing blobs in container '{CONTAINER_NAME}': {str(e)}") | |
return [] | |
# Fetch blob names at startup | |
blob_names = list_blobs() | |
print(f"Populated azure_blob dropdown with {len(blob_names)} options.") | |
# Global storage for activities and media paths | |
global_activities = [] | |
global_media_path = None | |
global_temp_media_path = None # Store path if downloaded from Azure for cleanup | |
# Create tmp directory for storing frames | |
tmp_dir = os.path.join('.', 'tmp') | |
os.makedirs(tmp_dir, exist_ok=True) | |
def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media_source, local_file, azure_blob): | |
"""Process the site diary entry with media from local file or Azure Blob Storage.""" | |
global global_activities, global_media_path, global_temp_media_path | |
global_temp_media_path = None # Reset before processing | |
if media_source == "Local File": | |
if local_file is None: | |
return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", None, None, [], None, []] | |
media_path = local_file # local_file is a string path in Gradio | |
print(f"Processing local file: {media_path}") | |
else: # Azure Blob | |
if not azure_blob or not blob_service_client: | |
return [day, date, "No blob selected or Azure not configured", "No blob selected or Azure not configured", | |
"No blob selected or Azure not configured", None, None, [], None, []] | |
try: | |
blob_client = blob_service_client.get_blob_client(container=CONTAINER_NAME, blob=azure_blob) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(azure_blob)[1]) as temp_file: | |
temp_path = temp_file.name | |
blob_data = blob_client.download_blob() | |
blob_data.readinto(temp_file) | |
media_path = temp_path | |
global_temp_media_path = media_path # Store for cleanup | |
print(f"Downloaded Azure blob '{azure_blob}' to temporary file: {media_path}") | |
except Exception as e: | |
print(f"Error downloading blob '{azure_blob}': {str(e)}") | |
return [day, date, "Error downloading blob", "Error downloading blob", "Error downloading blob", None, None, [], None, []] | |
# Ensure cleanup happens even on error | |
try: | |
file_ext = get_file_extension(media_path) | |
if not (is_image(media_path) or is_video(media_path)): | |
raise ValueError(f"Unsupported file type: {file_ext}") | |
annotated_video_path = None # Initialize | |
if is_image(media_path): | |
# Process image with original function | |
detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(media_path) | |
detected_activities = analyze_image_activities(media_path) | |
else: # It's a video | |
# Process video with the unified function | |
print("Processing video with unified YOLO function...") | |
detected_people, detected_machinery, detected_machinery_types, annotated_video_path = process_video_unified(media_path) | |
print(f"Unified YOLO results - People: {detected_people}, Machinery: {detected_machinery}, Types: {detected_machinery_types}, Annotated Video: {annotated_video_path}") | |
# Now analyze activities | |
detected_activities = analyze_video_activities(media_path, model, tokenizer, processor) | |
# Debug the detected activities | |
print(f"Detected activities (raw): {detected_activities}") | |
print(f"Type of detected_activities: {type(detected_activities)}") | |
# Ensure detected_activities is a list of dictionaries | |
if isinstance(detected_activities, str): | |
print("Warning: detected_activities is a string, converting to list of dicts.") | |
detected_activities = [{"time": "Unknown", "summary": detected_activities}] | |
elif not isinstance(detected_activities, list): | |
print("Warning: detected_activities is not a list, wrapping in a list.") | |
detected_activities = [{"time": "Unknown", "summary": str(detected_activities)}] | |
# Validate each activity | |
for activity in detected_activities: | |
if not isinstance(activity, dict): | |
print(f"Warning: Invalid activity format: {activity}, converting.") | |
activity = {"time": "Unknown", "summary": str(activity)} | |
print(f"Processed detected_activities: {detected_activities}") | |
# Store activities and media path globally for chat mode | |
global_activities = detected_activities | |
global_media_path = media_path | |
# The annotation is now handled within process_video_unified for videos | |
# if is_video(media_path): | |
# annotated_video_path = annotate_video_with_bboxes(media_path) # Removed duplicate call | |
# print(f"Generated annotated video: {annotated_video_path}") | |
# Clean up temporary file if downloaded from Azure - This is now handled in the finally block | |
# if media_source == "Azure Blob" and os.path.exists(media_path): | |
# os.remove(media_path) | |
# print(f"Cleaned up temporary file: {media_path}") | |
detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()]) | |
# The cleanup for Azure temp files is now handled in the finally block | |
# os.remove(media_path) | |
# print(f"Cleaned up temporary file: {media_path}") | |
# We'll return the activities as a list for the card display | |
# Clear the chat history when loading new media | |
chat_history = [] | |
# Extract data for the activity table | |
activity_rows = [] | |
for activity in detected_activities: | |
time = activity.get('time', 'Unknown') | |
summary = activity.get('summary', 'No description available') | |
activity_rows.append([time, summary]) | |
print(f"Activity rows for Dataframe: {activity_rows}") | |
return [day, date, str(detected_people), str(detected_machinery), | |
detected_types_str, gr.update(visible=True), annotated_video_path, | |
detected_activities, chat_history, activity_rows] | |
except Exception as e: | |
print(f"Error processing media: {str(e)}") | |
# Cleanup is handled in finally block now | |
# if media_source == "Azure Blob" and os.path.exists(media_path): | |
# os.remove(media_path) | |
# print(f"Cleaned up temporary file due to error: {media_path}") | |
return [day, date, "Error processing media", "Error processing media", | |
"Error processing media", None, None, [], None, []] | |
finally: | |
# Cleanup temporary files and GPU memory | |
print("Running cleanup...") | |
if global_temp_media_path and os.path.exists(global_temp_media_path): | |
try: | |
os.remove(global_temp_media_path) | |
print(f"Cleaned up temporary Azure file: {global_temp_media_path}") | |
except OSError as e: | |
print(f"Error removing temporary Azure file {global_temp_media_path}: {e}") | |
# Clear GPU cache | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
print("Cleared GPU cache.") | |
def get_file_extension(filename): | |
return os.path.splitext(filename)[1].lower() | |
def on_card_click(activity_indices, history, evt: gr.SelectData): | |
"""Handle clicking on an activity card in the gallery""" | |
global global_activities, global_media_path | |
# Get the index of the selected activity from the SelectData event | |
selected_idx = evt.index | |
# Map the gallery index to the actual activity index | |
if selected_idx < 0 or selected_idx >= len(activity_indices): | |
return [gr.update(visible=True), gr.update(visible=False), [], None] | |
card_idx = activity_indices[selected_idx] | |
print(f"Gallery item {selected_idx} clicked, corresponds to activity index: {card_idx}") | |
if card_idx < 0 or card_idx >= len(global_activities): | |
return [gr.update(visible=True), gr.update(visible=False), [], None] | |
selected_activity = global_activities[card_idx] | |
chunk_video_path = None | |
# Use the pre-saved chunk video if available | |
if 'chunk_path' in selected_activity and os.path.exists(selected_activity['chunk_path']): | |
chunk_video_path = selected_activity['chunk_path'] | |
print(f"Using pre-saved chunk video: {chunk_video_path}") | |
else: | |
# Fallback to full video if chunk not available | |
chunk_video_path = global_media_path | |
print(f"Chunk video not available, using full video: {chunk_video_path}") | |
# Add the selected activity to chat history | |
history = [] | |
history.append((None, f"🎬 Selected video at timestamp {selected_activity['time']}")) | |
# Add the thumbnail to the chat as a visual element | |
if 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']): | |
# Use the tuple format for images in chatbot | |
thumbnail_path = selected_activity['thumbnail'] | |
history.append((None, f"📷 Video frame at {selected_activity['time']}")) | |
history.append((None, thumbnail_path)) | |
# Format message about the detected activity | |
activity_info = f"I detected the following activity:\n\n{selected_activity['summary']}" | |
if selected_activity['objects']: | |
activity_info += f"\n\nIdentified objects: {', '.join(selected_activity['objects'])}" | |
history.append(("Tell me about this video segment", activity_info)) | |
return [gr.update(visible=False), gr.update(visible=True), history, chunk_video_path] | |
def chat_with_video(message, history): | |
"""Chat with the mPLUG model about the selected video segment""" | |
global global_activities, global_media_path | |
try: | |
# Get the selected activity from the history to identify which chunk we're discussing | |
selected_chunk_idx = None | |
selected_time = None | |
selected_activity = None | |
for entry in history: | |
if entry[0] is None and "Selected video at timestamp" in entry[1]: | |
time_str = entry[1].split("Selected video at timestamp ")[1] | |
selected_time = time_str.strip() | |
break | |
# Find the corresponding chunk | |
if selected_time: | |
for i, activity in enumerate(global_activities): | |
if activity.get('time') == selected_time: | |
selected_chunk_idx = activity.get('chunk_id') | |
selected_activity = activity | |
break | |
# If we found the chunk, use the model to analyze it | |
if selected_chunk_idx is not None and global_media_path and selected_activity: | |
# Generate prompt based on user question and add context about what's in the video | |
context = f"This video shows construction site activities at timestamp {selected_time}." | |
if selected_activity.get('objects'): | |
context += f" The scene contains {', '.join(selected_activity.get('objects'))}." | |
prompt = f"{context} Analyze this segment of construction site video and answer this question: {message}" | |
# This would ideally use the specific chunk, but for simplicity we'll use the global path | |
# In a production system, you'd extract just that chunk of the video | |
vr = VideoReader(global_media_path, ctx=cpu(0)) | |
# Get the frames for this chunk | |
sample_fps = round(vr.get_avg_fps() / 1) | |
frame_idx = [i for i in range(0, len(vr), sample_fps)] | |
# Extract frames for the specific chunk | |
chunk_size = MAX_NUM_FRAMES | |
start_idx = selected_chunk_idx * chunk_size | |
end_idx = min(start_idx + chunk_size, len(frame_idx)) | |
chunk_frames = frame_idx[start_idx:end_idx] | |
if chunk_frames: | |
frames = vr.get_batch(chunk_frames).asnumpy() | |
frames_pil = [Image.fromarray(v.astype('uint8')) for v in frames] | |
# Process frames with model | |
response = process_video_chunk(frames_pil, model, tokenizer, processor, prompt) | |
return history + [(message, response)] | |
else: | |
return history + [(message, "Could not extract frames for this segment.")] | |
else: | |
# Fallback response if we can't identify the chunk | |
thumbnail = None | |
response_text = f"I'm analyzing your question about the video segment: {message}\n\nBased on what I can see in this segment, it appears to show construction activity with various machinery and workers on site. The specific details would depend on the exact timestamp you're referring to." | |
# Try to get a thumbnail from the selected activity if available | |
if selected_activity and 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']): | |
thumbnail = selected_activity['thumbnail'] | |
new_history = history + [(message, response_text)] | |
new_history.append((None, f"📷 Video frame at {selected_time}")) | |
new_history.append((None, thumbnail)) | |
return new_history | |
return history + [(message, response_text)] | |
except Exception as e: | |
print(f"Error in chat_with_video: {str(e)}") | |
return history + [(message, f"I encountered an error while processing your question. Let me try to answer based on what I can see: {message}\n\nThe video appears to show construction site activities, but I'm having trouble with the detailed analysis at the moment.")] | |
# Native Gradio activity cards | |
def create_activity_cards_ui(activities): | |
"""Create activity cards using native Gradio components""" | |
if not activities: | |
return gr.HTML("<div class='activity-timeline'><h3>No activities detected</h3></div>"), [] | |
# Prepare data for gallery | |
thumbnails = [] | |
captions = [] | |
activity_indices = [] | |
for i, activity in enumerate(activities): | |
thumbnail = activity.get('thumbnail', '') | |
time = activity.get('time', 'Unknown') | |
summary = activity.get('summary', 'No description available') | |
objects_list = activity.get('objects', []) | |
objects_text = f"Objects: {', '.join(objects_list)}" if objects_list else "" | |
# Truncate summary if too long | |
if len(summary) > 150: | |
summary = summary[:147] + "..." | |
thumbnails.append(thumbnail) | |
captions.append(f"Timestamp: {time} | {summary}") | |
activity_indices.append(i) | |
# Create a gallery for the thumbnails | |
gallery = gr.Gallery( | |
value=[(path, caption) for path, caption in zip(thumbnails, captions)], | |
columns=5, | |
rows=None, | |
height="auto", | |
object_fit="contain", | |
label="Activity Timeline" | |
) | |
return gallery, activity_indices | |
# Create the Gradio interface | |
with gr.Blocks(title="Digital Site Diary", css="") as demo: | |
gr.Markdown("# 📝 Digital Site Diary") | |
# Activity data and indices storage | |
activity_data = gr.State([]) | |
activity_indices = gr.State([]) | |
# Create tabs for different views | |
with gr.Tabs() as tabs: | |
with gr.Tab("Site Diary"): | |
with gr.Row(): | |
# User Input Column | |
with gr.Column(): | |
gr.Markdown("### User Input") | |
day = gr.Textbox(label="Day", value='9') | |
date = gr.Textbox(label="Date", placeholder="YYYY-MM-DD", value=datetime.now().strftime("%Y-%m-%d")) | |
total_people = gr.Number(label="Total Number of People", precision=0, value=10) | |
total_machinery = gr.Number(label="Total Number of Machinery", precision=0, value=3) | |
machinery_types = gr.Textbox( | |
label="Number of Machinery Per Type", | |
placeholder="e.g., Excavator: 2, Roller: 1", | |
value="Excavator: 2, Roller: 1" | |
) | |
activities = gr.Textbox( | |
label="Activity", | |
placeholder="e.g., 9 AM: Excavation, 10 AM: Concreting", | |
value="9 AM: Excavation, 10 AM: Concreting", | |
lines=3 | |
) | |
media_source = gr.Radio(["Local File", "Azure Blob"], label="Media Source", value="Local File") | |
local_file = gr.File(label="Upload Image/Video", file_types=["image", "video"], visible=True) | |
azure_blob = gr.Dropdown(label="Select Video from Azure", choices=blob_names, visible=False) | |
submit_btn = gr.Button("Submit", variant="primary") | |
# Model Detection Column | |
with gr.Column(): | |
gr.Markdown("### Model Detection") | |
model_day = gr.Textbox(label="Day") | |
model_date = gr.Textbox(label="Date") | |
model_people = gr.Textbox(label="Total Number of People") | |
model_machinery = gr.Textbox(label="Total Number of Machinery") | |
model_machinery_types = gr.Textbox(label="Number of Machinery Per Type") | |
# Activity Row with Timestamps | |
with gr.Row(): | |
gr.Markdown("#### Activities with Timestamps") | |
model_activities = gr.Dataframe( | |
headers=["Time", "Activity Description"], | |
datatype=["str", "str"], | |
label="Detected Activities", | |
interactive=False, | |
wrap=True | |
) | |
# Activity timeline section | |
with gr.Row(): | |
# Timeline View (default visible) | |
with gr.Column(visible=True) as timeline_view: | |
activity_gallery = gr.Gallery(label="Activity Timeline") | |
model_annotated_video = gr.Video(label="Full Video") | |
# Chat View (initially hidden) | |
with gr.Column(visible=False) as chat_view: | |
chunk_video = gr.Video(label="Chunk video") | |
chatbot = gr.Chatbot(height=400) | |
chat_input = gr.Textbox( | |
placeholder="Ask about this video segment...", | |
show_label=False | |
) | |
back_btn = gr.Button("← Back to Timeline") | |
# Update visibility based on media source | |
def update_visibility(source): | |
if source == "Local File": | |
return gr.update(visible=True), gr.update(visible=False) | |
else: | |
return gr.update(visible=False), gr.update(visible=True) | |
media_source.change(fn=update_visibility, inputs=media_source, outputs=[local_file, azure_blob]) | |
# Connect the submit button to the processing function | |
submit_btn.click( | |
fn=process_diary, | |
inputs=[day, date, total_people, total_machinery, machinery_types, activities, media_source, local_file, azure_blob], | |
outputs=[ | |
model_day, | |
model_date, | |
model_people, | |
model_machinery, | |
model_machinery_types, | |
timeline_view, | |
model_annotated_video, | |
activity_data, | |
chatbot, | |
model_activities | |
] | |
) | |
# Process activity data into gallery | |
activity_data.change( | |
fn=create_activity_cards_ui, | |
inputs=[activity_data], | |
outputs=[activity_gallery, activity_indices] | |
) | |
# Handle gallery selection | |
activity_gallery.select( | |
fn=on_card_click, | |
inputs=[activity_indices, chatbot], | |
outputs=[timeline_view, chat_view, chatbot, chunk_video] | |
) | |
# Chat submission | |
chat_input.submit( | |
fn=chat_with_video, | |
inputs=[chat_input, chatbot], | |
outputs=[chatbot] | |
) | |
# Back button | |
back_btn.click( | |
fn=lambda: [gr.update(visible=True), gr.update(visible=False)], | |
inputs=None, | |
outputs=[timeline_view, chat_view] | |
) | |
# Add enhanced CSS styling | |
gr.HTML(""" | |
<style> | |
/* Gallery customizations */ | |
.gradio-container .gallery-item { | |
border: 1px solid #444444 !important; | |
border-radius: 8px !important; | |
padding: 8px !important; | |
margin: 10px !important; | |
cursor: pointer !important; | |
transition: all 0.3s !important; | |
background: #18181b !important; | |
box-shadow: 0 2px 5px rgba(0,0,0,0.2) !important; | |
} | |
.gradio-container .gallery-item:hover { | |
transform: translateY(-2px) !important; | |
box-shadow: 0 4px 12px rgba(0,0,0,0.25) !important; | |
border-color: #007bff !important; | |
background: #202025 !important; | |
} | |
.gradio-container .gallery-item.selected { | |
border: 2px solid #007bff !important; | |
background: #202030 !important; | |
} | |
/* Improved image display */ | |
.gradio-container .gallery-item img { | |
height: 180px !important; | |
object-fit: cover !important; | |
border-radius: 4px !important; | |
border: 1px solid #444444 !important; | |
margin-bottom: 8px !important; | |
} | |
/* Caption styling */ | |
.gradio-container .caption { | |
color: #e0e0e0 !important; | |
font-size: 0.9em !important; | |
margin-top: 8px !important; | |
line-height: 1.4 !important; | |
padding: 0 4px !important; | |
} | |
/* Gallery container */ | |
.gradio-container [id*='gallery'] > div:first-child { | |
background-color: #27272a !important; | |
padding: 15px !important; | |
border-radius: 10px !important; | |
} | |
/* Chatbot styling */ | |
.gradio-container .chatbot { | |
background-color: #27272a !important; | |
border-radius: 10px !important; | |
border: 1px solid #444444 !important; | |
} | |
.gradio-container .chatbot .message.user { | |
background-color: #18181b !important; | |
border-radius: 8px !important; | |
} | |
.gradio-container .chatbot .message.bot { | |
background-color: #202030 !important; | |
border-radius: 8px !important; | |
} | |
/* Button styling */ | |
.gradio-container button.secondary { | |
background-color: #3d4452 !important; | |
color: white !important; | |
} | |
</style> | |
""") | |
if __name__ == "__main__": | |
demo.launch(allowed_paths=["./tmp"], share=True) | |