Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,288 +1,329 @@
|
|
1 |
-
import spaces
|
2 |
-
import torch
|
3 |
import gradio as gr
|
4 |
-
|
5 |
-
from PIL import Image
|
6 |
import random
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
import user_history
|
23 |
-
from illusion_style import css
|
24 |
-
import os
|
25 |
-
from transformers import CLIPImageProcessor
|
26 |
-
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
27 |
-
|
28 |
-
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
|
29 |
-
|
30 |
-
# Initialize both pipelines
|
31 |
-
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
|
32 |
-
controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float16)
|
33 |
-
|
34 |
-
# Initialize the safety checker conditionally
|
35 |
-
SAFETY_CHECKER_ENABLED = os.environ.get("SAFETY_CHECKER", "0") == "1"
|
36 |
-
safety_checker = None
|
37 |
-
feature_extractor = None
|
38 |
-
if SAFETY_CHECKER_ENABLED:
|
39 |
-
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cuda")
|
40 |
-
feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
41 |
-
|
42 |
-
main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
43 |
-
BASE_MODEL,
|
44 |
-
controlnet=controlnet,
|
45 |
-
vae=vae,
|
46 |
-
safety_checker=safety_checker,
|
47 |
-
feature_extractor=feature_extractor,
|
48 |
-
torch_dtype=torch.float16,
|
49 |
-
).to("cuda")
|
50 |
-
|
51 |
-
# Function to check NSFW images
|
52 |
-
#def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
|
53 |
-
# if SAFETY_CHECKER_ENABLED:
|
54 |
-
# safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
|
55 |
-
# has_nsfw_concepts = safety_checker(
|
56 |
-
# images=[images],
|
57 |
-
# clip_input=safety_checker_input.pixel_values.to("cuda")
|
58 |
-
# )
|
59 |
-
# return images, has_nsfw_concepts
|
60 |
-
# else:
|
61 |
-
# return images, [False] * len(images)
|
62 |
-
|
63 |
-
#main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
64 |
-
#main_pipe.unet.to(memory_format=torch.channels_last)
|
65 |
-
#main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
66 |
-
#model_id = "stabilityai/sd-x2-latent-upscaler"
|
67 |
-
image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
|
68 |
-
|
69 |
-
|
70 |
-
#image_pipe.unet = torch.compile(image_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
71 |
-
#upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
72 |
-
#upscaler.to("cuda")
|
73 |
-
|
74 |
-
|
75 |
-
# Sampler map
|
76 |
-
SAMPLER_MAP = {
|
77 |
-
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
|
78 |
-
"Euler": lambda config: EulerDiscreteScheduler.from_config(config),
|
79 |
}
|
80 |
|
81 |
-
def
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
):
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
# Generate the initial image
|
159 |
-
#init_image = init_pipe(prompt).images[0]
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
control_guidance_start=float(control_guidance_start),
|
177 |
-
control_guidance_end=float(control_guidance_end),
|
178 |
-
num_inference_steps=15,
|
179 |
-
output_type="latent"
|
180 |
-
)
|
181 |
-
upscaled_latents = upscale(out, "nearest-exact", 2)
|
182 |
-
out_image = image_pipe(
|
183 |
-
prompt=prompt,
|
184 |
-
negative_prompt=negative_prompt,
|
185 |
-
control_image=control_image_large,
|
186 |
-
image=upscaled_latents,
|
187 |
-
guidance_scale=float(guidance_scale),
|
188 |
-
generator=generator,
|
189 |
-
num_inference_steps=20,
|
190 |
-
strength=upscaler_strength,
|
191 |
-
control_guidance_start=float(control_guidance_start),
|
192 |
-
control_guidance_end=float(control_guidance_end),
|
193 |
-
controlnet_conditioning_scale=float(controlnet_conditioning_scale)
|
194 |
-
)
|
195 |
-
end_time = time.time()
|
196 |
-
end_time_struct = time.localtime(end_time)
|
197 |
-
end_time_formatted = time.strftime("%H:%M:%S", end_time_struct)
|
198 |
-
print(f"Inference ended at {end_time_formatted}, taking {end_time-start_time}s")
|
199 |
-
|
200 |
-
# Save image + metadata
|
201 |
-
user_history.save_image(
|
202 |
-
label=prompt,
|
203 |
-
image=out_image["images"][0],
|
204 |
-
profile=profile,
|
205 |
-
metadata={
|
206 |
-
"prompt": prompt,
|
207 |
-
"negative_prompt": negative_prompt,
|
208 |
-
"guidance_scale": guidance_scale,
|
209 |
-
"controlnet_conditioning_scale": controlnet_conditioning_scale,
|
210 |
-
"control_guidance_start": control_guidance_start,
|
211 |
-
"control_guidance_end": control_guidance_end,
|
212 |
-
"upscaler_strength": upscaler_strength,
|
213 |
-
"seed": seed,
|
214 |
-
"sampler": sampler,
|
215 |
-
},
|
216 |
)
|
217 |
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
</div>
|
230 |
-
|
231 |
-
)
|
232 |
|
|
|
|
|
233 |
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
with gr.Accordion(
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
|
|
284 |
|
285 |
-
app_with_history.queue(max_size=20,api_open=False )
|
286 |
|
287 |
if __name__ == "__main__":
|
288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
|
|
3 |
import random
|
4 |
+
import httpx
|
5 |
+
import asyncio
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
from typing import Any
|
8 |
|
9 |
+
# 常量定义
|
10 |
+
HTTP_STATUS_CENSORED = 451
|
11 |
+
HTTP_STATUS_OK = 200
|
12 |
+
MAX_SEED = 2147483647 # (2**31 - 1)
|
13 |
+
MAX_IMAGE_SIZE = 2048
|
14 |
+
MIN_IMAGE_SIZE = 256 # Smallest dimension for SDXL like models often 512, but API might support smaller. Adjusted to API's limits.
|
15 |
+
|
16 |
+
# 调试模式
|
17 |
+
DEBUG_MODE = os.environ.get("DEBUG_MODE", "false").lower() == "true"
|
18 |
+
|
19 |
+
# 模型配置映射
|
20 |
+
MODEL_CONFIGS = {
|
21 |
+
"ep3": "ep3.pth",
|
22 |
+
"ep3latest": "ep3latest.pth"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
}
|
24 |
|
25 |
+
def validate_dimensions(width: int, height: int) -> tuple[int, int]:
|
26 |
+
"""验证并调整图片尺寸"""
|
27 |
+
width = max(MIN_IMAGE_SIZE, min(int(width), MAX_IMAGE_SIZE))
|
28 |
+
height = max(MIN_IMAGE_SIZE, min(int(height), MAX_IMAGE_SIZE))
|
29 |
+
width = (width // 32) * 32
|
30 |
+
height = (height // 32) * 32
|
31 |
+
return width, height
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class LuminaConfig:
|
35 |
+
"""Lumina模型配置"""
|
36 |
+
model_name: str | None = None
|
37 |
+
cfg: float | None = None
|
38 |
+
step: int | None = None
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class ImageGenerationConfig:
|
42 |
+
"""图像生成配置"""
|
43 |
+
prompts: list[dict[str, Any]] = field(default_factory=list)
|
44 |
+
width: int = 1024
|
45 |
+
height: int = 1024
|
46 |
+
seed: int | None = None
|
47 |
+
use_polish: bool = False # This wasn't exposed in UI, assuming false
|
48 |
+
is_lumina: bool = True
|
49 |
+
lumina_config: LuminaConfig = field(default_factory=LuminaConfig)
|
50 |
+
|
51 |
+
class ImageClient:
|
52 |
+
"""图像生成客户端"""
|
53 |
+
def __init__(self) -> None:
|
54 |
+
self.x_token = os.environ.get("API_TOKEN", "")
|
55 |
+
if not self.x_token:
|
56 |
+
print("Warning: API_TOKEN environment variable not set. Using a placeholder. API calls will likely fail.")
|
57 |
+
self.x_token = "YOUR_API_TOKEN_PLACEHOLDER" # Placeholder for app to load
|
58 |
+
|
59 |
+
self.lumina_api_url = "https://ops.api.talesofai.cn/v3/make_image"
|
60 |
+
self.lumina_task_status_url = "https://ops.api.talesofai.cn/v1/artifact/task/{task_uuid}"
|
61 |
+
self.max_polling_attempts = 100
|
62 |
+
self.polling_interval = 3.0
|
63 |
+
self.default_headers = {
|
64 |
+
"Content-Type": "application/json",
|
65 |
+
"x-platform": "nieta-app/web", # Or a generic identifier if preferred
|
66 |
+
"X-Token": self.x_token,
|
67 |
+
}
|
68 |
+
|
69 |
+
def _prepare_prompt_data(self, prompt: str, negative_prompt: str = "") -> list[dict[str, Any]]:
|
70 |
+
prompts_data = [{"type": "freetext", "value": prompt, "weight": 1.0}]
|
71 |
+
if negative_prompt:
|
72 |
+
prompts_data.append({"type": "freetext", "value": negative_prompt, "weight": -1.0})
|
73 |
+
prompts_data.append({
|
74 |
+
"type": "elementum", "value": "b5edccfe-46a2-4a14-a8ff-f4d430343805",
|
75 |
+
"uuid": "b5edccfe-46a2-4a14-a8ff-f4d430343805", "weight": 1.0, "name": "lumina1",
|
76 |
+
"img_url": "https://oss.talesofai.cn/picture_s/1y7f53e6itfn_0.jpeg",
|
77 |
+
"domain": "", "parent": "", "label": None, "sort_index": 0, "status": "IN_USE",
|
78 |
+
"polymorphi_values": {}, "sub_type": None,
|
79 |
+
})
|
80 |
+
return prompts_data
|
81 |
+
|
82 |
+
def _build_payload(self, config: ImageGenerationConfig) -> dict[str, Any]:
|
83 |
+
payload = {
|
84 |
+
"storyId": "", "jobType": "universal", "width": config.width, "height": config.height,
|
85 |
+
"rawPrompt": config.prompts, "seed": config.seed, "meta": {"entrance": "PICTURE,PURE"},
|
86 |
+
"context_model_series": None, "negative_freetext": "", # Negative handled in rawPrompt
|
87 |
+
"advanced_translator": config.use_polish,
|
88 |
+
}
|
89 |
+
if config.is_lumina:
|
90 |
+
client_args = {}
|
91 |
+
if config.lumina_config.model_name: client_args["ckpt_name"] = config.lumina_config.model_name
|
92 |
+
if config.lumina_config.cfg is not None: client_args["cfg"] = str(config.lumina_config.cfg)
|
93 |
+
if config.lumina_config.step is not None: client_args["steps"] = str(config.lumina_config.step)
|
94 |
+
if client_args: payload["client_args"] = client_args
|
95 |
+
return payload
|
96 |
+
|
97 |
+
async def _poll_task_status(self, task_uuid: str, progress: gr.Progress | None = None) -> dict[str, Any]:
|
98 |
+
status_url = self.lumina_task_status_url.format(task_uuid=task_uuid)
|
99 |
+
last_status_message = ""
|
100 |
+
async with httpx.AsyncClient(timeout=30.0) as client: # Timeout for individual poll request
|
101 |
+
for attempt in range(self.max_polling_attempts):
|
102 |
+
if progress:
|
103 |
+
progress(attempt / self.max_polling_attempts, desc=f"Polling task status ({attempt+1}/{self.max_polling_attempts})... {last_status_message}")
|
104 |
+
|
105 |
+
try:
|
106 |
+
response = await client.get(status_url, headers=self.default_headers)
|
107 |
+
response.raise_for_status() # Will raise HTTPError for 4xx/5xx
|
108 |
+
result = response.json()
|
109 |
+
except httpx.HTTPStatusError as e:
|
110 |
+
return {"success": False, "error": f"获取任务状态失败: {e.response.status_code} - {e.response.text}"}
|
111 |
+
except httpx.RequestError as e: # Catches network errors, timeouts for this specific request
|
112 |
+
return {"success": False, "error": f"网络请求错误: {str(e)}"}
|
113 |
+
except Exception as e: # Catch JSON parsing errors or other unexpected issues
|
114 |
+
return {"success": False, "error": f"任务状态响应处理失败: {str(e)}"}
|
115 |
+
|
116 |
+
task_status = result.get("task_status")
|
117 |
+
last_status_message = f"Status: {task_status}"
|
118 |
+
if DEBUG_MODE: print(f"DEBUG: Poll {attempt+1}, Task {task_uuid}, Status: {task_status}, Result: {result}")
|
119 |
+
|
120 |
+
if task_status == "SUCCESS":
|
121 |
+
artifacts = result.get("artifacts", [])
|
122 |
+
if artifacts and "url" in artifacts[0]:
|
123 |
+
return {"success": True, "image_url": artifacts[0]["url"]}
|
124 |
+
return {"success": False, "error": "任务成功但未找到图像URL。"}
|
125 |
+
elif task_status in ["FAILURE", "ILLEGAL_IMAGE", "TIMEOUT"]:
|
126 |
+
error_msg = result.get("error", f"任务失败,状态: {task_status}")
|
127 |
+
if "error_message" in result: error_msg = result["error_message"] # API specific field
|
128 |
+
return {"success": False, "error": error_msg}
|
129 |
+
|
130 |
+
# For PENDING, RUNNING, QUEUED, or unknown statuses, continue polling
|
131 |
+
await asyncio.sleep(self.polling_interval)
|
132 |
+
return {"success": False, "error": "⏳ 生图任务轮询超时(5分钟),请稍后重试。"}
|
133 |
+
|
134 |
+
async def generate_image(self, prompt_str: str, negative_prompt_str: str, seed_val: int, width_val: int, height_val: int, cfg_val: float, steps_val: int, model_name_str: str = "ep3", progress: gr.Progress | None = None) -> tuple[str | None, str | None]:
|
135 |
+
if not self.x_token or self.x_token == "YOUR_API_TOKEN_PLACEHOLDER":
|
136 |
+
return None, "API_TOKEN未配置。请在环境变量中设置API_TOKEN以使用此功能。"
|
137 |
+
try:
|
138 |
+
if progress: progress(0.05, desc="准备请求...")
|
139 |
+
model_path = MODEL_CONFIGS.get(model_name_str, MODEL_CONFIGS["ep3"])
|
140 |
+
config = ImageGenerationConfig(
|
141 |
+
prompts=self._prepare_prompt_data(prompt_str, negative_prompt_str),
|
142 |
+
width=width_val, height=height_val, seed=seed_val,
|
143 |
+
lumina_config=LuminaConfig(model_name=model_path, cfg=cfg_val, step=steps_val)
|
144 |
+
)
|
145 |
+
payload = self._build_payload(config)
|
146 |
+
if DEBUG_MODE: print(f"DEBUG: API Payload: {payload}, Headers: {self.default_headers}")
|
147 |
+
|
148 |
+
if progress: progress(0.1, desc="发送生成请求...")
|
149 |
+
async with httpx.AsyncClient(timeout=60.0) as client: # Timeout for initial POST request
|
150 |
+
response = await client.post(self.lumina_api_url, json=payload, headers=self.default_headers)
|
151 |
+
|
152 |
+
if DEBUG_MODE: print(f"DEBUG: API Initial Response: {response.status_code}, {response.text[:500]}")
|
153 |
+
|
154 |
+
if response.status_code == HTTP_STATUS_CENSORED: return None, "内容不合规,请修改提示词。"
|
155 |
+
if response.status_code == 433: return None, "⏳ 服务器繁忙(达到并发上限),请稍后重试。"
|
156 |
+
|
157 |
+
try:
|
158 |
+
response.raise_for_status() # Check for other HTTP errors
|
159 |
+
task_uuid = response.text.strip().replace('"', "")
|
160 |
+
if not task_uuid or len(task_uuid) < 10: # Basic UUID validation
|
161 |
+
return None, f"未能获取有效的任务ID。API响应: {response.text[:200]}"
|
162 |
+
except httpx.HTTPStatusError as e:
|
163 |
+
err_text = e.response.text
|
164 |
+
try: err_json = e.response.json(); err_text = err_json.get("message", err_text)
|
165 |
+
except: pass
|
166 |
+
return None, f"API请求失败: {e.response.status_code} - {err_text[:200]}"
|
167 |
+
|
168 |
+
|
169 |
+
if progress: progress(0.2, desc=f"任务已提交 (ID: {task_uuid[:8]}...), 开始轮询状态...")
|
170 |
+
poll_result = await self._poll_task_status(task_uuid, progress)
|
171 |
+
if poll_result["success"]:
|
172 |
+
if progress: progress(1, desc="图片生成成功!")
|
173 |
+
return poll_result["image_url"], None
|
174 |
+
else:
|
175 |
+
return None, poll_result["error"]
|
176 |
+
except httpx.TimeoutException:
|
177 |
+
return None, "API请求超时,请检查网络连接或稍后再试。"
|
178 |
+
except httpx.RequestError as e:
|
179 |
+
return None, f"网络请求错误: {str(e)}"
|
180 |
+
except Exception as e:
|
181 |
+
if DEBUG_MODE: import traceback; traceback.print_exc()
|
182 |
+
return None, f"生成图片时发生意外错误: {str(e)}"
|
183 |
+
|
184 |
+
# Initialize client
|
185 |
+
try:
|
186 |
+
image_client = ImageClient()
|
187 |
+
except Exception as e: # Catch any init error
|
188 |
+
print(f"Failed to initialize ImageClient: {e}")
|
189 |
+
image_client = None
|
190 |
+
|
191 |
+
# Example prompts
|
192 |
+
example_titles = [
|
193 |
+
"A stylized female demon with red hair and glitch effects",
|
194 |
+
"A young man relaxes on a hazy urban rooftop",
|
195 |
+
"A gentle, freckled girl embraces a goat in a meadow"
|
196 |
+
]
|
197 |
+
full_prompts = {
|
198 |
+
example_titles[0]: "Stylized anime illustration of a female demon or supernatural character with vibrant red hair in twintails/pigtails and glowing purple eyes. Character has black horns and features bandage-like cross markings on face. Subject wears a black sleeveless top and holds a pink bubblegum or candy sphere near mouth. Digital glitch effects create pixelated elements in her hair and around background. Dramatic lighting with stark white/black contrasting background featuring cracks or lightning patterns. Character has gold/yellow accessories including bracelets and hair decorations. Modern anime art style with sharp contrast and vivid colors. Portrait composition showing three-quarter view of character with confident or playful expression. Color palette dominated by reds, blacks, whites, purple and pink accents. Surreal or otherworldly atmosphere enhanced by particle effects and lighting. Professional digital illustration combining traditional anime aesthetics with contemporary glitch art elements. Character design suggests edgy or alternative styling with possible cyberpunk or modern demon girl influences.",
|
199 |
+
example_titles[1]: "Atmospheric anime illustration of young man with messy brown hair on urban rooftop during overcast day. Character wears white dress shirt and dark trousers, leaning back against railing while holding canned drink. Scene set on building rooftop with industrial elements like water tower, power lines, and metal structures visible. Cityscape background shows apartment buildings and urban architecture through soft hazy lighting. Subject has relaxed pose suggesting brief break or moment of contemplation. Color palette uses muted whites, grays, and industrial tones creating realistic urban atmosphere. Art style combines detailed architectural elements with soft, painterly technique. Composition emphasizes vertical lines of city buildings and metal structures. Professional digital artwork capturing slice-of-life moment in urban setting. Scene suggests peaceful solitude amid busy city environment. Lighting creates gentle, overcast mood with subtle shadows and highlights. Character design and setting reflect contemporary Japanese salary-man or office worker aesthetic.",
|
200 |
+
example_titles[2]: "Enchanting anime illustration of a gentle, freckled girl with long, wavy orange hair and elegant ram horns, tenderly embracing a white baby goat in a sunlit meadow. The composition is a close-up, focusing on the upper body and faces of both the girl and the goat, capturing an intimate and heartwarming moment. She wears a vintage-inspired dress with a high collar, puffed sleeves, and a delicate white headband, adorned with golden ribbons and lace details. The sunlight bathes the scene in warm, golden tones, casting soft shadows and creating a dreamy, pastoral atmosphere. The background is filled with lush green grass and scattered white flowers, enhancing the idyllic countryside setting. The art style is painterly and vibrant, with expressive brushwork and a focus on light and texture, evoking a sense of peace, innocence, and connection with nature."
|
201 |
+
}
|
202 |
+
|
203 |
+
async def infer(
|
204 |
+
prompt_text, seed_val, randomize_seed_val, width_val, height_val,
|
205 |
+
cfg_val, steps_val, model_name_val, progress=gr.Progress(track_tqdm=True)
|
206 |
):
|
207 |
+
if image_client is None:
|
208 |
+
raise gr.Error("ImageClient 未正确初始化。请检查应用日志和API_TOKEN配置。")
|
209 |
+
if not prompt_text.strip():
|
210 |
+
raise gr.Error("提示词不能为空。请输入您想生成的图像描述。")
|
|
|
|
|
|
|
211 |
|
212 |
+
current_seed = int(seed_val)
|
213 |
+
if randomize_seed_val:
|
214 |
+
current_seed = random.randint(0, MAX_SEED)
|
215 |
|
216 |
+
width_val, height_val = validate_dimensions(width_val, height_val)
|
217 |
+
|
218 |
+
if not (1.0 <= float(cfg_val) <= 20.0): raise gr.Error("CFG Scale 必须在 1.0 到 20.0 之间。")
|
219 |
+
if not (1 <= int(steps_val) <= 50): raise gr.Error("Steps 必须在 1 到 50 之间。")
|
220 |
+
|
221 |
+
progress(0, desc="开始生成...")
|
222 |
+
image_url, error = await image_client.generate_image(
|
223 |
+
prompt_str=prompt_text, negative_prompt_str="", # Negative prompt not exposed, can be added
|
224 |
+
seed_val=current_seed, width_val=width_val, height_val=height_val,
|
225 |
+
cfg_val=float(cfg_val), steps_val=int(steps_val), model_name_str=model_name_val,
|
226 |
+
progress=progress
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
)
|
228 |
|
229 |
+
if error:
|
230 |
+
# Check if the error is already user-friendly, if not, provide a generic one
|
231 |
+
if "API请求失败" in error or "内容不合规" in error or "服务器繁忙" in error or "任务轮询超时" in error or "API_TOKEN" in error:
|
232 |
+
raise gr.Error(error)
|
233 |
+
else:
|
234 |
+
# For less clear errors, provide a generic message and log the detail if in debug mode
|
235 |
+
if DEBUG_MODE: print(f"Internal error during image generation: {error}")
|
236 |
+
raise gr.Error(f"图片生成失败: {error}. 请稍后再试或检查提示词。")
|
237 |
+
|
238 |
+
|
239 |
+
return image_url, current_seed
|
240 |
+
|
241 |
+
|
242 |
+
# Links for HTML header
|
243 |
+
DISCORD_LINK = os.environ.get("DISCORD_LINK", "https://discord.gg/your-community") # Example
|
244 |
+
APP_INDEX_LINK = os.environ.get("APP_INDEX_LINK", "https://huggingface.co/spaces") # Example
|
245 |
+
APP_INDEX_ICON = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg" # Using HF logo
|
246 |
+
|
247 |
+
|
248 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Lumina Image Playground") as demo:
|
249 |
+
gr.HTML(f"""
|
250 |
+
<div style="display: flex; justify-content: flex-end; align-items: center; gap: 15px; margin-bottom: 10px; padding: 5px;">
|
251 |
+
<a href="{DISCORD_LINK}" target="_blank" style="text-decoration: none; color: #5865F2; font-weight: 500; display: inline-flex; align-items: center; gap: 5px;">
|
252 |
+
<img src="https://assets-global.website-files.com/6257adef93867e50d84d30e2/636e0a69f118df70ad7828d4_icon_clyde_blurple_RGB.svg" alt="Discord" style="height: 20px;">
|
253 |
+
Join Discord
|
254 |
+
</a>
|
255 |
+
<a href="{APP_INDEX_LINK}" target="_blank" style="text-decoration: none; color: #333; font-weight: 500; display: inline-flex; align-items: center; gap: 5px;">
|
256 |
+
<img src="{APP_INDEX_ICON}" alt="App Index" style="height: 20px; border-radius: 3px;">
|
257 |
+
More Apps
|
258 |
+
</a>
|
259 |
</div>
|
260 |
+
""")
|
|
|
261 |
|
262 |
+
gr.Markdown("<h1>🎨 Lumina Text-to-Image Playground</h1>")
|
263 |
+
gr.Markdown("Describe your vision and let the AI bring it to life! Uses an external API for image generation.")
|
264 |
|
265 |
+
with gr.Row(variant="panel"):
|
266 |
+
with gr.Column(scale=2): # Controls Panel
|
267 |
+
gr.Markdown("## ⚙️ Generation Controls")
|
268 |
+
prompt = gr.Textbox(
|
269 |
+
label="Prompt", lines=5,
|
270 |
+
placeholder="e.g., A majestic dragon soaring through a cyberpunk city skyline, neon lights reflecting off its scales, intricate details.",
|
271 |
+
info="Describe the image you want to create."
|
272 |
+
)
|
273 |
+
|
274 |
+
with gr.Accordion("🔧 Advanced Settings", open=True):
|
275 |
+
model_name = gr.Dropdown(
|
276 |
+
label="Model Version", choices=list(MODEL_CONFIGS.keys()), value="ep3",
|
277 |
+
info="Select the generation model."
|
278 |
+
)
|
279 |
+
with gr.Row():
|
280 |
+
cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=20.0, step=0.1, value=5.5, info="Guidance strength. Higher values adhere more to prompt.")
|
281 |
+
steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=50, step=1, value=30, info="Number of steps. More steps can improve quality but take longer.")
|
282 |
+
|
283 |
+
with gr.Row():
|
284 |
+
width = gr.Slider(label="Width", minimum=MIN_IMAGE_SIZE, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
|
285 |
+
height = gr.Slider(label="Height", minimum=MIN_IMAGE_SIZE, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
|
286 |
+
|
287 |
+
with gr.Row():
|
288 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=random.randint(0, MAX_SEED))
|
289 |
+
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True, info="Use a new random seed for each generation if checked.")
|
290 |
+
|
291 |
+
run_button = gr.Button("🚀 Generate Image", variant="primary", scale=0) # scale=0 for button to not take full width in some cases if alone
|
292 |
+
|
293 |
+
with gr.Group():
|
294 |
+
gr.Markdown("### ✨ Example Prompts")
|
295 |
+
for i, title in enumerate(example_titles):
|
296 |
+
btn = gr.Button(title)
|
297 |
+
btn.click(lambda t=title: full_prompts[t], outputs=[prompt])
|
298 |
+
|
299 |
+
|
300 |
+
with gr.Column(scale=3): # Output Panel
|
301 |
+
gr.Markdown("## 🖼️ Generated Image")
|
302 |
+
result_image = gr.Image(
|
303 |
+
label="Output Image", show_label=False, type="filepath",
|
304 |
+
height=600, # Max display height
|
305 |
+
show_download_button=True, interactive=False,
|
306 |
+
elem_id="result_image_display" # for potential CSS targeting if needed
|
307 |
+
)
|
308 |
+
generated_seed_info = gr.Textbox(label="Seed Used", interactive=False, placeholder="The seed for the generated image will appear here.")
|
309 |
|
310 |
+
# Event Handlers
|
311 |
+
inputs_list = [prompt, seed, randomize_seed, width, height, cfg, steps, model_name]
|
312 |
+
outputs_list = [result_image, generated_seed_info]
|
313 |
+
|
314 |
+
run_button.click(fn=infer, inputs=inputs_list, outputs=outputs_list, api_name="generate_image")
|
315 |
+
prompt.submit(fn=infer, inputs=inputs_list, outputs=outputs_list, api_name="generate_image_submit")
|
316 |
|
|
|
317 |
|
318 |
if __name__ == "__main__":
|
319 |
+
if DEBUG_MODE:
|
320 |
+
print("DEBUG_MODE is enabled.")
|
321 |
+
if not os.environ.get("API_TOKEN"):
|
322 |
+
print("**************************************************************************************")
|
323 |
+
print("WARNING: API_TOKEN environment variable is not set locally.")
|
324 |
+
print("The application will run, but image generation will fail until API_TOKEN is provided.")
|
325 |
+
print("You can set it by running: export API_TOKEN='your_actual_token_here'")
|
326 |
+
print("Or if using a .env file, ensure it's loaded or API_TOKEN is set in your run config.")
|
327 |
+
print("**************************************************************************************")
|
328 |
+
|
329 |
+
demo.launch(debug=DEBUG_MODE, show_error=True)
|