V3Test / app.py
assentian1970's picture
Update app.py
6c50846 verified
import spaces
import gradio as gr
from datetime import datetime
import tempfile
import os
import json
import torch
import gc
import shutil # Added for directory cleanup
from azure.storage.blob import BlobServiceClient # Added for Azure integration
def debug():
torch.randn(10).cuda()
debug()
from PIL import Image
from decord import VideoReader, cpu
from yolo_detection import (
detect_people_and_machinery, # Keep for images
# annotate_video_with_bboxes, # Replaced by unified function
process_video_unified, # Import the new unified function
is_image,
is_video
)
from image_captioning import (
analyze_image_activities,
analyze_video_activities,
process_video_chunk,
load_model_and_tokenizer,
MAX_NUM_FRAMES
)
# Load model instance once
gc.collect()
torch.cuda.empty_cache()
model, tokenizer, processor = load_model_and_tokenizer()
print("Model loaded.")
# Azure Blob Storage Setup
CONTAINER_NAME = "logs" # Replace with your actual container name
connection_string = "BlobEndpoint=https://assentian.blob.core.windows.net/;QueueEndpoint=https://assentian.queue.core.windows.net/;FileEndpoint=https://assentian.file.core.windows.net/;TableEndpoint=https://assentian.table.core.windows.net/;SharedAccessSignature=sv=2024-11-04&ss=bfqt&srt=sco&sp=rwdlacupiytfx&se=2025-04-30T17:16:18Z&st=2025-04-22T09:16:18Z&spr=https&sig=AkJb79C%2FJ0G1HqfotIYuSfm%2Fb%2BQ2E%2FjvxV3ZG7ejVQo%3D"
if not connection_string:
print("Warning: AZURE_STORAGE_CONNECTION_STRING not found. Azure Blob functionality will be disabled.")
# Initialize Azure Blob Service Client if connection string is available
blob_service_client = None
if connection_string:
try:
blob_service_client = BlobServiceClient.from_connection_string(connection_string)
print("Azure Blob Service Client initialized successfully.")
except Exception as e:
print(f"Error initializing BlobServiceClient: {str(e)}")
blob_service_client = None
def list_blobs():
"""List video blobs in the specified Azure container."""
if not blob_service_client:
print("Cannot list blobs: BlobServiceClient is not initialized.")
return []
try:
container_client = blob_service_client.get_container_client(CONTAINER_NAME)
blobs = container_client.list_blobs()
video_extensions = ['.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v']
blob_list = [blob.name for blob in blobs if any(blob.name.lower().endswith(ext) for ext in video_extensions)]
print(f"Found {len(blob_list)} video blobs in container '{CONTAINER_NAME}': {blob_list}")
return blob_list
except Exception as e:
print(f"Error listing blobs in container '{CONTAINER_NAME}': {str(e)}")
return []
# Fetch blob names at startup
blob_names = list_blobs()
print(f"Populated azure_blob dropdown with {len(blob_names)} options.")
# Global storage for activities and media paths
global_activities = []
global_media_path = None
global_temp_media_path = None # Store path if downloaded from Azure for cleanup
# Create tmp directory for storing frames
tmp_dir = os.path.join('.', 'tmp')
os.makedirs(tmp_dir, exist_ok=True)
@spaces.GPU
def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media_source, local_file, azure_blob):
"""Process the site diary entry with media from local file or Azure Blob Storage."""
global global_activities, global_media_path, global_temp_media_path
global_temp_media_path = None # Reset before processing
if media_source == "Local File":
if local_file is None:
return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", None, None, [], None, []]
media_path = local_file # local_file is a string path in Gradio
print(f"Processing local file: {media_path}")
else: # Azure Blob
if not azure_blob or not blob_service_client:
return [day, date, "No blob selected or Azure not configured", "No blob selected or Azure not configured",
"No blob selected or Azure not configured", None, None, [], None, []]
try:
blob_client = blob_service_client.get_blob_client(container=CONTAINER_NAME, blob=azure_blob)
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(azure_blob)[1]) as temp_file:
temp_path = temp_file.name
blob_data = blob_client.download_blob()
blob_data.readinto(temp_file)
media_path = temp_path
global_temp_media_path = media_path # Store for cleanup
print(f"Downloaded Azure blob '{azure_blob}' to temporary file: {media_path}")
except Exception as e:
print(f"Error downloading blob '{azure_blob}': {str(e)}")
return [day, date, "Error downloading blob", "Error downloading blob", "Error downloading blob", None, None, [], None, []]
# Ensure cleanup happens even on error
try:
file_ext = get_file_extension(media_path)
if not (is_image(media_path) or is_video(media_path)):
raise ValueError(f"Unsupported file type: {file_ext}")
annotated_video_path = None # Initialize
if is_image(media_path):
# Process image with original function
detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(media_path)
detected_activities = analyze_image_activities(media_path)
else: # It's a video
# Process video with the unified function
print("Processing video with unified YOLO function...")
detected_people, detected_machinery, detected_machinery_types, annotated_video_path = process_video_unified(media_path)
print(f"Unified YOLO results - People: {detected_people}, Machinery: {detected_machinery}, Types: {detected_machinery_types}, Annotated Video: {annotated_video_path}")
# Now analyze activities
detected_activities = analyze_video_activities(media_path, model, tokenizer, processor)
# Debug the detected activities
print(f"Detected activities (raw): {detected_activities}")
print(f"Type of detected_activities: {type(detected_activities)}")
# Ensure detected_activities is a list of dictionaries
if isinstance(detected_activities, str):
print("Warning: detected_activities is a string, converting to list of dicts.")
detected_activities = [{"time": "Unknown", "summary": detected_activities}]
elif not isinstance(detected_activities, list):
print("Warning: detected_activities is not a list, wrapping in a list.")
detected_activities = [{"time": "Unknown", "summary": str(detected_activities)}]
# Validate each activity
for activity in detected_activities:
if not isinstance(activity, dict):
print(f"Warning: Invalid activity format: {activity}, converting.")
activity = {"time": "Unknown", "summary": str(activity)}
print(f"Processed detected_activities: {detected_activities}")
# Store activities and media path globally for chat mode
global_activities = detected_activities
global_media_path = media_path
# The annotation is now handled within process_video_unified for videos
# if is_video(media_path):
# annotated_video_path = annotate_video_with_bboxes(media_path) # Removed duplicate call
# print(f"Generated annotated video: {annotated_video_path}")
# Clean up temporary file if downloaded from Azure - This is now handled in the finally block
# if media_source == "Azure Blob" and os.path.exists(media_path):
# os.remove(media_path)
# print(f"Cleaned up temporary file: {media_path}")
detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()])
# The cleanup for Azure temp files is now handled in the finally block
# os.remove(media_path)
# print(f"Cleaned up temporary file: {media_path}")
# We'll return the activities as a list for the card display
# Clear the chat history when loading new media
chat_history = []
# Extract data for the activity table
activity_rows = []
for activity in detected_activities:
time = activity.get('time', 'Unknown')
summary = activity.get('summary', 'No description available')
activity_rows.append([time, summary])
print(f"Activity rows for Dataframe: {activity_rows}")
return [day, date, str(detected_people), str(detected_machinery),
detected_types_str, gr.update(visible=True), annotated_video_path,
detected_activities, chat_history, activity_rows]
except Exception as e:
print(f"Error processing media: {str(e)}")
# Cleanup is handled in finally block now
# if media_source == "Azure Blob" and os.path.exists(media_path):
# os.remove(media_path)
# print(f"Cleaned up temporary file due to error: {media_path}")
return [day, date, "Error processing media", "Error processing media",
"Error processing media", None, None, [], None, []]
finally:
# Cleanup temporary files and GPU memory
print("Running cleanup...")
if global_temp_media_path and os.path.exists(global_temp_media_path):
try:
os.remove(global_temp_media_path)
print(f"Cleaned up temporary Azure file: {global_temp_media_path}")
except OSError as e:
print(f"Error removing temporary Azure file {global_temp_media_path}: {e}")
# Clear GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
print("Cleared GPU cache.")
def get_file_extension(filename):
return os.path.splitext(filename)[1].lower()
def on_card_click(activity_indices, history, evt: gr.SelectData):
"""Handle clicking on an activity card in the gallery"""
global global_activities, global_media_path
# Get the index of the selected activity from the SelectData event
selected_idx = evt.index
# Map the gallery index to the actual activity index
if selected_idx < 0 or selected_idx >= len(activity_indices):
return [gr.update(visible=True), gr.update(visible=False), [], None]
card_idx = activity_indices[selected_idx]
print(f"Gallery item {selected_idx} clicked, corresponds to activity index: {card_idx}")
if card_idx < 0 or card_idx >= len(global_activities):
return [gr.update(visible=True), gr.update(visible=False), [], None]
selected_activity = global_activities[card_idx]
chunk_video_path = None
# Use the pre-saved chunk video if available
if 'chunk_path' in selected_activity and os.path.exists(selected_activity['chunk_path']):
chunk_video_path = selected_activity['chunk_path']
print(f"Using pre-saved chunk video: {chunk_video_path}")
else:
# Fallback to full video if chunk not available
chunk_video_path = global_media_path
print(f"Chunk video not available, using full video: {chunk_video_path}")
# Add the selected activity to chat history
history = []
history.append((None, f"🎬 Selected video at timestamp {selected_activity['time']}"))
# Add the thumbnail to the chat as a visual element
if 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']):
# Use the tuple format for images in chatbot
thumbnail_path = selected_activity['thumbnail']
history.append((None, f"📷 Video frame at {selected_activity['time']}"))
history.append((None, thumbnail_path))
# Format message about the detected activity
activity_info = f"I detected the following activity:\n\n{selected_activity['summary']}"
if selected_activity['objects']:
activity_info += f"\n\nIdentified objects: {', '.join(selected_activity['objects'])}"
history.append(("Tell me about this video segment", activity_info))
return [gr.update(visible=False), gr.update(visible=True), history, chunk_video_path]
def chat_with_video(message, history):
"""Chat with the mPLUG model about the selected video segment"""
global global_activities, global_media_path
try:
# Get the selected activity from the history to identify which chunk we're discussing
selected_chunk_idx = None
selected_time = None
selected_activity = None
for entry in history:
if entry[0] is None and "Selected video at timestamp" in entry[1]:
time_str = entry[1].split("Selected video at timestamp ")[1]
selected_time = time_str.strip()
break
# Find the corresponding chunk
if selected_time:
for i, activity in enumerate(global_activities):
if activity.get('time') == selected_time:
selected_chunk_idx = activity.get('chunk_id')
selected_activity = activity
break
# If we found the chunk, use the model to analyze it
if selected_chunk_idx is not None and global_media_path and selected_activity:
# Generate prompt based on user question and add context about what's in the video
context = f"This video shows construction site activities at timestamp {selected_time}."
if selected_activity.get('objects'):
context += f" The scene contains {', '.join(selected_activity.get('objects'))}."
prompt = f"{context} Analyze this segment of construction site video and answer this question: {message}"
# This would ideally use the specific chunk, but for simplicity we'll use the global path
# In a production system, you'd extract just that chunk of the video
vr = VideoReader(global_media_path, ctx=cpu(0))
# Get the frames for this chunk
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = [i for i in range(0, len(vr), sample_fps)]
# Extract frames for the specific chunk
chunk_size = MAX_NUM_FRAMES
start_idx = selected_chunk_idx * chunk_size
end_idx = min(start_idx + chunk_size, len(frame_idx))
chunk_frames = frame_idx[start_idx:end_idx]
if chunk_frames:
frames = vr.get_batch(chunk_frames).asnumpy()
frames_pil = [Image.fromarray(v.astype('uint8')) for v in frames]
# Process frames with model
response = process_video_chunk(frames_pil, model, tokenizer, processor, prompt)
return history + [(message, response)]
else:
return history + [(message, "Could not extract frames for this segment.")]
else:
# Fallback response if we can't identify the chunk
thumbnail = None
response_text = f"I'm analyzing your question about the video segment: {message}\n\nBased on what I can see in this segment, it appears to show construction activity with various machinery and workers on site. The specific details would depend on the exact timestamp you're referring to."
# Try to get a thumbnail from the selected activity if available
if selected_activity and 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']):
thumbnail = selected_activity['thumbnail']
new_history = history + [(message, response_text)]
new_history.append((None, f"📷 Video frame at {selected_time}"))
new_history.append((None, thumbnail))
return new_history
return history + [(message, response_text)]
except Exception as e:
print(f"Error in chat_with_video: {str(e)}")
return history + [(message, f"I encountered an error while processing your question. Let me try to answer based on what I can see: {message}\n\nThe video appears to show construction site activities, but I'm having trouble with the detailed analysis at the moment.")]
# Native Gradio activity cards
def create_activity_cards_ui(activities):
"""Create activity cards using native Gradio components"""
if not activities:
return gr.HTML("<div class='activity-timeline'><h3>No activities detected</h3></div>"), []
# Prepare data for gallery
thumbnails = []
captions = []
activity_indices = []
for i, activity in enumerate(activities):
thumbnail = activity.get('thumbnail', '')
time = activity.get('time', 'Unknown')
summary = activity.get('summary', 'No description available')
objects_list = activity.get('objects', [])
objects_text = f"Objects: {', '.join(objects_list)}" if objects_list else ""
# Truncate summary if too long
if len(summary) > 150:
summary = summary[:147] + "..."
thumbnails.append(thumbnail)
captions.append(f"Timestamp: {time} | {summary}")
activity_indices.append(i)
# Create a gallery for the thumbnails
gallery = gr.Gallery(
value=[(path, caption) for path, caption in zip(thumbnails, captions)],
columns=5,
rows=None,
height="auto",
object_fit="contain",
label="Activity Timeline"
)
return gallery, activity_indices
# Create the Gradio interface
with gr.Blocks(title="Digital Site Diary", css="") as demo:
gr.Markdown("# 📝 Digital Site Diary")
# Activity data and indices storage
activity_data = gr.State([])
activity_indices = gr.State([])
# Create tabs for different views
with gr.Tabs() as tabs:
with gr.Tab("Site Diary"):
with gr.Row():
# User Input Column
with gr.Column():
gr.Markdown("### User Input")
day = gr.Textbox(label="Day", value='9')
date = gr.Textbox(label="Date", placeholder="YYYY-MM-DD", value=datetime.now().strftime("%Y-%m-%d"))
total_people = gr.Number(label="Total Number of People", precision=0, value=10)
total_machinery = gr.Number(label="Total Number of Machinery", precision=0, value=3)
machinery_types = gr.Textbox(
label="Number of Machinery Per Type",
placeholder="e.g., Excavator: 2, Roller: 1",
value="Excavator: 2, Roller: 1"
)
activities = gr.Textbox(
label="Activity",
placeholder="e.g., 9 AM: Excavation, 10 AM: Concreting",
value="9 AM: Excavation, 10 AM: Concreting",
lines=3
)
media_source = gr.Radio(["Local File", "Azure Blob"], label="Media Source", value="Local File")
local_file = gr.File(label="Upload Image/Video", file_types=["image", "video"], visible=True)
azure_blob = gr.Dropdown(label="Select Video from Azure", choices=blob_names, visible=False)
submit_btn = gr.Button("Submit", variant="primary")
# Model Detection Column
with gr.Column():
gr.Markdown("### Model Detection")
model_day = gr.Textbox(label="Day")
model_date = gr.Textbox(label="Date")
model_people = gr.Textbox(label="Total Number of People")
model_machinery = gr.Textbox(label="Total Number of Machinery")
model_machinery_types = gr.Textbox(label="Number of Machinery Per Type")
# Activity Row with Timestamps
with gr.Row():
gr.Markdown("#### Activities with Timestamps")
model_activities = gr.Dataframe(
headers=["Time", "Activity Description"],
datatype=["str", "str"],
label="Detected Activities",
interactive=False,
wrap=True
)
# Activity timeline section
with gr.Row():
# Timeline View (default visible)
with gr.Column(visible=True) as timeline_view:
activity_gallery = gr.Gallery(label="Activity Timeline")
model_annotated_video = gr.Video(label="Full Video")
# Chat View (initially hidden)
with gr.Column(visible=False) as chat_view:
chunk_video = gr.Video(label="Chunk video")
chatbot = gr.Chatbot(height=400)
chat_input = gr.Textbox(
placeholder="Ask about this video segment...",
show_label=False
)
back_btn = gr.Button("← Back to Timeline")
# Update visibility based on media source
def update_visibility(source):
if source == "Local File":
return gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True)
media_source.change(fn=update_visibility, inputs=media_source, outputs=[local_file, azure_blob])
# Connect the submit button to the processing function
submit_btn.click(
fn=process_diary,
inputs=[day, date, total_people, total_machinery, machinery_types, activities, media_source, local_file, azure_blob],
outputs=[
model_day,
model_date,
model_people,
model_machinery,
model_machinery_types,
timeline_view,
model_annotated_video,
activity_data,
chatbot,
model_activities
]
)
# Process activity data into gallery
activity_data.change(
fn=create_activity_cards_ui,
inputs=[activity_data],
outputs=[activity_gallery, activity_indices]
)
# Handle gallery selection
activity_gallery.select(
fn=on_card_click,
inputs=[activity_indices, chatbot],
outputs=[timeline_view, chat_view, chatbot, chunk_video]
)
# Chat submission
chat_input.submit(
fn=chat_with_video,
inputs=[chat_input, chatbot],
outputs=[chatbot]
)
# Back button
back_btn.click(
fn=lambda: [gr.update(visible=True), gr.update(visible=False)],
inputs=None,
outputs=[timeline_view, chat_view]
)
# Add enhanced CSS styling
gr.HTML("""
<style>
/* Gallery customizations */
.gradio-container .gallery-item {
border: 1px solid #444444 !important;
border-radius: 8px !important;
padding: 8px !important;
margin: 10px !important;
cursor: pointer !important;
transition: all 0.3s !important;
background: #18181b !important;
box-shadow: 0 2px 5px rgba(0,0,0,0.2) !important;
}
.gradio-container .gallery-item:hover {
transform: translateY(-2px) !important;
box-shadow: 0 4px 12px rgba(0,0,0,0.25) !important;
border-color: #007bff !important;
background: #202025 !important;
}
.gradio-container .gallery-item.selected {
border: 2px solid #007bff !important;
background: #202030 !important;
}
/* Improved image display */
.gradio-container .gallery-item img {
height: 180px !important;
object-fit: cover !important;
border-radius: 4px !important;
border: 1px solid #444444 !important;
margin-bottom: 8px !important;
}
/* Caption styling */
.gradio-container .caption {
color: #e0e0e0 !important;
font-size: 0.9em !important;
margin-top: 8px !important;
line-height: 1.4 !important;
padding: 0 4px !important;
}
/* Gallery container */
.gradio-container [id*='gallery'] > div:first-child {
background-color: #27272a !important;
padding: 15px !important;
border-radius: 10px !important;
}
/* Chatbot styling */
.gradio-container .chatbot {
background-color: #27272a !important;
border-radius: 10px !important;
border: 1px solid #444444 !important;
}
.gradio-container .chatbot .message.user {
background-color: #18181b !important;
border-radius: 8px !important;
}
.gradio-container .chatbot .message.bot {
background-color: #202030 !important;
border-radius: 8px !important;
}
/* Button styling */
.gradio-container button.secondary {
background-color: #3d4452 !important;
color: white !important;
}
</style>
""")
if __name__ == "__main__":
demo.launch(allowed_paths=["./tmp"], share=True)