Spaces:
Runtime error
Runtime error
import spaces | |
import gradio as gr | |
from datetime import datetime | |
import tempfile | |
import os | |
import json | |
import torch | |
import gc | |
def debug(): | |
torch.randn(10).cuda() | |
debug() | |
from PIL import Image | |
from decord import VideoReader, cpu | |
from yolo_detection import ( | |
detect_people_and_machinery, | |
annotate_video_with_bboxes, | |
is_image, | |
is_video | |
) | |
from image_captioning import ( | |
analyze_image_activities, | |
analyze_video_activities, | |
process_video_chunk, | |
load_model_and_tokenizer, | |
MAX_NUM_FRAMES | |
) | |
# Global storage for activities and media paths | |
global_activities = [] | |
global_media_path = None | |
# Create tmp directory for storing frames | |
tmp_dir = os.path.join('.', 'tmp') | |
os.makedirs(tmp_dir, exist_ok=True) | |
def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media): | |
"""Process the site diary entry""" | |
global global_activities, global_media_path | |
if media is None: | |
return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", None, None, [], None] | |
try: | |
if not hasattr(media, 'name'): | |
raise ValueError("Invalid file upload") | |
file_ext = get_file_extension(media.name) | |
if not (is_image(media.name) or is_video(media.name)): | |
raise ValueError(f"Unsupported file type: {file_ext}") | |
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file: | |
temp_path = temp_file.name | |
if hasattr(media, 'name') and os.path.exists(media.name): | |
with open(media.name, 'rb') as f: | |
temp_file.write(f.read()) | |
else: | |
file_content = media.read() if hasattr(media, 'read') else media | |
temp_file.write(file_content if isinstance(file_content, bytes) else file_content.read()) | |
detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(temp_path) | |
print(f"Detected people: {detected_people}, machinery: {detected_machinery}, types: {detected_machinery_types}") | |
annotated_video_path = None | |
detected_activities = analyze_image_activities(temp_path) if is_image(media.name) else analyze_video_activities(temp_path) | |
print(f"Detected activities: {detected_activities}") | |
# Store activities and media path globally for chat mode | |
global_activities = detected_activities | |
global_media_path = temp_path | |
if is_video(media.name): | |
annotated_video_path = annotate_video_with_bboxes(temp_path) # Or use annotate_video_with_bboxes(temp_path) if implemented | |
detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()]) | |
# We'll return the activities as a list for the card display | |
# Clear the chat history when loading new media | |
chat_history = [] | |
# Extract data for the activity table | |
activity_rows = [] | |
for activity in detected_activities: | |
time = activity.get('time', 'Unknown') | |
summary = activity.get('summary', 'No description available') | |
activity_rows.append([time, summary]) | |
return [day, date, str(detected_people), str(detected_machinery), | |
detected_types_str, gr.update(visible=True), annotated_video_path, | |
detected_activities, chat_history, activity_rows] | |
except Exception as e: | |
print(f"Error processing media: {str(e)}") | |
return [day, date, "Error processing media", "Error processing media", | |
"Error processing media", None, None, [], None, []] | |
def get_file_extension(filename): | |
return os.path.splitext(filename)[1].lower() | |
def on_card_click(activity_indices, history, evt: gr.SelectData): | |
"""Handle clicking on an activity card in the gallery""" | |
global global_activities, global_media_path | |
# Get the index of the selected activity from the SelectData event | |
selected_idx = evt.index | |
# Map the gallery index to the actual activity index | |
if selected_idx < 0 or selected_idx >= len(activity_indices): | |
return [gr.update(visible=True), gr.update(visible=False), [], None] | |
card_idx = activity_indices[selected_idx] | |
print(f"Gallery item {selected_idx} clicked, corresponds to activity index: {card_idx}") | |
if card_idx < 0 or card_idx >= len(global_activities): | |
return [gr.update(visible=True), gr.update(visible=False), [], None] | |
selected_activity = global_activities[card_idx] | |
chunk_video_path = None | |
# Use the pre-saved chunk video if available | |
if 'chunk_path' in selected_activity and os.path.exists(selected_activity['chunk_path']): | |
chunk_video_path = selected_activity['chunk_path'] | |
print(f"Using pre-saved chunk video: {chunk_video_path}") | |
else: | |
# Fallback to full video if chunk not available | |
chunk_video_path = global_media_path | |
print(f"Chunk video not available, using full video: {chunk_video_path}") | |
# Add the selected activity to chat history | |
history = [] | |
history.append((None, f"🎬 Selected video at timestamp {selected_activity['time']}")) | |
# Add the thumbnail to the chat as a visual element | |
if 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']): | |
# Use the tuple format for images in chatbot | |
thumbnail_path = selected_activity['thumbnail'] | |
history.append((None, f"📷 Video frame at {selected_activity['time']}")) | |
history.append((None, thumbnail_path)) | |
# Format message about the detected activity | |
activity_info = f"I detected the following activity:\n\n{selected_activity['summary']}" | |
if selected_activity['objects']: | |
activity_info += f"\n\nIdentified objects: {', '.join(selected_activity['objects'])}" | |
history.append(("Tell me about this video segment", activity_info)) | |
return [gr.update(visible=False), gr.update(visible=True), history, chunk_video_path] | |
def chat_with_video(message, history): | |
"""Chat with the mPLUG model about the selected video segment""" | |
global global_activities, global_media_path | |
try: | |
# Get the selected activity from the history to identify which chunk we're discussing | |
selected_chunk_idx = None | |
selected_time = None | |
selected_activity = None | |
for entry in history: | |
if entry[0] is None and "Selected video at timestamp" in entry[1]: | |
time_str = entry[1].split("Selected video at timestamp ")[1] | |
selected_time = time_str.strip() | |
break | |
# Find the corresponding chunk | |
if selected_time: | |
for i, activity in enumerate(global_activities): | |
if activity.get('time') == selected_time: | |
selected_chunk_idx = activity.get('chunk_id') | |
selected_activity = activity | |
break | |
# If we found the chunk, use the model to analyze it | |
if selected_chunk_idx is not None and global_media_path and selected_activity: | |
# Load model | |
model, tokenizer, processor = load_model_and_tokenizer() | |
# Generate prompt based on user question and add context about what's in the video | |
context = f"This video shows construction site activities at timestamp {selected_time}." | |
if selected_activity.get('objects'): | |
context += f" The scene contains {', '.join(selected_activity.get('objects'))}." | |
prompt = f"{context} Analyze this segment of construction site video and answer this question: {message}" | |
# This would ideally use the specific chunk, but for simplicity we'll use the global path | |
# In a production system, you'd extract just that chunk of the video | |
vr = VideoReader(global_media_path, ctx=cpu(0)) | |
# Get the frames for this chunk | |
sample_fps = round(vr.get_avg_fps() / 1) | |
frame_idx = [i for i in range(0, len(vr), sample_fps)] | |
# Extract frames for the specific chunk | |
chunk_size = MAX_NUM_FRAMES # From the constants in image_captioning.py | |
start_idx = selected_chunk_idx * chunk_size | |
end_idx = min(start_idx + chunk_size, len(frame_idx)) | |
chunk_frames = frame_idx[start_idx:end_idx] | |
if chunk_frames: | |
frames = vr.get_batch(chunk_frames).asnumpy() | |
frames_pil = [Image.fromarray(v.astype('uint8')) for v in frames] | |
# Process frames with model | |
response = process_video_chunk(frames_pil, model, tokenizer, processor, prompt) | |
# If we couldn't save a frame, just return the text response | |
# Clean up | |
del model, tokenizer, processor | |
torch.cuda.empty_cache() | |
gc.collect() | |
return history + [(message, response)] | |
else: | |
return history + [(message, "Could not extract frames for this segment.")] | |
else: | |
# Fallback response if we can't identify the chunk | |
thumbnail = None | |
response_text = f"I'm analyzing your question about the video segment: {message}\n\nBased on what I can see in this segment, it appears to show construction activity with various machinery and workers on site. The specific details would depend on the exact timestamp you're referring to." | |
# Try to get a thumbnail from the selected activity if available | |
if selected_activity and 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']): | |
thumbnail = selected_activity['thumbnail'] | |
new_history = history + [(message, response_text)] | |
new_history.append((None, f"📷 Video frame at {selected_time}")) | |
new_history.append((None, thumbnail)) | |
return new_history | |
return history + [(message, response_text)] | |
except Exception as e: | |
print(f"Error in chat_with_video: {str(e)}") | |
return history + [(message, f"I encountered an error while processing your question. Let me try to answer based on what I can see: {message}\n\nThe video appears to show construction site activities, but I'm having trouble with the detailed analysis at the moment.")] | |
# Native Gradio activity cards | |
def create_activity_cards_ui(activities): | |
"""Create activity cards using native Gradio components""" | |
if not activities: | |
return gr.HTML("<div class='activity-timeline'><h3>No activities detected</h3></div>"), [] | |
# Prepare data for gallery | |
thumbnails = [] | |
captions = [] | |
activity_indices = [] | |
for i, activity in enumerate(activities): | |
thumbnail = activity.get('thumbnail', '') | |
time = activity.get('time', 'Unknown') | |
summary = activity.get('summary', 'No description available') | |
objects_list = activity.get('objects', []) | |
objects_text = f"Objects: {', '.join(objects_list)}" if objects_list else "" | |
# Truncate summary if too long | |
if len(summary) > 150: | |
summary = summary[:147] + "..." | |
thumbnails.append(thumbnail) | |
captions.append(f"Timestamp: {time} | {summary}") | |
activity_indices.append(i) | |
# Create a gallery for the thumbnails | |
gallery = gr.Gallery( | |
value=[(path, caption) for path, caption in zip(thumbnails, captions)], | |
columns=5, | |
rows=None, | |
height="auto", | |
object_fit="contain", | |
label="Activity Timeline" | |
) | |
return gallery, activity_indices | |
# Create the Gradio interface | |
with gr.Blocks(title="Digital Site Diary", css="") as demo: | |
gr.Markdown("# 📝 Digital Site Diary") | |
# Activity data and indices storage | |
activity_data = gr.State([]) | |
activity_indices = gr.State([]) | |
# Create tabs for different views | |
with gr.Tabs() as tabs: | |
with gr.Tab("Site Diary"): | |
with gr.Row(): | |
# User Input Column | |
with gr.Column(): | |
gr.Markdown("### User Input") | |
day = gr.Textbox(label="Day",value='9') | |
date = gr.Textbox(label="Date", placeholder="YYYY-MM-DD", value=datetime.now().strftime("%Y-%m-%d")) | |
total_people = gr.Number(label="Total Number of People", precision=0, value=10) | |
total_machinery = gr.Number(label="Total Number of Machinery", precision=0, value=3) | |
machinery_types = gr.Textbox( | |
label="Number of Machinery Per Type", | |
placeholder="e.g., Excavator: 2, Roller: 1", | |
value="Excavator: 2, Roller: 1" | |
) | |
activities = gr.Textbox( | |
label="Activity", | |
placeholder="e.g., 9 AM: Excavation, 10 AM: Concreting", | |
value="9 AM: Excavation, 10 AM: Concreting", | |
lines=3 | |
) | |
media = gr.File(label="Upload Image/Video", file_types=["image", "video"]) | |
submit_btn = gr.Button("Submit", variant="primary") | |
# Model Detection Column | |
with gr.Column(): | |
gr.Markdown("### Model Detection") | |
model_day = gr.Textbox(label="Day") | |
model_date = gr.Textbox(label="Date") | |
model_people = gr.Textbox(label="Total Number of People") | |
model_machinery = gr.Textbox(label="Total Number of Machinery") | |
model_machinery_types = gr.Textbox(label="Number of Machinery Per Type") | |
# Activity Row with Timestamps | |
with gr.Row(): | |
gr.Markdown("#### Activities with Timestamps") | |
model_activities = gr.Dataframe( | |
headers=["Time", "Activity Description"], | |
datatype=["str", "str"], | |
label="Detected Activities", | |
interactive=False, | |
wrap=True | |
) | |
# Activity timeline section | |
with gr.Row(): | |
# Timeline View (default visible) | |
with gr.Column(visible=True) as timeline_view: | |
activity_gallery = gr.Gallery(label="Activity Timeline") | |
model_annotated_video = gr.Video(label="Full Video") | |
# Chat View (initially hidden) | |
with gr.Column(visible=False) as chat_view: | |
chunk_video = gr.Video(label="Chunk video") | |
chatbot = gr.Chatbot(height=400) | |
chat_input = gr.Textbox( | |
placeholder="Ask about this video segment...", | |
show_label=False | |
) | |
back_btn = gr.Button("← Back to Timeline") | |
# Connect the submit button to the processing function | |
submit_btn.click( | |
fn=process_diary, | |
inputs=[day, date, total_people, total_machinery, machinery_types, activities, media], | |
outputs=[ | |
model_day, | |
model_date, | |
model_people, | |
model_machinery, | |
model_machinery_types, | |
timeline_view, | |
model_annotated_video, | |
activity_data, | |
chatbot, | |
model_activities | |
] | |
) | |
# Process activity data into gallery | |
activity_data.change( | |
fn=create_activity_cards_ui, | |
inputs=[activity_data], | |
outputs=[activity_gallery, activity_indices] | |
) | |
# Handle gallery selection | |
activity_gallery.select( | |
fn=on_card_click, | |
inputs=[activity_indices, chatbot], | |
outputs=[timeline_view, chat_view, chatbot, chunk_video] | |
) | |
# Chat submission | |
chat_input.submit( | |
fn=chat_with_video, | |
inputs=[chat_input, chatbot], | |
outputs=[chatbot] | |
) | |
# Back button | |
back_btn.click( | |
fn=lambda: [gr.update(visible=True), gr.update(visible=False)], | |
inputs=None, | |
outputs=[timeline_view, chat_view] | |
) | |
# Add enhanced CSS styling | |
gr.HTML(""" | |
<style> | |
/* Gallery customizations */ | |
.gradio-container .gallery-item { | |
border: 1px solid #444444 !important; | |
border-radius: 8px !important; | |
padding: 8px !important; | |
margin: 10px !important; | |
cursor: pointer !important; | |
transition: all 0.3s !important; | |
background: #18181b !important; | |
box-shadow: 0 2px 5px rgba(0,0,0,0.2) !important; | |
} | |
.gradio-container .gallery-item:hover { | |
transform: translateY(-2px) !important; | |
box-shadow: 0 4px 12px rgba(0,0,0,0.25) !important; | |
border-color: #007bff !important; | |
background: #202025 !important; | |
} | |
.gradio-container .gallery-item.selected { | |
border: 2px solid #007bff !important; | |
background: #202030 !important; | |
} | |
/* Improved image display */ | |
.gradio-container .gallery-item img { | |
height: 180px !important; | |
object-fit: cover !important; | |
border-radius: 4px !important; | |
border: 1px solid #444444 !important; | |
margin-bottom: 8px !important; | |
} | |
/* Caption styling */ | |
.gradio-container .caption { | |
color: #e0e0e0 !important; | |
font-size: 0.9em !important; | |
margin-top: 8px !important; | |
line-height: 1.4 !important; | |
padding: 0 4px !important; | |
} | |
/* Gallery container */ | |
.gradio-container [id*='gallery'] > div:first-child { | |
background-color: #27272a !important; | |
padding: 15px !important; | |
border-radius: 10px !important; | |
} | |
/* Chatbot styling */ | |
.gradio-container .chatbot { | |
background-color: #27272a !important; | |
border-radius: 10px !important; | |
border: 1px solid #444444 !important; | |
} | |
.gradio-container .chatbot .message.user { | |
background-color: #18181b !important; | |
border-radius: 8px !important; | |
} | |
.gradio-container .chatbot .message.bot { | |
background-color: #202030 !important; | |
border-radius: 8px !important; | |
} | |
/* Button styling */ | |
.gradio-container button.secondary { | |
background-color: #3d4452 !important; | |
color: white !important; | |
} | |
</style> | |
""") | |
if __name__ == "__main__": | |
demo.launch(allowed_paths=["./tmp"]) | |