mishiawan commited on
Commit
6d6633e
·
verified ·
1 Parent(s): f32e9c5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -0
app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from copy import deepcopy
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import PIL
8
+ import spaces
9
+ import torch
10
+ import yaml
11
+ from gradio_imageslider import ImageSlider
12
+ from huggingface_hub import hf_hub_download
13
+ from PIL import Image
14
+ from safetensors.torch import load_file
15
+ from torchvision.transforms import ToPILImage, ToTensor
16
+ from transformers import AutoModelForImageSegmentation
17
+ from utils import extract_object, get_model_from_config, resize_and_center_crop
18
+
19
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
20
+
21
+ ASPECT_RATIOS = {
22
+ str(512 / 2048): (512, 2048),
23
+ str(1024 / 1024): (1024, 1024),
24
+ str(2048 / 512): (2048, 512),
25
+ str(896 / 1152): (896, 1152),
26
+ str(1152 / 896): (1152, 896),
27
+ str(512 / 1920): (512, 1920),
28
+ str(640 / 1536): (640, 1536),
29
+ str(768 / 1280): (768, 1280),
30
+ str(1280 / 768): (1280, 768),
31
+ str(1536 / 640): (1536, 640),
32
+ str(1920 / 512): (1920, 512),
33
+ }
34
+
35
+ # download the config and model
36
+ MODEL_PATH = hf_hub_download("jasperai/LBM_relighting", "relight.safetensors", token=huggingface_token)
37
+ CONFIG_PATH = hf_hub_download("jasperai/LBM_relighting", "relight.yaml", token=huggingface_token)
38
+
39
+ with open(CONFIG_PATH, "r") as f:
40
+ config = yaml.safe_load(f)
41
+ model = get_model_from_config(**config)
42
+ sd = load_file(MODEL_PATH)
43
+ model.load_state_dict(sd, strict=True)
44
+ model.to("cuda").to(torch.bfloat16)
45
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
46
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
47
+ ).cuda()
48
+ image_size = (1024, 1024)
49
+
50
+
51
+ @spaces.GPU
52
+ def evaluate(
53
+ fg_image: PIL.Image.Image,
54
+ bg_image: PIL.Image.Image,
55
+ num_sampling_steps: int = 1,
56
+ ):
57
+ gr.Info("Relighting Image...", duration=3)
58
+
59
+ ori_h_bg, ori_w_bg = fg_image.size
60
+ ar_bg = ori_h_bg / ori_w_bg
61
+ closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
62
+ dimensions_bg = ASPECT_RATIOS[closest_ar_bg]
63
+
64
+ _, fg_mask = extract_object(birefnet, deepcopy(fg_image))
65
+
66
+ fg_image = resize_and_center_crop(fg_image, dimensions_bg[0], dimensions_bg[1])
67
+ fg_mask = resize_and_center_crop(fg_mask, dimensions_bg[0], dimensions_bg[1])
68
+ bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1])
69
+
70
+ img_pasted = Image.composite(fg_image, bg_image, fg_mask)
71
+
72
+ img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1
73
+ batch = {
74
+ "source_image": img_pasted_tensor.cuda().to(torch.bfloat16),
75
+ }
76
+
77
+ z_source = model.vae.encode(batch[model.source_key])
78
+
79
+ output_image = model.sample(
80
+ z=z_source,
81
+ num_steps=num_sampling_steps,
82
+ conditioner_inputs=batch,
83
+ max_samples=1,
84
+ ).clamp(-1, 1)
85
+
86
+ output_image = (output_image[0].float().cpu() + 1) / 2
87
+ output_image = ToPILImage()(output_image)
88
+
89
+ # paste the output image on the background image
90
+ output_image = Image.composite(output_image, bg_image, fg_mask)
91
+
92
+ output_image.resize((ori_h_bg, ori_w_bg))
93
+ print(output_image.size, img_pasted.size)
94
+
95
+ return (np.array(img_pasted), np.array(output_image))
96
+
97
+
98
+ with gr.Blocks(title="LBM Object Relighting") as demo:
99
+ gr.Markdown(
100
+ f"""
101
+ # Object Relighting with Latent Bridge Matching
102
+ This is an interactive demo of [LBM: Latent Bridge Matching for Fast Image-to-Image Translation](https://arxiv.org/abs/2503.07535) *by Jasper Research*. We are internally exploring the possibility of releasing the model. If you enjoy the space, please also promote *open-source* by giving a ⭐ to the <a href='https://github.com/gojasper/LBM' target='_blank'>Github Repo</a>.
103
+ """
104
+ )
105
+ gr.Markdown(
106
+ "💡 *Hint:* To better appreciate the low latency of our method, run the demo locally !"
107
+ )
108
+ with gr.Row():
109
+ with gr.Column():
110
+ with gr.Row():
111
+ fg_image = gr.Image(
112
+ type="pil",
113
+ label="Input Image",
114
+ image_mode="RGB",
115
+ height=360,
116
+ # width=360,
117
+ )
118
+ bg_image = gr.Image(
119
+ type="pil",
120
+ label="Target Background",
121
+ image_mode="RGB",
122
+ height=360,
123
+ # width=360,
124
+ )
125
+
126
+ with gr.Row():
127
+ submit_button = gr.Button("Relight", variant="primary")
128
+ with gr.Row():
129
+ num_inference_steps = gr.Slider(
130
+ minimum=1,
131
+ maximum=4,
132
+ value=1,
133
+ step=1,
134
+ label="Number of Inference Steps",
135
+ )
136
+
137
+ bg_gallery = gr.Gallery(
138
+ # height=450,
139
+ object_fit="contain",
140
+ label="Background List",
141
+ value=[path for path in glob.glob("examples/backgrounds/*.jpg")],
142
+ columns=5,
143
+ allow_preview=False,
144
+ )
145
+
146
+ with gr.Column():
147
+ output_slider = ImageSlider(label="Composite vs LBM", type="numpy")
148
+ output_slider.upload(
149
+ fn=evaluate,
150
+ inputs=[fg_image, bg_image, num_inference_steps],
151
+ outputs=[output_slider],
152
+ )
153
+
154
+ submit_button.click(
155
+ evaluate,
156
+ inputs=[fg_image, bg_image, num_inference_steps],
157
+ outputs=[output_slider],
158
+ show_progress="full",
159
+ show_api=False,
160
+ )
161
+
162
+ with gr.Row():
163
+ gr.Examples(
164
+ fn=evaluate,
165
+ examples=[
166
+ [
167
+ "examples/foregrounds/2.jpg",
168
+ "examples/backgrounds/14.jpg",
169
+ 1,
170
+ ],
171
+ [
172
+ "examples/foregrounds/10.jpg",
173
+ "examples/backgrounds/4.jpg",
174
+ 1,
175
+ ],
176
+ [
177
+ "examples/foregrounds/11.jpg",
178
+ "examples/backgrounds/24.jpg",
179
+ 1,
180
+ ],
181
+ [
182
+ "examples/foregrounds/19.jpg",
183
+ "examples/backgrounds/3.jpg",
184
+ 1,
185
+ ],
186
+ [
187
+ "examples/foregrounds/4.jpg",
188
+ "examples/backgrounds/6.jpg",
189
+ 1,
190
+ ],
191
+ [
192
+ "examples/foregrounds/14.jpg",
193
+ "examples/backgrounds/22.jpg",
194
+ 1,
195
+ ],
196
+ [
197
+ "examples/foregrounds/12.jpg",
198
+ "examples/backgrounds/1.jpg",
199
+ 1,
200
+ ],
201
+ ],
202
+ inputs=[fg_image, bg_image, num_inference_steps],
203
+ outputs=[output_slider],
204
+ run_on_click=True,
205
+ )
206
+
207
+ gr.Markdown("**Disclaimer:**")
208
+ gr.Markdown(
209
+ "This demo is only for research purpose. Jasper cannot be held responsible for the generation of NSFW (Not Safe For Work) content through the use of this demo. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards. Jasper provides the tools, but the responsibility for their use lies with the individual user."
210
+ )
211
+ gr.Markdown("**Note:** Some backgrounds example are taken from [IC-Light repo](https://github.com/lllyasviel/IC-Light)")
212
+
213
+ def bg_gallery_selected(gal, evt: gr.SelectData):
214
+ print(gal, evt.index)
215
+ return gal[evt.index][0]
216
+
217
+ bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=bg_image)
218
+
219
+ if __name__ == "__main__":
220
+
221
+ demo.queue().launch(show_api=False)
222
+