Alessio Grancini
Update app.py
2f27314 verified
raw
history blame
11.3 kB
import cv2
import gradio as gr
import numpy as np
import os
import utils
import plotly.graph_objects as go
import spaces
import torch
from image_segmenter import ImageSegmenter
from monocular_depth_estimator import MonocularDepthEstimator
from point_cloud_generator import display_pcd
# params
CANCEL_PROCESSING = False
# Initialize classes without loading models
img_seg = None
depth_estimator = None
def initialize_models():
"""Loads models onto GPU if available, otherwise falls back to CPU."""
global img_seg, depth_estimator
device = "cuda" if torch.cuda.is_available() else "cpu"
if img_seg is None:
print(f"๐Ÿ”น Loading ImageSegmenter model on {device}...")
img_seg = ImageSegmenter(model_type="yolov8s-seg", device=device)
if depth_estimator is None:
print(f"๐Ÿ”น Loading Depth Estimator model on {device}...")
depth_estimator = MonocularDepthEstimator(model_type="midas_v21_small_256", device=device)
def safe_gpu_decorator(func):
"""Custom decorator to handle GPU operations safely"""
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except RuntimeError as e:
if "cudaGetDeviceCount" in str(e):
print("GPU initialization failed, falling back to CPU")
# Set environment variable to force CPU
os.environ['CUDA_VISIBLE_DEVICES'] = ''
return func(*args, **kwargs)
raise
return wrapper
@safe_gpu_decorator
def process_image(image):
try:
print("๐Ÿš€ Starting image processing...")
initialize_models()
if torch.cuda.is_available():
print("โœ… Using GPU for processing")
torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
print("โš ๏ธ Using CPU for processing")
# Process image
image = utils.resize(image)
image_segmentation, objects_data = img_seg.predict(image)
depthmap, depth_colormap = depth_estimator.make_prediction(image)
dist_image = utils.draw_depth_info(image, depthmap, objects_data)
objs_pcd = utils.generate_obj_pcd(depthmap, objects_data)
plot_fig = display_pcd(objs_pcd)
return image_segmentation, depth_colormap, dist_image, plot_fig
except RuntimeError as e:
print(f"๐Ÿšจ RuntimeError in process_image: {e}")
if "cuda" in str(e).lower():
print("โš ๏ธ CUDA error detected. Switching to CPU mode.")
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import traceback
print(traceback.format_exc())
raise
@safe_gpu_decorator
def test_process_img(image):
initialize_models()
image = utils.resize(image)
image_segmentation, objects_data = img_seg.predict(image)
depthmap, depth_colormap = depth_estimator.make_prediction(image)
return image_segmentation, objects_data, depthmap, depth_colormap
@safe_gpu_decorator
def process_video(vid_path=None):
try:
initialize_models()
vid_cap = cv2.VideoCapture(vid_path)
while vid_cap.isOpened():
ret, frame = vid_cap.read()
if ret:
print("making predictions ....")
frame = utils.resize(frame)
image_segmentation, objects_data = img_seg.predict(frame)
depthmap, depth_colormap = depth_estimator.make_prediction(frame)
dist_image = utils.draw_depth_info(frame, depthmap, objects_data)
yield cv2.cvtColor(image_segmentation, cv2.COLOR_BGR2RGB), depth_colormap, cv2.cvtColor(dist_image, cv2.COLOR_BGR2RGB)
vid_cap.release()
return None
except Exception as e:
print(f"Error in process_video: {str(e)}")
import traceback
print(traceback.format_exc())
raise
def update_segmentation_options(options):
initialize_models()
img_seg.is_show_bounding_boxes = True if 'Show Boundary Box' in options else False
img_seg.is_show_segmentation = True if 'Show Segmentation Region' in options else False
img_seg.is_show_segmentation_boundary = True if 'Show Segmentation Boundary' in options else False
def update_confidence_threshold(thres_val):
initialize_models()
img_seg.confidence_threshold = thres_val/100
@safe_gpu_decorator
def model_selector(model_type):
global img_seg, depth_estimator
device = "cuda" if torch.cuda.is_available() else "cpu"
model_dict = {
"Small - Better performance and less accuracy": ("midas_v21_small_256", "yolov8s-seg"),
"Medium - Balanced performance and accuracy": ("dpt_hybrid_384", "yolov8m-seg"),
"Large - Slow performance and high accuracy": ("dpt_large_384", "yolov8l-seg"),
}
midas_model, yolo_model = model_dict.get(model_type, ("midas_v21_small_256", "yolov8s-seg"))
print(f"๐Ÿ”น Switching to models: YOLO={yolo_model}, MiDaS={midas_model} on {device}")
img_seg = ImageSegmenter(model_type=yolo_model, device=device)
depth_estimator = MonocularDepthEstimator(model_type=midas_model, device=device)
def cancel():
global CANCEL_PROCESSING
CANCEL_PROCESSING = True
if __name__ == "__main__":
# Ensure CUDA is properly initialized
try:
if torch.cuda.is_available():
print(f"โœ… CUDA is available: {torch.cuda.get_device_name(0)}")
device = torch.device("cuda")
torch.cuda.empty_cache() # Clear GPU cache
else:
print("โŒ No CUDA available. Falling back to CPU.")
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
device = torch.device("cpu")
except RuntimeError as e:
print(f"๐Ÿšจ CUDA initialization failed: {e}. Switching to CPU mode.")
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
device = torch.device("cpu")
with gr.Blocks() as my_app:
# title
gr.Markdown("<h1><center>Simultaneous Segmentation and Depth Estimation</center></h1>")
gr.Markdown("<h3><center>Created by Vaishanth</center></h3>")
gr.Markdown("<h3><center>This model estimates the depth of segmented objects.</center></h3>")
# tabs
with gr.Tab("Image"):
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image()
model_type_img = gr.Dropdown(
["Small - Better performance and less accuracy",
"Medium - Balanced performance and accuracy",
"Large - Slow performance and high accuracy"],
label="Model Type", value="Small - Better performance and less accuracy",
info="Select the inference model before running predictions!")
options_checkbox_img = gr.CheckboxGroup(["Show Boundary Box", "Show Segmentation Region", "Show Segmentation Boundary"], label="Options")
conf_thres_img = gr.Slider(1, 100, value=60, label="Confidence Threshold", info="Choose the threshold above which objects should be detected")
submit_btn_img = gr.Button(value="Predict")
with gr.Column(scale=2):
with gr.Row():
segmentation_img_output = gr.Image(height=300, label="Segmentation")
depth_img_output = gr.Image(height=300, label="Depth Estimation")
with gr.Row():
dist_img_output = gr.Image(height=300, label="Distance")
pcd_img_output = gr.Plot(label="Point Cloud")
gr.Examples(
examples=[os.path.join(os.path.dirname(__file__), "assets/images/baggage_claim.jpg"),
os.path.join(os.path.dirname(__file__), "assets/images/kitchen_2.png"),
os.path.join(os.path.dirname(__file__), "assets/images/soccer.jpg"),
os.path.join(os.path.dirname(__file__), "assets/images/room_2.png"),
os.path.join(os.path.dirname(__file__), "assets/images/living_room.jpg")],
inputs=img_input,
outputs=[segmentation_img_output, depth_img_output, dist_img_output, pcd_img_output],
fn=process_image,
cache_examples=True,
)
with gr.Tab("Video"):
with gr.Row():
with gr.Column(scale=1):
vid_input = gr.Video()
model_type_vid = gr.Dropdown(
["Small - Better performance and less accuracy",
"Medium - Balanced performance and accuracy",
"Large - Slow performance and high accuracy"],
label="Model Type", value="Small - Better performance and less accuracy",
info="Select the inference model before running predictions!")
options_checkbox_vid = gr.CheckboxGroup(["Show Boundary Box", "Show Segmentation Region", "Show Segmentation Boundary"], label="Options")
conf_thres_vid = gr.Slider(1, 100, value=60, label="Confidence Threshold", info="Choose the threshold above which objects should be detected")
with gr.Row():
cancel_btn = gr.Button(value="Cancel")
submit_btn_vid = gr.Button(value="Predict")
with gr.Column(scale=2):
with gr.Row():
segmentation_vid_output = gr.Image(height=300, label="Segmentation")
depth_vid_output = gr.Image(height=300, label="Depth Estimation")
with gr.Row():
dist_vid_output = gr.Image(height=300, label="Distance")
gr.Examples(
examples=[os.path.join(os.path.dirname(__file__), "assets/videos/input_video.mp4"),
os.path.join(os.path.dirname(__file__), "assets/videos/driving.mp4"),
os.path.join(os.path.dirname(__file__), "assets/videos/overpass.mp4"),
os.path.join(os.path.dirname(__file__), "assets/videos/walking.mp4")],
inputs=vid_input,
)
# image tab logic
submit_btn_img.click(process_image, inputs=img_input, outputs=[segmentation_img_output, depth_img_output, dist_img_output, pcd_img_output])
options_checkbox_img.change(update_segmentation_options, options_checkbox_img, [])
conf_thres_img.change(update_confidence_threshold, conf_thres_img, [])
model_type_img.change(model_selector, model_type_img, [])
# video tab logic
submit_btn_vid.click(process_video, inputs=vid_input, outputs=[segmentation_vid_output, depth_vid_output, dist_vid_output])
model_type_vid.change(model_selector, model_type_vid, [])
cancel_btn.click(cancel, inputs=[], outputs=[])
options_checkbox_vid.change(update_segmentation_options, options_checkbox_vid, [])
conf_thres_vid.change(update_confidence_threshold, conf_thres_vid, [])
my_app.queue(max_size=10).launch()