assentian1970's picture
Update app.py
e5cd3ca verified
import spaces
import gradio as gr
from datetime import datetime
import tempfile
import os
import json
import torch
import gc
def debug():
torch.randn(10).cuda()
debug()
from PIL import Image
from decord import VideoReader, cpu
from yolo_detection import (
detect_people_and_machinery,
annotate_video_with_bboxes,
is_image,
is_video
)
from image_captioning import (
analyze_image_activities,
analyze_video_activities,
process_video_chunk,
load_model_and_tokenizer,
MAX_NUM_FRAMES
)
# Global storage for activities and media paths
global_activities = []
global_media_path = None
# Create tmp directory for storing frames
tmp_dir = os.path.join('.', 'tmp')
os.makedirs(tmp_dir, exist_ok=True)
@spaces.GPU
def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media):
"""Process the site diary entry"""
global global_activities, global_media_path
if media is None:
return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", None, None, [], None]
try:
if not hasattr(media, 'name'):
raise ValueError("Invalid file upload")
file_ext = get_file_extension(media.name)
if not (is_image(media.name) or is_video(media.name)):
raise ValueError(f"Unsupported file type: {file_ext}")
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file:
temp_path = temp_file.name
if hasattr(media, 'name') and os.path.exists(media.name):
with open(media.name, 'rb') as f:
temp_file.write(f.read())
else:
file_content = media.read() if hasattr(media, 'read') else media
temp_file.write(file_content if isinstance(file_content, bytes) else file_content.read())
detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(temp_path)
print(f"Detected people: {detected_people}, machinery: {detected_machinery}, types: {detected_machinery_types}")
annotated_video_path = None
detected_activities = analyze_image_activities(temp_path) if is_image(media.name) else analyze_video_activities(temp_path)
print(f"Detected activities: {detected_activities}")
# Store activities and media path globally for chat mode
global_activities = detected_activities
global_media_path = temp_path
if is_video(media.name):
annotated_video_path = annotate_video_with_bboxes(temp_path) # Or use annotate_video_with_bboxes(temp_path) if implemented
detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()])
# We'll return the activities as a list for the card display
# Clear the chat history when loading new media
chat_history = []
# Extract data for the activity table
activity_rows = []
for activity in detected_activities:
time = activity.get('time', 'Unknown')
summary = activity.get('summary', 'No description available')
activity_rows.append([time, summary])
return [day, date, str(detected_people), str(detected_machinery),
detected_types_str, gr.update(visible=True), annotated_video_path,
detected_activities, chat_history, activity_rows]
except Exception as e:
print(f"Error processing media: {str(e)}")
return [day, date, "Error processing media", "Error processing media",
"Error processing media", None, None, [], None, []]
def get_file_extension(filename):
return os.path.splitext(filename)[1].lower()
def on_card_click(activity_indices, history, evt: gr.SelectData):
"""Handle clicking on an activity card in the gallery"""
global global_activities, global_media_path
# Get the index of the selected activity from the SelectData event
selected_idx = evt.index
# Map the gallery index to the actual activity index
if selected_idx < 0 or selected_idx >= len(activity_indices):
return [gr.update(visible=True), gr.update(visible=False), [], None]
card_idx = activity_indices[selected_idx]
print(f"Gallery item {selected_idx} clicked, corresponds to activity index: {card_idx}")
if card_idx < 0 or card_idx >= len(global_activities):
return [gr.update(visible=True), gr.update(visible=False), [], None]
selected_activity = global_activities[card_idx]
chunk_video_path = None
# Use the pre-saved chunk video if available
if 'chunk_path' in selected_activity and os.path.exists(selected_activity['chunk_path']):
chunk_video_path = selected_activity['chunk_path']
print(f"Using pre-saved chunk video: {chunk_video_path}")
else:
# Fallback to full video if chunk not available
chunk_video_path = global_media_path
print(f"Chunk video not available, using full video: {chunk_video_path}")
# Add the selected activity to chat history
history = []
history.append((None, f"🎬 Selected video at timestamp {selected_activity['time']}"))
# Add the thumbnail to the chat as a visual element
if 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']):
# Use the tuple format for images in chatbot
thumbnail_path = selected_activity['thumbnail']
history.append((None, f"📷 Video frame at {selected_activity['time']}"))
history.append((None, thumbnail_path))
# Format message about the detected activity
activity_info = f"I detected the following activity:\n\n{selected_activity['summary']}"
if selected_activity['objects']:
activity_info += f"\n\nIdentified objects: {', '.join(selected_activity['objects'])}"
history.append(("Tell me about this video segment", activity_info))
return [gr.update(visible=False), gr.update(visible=True), history, chunk_video_path]
def chat_with_video(message, history):
"""Chat with the mPLUG model about the selected video segment"""
global global_activities, global_media_path
try:
# Get the selected activity from the history to identify which chunk we're discussing
selected_chunk_idx = None
selected_time = None
selected_activity = None
for entry in history:
if entry[0] is None and "Selected video at timestamp" in entry[1]:
time_str = entry[1].split("Selected video at timestamp ")[1]
selected_time = time_str.strip()
break
# Find the corresponding chunk
if selected_time:
for i, activity in enumerate(global_activities):
if activity.get('time') == selected_time:
selected_chunk_idx = activity.get('chunk_id')
selected_activity = activity
break
# If we found the chunk, use the model to analyze it
if selected_chunk_idx is not None and global_media_path and selected_activity:
# Load model
model, tokenizer, processor = load_model_and_tokenizer()
# Generate prompt based on user question and add context about what's in the video
context = f"This video shows construction site activities at timestamp {selected_time}."
if selected_activity.get('objects'):
context += f" The scene contains {', '.join(selected_activity.get('objects'))}."
prompt = f"{context} Analyze this segment of construction site video and answer this question: {message}"
# This would ideally use the specific chunk, but for simplicity we'll use the global path
# In a production system, you'd extract just that chunk of the video
vr = VideoReader(global_media_path, ctx=cpu(0))
# Get the frames for this chunk
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = [i for i in range(0, len(vr), sample_fps)]
# Extract frames for the specific chunk
chunk_size = MAX_NUM_FRAMES # From the constants in image_captioning.py
start_idx = selected_chunk_idx * chunk_size
end_idx = min(start_idx + chunk_size, len(frame_idx))
chunk_frames = frame_idx[start_idx:end_idx]
if chunk_frames:
frames = vr.get_batch(chunk_frames).asnumpy()
frames_pil = [Image.fromarray(v.astype('uint8')) for v in frames]
# Process frames with model
response = process_video_chunk(frames_pil, model, tokenizer, processor, prompt)
# If we couldn't save a frame, just return the text response
# Clean up
del model, tokenizer, processor
torch.cuda.empty_cache()
gc.collect()
return history + [(message, response)]
else:
return history + [(message, "Could not extract frames for this segment.")]
else:
# Fallback response if we can't identify the chunk
thumbnail = None
response_text = f"I'm analyzing your question about the video segment: {message}\n\nBased on what I can see in this segment, it appears to show construction activity with various machinery and workers on site. The specific details would depend on the exact timestamp you're referring to."
# Try to get a thumbnail from the selected activity if available
if selected_activity and 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']):
thumbnail = selected_activity['thumbnail']
new_history = history + [(message, response_text)]
new_history.append((None, f"📷 Video frame at {selected_time}"))
new_history.append((None, thumbnail))
return new_history
return history + [(message, response_text)]
except Exception as e:
print(f"Error in chat_with_video: {str(e)}")
return history + [(message, f"I encountered an error while processing your question. Let me try to answer based on what I can see: {message}\n\nThe video appears to show construction site activities, but I'm having trouble with the detailed analysis at the moment.")]
# Native Gradio activity cards
def create_activity_cards_ui(activities):
"""Create activity cards using native Gradio components"""
if not activities:
return gr.HTML("<div class='activity-timeline'><h3>No activities detected</h3></div>"), []
# Prepare data for gallery
thumbnails = []
captions = []
activity_indices = []
for i, activity in enumerate(activities):
thumbnail = activity.get('thumbnail', '')
time = activity.get('time', 'Unknown')
summary = activity.get('summary', 'No description available')
objects_list = activity.get('objects', [])
objects_text = f"Objects: {', '.join(objects_list)}" if objects_list else ""
# Truncate summary if too long
if len(summary) > 150:
summary = summary[:147] + "..."
thumbnails.append(thumbnail)
captions.append(f"Timestamp: {time} | {summary}")
activity_indices.append(i)
# Create a gallery for the thumbnails
gallery = gr.Gallery(
value=[(path, caption) for path, caption in zip(thumbnails, captions)],
columns=5,
rows=None,
height="auto",
object_fit="contain",
label="Activity Timeline"
)
return gallery, activity_indices
# Create the Gradio interface
with gr.Blocks(title="Digital Site Diary", css="") as demo:
gr.Markdown("# 📝 Digital Site Diary")
# Activity data and indices storage
activity_data = gr.State([])
activity_indices = gr.State([])
# Create tabs for different views
with gr.Tabs() as tabs:
with gr.Tab("Site Diary"):
with gr.Row():
# User Input Column
with gr.Column():
gr.Markdown("### User Input")
day = gr.Textbox(label="Day",value='9')
date = gr.Textbox(label="Date", placeholder="YYYY-MM-DD", value=datetime.now().strftime("%Y-%m-%d"))
total_people = gr.Number(label="Total Number of People", precision=0, value=10)
total_machinery = gr.Number(label="Total Number of Machinery", precision=0, value=3)
machinery_types = gr.Textbox(
label="Number of Machinery Per Type",
placeholder="e.g., Excavator: 2, Roller: 1",
value="Excavator: 2, Roller: 1"
)
activities = gr.Textbox(
label="Activity",
placeholder="e.g., 9 AM: Excavation, 10 AM: Concreting",
value="9 AM: Excavation, 10 AM: Concreting",
lines=3
)
media = gr.File(label="Upload Image/Video", file_types=["image", "video"])
submit_btn = gr.Button("Submit", variant="primary")
# Model Detection Column
with gr.Column():
gr.Markdown("### Model Detection")
model_day = gr.Textbox(label="Day")
model_date = gr.Textbox(label="Date")
model_people = gr.Textbox(label="Total Number of People")
model_machinery = gr.Textbox(label="Total Number of Machinery")
model_machinery_types = gr.Textbox(label="Number of Machinery Per Type")
# Activity Row with Timestamps
with gr.Row():
gr.Markdown("#### Activities with Timestamps")
model_activities = gr.Dataframe(
headers=["Time", "Activity Description"],
datatype=["str", "str"],
label="Detected Activities",
interactive=False,
wrap=True
)
# Activity timeline section
with gr.Row():
# Timeline View (default visible)
with gr.Column(visible=True) as timeline_view:
activity_gallery = gr.Gallery(label="Activity Timeline")
model_annotated_video = gr.Video(label="Full Video")
# Chat View (initially hidden)
with gr.Column(visible=False) as chat_view:
chunk_video = gr.Video(label="Chunk video")
chatbot = gr.Chatbot(height=400)
chat_input = gr.Textbox(
placeholder="Ask about this video segment...",
show_label=False
)
back_btn = gr.Button("← Back to Timeline")
# Connect the submit button to the processing function
submit_btn.click(
fn=process_diary,
inputs=[day, date, total_people, total_machinery, machinery_types, activities, media],
outputs=[
model_day,
model_date,
model_people,
model_machinery,
model_machinery_types,
timeline_view,
model_annotated_video,
activity_data,
chatbot,
model_activities
]
)
# Process activity data into gallery
activity_data.change(
fn=create_activity_cards_ui,
inputs=[activity_data],
outputs=[activity_gallery, activity_indices]
)
# Handle gallery selection
activity_gallery.select(
fn=on_card_click,
inputs=[activity_indices, chatbot],
outputs=[timeline_view, chat_view, chatbot, chunk_video]
)
# Chat submission
chat_input.submit(
fn=chat_with_video,
inputs=[chat_input, chatbot],
outputs=[chatbot]
)
# Back button
back_btn.click(
fn=lambda: [gr.update(visible=True), gr.update(visible=False)],
inputs=None,
outputs=[timeline_view, chat_view]
)
# Add enhanced CSS styling
gr.HTML("""
<style>
/* Gallery customizations */
.gradio-container .gallery-item {
border: 1px solid #444444 !important;
border-radius: 8px !important;
padding: 8px !important;
margin: 10px !important;
cursor: pointer !important;
transition: all 0.3s !important;
background: #18181b !important;
box-shadow: 0 2px 5px rgba(0,0,0,0.2) !important;
}
.gradio-container .gallery-item:hover {
transform: translateY(-2px) !important;
box-shadow: 0 4px 12px rgba(0,0,0,0.25) !important;
border-color: #007bff !important;
background: #202025 !important;
}
.gradio-container .gallery-item.selected {
border: 2px solid #007bff !important;
background: #202030 !important;
}
/* Improved image display */
.gradio-container .gallery-item img {
height: 180px !important;
object-fit: cover !important;
border-radius: 4px !important;
border: 1px solid #444444 !important;
margin-bottom: 8px !important;
}
/* Caption styling */
.gradio-container .caption {
color: #e0e0e0 !important;
font-size: 0.9em !important;
margin-top: 8px !important;
line-height: 1.4 !important;
padding: 0 4px !important;
}
/* Gallery container */
.gradio-container [id*='gallery'] > div:first-child {
background-color: #27272a !important;
padding: 15px !important;
border-radius: 10px !important;
}
/* Chatbot styling */
.gradio-container .chatbot {
background-color: #27272a !important;
border-radius: 10px !important;
border: 1px solid #444444 !important;
}
.gradio-container .chatbot .message.user {
background-color: #18181b !important;
border-radius: 8px !important;
}
.gradio-container .chatbot .message.bot {
background-color: #202030 !important;
border-radius: 8px !important;
}
/* Button styling */
.gradio-container button.secondary {
background-color: #3d4452 !important;
color: white !important;
}
</style>
""")
if __name__ == "__main__":
demo.launch(allowed_paths=["./tmp"])