File size: 4,689 Bytes
c9473c9 623e1bf b7a75e4 623e1bf 4d7e87d b7a75e4 4d7e87d 623e1bf c333b0b 623e1bf 4d7e87d c9473c9 f397a20 c9473c9 b7a75e4 4d7e87d b7a75e4 4d7e87d b7a75e4 f397a20 c9473c9 f397a20 c9473c9 b7a75e4 4d7e87d c333b0b 4d7e87d b7a75e4 4d7e87d c333b0b f397a20 36a76ae c333b0b 4d7e87d 36a76ae 4d7e87d 36a76ae b7a75e4 4d7e87d 36a76ae 4d7e87d 36a76ae 1018e38 36a76ae 4d7e87d 36a76ae 4d7e87d 36a76ae 623e1bf c9473c9 e041428 c9473c9 36a76ae b7a75e4 623e1bf b7a75e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import time
import torch
import gc
from transformers import AutoConfig, AutoModelForImageSegmentation
from PIL import Image
from torchvision import transforms
import gradio as gr
def load_model():
# Fetch the config first (with trust_remote_code=True)
config = AutoConfig.from_pretrained("zhengpeng7/BiRefNet_lite", trust_remote_code=True)
# Ensure it's not treated as a seq2seq model
config.is_encoder_decoder = False
# Optionally, block calls to get_text_config if needed:
# config.get_text_config = lambda decoder=True: None
# Now load the model with our tweaked config
model = AutoModelForImageSegmentation.from_pretrained(
"zhengpeng7/BiRefNet_lite",
config=config,
trust_remote_code=True
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
return model, device
birefnet, device = load_model()
# Preprocessing
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def run_inference(images, model, device):
inputs = []
original_sizes = []
for img in images:
original_sizes.append(img.size)
inputs.append(transform_image(img))
input_tensor = torch.stack(inputs).to(device)
try:
with torch.no_grad():
# If the last layer is returned as [-1],
# adjust accordingly or see how your model outputs are structured
preds = model(input_tensor)[-1].sigmoid().cpu()
except torch.OutOfMemoryError:
del input_tensor
torch.cuda.empty_cache()
raise
# Post-process
results = []
for i, img in enumerate(images):
pred = preds[i].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(original_sizes[i])
result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
result.paste(img, mask=mask)
results.append(result)
# Cleanup
del input_tensor, preds
gc.collect()
torch.cuda.empty_cache()
return results
def binary_search_max(images):
# After OOM, try to find max feasible batch
low, high = 1, len(images)
best = None
best_count = 0
while low <= high:
mid = (low + high) // 2
batch = images[:mid]
try:
global birefnet, device
birefnet, device = load_model() # re-init to reduce memory fragmentation
res = run_inference(batch, birefnet, device)
best = res
best_count = mid
low = mid + 1
except torch.OutOfMemoryError:
high = mid - 1
return best, best_count
def extract_objects(filepaths):
images = [Image.open(p).convert("RGB") for p in filepaths]
start_time = time.time()
# First attempt: all images
try:
results = run_inference(images, birefnet, device)
end_time = time.time()
total_time = end_time - start_time
summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
return results, summary
except torch.OutOfMemoryError:
# OOM occurred, try fallback
oom_time = time.time()
initial_attempt_time = oom_time - start_time
best, best_count = binary_search_max(images)
end_time = time.time()
total_time = end_time - start_time
if best is None:
# Not even 1 image works
summary = (
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
f"Could not process even a single image.\n"
f"Total time including fallback attempts: {total_time:.2f}s."
)
return [], summary
else:
summary = (
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
f"Found that {best_count} images can be processed without OOM.\n"
f"Total time including fallback attempts: {total_time:.2f}s.\n"
f"Next time, try using up to {best_count} images."
)
return best, summary
iface = gr.Interface(
fn=extract_objects,
inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
title="BiRefNet Bulk Background Removal with On-Demand Fallback",
description="Upload as many images as you want. If OOM occurs, fallback logic will find the max feasible number."
)
if __name__ == "__main__":
iface.launch()
|