Spaces:
Runtime error
Runtime error
Commit
Β·
eaf8b52
1
Parent(s):
f76a568
matte
Browse files- matte_processor.py +8 -19
- requirements.txt +1 -0
matte_processor.py
CHANGED
@@ -4,6 +4,7 @@ import torch
|
|
4 |
from transformers import VitMatteImageProcessor, VitMatteForImageMatting
|
5 |
import math
|
6 |
from pathlib import Path
|
|
|
7 |
|
8 |
class VITMatteModel:
|
9 |
def __init__(self, model, processor):
|
@@ -11,11 +12,14 @@ class VITMatteModel:
|
|
11 |
self.processor = processor
|
12 |
|
13 |
def load_VITMatte_model(local_files_only=False):
|
14 |
-
model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k", local_files_only=local_files_only)
|
15 |
processor = VitMatteImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k", local_files_only=local_files_only)
|
16 |
return VITMatteModel(model, processor)
|
17 |
|
18 |
-
|
|
|
|
|
|
|
19 |
if image is None or trimap is None:
|
20 |
return None
|
21 |
|
@@ -45,19 +49,11 @@ def generate_VITMatte(image, trimap, local_files_only=False, device="cpu", max_m
|
|
45 |
trimap = trimap.resize((target_width, target_height), Image.BILINEAR)
|
46 |
resized = True
|
47 |
|
48 |
-
#
|
49 |
-
if device == "cuda" and not torch.cuda.is_available():
|
50 |
-
device = "cpu"
|
51 |
-
device = torch.device(device)
|
52 |
-
|
53 |
-
# Load and process
|
54 |
-
vit_matte_model = load_VITMatte_model(local_files_only=local_files_only)
|
55 |
-
vit_matte_model.model.to(device)
|
56 |
-
|
57 |
inputs = vit_matte_model.processor(images=image, trimaps=trimap, return_tensors="pt")
|
58 |
|
59 |
with torch.no_grad():
|
60 |
-
inputs = {k: v.to(
|
61 |
predictions = vit_matte_model.model(**inputs).alphas
|
62 |
|
63 |
if torch.cuda.is_available():
|
@@ -83,12 +79,6 @@ def create_matte_tab():
|
|
83 |
input_image = gr.Image(label="Input Image", type="numpy", height=256)
|
84 |
trimap_image = gr.Image(label="Trimap Image", type="numpy", height=256)
|
85 |
|
86 |
-
device = gr.Radio(
|
87 |
-
choices=["cpu", "cuda"],
|
88 |
-
value="cpu",
|
89 |
-
label="Device"
|
90 |
-
)
|
91 |
-
|
92 |
max_megapixels = gr.Slider(
|
93 |
minimum=0.5,
|
94 |
maximum=8.0,
|
@@ -113,7 +103,6 @@ def create_matte_tab():
|
|
113 |
input_image,
|
114 |
trimap_image,
|
115 |
local_files,
|
116 |
-
device,
|
117 |
max_megapixels
|
118 |
],
|
119 |
outputs=output_image
|
|
|
4 |
from transformers import VitMatteImageProcessor, VitMatteForImageMatting
|
5 |
import math
|
6 |
from pathlib import Path
|
7 |
+
import numpy as np
|
8 |
|
9 |
class VITMatteModel:
|
10 |
def __init__(self, model, processor):
|
|
|
12 |
self.processor = processor
|
13 |
|
14 |
def load_VITMatte_model(local_files_only=False):
|
15 |
+
model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k", local_files_only=local_files_only).to("cuda")
|
16 |
processor = VitMatteImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k", local_files_only=local_files_only)
|
17 |
return VITMatteModel(model, processor)
|
18 |
|
19 |
+
# Load model globally
|
20 |
+
vit_matte_model = load_VITMatte_model(local_files_only=False)
|
21 |
+
|
22 |
+
def generate_VITMatte(image, trimap, local_files_only=False, max_megapixels=2.0):
|
23 |
if image is None or trimap is None:
|
24 |
return None
|
25 |
|
|
|
49 |
trimap = trimap.resize((target_width, target_height), Image.BILINEAR)
|
50 |
resized = True
|
51 |
|
52 |
+
# Use global model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
inputs = vit_matte_model.processor(images=image, trimaps=trimap, return_tensors="pt")
|
54 |
|
55 |
with torch.no_grad():
|
56 |
+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
57 |
predictions = vit_matte_model.model(**inputs).alphas
|
58 |
|
59 |
if torch.cuda.is_available():
|
|
|
79 |
input_image = gr.Image(label="Input Image", type="numpy", height=256)
|
80 |
trimap_image = gr.Image(label="Trimap Image", type="numpy", height=256)
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
max_megapixels = gr.Slider(
|
83 |
minimum=0.5,
|
84 |
maximum=8.0,
|
|
|
103 |
input_image,
|
104 |
trimap_image,
|
105 |
local_files,
|
|
|
106 |
max_megapixels
|
107 |
],
|
108 |
outputs=output_image
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
pixeloe
|
2 |
torch
|
3 |
torchvision
|
|
|
1 |
+
transformers
|
2 |
pixeloe
|
3 |
torch
|
4 |
torchvision
|