gokaygokay commited on
Commit
eaf8b52
Β·
1 Parent(s): f76a568
Files changed (2) hide show
  1. matte_processor.py +8 -19
  2. requirements.txt +1 -0
matte_processor.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  from transformers import VitMatteImageProcessor, VitMatteForImageMatting
5
  import math
6
  from pathlib import Path
 
7
 
8
  class VITMatteModel:
9
  def __init__(self, model, processor):
@@ -11,11 +12,14 @@ class VITMatteModel:
11
  self.processor = processor
12
 
13
  def load_VITMatte_model(local_files_only=False):
14
- model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k", local_files_only=local_files_only)
15
  processor = VitMatteImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k", local_files_only=local_files_only)
16
  return VITMatteModel(model, processor)
17
 
18
- def generate_VITMatte(image, trimap, local_files_only=False, device="cpu", max_megapixels=2.0):
 
 
 
19
  if image is None or trimap is None:
20
  return None
21
 
@@ -45,19 +49,11 @@ def generate_VITMatte(image, trimap, local_files_only=False, device="cpu", max_m
45
  trimap = trimap.resize((target_width, target_height), Image.BILINEAR)
46
  resized = True
47
 
48
- # Set device
49
- if device == "cuda" and not torch.cuda.is_available():
50
- device = "cpu"
51
- device = torch.device(device)
52
-
53
- # Load and process
54
- vit_matte_model = load_VITMatte_model(local_files_only=local_files_only)
55
- vit_matte_model.model.to(device)
56
-
57
  inputs = vit_matte_model.processor(images=image, trimaps=trimap, return_tensors="pt")
58
 
59
  with torch.no_grad():
60
- inputs = {k: v.to(device) for k, v in inputs.items()}
61
  predictions = vit_matte_model.model(**inputs).alphas
62
 
63
  if torch.cuda.is_available():
@@ -83,12 +79,6 @@ def create_matte_tab():
83
  input_image = gr.Image(label="Input Image", type="numpy", height=256)
84
  trimap_image = gr.Image(label="Trimap Image", type="numpy", height=256)
85
 
86
- device = gr.Radio(
87
- choices=["cpu", "cuda"],
88
- value="cpu",
89
- label="Device"
90
- )
91
-
92
  max_megapixels = gr.Slider(
93
  minimum=0.5,
94
  maximum=8.0,
@@ -113,7 +103,6 @@ def create_matte_tab():
113
  input_image,
114
  trimap_image,
115
  local_files,
116
- device,
117
  max_megapixels
118
  ],
119
  outputs=output_image
 
4
  from transformers import VitMatteImageProcessor, VitMatteForImageMatting
5
  import math
6
  from pathlib import Path
7
+ import numpy as np
8
 
9
  class VITMatteModel:
10
  def __init__(self, model, processor):
 
12
  self.processor = processor
13
 
14
  def load_VITMatte_model(local_files_only=False):
15
+ model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k", local_files_only=local_files_only).to("cuda")
16
  processor = VitMatteImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k", local_files_only=local_files_only)
17
  return VITMatteModel(model, processor)
18
 
19
+ # Load model globally
20
+ vit_matte_model = load_VITMatte_model(local_files_only=False)
21
+
22
+ def generate_VITMatte(image, trimap, local_files_only=False, max_megapixels=2.0):
23
  if image is None or trimap is None:
24
  return None
25
 
 
49
  trimap = trimap.resize((target_width, target_height), Image.BILINEAR)
50
  resized = True
51
 
52
+ # Use global model
 
 
 
 
 
 
 
 
53
  inputs = vit_matte_model.processor(images=image, trimaps=trimap, return_tensors="pt")
54
 
55
  with torch.no_grad():
56
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
57
  predictions = vit_matte_model.model(**inputs).alphas
58
 
59
  if torch.cuda.is_available():
 
79
  input_image = gr.Image(label="Input Image", type="numpy", height=256)
80
  trimap_image = gr.Image(label="Trimap Image", type="numpy", height=256)
81
 
 
 
 
 
 
 
82
  max_megapixels = gr.Slider(
83
  minimum=0.5,
84
  maximum=8.0,
 
103
  input_image,
104
  trimap_image,
105
  local_files,
 
106
  max_megapixels
107
  ],
108
  outputs=output_image
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  pixeloe
2
  torch
3
  torchvision
 
1
+ transformers
2
  pixeloe
3
  torch
4
  torchvision