fantaxy commited on
Commit
048ba10
ยท
verified ยท
1 Parent(s): 645ebcd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -15
app.py CHANGED
@@ -12,7 +12,19 @@ HF_TOKEN = os.getenv("HF_TOKEN")
12
  if not HF_TOKEN:
13
  raise ValueError("HF_TOKEN environment variable is not set")
14
 
15
- def query(prompt, model, custom_lora, is_negative=False, steps=35, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, width=1024, height=1024):
 
 
 
 
 
 
 
 
 
 
 
 
16
  print("Starting query function...")
17
 
18
  if not prompt:
@@ -235,18 +247,22 @@ def query(prompt, model, custom_lora, is_negative=False, steps=35, cfg_scale=7,
235
  else:
236
  API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
237
 
238
- # Prepare payload
 
239
  payload = {
240
  "inputs": prompt,
241
- "is_negative": is_negative,
242
- "steps": steps,
243
- "cfg_scale": cfg_scale,
244
- "seed": seed if seed != -1 else random.randint(1, 1000000000),
245
- "strength": strength,
246
  "parameters": {
 
 
 
247
  "width": width,
248
- "height": height
249
- }
 
 
 
 
 
250
  }
251
 
252
  # Improved retry logic with exponential backoff
@@ -256,20 +272,27 @@ def query(prompt, model, custom_lora, is_negative=False, steps=35, cfg_scale=7,
256
 
257
  while current_retry < max_retries:
258
  try:
259
- response = requests.post(API_URL, headers=headers, json=payload, timeout=180) # 3-minute timeout
260
- response.raise_for_status()
261
 
 
 
 
 
 
262
  image = Image.open(io.BytesIO(response.content))
 
263
  print(f'Generation {key} completed successfully')
264
  return image
265
 
266
- except (requests.exceptions.Timeout, requests.exceptions.ConnectionError,
267
- requests.exceptions.HTTPError, requests.exceptions.RequestException) as e:
 
 
268
  current_retry += 1
269
  if current_retry < max_retries:
270
  wait_time = backoff_factor ** current_retry # Exponential backoff
271
  print(f"Network error occurred: {str(e)}. Retrying in {wait_time} seconds... (Attempt {current_retry + 1}/{max_retries})")
272
- time.sleep(wait_time) # Add delay before retry
273
  continue
274
  else:
275
  # Detailed error message based on exception type
@@ -513,4 +536,4 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as dalle:
513
  dalle.load(fn=update_network_status, outputs=network_status)
514
 
515
  if __name__ == "__main__":
516
- dalle.launch(show_api=False, share=False)
 
12
  if not HF_TOKEN:
13
  raise ValueError("HF_TOKEN environment variable is not set")
14
 
15
+ def query(
16
+ prompt,
17
+ model,
18
+ custom_lora,
19
+ negative_prompt="", # โ† ๊ธฐ์กด is_negative=False โ†’ negative_prompt="" ๋กœ ๋ณ€๊ฒฝ
20
+ steps=35,
21
+ cfg_scale=7,
22
+ sampler="DPM++ 2M Karras",
23
+ seed=-1,
24
+ strength=0.7,
25
+ width=1024,
26
+ height=1024
27
+ ):
28
  print("Starting query function...")
29
 
30
  if not prompt:
 
247
  else:
248
  API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
249
 
250
+ # Prepare payload in Hugging Face Inference API style
251
+ # (negative_prompt, steps, cfg_scale, seed, strength ๋“ฑ์€ parameters ์•ˆ์— ๋ฐฐ์น˜)
252
  payload = {
253
  "inputs": prompt,
 
 
 
 
 
254
  "parameters": {
255
+ "negative_prompt": negative_prompt,
256
+ "num_inference_steps": steps,
257
+ "guidance_scale": cfg_scale,
258
  "width": width,
259
+ "height": height,
260
+ "strength": strength,
261
+ # seed๋ฅผ ์ง€์›ํ•˜๋Š” ๋ชจ๋ธ/์—”๋“œํฌ์ธํŠธ์— ๋”ฐ๋ผ ๋ฌด์‹œ๋  ์ˆ˜๋„ ์žˆ์Œ
262
+ "seed": seed if seed != -1 else random.randint(1, 1000000000),
263
+ },
264
+ # ๋ชจ๋ธ์ด ๋กœ๋”ฉ ์ค‘์ผ ๊ฒฝ์šฐ ๊ธฐ๋‹ค๋ฆฌ๋„๋ก ์„ค์ •
265
+ "options": {"wait_for_model": True}
266
  }
267
 
268
  # Improved retry logic with exponential backoff
 
272
 
273
  while current_retry < max_retries:
274
  try:
275
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=180)
 
276
 
277
+ # ๋””๋ฒ„๊น…์šฉ ์ •๋ณด ์ถœ๋ ฅ
278
+ print("Response Content-Type:", response.headers.get("Content-Type"))
279
+ print("Response Text (snippet):", response.text[:500])
280
+
281
+ response.raise_for_status() # HTTP ์—๋Ÿฌ ์ฝ”๋“œ ์‹œ ์˜ˆ์™ธ ๋ฐœ์ƒ
282
  image = Image.open(io.BytesIO(response.content))
283
+
284
  print(f'Generation {key} completed successfully')
285
  return image
286
 
287
+ except (requests.exceptions.Timeout,
288
+ requests.exceptions.ConnectionError,
289
+ requests.exceptions.HTTPError,
290
+ requests.exceptions.RequestException) as e:
291
  current_retry += 1
292
  if current_retry < max_retries:
293
  wait_time = backoff_factor ** current_retry # Exponential backoff
294
  print(f"Network error occurred: {str(e)}. Retrying in {wait_time} seconds... (Attempt {current_retry + 1}/{max_retries})")
295
+ time.sleep(wait_time)
296
  continue
297
  else:
298
  # Detailed error message based on exception type
 
536
  dalle.load(fn=update_network_status, outputs=network_status)
537
 
538
  if __name__ == "__main__":
539
+ dalle.launch(show_api=False, share=False)