Tanut commited on
Commit
a3657ef
·
1 Parent(s): 708ab1d
Files changed (1) hide show
  1. app.py +27 -32
app.py CHANGED
@@ -12,31 +12,14 @@ from diffusers import (
12
  DPMSolverMultistepScheduler,
13
  )
14
 
15
- # --- gradio_client bool-schema hotfix (prevents blank page on Spaces) ---
16
- try:
17
- import gradio_client.utils as _gcu
18
- _orig_get_type = _gcu.get_type
19
- def _get_type_safe(schema):
20
- if isinstance(schema, bool): # handle JSON Schema True/False
21
- return "any"
22
- return _orig_get_type(schema)
23
- _gcu.get_type = _get_type_safe
24
- except Exception:
25
- pass
26
- # -----------------------------------------------------------------------
27
-
28
  # Quiet matplotlib cache warning on Spaces
29
  os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl")
30
 
31
- # Token helper: add a Secret in your Space named HUGGINGFACE_HUB_TOKEN
32
- def _hf_auth():
33
- tok = os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
34
- return {"token": tok, "use_auth_token": tok} if tok else {}
35
 
36
  # ---- base models for the two tabs ----
37
  BASE_MODELS = {
38
  "stable-diffusion-v1-5": "runwayml/stable-diffusion-v1-5",
39
- "dream": "Lykon/dreamshaper-8",
40
  }
41
 
42
  # ControlNet (QR Monster v2 for SD15)
@@ -65,6 +48,9 @@ def normalize_color(c):
65
  return "white"
66
 
67
  def make_qr(url="https://example.com", size=768, border=12, back_color="#FFFFFF", blur_radius=0.0):
 
 
 
68
  qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=int(border))
69
  qr.add_data(url.strip()); qr.make(fit=True)
70
  img = qr.make_image(fill_color="black", back_color=normalize_color(back_color)).convert("RGB")
@@ -74,6 +60,7 @@ def make_qr(url="https://example.com", size=768, border=12, back_color="#FFFFFF"
74
  return img
75
 
76
  def enforce_qr_contrast(stylized: Image.Image, qr_img: Image.Image, strength: float = 0.0, feather: float = 1.0) -> Image.Image:
 
77
  if strength <= 0: return stylized
78
  q = qr_img.convert("L")
79
  black_mask = q.point(lambda p: 255 if p < 128 else 0).filter(ImageFilter.GaussianBlur(radius=float(feather)))
@@ -102,7 +89,7 @@ def _base_scheduler_for(pipe):
102
  def get_cn():
103
  global _CN
104
  if _CN is None:
105
- _CN = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True, **_hf_auth())
106
  return _CN
107
 
108
  def get_qrmon_txt2img_pipe(model_id: str):
@@ -114,7 +101,6 @@ def get_qrmon_txt2img_pipe(model_id: str):
114
  safety_checker=None,
115
  use_safetensors=True,
116
  low_cpu_mem_usage=True,
117
- **_hf_auth(),
118
  )
119
  _CN_TXT2IMG[model_id] = _base_scheduler_for(pipe)
120
  return _CN_TXT2IMG[model_id]
@@ -128,7 +114,6 @@ def get_qrmon_img2img_pipe(model_id: str):
128
  safety_checker=None,
129
  use_safetensors=True,
130
  low_cpu_mem_usage=True,
131
- **_hf_auth(),
132
  )
133
  _CN_IMG2IMG[model_id] = _base_scheduler_for(pipe)
134
  return _CN_IMG2IMG[model_id]
@@ -143,12 +128,15 @@ def _qr_txt2img_core(model_id: str,
143
 
144
  s = snap8(size)
145
 
 
146
  qr_img = make_qr(url=url, size=s, border=int(border), back_color="#FFFFFF", blur_radius=0.0)
147
 
 
148
  if int(seed) < 0:
149
  seed = random.randint(0, 2**31 - 1)
150
  gen = torch.Generator(device="cuda").manual_seed(int(seed))
151
 
 
152
  pipe = get_qrmon_txt2img_pipe(model_id)
153
  if torch.cuda.is_available(): torch.cuda.empty_cache()
154
  gc.collect()
@@ -156,8 +144,8 @@ def _qr_txt2img_core(model_id: str,
156
  out = pipe(
157
  prompt=str(style_prompt),
158
  negative_prompt=str(negative or ""),
159
- image=qr_img,
160
- controlnet_conditioning_scale=float(qr_weight),
161
  control_guidance_start=0.0,
162
  control_guidance_end=1.0,
163
  num_inference_steps=int(steps),
@@ -168,6 +156,7 @@ def _qr_txt2img_core(model_id: str,
168
  lowres = out.images[0]
169
  lowres = enforce_qr_contrast(lowres, qr_img, strength=float(repair_strength), feather=float(feather))
170
 
 
171
  final = lowres
172
  if use_hires:
173
  up = max(1.0, min(2.0, float(hires_upscale)))
@@ -179,9 +168,9 @@ def _qr_txt2img_core(model_id: str,
179
  out2 = pipe2(
180
  prompt=str(style_prompt),
181
  negative_prompt=str(negative or ""),
182
- image=lowres,
183
- control_image=qr_img,
184
- strength=float(hires_strength),
185
  controlnet_conditioning_scale=float(qr_weight),
186
  control_guidance_start=0.0,
187
  control_guidance_end=1.0,
@@ -195,6 +184,7 @@ def _qr_txt2img_core(model_id: str,
195
  final = enforce_qr_contrast(final, qr_img, strength=float(repair_strength), feather=float(feather))
196
  return final, lowres, qr_img
197
 
 
198
  @spaces.GPU(duration=120)
199
  def qr_txt2img_anything(*args):
200
  return _qr_txt2img_core(BASE_MODELS["stable-diffusion-v1-5"], *args)
@@ -207,6 +197,7 @@ def qr_txt2img_dream(*args):
207
  with gr.Blocks() as demo:
208
  gr.Markdown("# ZeroGPU • Method 1: QR Control (two base models)")
209
 
 
210
  with gr.Tab("stable-diffusion-v1-5"):
211
  url1 = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
212
  s_prompt1 = gr.Textbox(label="Style prompt", value="japanese painting, elegant shrine and torii, distant mount fuji, autumn maple trees, warm sunlight, 1girl in kimono, highly detailed, intricate patterns, anime key visual, dramatic composition")
@@ -217,14 +208,18 @@ with gr.Blocks() as demo:
217
  border1 = gr.Slider(2, 16, value=4, step=1, label="QR border (quiet zone)")
218
  qr_w1 = gr.Slider(0.6, 1.6, value=1.5, step=0.05, label="QR control weight")
219
  seed1 = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
 
220
  use_hires1 = gr.Checkbox(value=True, label="Hi-Res Fix (img2img upscale)")
221
  hires_up1 = gr.Slider(1.0, 2.0, value=2.0, step=0.25, label="Hi-Res upscale (×)")
222
  hires_str1 = gr.Slider(0.3, 0.9, value=0.7, step=0.05, label="Hi-Res denoise strength")
 
223
  repair1 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Post repair strength (optional)")
224
  feather1 = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)")
 
225
  final_img1 = gr.Image(label="Final (or Hi-Res) image")
226
  low_img1 = gr.Image(label="Low-res (Stage A) preview")
227
  ctrl_img1 = gr.Image(label="Control QR used")
 
228
  gr.Button("Generate with stable-diffusion-v1-5").click(
229
  qr_txt2img_anything,
230
  [url1, s_prompt1, s_negative1, steps1, cfg1, size1, border1, qr_w1, seed1,
@@ -232,6 +227,7 @@ with gr.Blocks() as demo:
232
  [final_img1, low_img1, ctrl_img1]
233
  )
234
 
 
235
  with gr.Tab("DreamShaper 8"):
236
  url2 = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
237
  s_prompt2 = gr.Textbox(label="Style prompt", value="ornate baroque palace interior, gilded details, chandeliers, volumetric light, ultra detailed, cinematic")
@@ -242,14 +238,18 @@ with gr.Blocks() as demo:
242
  border2 = gr.Slider(2, 16, value=8, step=1, label="QR border (quiet zone)")
243
  qr_w2 = gr.Slider(0.6, 1.6, value=1.5, step=0.05, label="QR control weight")
244
  seed2 = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
 
245
  use_hires2 = gr.Checkbox(value=True, label="Hi-Res Fix (img2img upscale)")
246
  hires_up2 = gr.Slider(1.0, 2.0, value=2.0, step=0.25, label="Hi-Res upscale (×)")
247
  hires_str2 = gr.Slider(0.3, 0.9, value=0.7, step=0.05, label="Hi-Res denoise strength")
 
248
  repair2 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Post repair strength (optional)")
249
  feather2 = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)")
 
250
  final_img2 = gr.Image(label="Final (or Hi-Res) image")
251
  low_img2 = gr.Image(label="Low-res (Stage A) preview")
252
  ctrl_img2 = gr.Image(label="Control QR used")
 
253
  gr.Button("Generate with DreamShaper 8").click(
254
  qr_txt2img_dream,
255
  [url2, s_prompt2, s_negative2, steps2, cfg2, size2, border2, qr_w2, seed2,
@@ -258,9 +258,4 @@ with gr.Blocks() as demo:
258
  )
259
 
260
  if __name__ == "__main__":
261
- # Keep launch simple on Spaces
262
- demo.queue(max_size=12).launch(
263
- server_name="0.0.0.0",
264
- server_port=int(os.environ.get("PORT", 7860)),
265
- show_error=True,
266
- )
 
12
  DPMSolverMultistepScheduler,
13
  )
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Quiet matplotlib cache warning on Spaces
16
  os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl")
17
 
 
 
 
 
18
 
19
  # ---- base models for the two tabs ----
20
  BASE_MODELS = {
21
  "stable-diffusion-v1-5": "runwayml/stable-diffusion-v1-5",
22
+ "dream": "Lykon/dreamshaper-8",
23
  }
24
 
25
  # ControlNet (QR Monster v2 for SD15)
 
48
  return "white"
49
 
50
  def make_qr(url="https://example.com", size=768, border=12, back_color="#FFFFFF", blur_radius=0.0):
51
+ """
52
+ IMPORTANT for Method 1: give ControlNet a sharp, black-on-WHITE QR (no blur).
53
+ """
54
  qr = qrcode.QRCode(version=None, error_correction=ERROR_CORRECT_H, box_size=10, border=int(border))
55
  qr.add_data(url.strip()); qr.make(fit=True)
56
  img = qr.make_image(fill_color="black", back_color=normalize_color(back_color)).convert("RGB")
 
60
  return img
61
 
62
  def enforce_qr_contrast(stylized: Image.Image, qr_img: Image.Image, strength: float = 0.0, feather: float = 1.0) -> Image.Image:
63
+ """Optional gentle repair. Default OFF for Method 1."""
64
  if strength <= 0: return stylized
65
  q = qr_img.convert("L")
66
  black_mask = q.point(lambda p: 255 if p < 128 else 0).filter(ImageFilter.GaussianBlur(radius=float(feather)))
 
89
  def get_cn():
90
  global _CN
91
  if _CN is None:
92
+ _CN = ControlNetModel.from_pretrained(CN_QRMON, torch_dtype=DTYPE, use_safetensors=True)
93
  return _CN
94
 
95
  def get_qrmon_txt2img_pipe(model_id: str):
 
101
  safety_checker=None,
102
  use_safetensors=True,
103
  low_cpu_mem_usage=True,
 
104
  )
105
  _CN_TXT2IMG[model_id] = _base_scheduler_for(pipe)
106
  return _CN_TXT2IMG[model_id]
 
114
  safety_checker=None,
115
  use_safetensors=True,
116
  low_cpu_mem_usage=True,
 
117
  )
118
  _CN_IMG2IMG[model_id] = _base_scheduler_for(pipe)
119
  return _CN_IMG2IMG[model_id]
 
128
 
129
  s = snap8(size)
130
 
131
+ # Control image: crisp black-on-white QR
132
  qr_img = make_qr(url=url, size=s, border=int(border), back_color="#FFFFFF", blur_radius=0.0)
133
 
134
+ # Seed / generator
135
  if int(seed) < 0:
136
  seed = random.randint(0, 2**31 - 1)
137
  gen = torch.Generator(device="cuda").manual_seed(int(seed))
138
 
139
+ # --- Stage A: txt2img with ControlNet
140
  pipe = get_qrmon_txt2img_pipe(model_id)
141
  if torch.cuda.is_available(): torch.cuda.empty_cache()
142
  gc.collect()
 
144
  out = pipe(
145
  prompt=str(style_prompt),
146
  negative_prompt=str(negative or ""),
147
+ image=qr_img, # control image for txt2img
148
+ controlnet_conditioning_scale=float(qr_weight), # ~1.0–1.2 works well
149
  control_guidance_start=0.0,
150
  control_guidance_end=1.0,
151
  num_inference_steps=int(steps),
 
156
  lowres = out.images[0]
157
  lowres = enforce_qr_contrast(lowres, qr_img, strength=float(repair_strength), feather=float(feather))
158
 
159
+ # --- Optional Stage B: Hi-Res Fix (img2img with same QR)
160
  final = lowres
161
  if use_hires:
162
  up = max(1.0, min(2.0, float(hires_upscale)))
 
168
  out2 = pipe2(
169
  prompt=str(style_prompt),
170
  negative_prompt=str(negative or ""),
171
+ image=lowres, # init image
172
+ control_image=qr_img, # same QR
173
+ strength=float(hires_strength), # ~0.7 like "Hires Fix"
174
  controlnet_conditioning_scale=float(qr_weight),
175
  control_guidance_start=0.0,
176
  control_guidance_end=1.0,
 
184
  final = enforce_qr_contrast(final, qr_img, strength=float(repair_strength), feather=float(feather))
185
  return final, lowres, qr_img
186
 
187
+ # Wrappers for each tab (so Gradio can bind without passing the model id)
188
  @spaces.GPU(duration=120)
189
  def qr_txt2img_anything(*args):
190
  return _qr_txt2img_core(BASE_MODELS["stable-diffusion-v1-5"], *args)
 
197
  with gr.Blocks() as demo:
198
  gr.Markdown("# ZeroGPU • Method 1: QR Control (two base models)")
199
 
200
+ # ---- Tab 1: stable-diffusion-v1-5 (anime/illustration) ----
201
  with gr.Tab("stable-diffusion-v1-5"):
202
  url1 = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
203
  s_prompt1 = gr.Textbox(label="Style prompt", value="japanese painting, elegant shrine and torii, distant mount fuji, autumn maple trees, warm sunlight, 1girl in kimono, highly detailed, intricate patterns, anime key visual, dramatic composition")
 
208
  border1 = gr.Slider(2, 16, value=4, step=1, label="QR border (quiet zone)")
209
  qr_w1 = gr.Slider(0.6, 1.6, value=1.5, step=0.05, label="QR control weight")
210
  seed1 = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
211
+
212
  use_hires1 = gr.Checkbox(value=True, label="Hi-Res Fix (img2img upscale)")
213
  hires_up1 = gr.Slider(1.0, 2.0, value=2.0, step=0.25, label="Hi-Res upscale (×)")
214
  hires_str1 = gr.Slider(0.3, 0.9, value=0.7, step=0.05, label="Hi-Res denoise strength")
215
+
216
  repair1 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Post repair strength (optional)")
217
  feather1 = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)")
218
+
219
  final_img1 = gr.Image(label="Final (or Hi-Res) image")
220
  low_img1 = gr.Image(label="Low-res (Stage A) preview")
221
  ctrl_img1 = gr.Image(label="Control QR used")
222
+
223
  gr.Button("Generate with stable-diffusion-v1-5").click(
224
  qr_txt2img_anything,
225
  [url1, s_prompt1, s_negative1, steps1, cfg1, size1, border1, qr_w1, seed1,
 
227
  [final_img1, low_img1, ctrl_img1]
228
  )
229
 
230
+ # ---- Tab 2: DreamShaper (general art/painterly) ----
231
  with gr.Tab("DreamShaper 8"):
232
  url2 = gr.Textbox(label="URL/Text", value="http://www.mybirdfire.com")
233
  s_prompt2 = gr.Textbox(label="Style prompt", value="ornate baroque palace interior, gilded details, chandeliers, volumetric light, ultra detailed, cinematic")
 
238
  border2 = gr.Slider(2, 16, value=8, step=1, label="QR border (quiet zone)")
239
  qr_w2 = gr.Slider(0.6, 1.6, value=1.5, step=0.05, label="QR control weight")
240
  seed2 = gr.Number(value=-1, precision=0, label="Seed (-1 random)")
241
+
242
  use_hires2 = gr.Checkbox(value=True, label="Hi-Res Fix (img2img upscale)")
243
  hires_up2 = gr.Slider(1.0, 2.0, value=2.0, step=0.25, label="Hi-Res upscale (×)")
244
  hires_str2 = gr.Slider(0.3, 0.9, value=0.7, step=0.05, label="Hi-Res denoise strength")
245
+
246
  repair2 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Post repair strength (optional)")
247
  feather2 = gr.Slider(0.0, 3.0, value=1.0, step=0.1, label="Repair feather (px)")
248
+
249
  final_img2 = gr.Image(label="Final (or Hi-Res) image")
250
  low_img2 = gr.Image(label="Low-res (Stage A) preview")
251
  ctrl_img2 = gr.Image(label="Control QR used")
252
+
253
  gr.Button("Generate with DreamShaper 8").click(
254
  qr_txt2img_dream,
255
  [url2, s_prompt2, s_negative2, steps2, cfg2, size2, border2, qr_w2, seed2,
 
258
  )
259
 
260
  if __name__ == "__main__":
261
+ demo.queue(max_size=12).launch()