Jbowyer commited on
Commit
d0b9047
·
verified ·
1 Parent(s): fdd8db5

Upload birefnet_rembg.py

Browse files
hy3dshape/hy3dshape/birefnet_rembg.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ from transformers import AutoModelForImageSegmentation
5
+
6
+ class BiRefNetRemover:
7
+ def __init__(self, device="cuda"):
8
+ self.device = device
9
+ self.model = AutoModelForImageSegmentation.from_pretrained(
10
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
11
+ )
12
+ self.model.to(self.device)
13
+ self.transform_image = transforms.Compose(
14
+ [
15
+ transforms.Resize((1024, 1024)),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
18
+ ]
19
+ )
20
+
21
+ @torch.no_grad()
22
+ def __call__(self, image: Image.Image) -> Image.Image:
23
+ """
24
+ Removes the background from a PIL image.
25
+ """
26
+ if image.mode != "RGB":
27
+ image = image.convert("RGB")
28
+
29
+ image_size = image.size
30
+ input_images = self.transform_image(image).unsqueeze(0).to(self.device)
31
+
32
+ preds = self.model(input_images)[-1].sigmoid().cpu()
33
+
34
+ pred = preds[0].squeeze()
35
+ pred_pil = transforms.ToPILImage()(pred)
36
+
37
+ mask = pred_pil.resize(image_size)
38
+
39
+ image.putalpha(mask)
40
+ return image