aiqtech commited on
Commit
9335941
·
verified ·
1 Parent(s): 90b7eda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +475 -296
app.py CHANGED
@@ -1,337 +1,516 @@
1
- import os
2
- import random
3
  import numpy as np
 
4
  import torch
5
- import gradio as gr
6
- from PIL import Image
7
-
8
  import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- from huggingface_hub import snapshot_download, login
11
-
12
- from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
13
- from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import StableDiffusionXLPipeline
14
- from kolors.models.modeling_chatglm import ChatGLMModel
15
- from kolors.models.tokenization_chatglm import ChatGLMTokenizer
16
- from kolors.models.unet_2d_condition import UNet2DConditionModel
17
- from diffusers import AutoencoderKL, EulerDiscreteScheduler
18
-
19
-
20
- # ============= Runtime & Auth =============
21
- HF_TOKEN = os.getenv("HF_TOKEN")
22
- if HF_TOKEN:
23
- login(token=HF_TOKEN)
24
- print("Successfully logged in to Hugging Face Hub")
25
-
26
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
- DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
28
- print(f"Device: {DEVICE}, DType: {DTYPE}")
29
-
30
- # ============= Weights =============
31
- # 원본 코드 구조를 따르되, snapshot_download 경로를 그대로 활용합니다.
32
- print("Downloading models...")
33
- ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors", token=HF_TOKEN)
34
- ckpt_dir_ip = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus", token=HF_TOKEN)
35
-
36
- # ============= Load Models (IP-Adapter, not FaceID) =============
37
- # CPU에서는 fp16이 NaN을 유발할 수 있으므로 DTYPE로 통일
38
- text_encoder = ChatGLMModel.from_pretrained(
39
- f"{ckpt_dir}/text_encoder",
40
- torch_dtype=DTYPE,
41
- trust_remote_code=True
42
- )
43
- tokenizer = ChatGLMTokenizer.from_pretrained(
44
- f"{ckpt_dir}/text_encoder",
45
- trust_remote_code=True
46
- )
47
- vae = AutoencoderKL.from_pretrained(
48
- f"{ckpt_dir}/vae",
49
- revision=None,
50
- torch_dtype=DTYPE
51
- )
52
- scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
53
- unet = UNet2DConditionModel.from_pretrained(
54
- f"{ckpt_dir}/unet",
55
- revision=None,
56
- torch_dtype=DTYPE
57
- )
58
-
59
- # CLIP image encoder for IP-Adapter-Plus
60
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
61
- f"{ckpt_dir_ip}/image_encoder",
62
- ignore_mismatched_sizes=True
63
- ).to(dtype=DTYPE, device=DEVICE)
64
-
65
- ip_img_size = 336
66
- clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
67
-
68
- # StableDiffusionXL pipeline with IP-Adapter (image reference)
69
- pipe = StableDiffusionXLPipeline(
70
- vae=vae,
71
- text_encoder=text_encoder,
72
- tokenizer=tokenizer,
73
- unet=unet,
74
- scheduler=scheduler,
75
- image_encoder=image_encoder,
76
- feature_extractor=clip_image_processor,
77
- force_zeros_for_empty_prompt=False
78
- )
79
-
80
- # Move core modules to device/dtype
81
- pipe.vae = pipe.vae.to(DEVICE, dtype=DTYPE)
82
- pipe.text_encoder = pipe.text_encoder.to(DEVICE, dtype=DTYPE)
83
- pipe.unet = pipe.unet.to(DEVICE, dtype=DTYPE)
84
-
85
- # kolors unet 호환 처리
86
- if hasattr(pipe.unet, "encoder_hid_proj"):
87
- pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
88
 
89
- # Load IP-Adapter weights (general)
90
- pipe.load_ip_adapter(
91
- f"{ckpt_dir_ip}",
92
- subfolder="",
93
- weight_name=["ip_adapter_plus_general.bin"]
94
  )
 
95
 
96
- MAX_SEED = np.iinfo(np.int32).max
97
- MAX_IMAGE_SIZE = 1024
98
-
99
-
100
- def _to_multiple_of_8(x: int) -> int:
101
- return int(x // 8 * 8)
102
-
103
-
104
- def _ensure_even(x: int) -> int:
105
- return x if x % 2 == 0 else x - 1
106
-
107
-
108
- def _prepare_dims(width: int, height: int) -> tuple[int, int]:
109
- # SDXL 권장: 8의 배수 해상도
110
- w = _to_multiple_of_8(width)
111
- h = _to_multiple_of_8(height)
112
- # H.264 등과의 호환성을 고려해 짝수 유지(선택)
113
- w = _ensure_even(w)
114
- h = _ensure_even(h)
115
- return max(256, min(MAX_IMAGE_SIZE, w)), max(256, min(MAX_IMAGE_SIZE, h))
116
-
117
-
118
- def _move_to_device():
119
- # 호출 시점에 안전하게 보장
120
- global pipe, image_encoder
121
- pipe.vae = pipe.vae.to(DEVICE, dtype=DTYPE)
122
- pipe.text_encoder = pipe.text_encoder.to(DEVICE, dtype=DTYPE)
123
- pipe.unet = pipe.unet.to(DEVICE, dtype=DTYPE)
124
- image_encoder = image_encoder.to(device=DEVICE, dtype=DTYPE)
125
- pipe.image_encoder = image_encoder
126
-
127
-
128
- def _generate(
129
- prompt: str,
130
- ip_adapter_image: Image.Image,
131
- ip_adapter_scale: float,
132
- negative_prompt: str,
133
- seed: int,
134
- width: int,
135
- height: int,
136
- guidance_scale: float,
137
- num_inference_steps: int,
138
- ):
139
- _move_to_device()
140
- pipe.set_ip_adapter_scale([ip_adapter_scale])
141
-
142
- # 해상도 정규화
143
- width, height = _prepare_dims(width, height)
144
-
145
- generator = torch.Generator(device=DEVICE).manual_seed(seed)
146
-
147
- with torch.no_grad():
148
- if DEVICE == "cuda":
149
- with torch.autocast(device_type="cuda", dtype=torch.float16):
150
- images = pipe(
151
- prompt=prompt,
152
- ip_adapter_image=[ip_adapter_image],
153
- negative_prompt=negative_prompt,
154
- height=height,
155
- width=width,
156
- num_inference_steps=int(num_inference_steps),
157
- guidance_scale=float(guidance_scale),
158
- num_images_per_prompt=1,
159
- generator=generator,
160
- ).images
161
- else:
162
- images = pipe(
163
- prompt=prompt,
164
- ip_adapter_image=[ip_adapter_image],
165
- negative_prompt=negative_prompt,
166
- height=height,
167
- width=width,
168
- num_inference_steps=int(num_inference_steps),
169
- guidance_scale=float(guidance_scale),
170
- num_images_per_prompt=1,
171
- generator=generator,
172
- ).images
173
-
174
- return images[0]
175
-
176
 
177
- # Spaces GPU 스케줄러는 CUDA가 있을 때만 감쌉니다.
178
- def _infer_core(
 
179
  prompt,
180
- ip_adapter_image,
181
- ip_adapter_scale=0.5,
182
- negative_prompt="",
183
- seed=100,
184
  randomize_seed=False,
185
- width=1024,
186
- height=1024,
187
- guidance_scale=5.0,
188
- num_inference_steps=50,
 
189
  progress=gr.Progress(track_tqdm=True),
190
  ):
191
- if ip_adapter_image is None:
192
- gr.Warning("Please upload an IP-Adapter reference image.")
193
- return None, 0
194
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  if randomize_seed:
196
  seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- image = _generate(
199
- prompt=prompt or "",
200
- ip_adapter_image=ip_adapter_image,
201
- ip_adapter_scale=float(ip_adapter_scale),
202
- negative_prompt=negative_prompt or "",
203
- seed=int(seed),
204
- width=int(width),
205
- height=int(height),
206
- guidance_scale=float(guidance_scale),
207
- num_inference_steps=int(num_inference_steps),
208
- )
209
- return image, seed
210
-
211
-
212
- if torch.cuda.is_available():
213
- infer = spaces.GPU(duration=80)(_infer_core)
214
- else:
215
- infer = _infer_core
216
-
217
 
218
  examples = [
219
- ["A dog", "minta.jpeg", 0.4],
220
- ["A capybara", "king-min.png", 0.5],
221
- ["A cat", "blue_hair.png", 0.5],
222
- ["", "meow.jpeg", 1.0],
 
223
  ]
224
 
225
- css = """
226
- #col-container {
227
- margin: 0 auto;
228
- max-width: 720px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  }
230
- #result img{
231
- object-position: top;
 
232
  }
233
- #result .image-container{
234
- height: 100%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  }
236
  """
237
 
238
- with gr.Blocks(css=css) as demo:
239
- with gr.Column(elem_id="col-container"):
240
- gr.Markdown("# Kolors IP-Adapter - image reference and variations")
241
 
242
- with gr.Row():
243
- prompt = gr.Text(
244
- label="Prompt",
245
- show_label=False,
246
- max_lines=1,
247
- placeholder="Enter your prompt",
248
- container=False,
249
- )
250
- run_button = gr.Button("Run", scale=0)
251
-
252
- with gr.Row():
253
- with gr.Column():
254
- ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil")
255
- ip_adapter_scale = gr.Slider(
256
- label="Image influence scale",
257
- info="Use 1 for creating variations",
258
- minimum=0.0,
259
- maximum=1.0,
260
- step=0.05,
261
- value=0.5,
 
262
  )
263
- result = gr.Image(label="Result", elem_id="result")
264
-
265
- with gr.Accordion("Advanced Settings", open=False):
266
- negative_prompt = gr.Text(
267
- label="Negative prompt",
268
- max_lines=1,
269
- placeholder="Enter a negative prompt",
270
- )
271
- seed = gr.Slider(
272
- label="Seed",
273
- minimum=0,
274
- maximum=MAX_SEED,
275
- step=1,
276
- value=0,
277
- )
278
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
279
- with gr.Row():
280
- width = gr.Slider(
281
- label="Width",
282
- minimum=256,
283
- maximum=MAX_IMAGE_SIZE,
284
- step=32,
285
- value=1024,
286
  )
287
- height = gr.Slider(
288
- label="Height",
289
- minimum=256,
290
- maximum=MAX_IMAGE_SIZE,
291
- step=32,
292
- value=1024,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  )
294
- with gr.Row():
295
- guidance_scale = gr.Slider(
296
- label="Guidance scale",
297
- minimum=0.0,
298
- maximum=10.0,
299
- step=0.1,
300
- value=5.0,
301
  )
302
- num_inference_steps = gr.Slider(
303
- label="Number of inference steps",
304
- minimum=1,
305
- maximum=100,
306
- step=1,
307
- value=25,
 
 
308
  )
309
-
310
- # 파일 예시가 로컬에 없을 수 있어 cache_examples="lazy" 유지
311
- gr.Examples(
312
- examples=examples,
313
- fn=infer,
314
- inputs=[prompt, ip_adapter_image, ip_adapter_scale],
315
- outputs=[result, seed],
316
- cache_examples="lazy",
317
- )
318
-
319
  gr.on(
320
  triggers=[run_button.click, prompt.submit],
321
  fn=infer,
322
  inputs=[
 
323
  prompt,
324
- ip_adapter_image,
325
- ip_adapter_scale,
326
- negative_prompt,
327
  seed,
328
  randomize_seed,
329
- width,
330
- height,
331
- guidance_scale,
332
  num_inference_steps,
 
 
 
333
  ],
334
- outputs=[result, seed],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  )
336
 
337
- demo.queue().launch()
 
 
1
+ import gradio as gr
 
2
  import numpy as np
3
+ import random
4
  import torch
 
 
 
5
  import spaces
6
+ from PIL import Image
7
+ from diffusers import QwenImageEditPipeline
8
+ from diffusers.utils import is_xformers_available
9
+ import os
10
+ import base64
11
+ import json
12
+ from huggingface_hub import InferenceClient
13
+ import logging
14
+
15
+ #############################
16
+ os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
17
+ os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1')
18
+ logging.basicConfig(level=logging.DEBUG)
19
+ logger = logging.getLogger(__name__)
20
+ #############################
21
+
22
+ def get_caption_language(prompt):
23
+ """Detects if the prompt contains Chinese characters."""
24
+ ranges = [
25
+ ('\u4e00', '\u9fff'), # CJK Unified Ideographs
26
+ ]
27
+ for char in prompt:
28
+ if any(start <= char <= end for start, end in ranges):
29
+ return 'zh'
30
+ return 'en'
31
+
32
+ def polish_prompt(original_prompt, system_prompt, hf_token):
33
+ """
34
+ Rewrites the prompt using a Hugging Face InferenceClient.
35
+ Requires user-provided HF token for API access.
36
+ """
37
+ if not hf_token or not hf_token.strip():
38
+ gr.Warning("HF Token is required for prompt rewriting but was not provided!")
39
+ return original_prompt
40
+ client = InferenceClient(
41
+ provider="cerebras",
42
+ api_key=hf_token,
43
+ )
44
+ messages = [
45
+ {"role": "system", "content": system_prompt},
46
+ {"role": "user", "content": original_prompt}
47
+ ]
48
+ try:
49
+ completion = client.chat.completions.create(
50
+ model="Qwen/Qwen3-235B-A22B-Instruct-2507",
51
+ messages=messages,
52
+ max_tokens=512,
53
+ )
54
+ polished_prompt = completion.choices[0].message.content
55
+ polished_prompt = polished_prompt.strip().replace("\n", " ")
56
+ return polished_prompt
57
+ except Exception as e:
58
+ print(f"Error during Hugging Face API call: {e}")
59
+ gr.Warning("Failed to rewrite prompt. Using original.")
60
+ return original_prompt
61
+
62
+ SYSTEM_PROMPT_EDIT = '''
63
+ # Edit Instruction Rewriter
64
+ You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable instruction based on the user's intent and the input image.
65
+ ## 1. General Principles
66
+ - Keep the rewritten instruction **concise** and clear.
67
+ - Avoid contradictions, vagueness, or unachievable instructions.
68
+ - Maintain the core logic of the original instruction; only enhance clarity and feasibility.
69
+ - Ensure new added elements or modifications align with the image's original context and art style.
70
+ ## 2. Task Types
71
+ ### Add, Delete, Replace:
72
+ - When the input is detailed, only refine grammar and clarity.
73
+ - For vague instructions, infer minimal but sufficient details.
74
+ - For replacement, use the format: `"Replace X with Y"`.
75
+ ### Text Editing (e.g., text replacement):
76
+ - Enclose text content in quotes, e.g., `Replace "abc" with "xyz"`.
77
+ - Preserving the original structure and language—**do not translate** or alter style.
78
+ ### Human Editing (e.g., change a person's face/hair):
79
+ - Preserve core visual identity (gender, ethnic features).
80
+ - Describe expressions in subtle and natural terms.
81
+ - Maintain key clothing or styling details unless explicitly replaced.
82
+ ### Style Transformation:
83
+ - If a style is specified, e.g., `Disco style`, rewrite it to encapsulate the essential visual traits.
84
+ - Use a fixed template for **coloring/restoration**:
85
+ `"Restore old photograph, remove scratches, reduce noise, enhance details, high resolution, realistic, natural skin tones, clear facial features, no distortion, vintage photo restoration"`
86
+ if applicable.
87
+ ## 4. Output Format
88
+ Please provide the rewritten instruction in a clean `json` format as:
89
+ {
90
+ "Rewritten": "..."
91
+ }
92
+ '''
93
 
94
+ dtype = torch.bfloat16
95
+ device = "cuda" if torch.cuda.is_available() else "cpu"
96
+ pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=dtype).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ # Load LoRA weights for acceleration
99
+ pipe.load_lora_weights(
100
+ "lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors"
 
 
101
  )
102
+ pipe.fuse_lora()
103
 
104
+ if is_xformers_available():
105
+ pipe.enable_xformers_memory_efficient_attention()
106
+ else:
107
+ print("xformers not available or failed to load.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ @spaces.GPU(duration=60)
110
+ def infer(
111
+ image,
112
  prompt,
113
+ seed=42,
 
 
 
114
  randomize_seed=False,
115
+ true_guidance_scale=1.0,
116
+ num_inference_steps=8,
117
+ rewrite_prompt=False,
118
+ hf_token="",
119
+ num_images_per_prompt=1,
120
  progress=gr.Progress(track_tqdm=True),
121
  ):
122
+ """
123
+ Requires user-provided HF token for prompt rewriting.
124
+ """
125
+ original_prompt = prompt # Save original prompt for display
126
+ negative_prompt = " "
127
+ prompt_info = "" # Initialize info text
128
+
129
+ # Handle prompt rewriting with status messages
130
+ if rewrite_prompt:
131
+ if not hf_token.strip():
132
+ gr.Warning("HF Token is required for prompt rewriting but was not provided!")
133
+ prompt_info = f"""<div class="prompt-info-box warning">
134
+ <h3>⚠️ Prompt Rewriting Skipped</h3>
135
+ <p><strong>Original:</strong> {original_prompt}</p>
136
+ <p class="note">HF Token required for enhancement</p>
137
+ </div>"""
138
+ rewritten_prompt = original_prompt
139
+ else:
140
+ try:
141
+ rewritten_prompt = polish_prompt(original_prompt, SYSTEM_PROMPT_EDIT, hf_token)
142
+ prompt_info = f"""<div class="prompt-info-box success">
143
+ <h3>✨ Enhanced Successfully</h3>
144
+ <p><strong>Original:</strong> {original_prompt}</p>
145
+ <p><strong>Enhanced:</strong> {rewritten_prompt}</p>
146
+ </div>"""
147
+ except Exception as e:
148
+ gr.Warning(f"Prompt rewriting failed: {str(e)}")
149
+ rewritten_prompt = original_prompt
150
+ prompt_info = f"""<div class="prompt-info-box error">
151
+ <h3>❌ Enhancement Failed</h3>
152
+ <p><strong>Original:</strong> {original_prompt}</p>
153
+ <p class="note">Error: {str(e)}</p>
154
+ </div>"""
155
+ else:
156
+ rewritten_prompt = original_prompt
157
+ prompt_info = f"""<div class="prompt-info-box default">
158
+ <h3>📝 Original Prompt</h3>
159
+ <p>{original_prompt}</p>
160
+ </div>"""
161
+
162
+ # Generate images
163
  if randomize_seed:
164
  seed = random.randint(0, MAX_SEED)
165
+ generator = torch.Generator(device=device).manual_seed(seed)
166
+
167
+ edited_images = pipe(
168
+ image,
169
+ prompt=rewritten_prompt,
170
+ negative_prompt=negative_prompt,
171
+ num_inference_steps=num_inference_steps,
172
+ generator=generator,
173
+ true_cfg_scale=true_guidance_scale,
174
+ num_images_per_prompt=num_images_per_prompt,
175
+ ).images
176
+
177
+ return edited_images, seed, prompt_info
178
 
179
+ MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  examples = [
182
+ "Replace the cat with a friendly golden retriever. Make it look happier, and add more background details.",
183
+ "Add text 'Qwen - AI for image editing' in Chinese at the bottom center with a small shadow.",
184
+ "Change the style to 1970s vintage, add old photo effect, restore any scratches on the wall or window.",
185
+ "Remove the blue sky and replace it with a dark night cityscape.",
186
+ """Replace "Qwen" with "通义" in the Image. Ensure Chinese font is used for "通义" and position it to the top left with a light heading-style font."""
187
  ]
188
 
189
+ # Custom CSS for enhanced visual design
190
+ custom_css = """
191
+ /* Gradient background */
192
+ .gradio-container {
193
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 25%, #f093fb 50%, #fecfef 75%, #fecfef 100%);
194
+ min-height: 100vh;
195
+ }
196
+ /* Main container styling */
197
+ .container {
198
+ max-width: 1400px !important;
199
+ margin: 0 auto !important;
200
+ padding: 2rem !important;
201
+ }
202
+ /* Card-like sections */
203
+ .gr-box {
204
+ background: rgba(255, 255, 255, 0.95) !important;
205
+ backdrop-filter: blur(10px) !important;
206
+ border-radius: 20px !important;
207
+ box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1) !important;
208
+ border: 1px solid rgba(255, 255, 255, 0.5) !important;
209
+ padding: 1.5rem !important;
210
+ margin-bottom: 1.5rem !important;
211
+ }
212
+ /* Header styling */
213
+ h1 {
214
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
215
+ -webkit-background-clip: text;
216
+ -webkit-text-fill-color: transparent;
217
+ background-clip: text;
218
+ font-size: 3rem !important;
219
+ font-weight: 800 !important;
220
+ text-align: center;
221
+ margin-bottom: 0.5rem !important;
222
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
223
+ }
224
+ h2 {
225
+ color: #4a5568 !important;
226
+ font-size: 1.5rem !important;
227
+ font-weight: 600 !important;
228
+ margin-bottom: 1rem !important;
229
+ }
230
+ /* Button styling */
231
+ .gr-button-primary {
232
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
233
+ border: none !important;
234
+ color: white !important;
235
+ font-weight: 600 !important;
236
+ font-size: 1.1rem !important;
237
+ padding: 0.8rem 2rem !important;
238
+ border-radius: 12px !important;
239
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
240
+ transition: all 0.3s ease !important;
241
+ }
242
+ .gr-button-primary:hover {
243
+ transform: translateY(-2px) !important;
244
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.5) !important;
245
+ }
246
+ /* Input fields styling */
247
+ .gr-input, .gr-text-input, .gr-slider, .gr-dropdown {
248
+ border-radius: 10px !important;
249
+ border: 2px solid #e2e8f0 !important;
250
+ background: white !important;
251
+ transition: all 0.3s ease !important;
252
+ }
253
+ .gr-input:focus, .gr-text-input:focus {
254
+ border-color: #667eea !important;
255
+ box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important;
256
+ }
257
+ /* Accordion styling */
258
+ .gr-accordion {
259
+ background: rgba(255, 255, 255, 0.8) !important;
260
+ border-radius: 12px !important;
261
+ border: 1px solid rgba(102, 126, 234, 0.2) !important;
262
+ overflow: hidden !important;
263
+ }
264
+ /* Gallery styling */
265
+ .gr-gallery {
266
+ border-radius: 12px !important;
267
+ overflow: hidden !important;
268
+ }
269
+ /* Prompt info boxes */
270
+ .prompt-info-box {
271
+ padding: 1.5rem;
272
+ border-radius: 12px;
273
+ margin: 1rem 0;
274
+ animation: fadeIn 0.5s ease;
275
+ }
276
+ .prompt-info-box h3 {
277
+ margin: 0 0 0.75rem 0;
278
+ font-size: 1.2rem;
279
+ font-weight: 600;
280
+ }
281
+ .prompt-info-box p {
282
+ margin: 0.5rem 0;
283
+ line-height: 1.6;
284
+ }
285
+ .prompt-info-box.success {
286
+ background: linear-gradient(135deg, #d4f4dd 0%, #e3f9e5 100%);
287
+ border-left: 4px solid #48bb78;
288
+ }
289
+ .prompt-info-box.warning {
290
+ background: linear-gradient(135deg, #fef5e7 0%, #fff9ec 100%);
291
+ border-left: 4px solid #f6ad55;
292
+ }
293
+ .prompt-info-box.error {
294
+ background: linear-gradient(135deg, #fed7d7 0%, #fee5e5 100%);
295
+ border-left: 4px solid #fc8181;
296
  }
297
+ .prompt-info-box.default {
298
+ background: linear-gradient(135deg, #e6f3ff 0%, #f0f7ff 100%);
299
+ border-left: 4px solid #667eea;
300
  }
301
+ .prompt-info-box .note {
302
+ font-size: 0.9rem;
303
+ color: #718096;
304
+ font-style: italic;
305
+ }
306
+ /* Checkbox styling */
307
+ .gr-checkbox {
308
+ background: white !important;
309
+ border-radius: 8px !important;
310
+ padding: 0.5rem !important;
311
+ }
312
+ /* Token input field */
313
+ input[type="password"] {
314
+ font-family: monospace !important;
315
+ letter-spacing: 0.05em !important;
316
+ }
317
+ /* Info badges */
318
+ .gr-markdown p {
319
+ color: #4a5568;
320
+ line-height: 1.6;
321
+ }
322
+ .gr-markdown a {
323
+ color: #667eea !important;
324
+ text-decoration: none !important;
325
+ font-weight: 500 !important;
326
+ transition: color 0.3s ease !important;
327
+ }
328
+ .gr-markdown a:hover {
329
+ color: #764ba2 !important;
330
+ text-decoration: underline !important;
331
+ }
332
+ /* Animation */
333
+ @keyframes fadeIn {
334
+ from {
335
+ opacity: 0;
336
+ transform: translateY(10px);
337
+ }
338
+ to {
339
+ opacity: 1;
340
+ transform: translateY(0);
341
+ }
342
+ }
343
+ /* Slider styling */
344
+ .gr-slider input[type="range"] {
345
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%) !important;
346
+ }
347
+ /* Group styling */
348
+ .gr-group {
349
+ background: rgba(249, 250, 251, 0.8) !important;
350
+ border-radius: 12px !important;
351
+ padding: 1rem !important;
352
+ margin-top: 1rem !important;
353
+ }
354
+ /* Loading spinner customization */
355
+ .gr-loading {
356
+ color: #667eea !important;
357
+ }
358
+ /* Example buttons */
359
+ .gr-examples button {
360
+ background: white !important;
361
+ border: 2px solid #e2e8f0 !important;
362
+ border-radius: 8px !important;
363
+ padding: 0.5rem 1rem !important;
364
+ transition: all 0.3s ease !important;
365
+ }
366
+ .gr-examples button:hover {
367
+ border-color: #667eea !important;
368
+ background: rgba(102, 126, 234, 0.05) !important;
369
  }
370
  """
371
 
 
 
 
372
 
373
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
374
+ gr.Markdown("# 🎨 Nano-Banana")
375
+ gr.Markdown("✨ **Ultra-fast 8-step image editing with AI-powered prompt enhancement**")
376
+ gr.Markdown("🔐 **Secure prompt rewriting with your [Hugging Face token](https://huggingface.co/settings/tokens)**")
377
+
378
+ # 배지를 가운데 정렬하여 나란히 배치
379
+ gr.HTML("""
380
+ <div style="display: flex; justify-content: center; align-items: center; gap: 20px; margin: 20px 0;">
381
+ <a href="https://huggingface.co/spaces/Heartsync/Nano-Banana" target="_blank">
382
+ <img src="https://img.shields.io/static/v1?label=OPEN%20NANO-BANANA&message=Image%20EDITOR&color=%230000ff&labelColor=%23800080&logo=huggingface&logoColor=white&style=for-the-badge" alt="badge">
383
+ </a>
384
+ </div>
385
+ """)
386
+
387
+ with gr.Row():
388
+ with gr.Column(scale=1):
389
+ with gr.Group():
390
+ input_image = gr.Image(
391
+ label="📸 Input Image",
392
+ type="pil",
393
+ elem_classes="gr-box"
394
  )
395
+ prompt = gr.Text(
396
+ label="✏️ Edit Instruction",
397
+ placeholder="e.g. Add a dog to the right side, change the sky to sunset...",
398
+ lines=3,
399
+ elem_classes="gr-box"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  )
401
+
402
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
403
+ seed = gr.Slider(
404
+ label="Seed",
405
+ minimum=0,
406
+ maximum=MAX_SEED,
407
+ step=1,
408
+ value=0
409
+ )
410
+ randomize_seed = gr.Checkbox(label="🎲 Randomize Seed", value=True)
411
+
412
+ with gr.Row():
413
+ true_guidance_scale = gr.Slider(
414
+ label="Guidance Scale",
415
+ minimum=1.0,
416
+ maximum=5.0,
417
+ step=0.1,
418
+ value=4.0
419
+ )
420
+ num_inference_steps = gr.Slider(
421
+ label="Inference Steps",
422
+ minimum=4,
423
+ maximum=16,
424
+ step=1,
425
+ value=8
426
+ )
427
+
428
+ num_images_per_prompt = gr.Slider(
429
+ label="Images per Prompt",
430
+ minimum=1,
431
+ maximum=4,
432
+ step=1,
433
+ value=1
434
+ )
435
+
436
+ run_button = gr.Button("🚀 Generate Edit", variant="primary", size="lg")
437
+
438
+ with gr.Column(scale=1):
439
+ result = gr.Gallery(
440
+ label="🖼️ Output Images",
441
+ show_label=True,
442
+ columns=2,
443
+ rows=2,
444
+ elem_classes="gr-box"
445
+ )
446
+
447
+ # Prompt display component
448
+ prompt_info = gr.HTML(visible=False)
449
+
450
+ with gr.Group():
451
+ rewrite_toggle = gr.Checkbox(
452
+ label="🤖 Enable AI Prompt Enhancement",
453
+ value=False,
454
+ interactive=True
455
  )
456
+ hf_token_input = gr.Textbox(
457
+ label="🔑 Hugging Face API Token",
458
+ type="password",
459
+ placeholder="hf_xxxxxxxxxxxxxxxx",
460
+ visible=False,
461
+ info="Your token is secure and only used for API calls. Get yours from HuggingFace settings.",
462
+ elem_classes="gr-box"
463
  )
464
+
465
+ def toggle_token_visibility(checked):
466
+ return gr.update(visible=checked)
467
+
468
+ rewrite_toggle.change(
469
+ toggle_token_visibility,
470
+ inputs=[rewrite_toggle],
471
+ outputs=[hf_token_input]
472
  )
473
+
474
+ # Examples section
475
+ gr.Examples(
476
+ examples=examples,
477
+ inputs=prompt,
478
+ label="💡 Example Prompts"
479
+ )
480
+
 
 
481
  gr.on(
482
  triggers=[run_button.click, prompt.submit],
483
  fn=infer,
484
  inputs=[
485
+ input_image,
486
  prompt,
 
 
 
487
  seed,
488
  randomize_seed,
489
+ true_guidance_scale,
 
 
490
  num_inference_steps,
491
+ rewrite_toggle,
492
+ hf_token_input,
493
+ num_images_per_prompt
494
  ],
495
+ outputs=[result, seed, prompt_info]
496
+ )
497
+
498
+ # Show prompt info box after processing
499
+ def set_prompt_visible():
500
+ return gr.update(visible=True)
501
+
502
+ run_button.click(
503
+ fn=set_prompt_visible,
504
+ inputs=None,
505
+ outputs=[prompt_info],
506
+ queue=False
507
+ )
508
+ prompt.submit(
509
+ fn=set_prompt_visible,
510
+ inputs=None,
511
+ outputs=[prompt_info],
512
+ queue=False
513
  )
514
 
515
+ if __name__ == "__main__":
516
+ demo.launch()