Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
import os | |
import sys | |
from loadimg import load_img | |
from ben_base import BEN_Base | |
import random | |
import huggingface_hub | |
import numpy as np | |
def set_random_seed(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
set_random_seed(9) | |
torch.set_float32_matmul_precision("high") | |
model = BEN_Base() | |
# Download the model file from Hugging Face Hub | |
model_path = huggingface_hub.hf_hub_download( | |
repo_id="PramaLLC/BEN2", | |
filename="BEN2_Base.pth" | |
) | |
# Check if CUDA is available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Load model | |
model.loadcheckpoints(model_path) | |
model.to(device) | |
model.eval() | |
output_folder = 'output_images' | |
if not os.path.exists(output_folder): | |
os.makedirs(output_folder) | |
def fn(image): | |
im = load_img(image, output_type="pil") | |
im = im.convert("RGB") | |
result_image = process(im) | |
image_path = os.path.join(output_folder, "foreground.png") | |
result_image.save(image_path) | |
return result_image, image_path | |
def process_video(video_path): | |
output_path = "./foreground.mp4" | |
# print(type(video_path)) | |
# print(video_path) | |
model.segment_video(video_path) # This will save to ./foreground.mp4 | |
return output_path | |
def process(image): | |
foreground = model.inference(image) | |
print(type(foreground)) | |
return foreground | |
def process_file(f): | |
name_path = f.rsplit(".",1)[0]+".png" | |
im = load_img(f, output_type="pil") | |
im = im.convert("RGB") | |
transparent = process(im) | |
transparent.save(name_path) | |
return name_path | |
# Interface components | |
image = gr.Image(label="Upload an image") | |
video = gr.Video(label="Upload a video") | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
image_path = os.path.join(current_dir, "image.jpg") | |
examples = load_img(image_path, output_type="pil") | |
# Image processing tab | |
tab1 = gr.Interface( | |
fn, | |
inputs=image, | |
outputs=[ | |
gr.Image(label="Result Foreground"), | |
gr.File(label="Download PNG") | |
], | |
examples=[examples], | |
api_name="image" | |
) | |
# Video processing tab | |
tab2 = gr.Interface( | |
process_video, | |
inputs=video, | |
outputs=gr.Video(label="Result Video"), | |
api_name="video", | |
title="Video Processing (experimental)", | |
description="Note: For ZeroGPU timeout, videos are limited to processing the first 100 frames only." | |
) | |
# Combined interface | |
demo = gr.TabbedInterface( | |
[tab1, tab2], | |
["Image Processing", "Video Processing"], | |
title="BEN2 for background removal. Download the image/video for higher quality foreground.", | |
# description="Note: Video processing is limited to the first 100 frames for performance reasons." | |
) | |
if __name__ == "__main__": | |
demo.launch(show_error=True) |