Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
import os | |
import uuid | |
from extract_frames import video_to_keyframes | |
from apply_mask import apply_mask_and_crop | |
from run_gmm import run_gmm_inference | |
from compose_video import compose_final_video | |
# Ensure folders exist | |
for path in [ | |
"video_outputs/extracted_frames", | |
"video_outputs/masked_frames", | |
"video_outputs/output_heatmap", | |
"video_inputs", | |
"assets" | |
]: | |
os.makedirs(path, exist_ok=True) | |
# Get first frame for preview | |
def get_first_frame(video_path): | |
cap = cv2.VideoCapture(video_path) | |
success, frame = cap.read() | |
cap.release() | |
if success: | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
return Image.fromarray(frame) | |
return None | |
def process_video(video_file, progress=gr.Progress()): | |
base_dir = "video_outputs" | |
extracted_dir = os.path.join(base_dir, "extracted_frames") | |
masked_dir = os.path.join(base_dir, "masked_frames") | |
heatmap_dir = os.path.join(base_dir, "output_heatmap") | |
# Clear old frames | |
for folder in [extracted_dir, masked_dir, heatmap_dir]: | |
for f in os.listdir(folder): | |
os.remove(os.path.join(folder, f)) | |
# Load default mask | |
mask_path = "default_mask.png" | |
if not os.path.exists(mask_path): | |
raise gr.Error("β Default mask not found at 'assets/default_mask.png'") | |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
if mask is None: | |
raise gr.Error("β Failed to load default mask.") | |
progress(0, desc="Extracting keyframes...") | |
video_to_keyframes(video_file, extracted_dir) | |
# Load first frame to align mask size | |
first_frame_name = sorted(os.listdir(extracted_dir))[0] | |
first_frame = cv2.imread(os.path.join(extracted_dir, first_frame_name)) | |
if first_frame is None: | |
raise gr.Error("β Failed to read first extracted keyframe.") | |
if mask.shape != first_frame.shape[:2]: | |
mask = cv2.resize(mask, (first_frame.shape[1], first_frame.shape[0])) | |
# Optional: get bounding box (coords) of table region | |
coords = cv2.findNonZero(mask) | |
if coords is None: | |
raise gr.Error("β No table region detected in default mask.") | |
x, y, w, h = cv2.boundingRect(coords) | |
progress(0.3, desc="Applying mask and cropping...") | |
apply_mask_and_crop(extracted_dir, mask, masked_dir) | |
progress(0.6, desc="Running inference on frames...") | |
run_gmm_inference(masked_dir, heatmap_dir) | |
progress(0.85, desc="Composing final video...") | |
video_name = f"heatmap_output_{uuid.uuid4().hex[:6]}.mp4" | |
result_path = os.path.join(base_dir, video_name) | |
compose_final_video(mask, heatmap_dir, extracted_dir, result_path) | |
progress(1.0, desc="Done β ") | |
return "β Heatmap video generated successfully!", result_path, result_path | |
# Layout | |
custom_css = """ | |
.gradio-container { | |
background: url('/gradio_api/file=background.jpg') center/cover no-repeat !important; | |
background-color: #000 !important; | |
} | |
.panel { | |
max-width: 800px; | |
margin: 2rem auto; | |
padding: 2rem; | |
background: rgba(30,30,30, 0.8); | |
border-radius: 8px; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css, title="UV Scan - Table Heatmap") as demo: | |
gr.Markdown("## π₯ UV Scan β Table Heatmap Generator", elem_classes="panel") | |
with gr.Row(elem_classes="panel"): | |
video_input = gr.Video(label="Upload Video", format="mp4") | |
with gr.Row(elem_classes="panel"): | |
generate_btn = gr.Button("π₯ Generate Heatmap", variant="primary") | |
reset_btn = gr.Button("Reset") | |
download_btn = gr.File(label="β¬οΈ Download Video") | |
with gr.Row(elem_classes="panel"): | |
status_text = gr.Markdown("") | |
with gr.Row(elem_classes="panel"): | |
output_video = gr.Video(label="Output Video") | |
def on_video_upload(video_file): | |
return get_first_frame(video_file) | |
video_input.change(fn=on_video_upload, inputs=video_input, outputs=None) | |
generate_btn.click( | |
fn=process_video, | |
inputs=[video_input], | |
outputs=[status_text, output_video, download_btn] | |
) | |
reset_btn.click( | |
fn=lambda: (None, "", None, None), | |
inputs=[], | |
outputs=[video_input, status_text, output_video, download_btn] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |