mlbench123's picture
Update app.py
4d60115 verified
import gradio as gr
import numpy as np
from PIL import Image
import cv2
import os
import uuid
from extract_frames import video_to_keyframes
from apply_mask import apply_mask_and_crop
from run_gmm import run_gmm_inference
from compose_video import compose_final_video
# Ensure folders exist
for path in [
"video_outputs/extracted_frames",
"video_outputs/masked_frames",
"video_outputs/output_heatmap",
"video_inputs",
"assets"
]:
os.makedirs(path, exist_ok=True)
# Get first frame for preview
def get_first_frame(video_path):
cap = cv2.VideoCapture(video_path)
success, frame = cap.read()
cap.release()
if success:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return Image.fromarray(frame)
return None
def process_video(video_file, progress=gr.Progress()):
base_dir = "video_outputs"
extracted_dir = os.path.join(base_dir, "extracted_frames")
masked_dir = os.path.join(base_dir, "masked_frames")
heatmap_dir = os.path.join(base_dir, "output_heatmap")
# Clear old frames
for folder in [extracted_dir, masked_dir, heatmap_dir]:
for f in os.listdir(folder):
os.remove(os.path.join(folder, f))
# Load default mask
mask_path = "default_mask.png"
if not os.path.exists(mask_path):
raise gr.Error("❌ Default mask not found at 'assets/default_mask.png'")
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise gr.Error("❌ Failed to load default mask.")
progress(0, desc="Extracting keyframes...")
video_to_keyframes(video_file, extracted_dir)
# Load first frame to align mask size
first_frame_name = sorted(os.listdir(extracted_dir))[0]
first_frame = cv2.imread(os.path.join(extracted_dir, first_frame_name))
if first_frame is None:
raise gr.Error("❌ Failed to read first extracted keyframe.")
if mask.shape != first_frame.shape[:2]:
mask = cv2.resize(mask, (first_frame.shape[1], first_frame.shape[0]))
# Optional: get bounding box (coords) of table region
coords = cv2.findNonZero(mask)
if coords is None:
raise gr.Error("❌ No table region detected in default mask.")
x, y, w, h = cv2.boundingRect(coords)
progress(0.3, desc="Applying mask and cropping...")
apply_mask_and_crop(extracted_dir, mask, masked_dir)
progress(0.6, desc="Running inference on frames...")
run_gmm_inference(masked_dir, heatmap_dir)
progress(0.85, desc="Composing final video...")
video_name = f"heatmap_output_{uuid.uuid4().hex[:6]}.mp4"
result_path = os.path.join(base_dir, video_name)
compose_final_video(mask, heatmap_dir, extracted_dir, result_path)
progress(1.0, desc="Done βœ…")
return "βœ… Heatmap video generated successfully!", result_path, result_path
# Layout
custom_css = """
.gradio-container {
background: url('/gradio_api/file=background.jpg') center/cover no-repeat !important;
background-color: #000 !important;
}
.panel {
max-width: 800px;
margin: 2rem auto;
padding: 2rem;
background: rgba(30,30,30, 0.8);
border-radius: 8px;
}
"""
with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css, title="UV Scan - Table Heatmap") as demo:
gr.Markdown("## πŸŽ₯ UV Scan – Table Heatmap Generator", elem_classes="panel")
with gr.Row(elem_classes="panel"):
video_input = gr.Video(label="Upload Video", format="mp4")
with gr.Row(elem_classes="panel"):
generate_btn = gr.Button("πŸ”₯ Generate Heatmap", variant="primary")
reset_btn = gr.Button("Reset")
download_btn = gr.File(label="⬇️ Download Video")
with gr.Row(elem_classes="panel"):
status_text = gr.Markdown("")
with gr.Row(elem_classes="panel"):
output_video = gr.Video(label="Output Video")
def on_video_upload(video_file):
return get_first_frame(video_file)
video_input.change(fn=on_video_upload, inputs=video_input, outputs=None)
generate_btn.click(
fn=process_video,
inputs=[video_input],
outputs=[status_text, output_video, download_btn]
)
reset_btn.click(
fn=lambda: (None, "", None, None),
inputs=[],
outputs=[video_input, status_text, output_video, download_btn]
)
if __name__ == "__main__":
demo.launch()