assentian1970's picture
Update app.py
e5cd3ca verified
raw
history blame
19.3 kB
import spaces
import gradio as gr
from datetime import datetime
import tempfile
import os
import json
import torch
import gc
def debug():
torch.randn(10).cuda()
debug()
from PIL import Image
from decord import VideoReader, cpu
from yolo_detection import (
detect_people_and_machinery,
annotate_video_with_bboxes,
is_image,
is_video
)
from image_captioning import (
analyze_image_activities,
analyze_video_activities,
process_video_chunk,
load_model_and_tokenizer,
MAX_NUM_FRAMES
)
# Global storage for activities and media paths
global_activities = []
global_media_path = None
# Create tmp directory for storing frames
tmp_dir = os.path.join('.', 'tmp')
os.makedirs(tmp_dir, exist_ok=True)
@spaces.GPU
def process_diary(day, date, total_people, total_machinery, machinery_types, activities, media):
"""Process the site diary entry"""
global global_activities, global_media_path
if media is None:
return [day, date, "No media uploaded", "No media uploaded", "No media uploaded", None, None, [], None]
try:
if not hasattr(media, 'name'):
raise ValueError("Invalid file upload")
file_ext = get_file_extension(media.name)
if not (is_image(media.name) or is_video(media.name)):
raise ValueError(f"Unsupported file type: {file_ext}")
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as temp_file:
temp_path = temp_file.name
if hasattr(media, 'name') and os.path.exists(media.name):
with open(media.name, 'rb') as f:
temp_file.write(f.read())
else:
file_content = media.read() if hasattr(media, 'read') else media
temp_file.write(file_content if isinstance(file_content, bytes) else file_content.read())
detected_people, detected_machinery, detected_machinery_types = detect_people_and_machinery(temp_path)
print(f"Detected people: {detected_people}, machinery: {detected_machinery}, types: {detected_machinery_types}")
annotated_video_path = None
detected_activities = analyze_image_activities(temp_path) if is_image(media.name) else analyze_video_activities(temp_path)
print(f"Detected activities: {detected_activities}")
# Store activities and media path globally for chat mode
global_activities = detected_activities
global_media_path = temp_path
if is_video(media.name):
annotated_video_path = annotate_video_with_bboxes(temp_path) # Or use annotate_video_with_bboxes(temp_path) if implemented
detected_types_str = ", ".join([f"{k}: {v}" for k, v in detected_machinery_types.items()])
# We'll return the activities as a list for the card display
# Clear the chat history when loading new media
chat_history = []
# Extract data for the activity table
activity_rows = []
for activity in detected_activities:
time = activity.get('time', 'Unknown')
summary = activity.get('summary', 'No description available')
activity_rows.append([time, summary])
return [day, date, str(detected_people), str(detected_machinery),
detected_types_str, gr.update(visible=True), annotated_video_path,
detected_activities, chat_history, activity_rows]
except Exception as e:
print(f"Error processing media: {str(e)}")
return [day, date, "Error processing media", "Error processing media",
"Error processing media", None, None, [], None, []]
def get_file_extension(filename):
return os.path.splitext(filename)[1].lower()
def on_card_click(activity_indices, history, evt: gr.SelectData):
"""Handle clicking on an activity card in the gallery"""
global global_activities, global_media_path
# Get the index of the selected activity from the SelectData event
selected_idx = evt.index
# Map the gallery index to the actual activity index
if selected_idx < 0 or selected_idx >= len(activity_indices):
return [gr.update(visible=True), gr.update(visible=False), [], None]
card_idx = activity_indices[selected_idx]
print(f"Gallery item {selected_idx} clicked, corresponds to activity index: {card_idx}")
if card_idx < 0 or card_idx >= len(global_activities):
return [gr.update(visible=True), gr.update(visible=False), [], None]
selected_activity = global_activities[card_idx]
chunk_video_path = None
# Use the pre-saved chunk video if available
if 'chunk_path' in selected_activity and os.path.exists(selected_activity['chunk_path']):
chunk_video_path = selected_activity['chunk_path']
print(f"Using pre-saved chunk video: {chunk_video_path}")
else:
# Fallback to full video if chunk not available
chunk_video_path = global_media_path
print(f"Chunk video not available, using full video: {chunk_video_path}")
# Add the selected activity to chat history
history = []
history.append((None, f"🎬 Selected video at timestamp {selected_activity['time']}"))
# Add the thumbnail to the chat as a visual element
if 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']):
# Use the tuple format for images in chatbot
thumbnail_path = selected_activity['thumbnail']
history.append((None, f"📷 Video frame at {selected_activity['time']}"))
history.append((None, thumbnail_path))
# Format message about the detected activity
activity_info = f"I detected the following activity:\n\n{selected_activity['summary']}"
if selected_activity['objects']:
activity_info += f"\n\nIdentified objects: {', '.join(selected_activity['objects'])}"
history.append(("Tell me about this video segment", activity_info))
return [gr.update(visible=False), gr.update(visible=True), history, chunk_video_path]
def chat_with_video(message, history):
"""Chat with the mPLUG model about the selected video segment"""
global global_activities, global_media_path
try:
# Get the selected activity from the history to identify which chunk we're discussing
selected_chunk_idx = None
selected_time = None
selected_activity = None
for entry in history:
if entry[0] is None and "Selected video at timestamp" in entry[1]:
time_str = entry[1].split("Selected video at timestamp ")[1]
selected_time = time_str.strip()
break
# Find the corresponding chunk
if selected_time:
for i, activity in enumerate(global_activities):
if activity.get('time') == selected_time:
selected_chunk_idx = activity.get('chunk_id')
selected_activity = activity
break
# If we found the chunk, use the model to analyze it
if selected_chunk_idx is not None and global_media_path and selected_activity:
# Load model
model, tokenizer, processor = load_model_and_tokenizer()
# Generate prompt based on user question and add context about what's in the video
context = f"This video shows construction site activities at timestamp {selected_time}."
if selected_activity.get('objects'):
context += f" The scene contains {', '.join(selected_activity.get('objects'))}."
prompt = f"{context} Analyze this segment of construction site video and answer this question: {message}"
# This would ideally use the specific chunk, but for simplicity we'll use the global path
# In a production system, you'd extract just that chunk of the video
vr = VideoReader(global_media_path, ctx=cpu(0))
# Get the frames for this chunk
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = [i for i in range(0, len(vr), sample_fps)]
# Extract frames for the specific chunk
chunk_size = MAX_NUM_FRAMES # From the constants in image_captioning.py
start_idx = selected_chunk_idx * chunk_size
end_idx = min(start_idx + chunk_size, len(frame_idx))
chunk_frames = frame_idx[start_idx:end_idx]
if chunk_frames:
frames = vr.get_batch(chunk_frames).asnumpy()
frames_pil = [Image.fromarray(v.astype('uint8')) for v in frames]
# Process frames with model
response = process_video_chunk(frames_pil, model, tokenizer, processor, prompt)
# If we couldn't save a frame, just return the text response
# Clean up
del model, tokenizer, processor
torch.cuda.empty_cache()
gc.collect()
return history + [(message, response)]
else:
return history + [(message, "Could not extract frames for this segment.")]
else:
# Fallback response if we can't identify the chunk
thumbnail = None
response_text = f"I'm analyzing your question about the video segment: {message}\n\nBased on what I can see in this segment, it appears to show construction activity with various machinery and workers on site. The specific details would depend on the exact timestamp you're referring to."
# Try to get a thumbnail from the selected activity if available
if selected_activity and 'thumbnail' in selected_activity and os.path.exists(selected_activity['thumbnail']):
thumbnail = selected_activity['thumbnail']
new_history = history + [(message, response_text)]
new_history.append((None, f"📷 Video frame at {selected_time}"))
new_history.append((None, thumbnail))
return new_history
return history + [(message, response_text)]
except Exception as e:
print(f"Error in chat_with_video: {str(e)}")
return history + [(message, f"I encountered an error while processing your question. Let me try to answer based on what I can see: {message}\n\nThe video appears to show construction site activities, but I'm having trouble with the detailed analysis at the moment.")]
# Native Gradio activity cards
def create_activity_cards_ui(activities):
"""Create activity cards using native Gradio components"""
if not activities:
return gr.HTML("<div class='activity-timeline'><h3>No activities detected</h3></div>"), []
# Prepare data for gallery
thumbnails = []
captions = []
activity_indices = []
for i, activity in enumerate(activities):
thumbnail = activity.get('thumbnail', '')
time = activity.get('time', 'Unknown')
summary = activity.get('summary', 'No description available')
objects_list = activity.get('objects', [])
objects_text = f"Objects: {', '.join(objects_list)}" if objects_list else ""
# Truncate summary if too long
if len(summary) > 150:
summary = summary[:147] + "..."
thumbnails.append(thumbnail)
captions.append(f"Timestamp: {time} | {summary}")
activity_indices.append(i)
# Create a gallery for the thumbnails
gallery = gr.Gallery(
value=[(path, caption) for path, caption in zip(thumbnails, captions)],
columns=5,
rows=None,
height="auto",
object_fit="contain",
label="Activity Timeline"
)
return gallery, activity_indices
# Create the Gradio interface
with gr.Blocks(title="Digital Site Diary", css="") as demo:
gr.Markdown("# 📝 Digital Site Diary")
# Activity data and indices storage
activity_data = gr.State([])
activity_indices = gr.State([])
# Create tabs for different views
with gr.Tabs() as tabs:
with gr.Tab("Site Diary"):
with gr.Row():
# User Input Column
with gr.Column():
gr.Markdown("### User Input")
day = gr.Textbox(label="Day",value='9')
date = gr.Textbox(label="Date", placeholder="YYYY-MM-DD", value=datetime.now().strftime("%Y-%m-%d"))
total_people = gr.Number(label="Total Number of People", precision=0, value=10)
total_machinery = gr.Number(label="Total Number of Machinery", precision=0, value=3)
machinery_types = gr.Textbox(
label="Number of Machinery Per Type",
placeholder="e.g., Excavator: 2, Roller: 1",
value="Excavator: 2, Roller: 1"
)
activities = gr.Textbox(
label="Activity",
placeholder="e.g., 9 AM: Excavation, 10 AM: Concreting",
value="9 AM: Excavation, 10 AM: Concreting",
lines=3
)
media = gr.File(label="Upload Image/Video", file_types=["image", "video"])
submit_btn = gr.Button("Submit", variant="primary")
# Model Detection Column
with gr.Column():
gr.Markdown("### Model Detection")
model_day = gr.Textbox(label="Day")
model_date = gr.Textbox(label="Date")
model_people = gr.Textbox(label="Total Number of People")
model_machinery = gr.Textbox(label="Total Number of Machinery")
model_machinery_types = gr.Textbox(label="Number of Machinery Per Type")
# Activity Row with Timestamps
with gr.Row():
gr.Markdown("#### Activities with Timestamps")
model_activities = gr.Dataframe(
headers=["Time", "Activity Description"],
datatype=["str", "str"],
label="Detected Activities",
interactive=False,
wrap=True
)
# Activity timeline section
with gr.Row():
# Timeline View (default visible)
with gr.Column(visible=True) as timeline_view:
activity_gallery = gr.Gallery(label="Activity Timeline")
model_annotated_video = gr.Video(label="Full Video")
# Chat View (initially hidden)
with gr.Column(visible=False) as chat_view:
chunk_video = gr.Video(label="Chunk video")
chatbot = gr.Chatbot(height=400)
chat_input = gr.Textbox(
placeholder="Ask about this video segment...",
show_label=False
)
back_btn = gr.Button("← Back to Timeline")
# Connect the submit button to the processing function
submit_btn.click(
fn=process_diary,
inputs=[day, date, total_people, total_machinery, machinery_types, activities, media],
outputs=[
model_day,
model_date,
model_people,
model_machinery,
model_machinery_types,
timeline_view,
model_annotated_video,
activity_data,
chatbot,
model_activities
]
)
# Process activity data into gallery
activity_data.change(
fn=create_activity_cards_ui,
inputs=[activity_data],
outputs=[activity_gallery, activity_indices]
)
# Handle gallery selection
activity_gallery.select(
fn=on_card_click,
inputs=[activity_indices, chatbot],
outputs=[timeline_view, chat_view, chatbot, chunk_video]
)
# Chat submission
chat_input.submit(
fn=chat_with_video,
inputs=[chat_input, chatbot],
outputs=[chatbot]
)
# Back button
back_btn.click(
fn=lambda: [gr.update(visible=True), gr.update(visible=False)],
inputs=None,
outputs=[timeline_view, chat_view]
)
# Add enhanced CSS styling
gr.HTML("""
<style>
/* Gallery customizations */
.gradio-container .gallery-item {
border: 1px solid #444444 !important;
border-radius: 8px !important;
padding: 8px !important;
margin: 10px !important;
cursor: pointer !important;
transition: all 0.3s !important;
background: #18181b !important;
box-shadow: 0 2px 5px rgba(0,0,0,0.2) !important;
}
.gradio-container .gallery-item:hover {
transform: translateY(-2px) !important;
box-shadow: 0 4px 12px rgba(0,0,0,0.25) !important;
border-color: #007bff !important;
background: #202025 !important;
}
.gradio-container .gallery-item.selected {
border: 2px solid #007bff !important;
background: #202030 !important;
}
/* Improved image display */
.gradio-container .gallery-item img {
height: 180px !important;
object-fit: cover !important;
border-radius: 4px !important;
border: 1px solid #444444 !important;
margin-bottom: 8px !important;
}
/* Caption styling */
.gradio-container .caption {
color: #e0e0e0 !important;
font-size: 0.9em !important;
margin-top: 8px !important;
line-height: 1.4 !important;
padding: 0 4px !important;
}
/* Gallery container */
.gradio-container [id*='gallery'] > div:first-child {
background-color: #27272a !important;
padding: 15px !important;
border-radius: 10px !important;
}
/* Chatbot styling */
.gradio-container .chatbot {
background-color: #27272a !important;
border-radius: 10px !important;
border: 1px solid #444444 !important;
}
.gradio-container .chatbot .message.user {
background-color: #18181b !important;
border-radius: 8px !important;
}
.gradio-container .chatbot .message.bot {
background-color: #202030 !important;
border-radius: 8px !important;
}
/* Button styling */
.gradio-container button.secondary {
background-color: #3d4452 !important;
color: white !important;
}
</style>
""")
if __name__ == "__main__":
demo.launch(allowed_paths=["./tmp"])