maxorange commited on
Commit
c7d90e7
·
verified ·
1 Parent(s): 089ca2d

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +57 -0
  3. chameleon.jpg +3 -0
  4. requirements.txt +16 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ chameleon.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_imageslider import ImageSlider
3
+ from loadimg import load_img
4
+ import spaces
5
+ from transformers import AutoModelForImageSegmentation
6
+ import torch
7
+ from torchvision import transforms
8
+
9
+ torch.set_float32_matmul_precision(["high", "highest"][0])
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
14
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
15
+ )
16
+ birefnet.to(device)
17
+ transform_image = transforms.Compose(
18
+ [
19
+ transforms.Resize((1024, 1024)),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
22
+ ]
23
+ )
24
+
25
+
26
+ @spaces.GPU
27
+ def fn(image):
28
+ im = load_img(image, output_type="pil")
29
+ im = im.convert("RGB")
30
+ image_size = im.size
31
+ origin = im.copy()
32
+ image = load_img(im)
33
+ input_images = transform_image(image).unsqueeze(0).to(device)
34
+ # Prediction
35
+ with torch.no_grad():
36
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
37
+ pred = preds[0].squeeze()
38
+ pred_pil = transforms.ToPILImage()(pred)
39
+ mask = pred_pil.resize(image_size)
40
+ image.putalpha(mask)
41
+ return image
42
+
43
+
44
+ chameleon = load_img("chameleon.jpg", output_type="pil")
45
+
46
+ url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
47
+ demo = gr.Interface(
48
+ fn,
49
+ inputs=gr.Image(label="Upload an image"),
50
+ outputs=gr.Image(label="birefnet", format="png"),
51
+ examples=[chameleon],
52
+ api_name="image",
53
+ flagging_mode="never",
54
+ cache_mode="lazy",
55
+ )
56
+
57
+ demo.queue(default_concurrency_limit=1).launch(show_error=True)
chameleon.jpg ADDED

Git LFS Details

  • SHA256: f340841ad7379ef9ad15d94cf4a5bc35fb58435a6b84da9efbfc2ee8fa6a2621
  • Pointer size: 131 Bytes
  • Size of remote file: 477 kB
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ accelerate
3
+ opencv-python
4
+ spaces
5
+ pillow
6
+ numpy
7
+ timm
8
+ kornia
9
+ prettytable
10
+ typing
11
+ scikit-image
12
+ huggingface_hub
13
+ transformers>=4.39.1
14
+ gradio_imageslider
15
+ loadimg>=0.1.1
16
+ einops