gokaygokay commited on
Commit
f76a568
Β·
1 Parent(s): e538d2e
Files changed (2) hide show
  1. app.py +4 -2
  2. matte_processor.py +120 -0
app.py CHANGED
@@ -6,6 +6,7 @@ from color_match_processor import create_color_match_tab
6
  from simple_effects_processor import create_effects_tab
7
  from histogram_processor import create_histogram_tab
8
  from blend_processor import create_blend_tab
 
9
 
10
  with gr.Blocks(title="Image Processing Suite") as demo:
11
  gr.Markdown("# Image Processing Suite")
@@ -17,6 +18,7 @@ with gr.Blocks(title="Image Processing Suite") as demo:
17
  create_effects_tab()
18
  create_histogram_tab()
19
  create_blend_tab()
 
20
 
21
-
22
- demo.launch(debug=True)
 
6
  from simple_effects_processor import create_effects_tab
7
  from histogram_processor import create_histogram_tab
8
  from blend_processor import create_blend_tab
9
+ from matte_processor import create_matte_tab
10
 
11
  with gr.Blocks(title="Image Processing Suite") as demo:
12
  gr.Markdown("# Image Processing Suite")
 
18
  create_effects_tab()
19
  create_histogram_tab()
20
  create_blend_tab()
21
+ create_matte_tab() # Add this line
22
 
23
+ if __name__ == "__main__":
24
+ demo.launch(share=True)
matte_processor.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import VitMatteImageProcessor, VitMatteForImageMatting
5
+ import math
6
+ from pathlib import Path
7
+
8
+ class VITMatteModel:
9
+ def __init__(self, model, processor):
10
+ self.model = model
11
+ self.processor = processor
12
+
13
+ def load_VITMatte_model(local_files_only=False):
14
+ model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k", local_files_only=local_files_only)
15
+ processor = VitMatteImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k", local_files_only=local_files_only)
16
+ return VITMatteModel(model, processor)
17
+
18
+ def generate_VITMatte(image, trimap, local_files_only=False, device="cpu", max_megapixels=2.0):
19
+ if image is None or trimap is None:
20
+ return None
21
+
22
+ # Convert to proper formats
23
+ if isinstance(image, np.ndarray):
24
+ image = Image.fromarray(image)
25
+ if isinstance(trimap, np.ndarray):
26
+ trimap = Image.fromarray(trimap)
27
+
28
+ if image.mode != 'RGB':
29
+ image = image.convert('RGB')
30
+ if trimap.mode != 'L':
31
+ trimap = trimap.convert('L')
32
+
33
+ # Calculate resize if needed
34
+ max_megapixels *= 1048576
35
+ width, height = image.size
36
+ ratio = width / height
37
+ target_width = math.sqrt(ratio * max_megapixels)
38
+ target_height = target_width / ratio
39
+ target_width = int(target_width)
40
+ target_height = int(target_height)
41
+
42
+ resized = False
43
+ if width * height > max_megapixels:
44
+ image = image.resize((target_width, target_height), Image.BILINEAR)
45
+ trimap = trimap.resize((target_width, target_height), Image.BILINEAR)
46
+ resized = True
47
+
48
+ # Set device
49
+ if device == "cuda" and not torch.cuda.is_available():
50
+ device = "cpu"
51
+ device = torch.device(device)
52
+
53
+ # Load and process
54
+ vit_matte_model = load_VITMatte_model(local_files_only=local_files_only)
55
+ vit_matte_model.model.to(device)
56
+
57
+ inputs = vit_matte_model.processor(images=image, trimaps=trimap, return_tensors="pt")
58
+
59
+ with torch.no_grad():
60
+ inputs = {k: v.to(device) for k, v in inputs.items()}
61
+ predictions = vit_matte_model.model(**inputs).alphas
62
+
63
+ if torch.cuda.is_available():
64
+ torch.cuda.empty_cache()
65
+ torch.cuda.ipc_collect()
66
+
67
+ # Convert prediction to image
68
+ mask = predictions.cpu().squeeze().numpy()
69
+ mask = (mask * 255).astype(np.uint8)
70
+ mask = Image.fromarray(mask).convert('L')
71
+
72
+ mask = mask.crop((0, 0, image.width, image.height))
73
+
74
+ if resized:
75
+ mask = mask.resize((width, height), Image.BILINEAR)
76
+
77
+ return np.array(mask)
78
+
79
+ def create_matte_tab():
80
+ with gr.Tab("Image Matting"):
81
+ with gr.Row():
82
+ with gr.Column():
83
+ input_image = gr.Image(label="Input Image", type="numpy", height=256)
84
+ trimap_image = gr.Image(label="Trimap Image", type="numpy", height=256)
85
+
86
+ device = gr.Radio(
87
+ choices=["cpu", "cuda"],
88
+ value="cpu",
89
+ label="Device"
90
+ )
91
+
92
+ max_megapixels = gr.Slider(
93
+ minimum=0.5,
94
+ maximum=8.0,
95
+ value=2.0,
96
+ step=0.5,
97
+ label="Max Megapixels"
98
+ )
99
+
100
+ local_files = gr.Checkbox(
101
+ value=False,
102
+ label="Use Local Files Only"
103
+ )
104
+
105
+ process_btn = gr.Button("Generate Matte")
106
+
107
+ with gr.Column():
108
+ output_image = gr.Image(label="Generated Matte")
109
+
110
+ process_btn.click(
111
+ fn=generate_VITMatte,
112
+ inputs=[
113
+ input_image,
114
+ trimap_image,
115
+ local_files,
116
+ device,
117
+ max_megapixels
118
+ ],
119
+ outputs=output_image
120
+ )