File size: 5,084 Bytes
82df48f
 
 
407a5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82df48f
 
 
 
 
e2bc0a8
407a5fa
e2bc0a8
 
 
 
 
 
 
 
 
 
 
407a5fa
 
 
e2bc0a8
 
 
407a5fa
e2bc0a8
 
82df48f
 
 
 
 
 
 
 
 
 
 
 
 
407a5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2bc0a8
82df48f
e2bc0a8
82df48f
 
 
 
 
 
e2bc0a8
 
82df48f
e2bc0a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82df48f
e2bc0a8
 
 
 
 
 
 
 
 
 
407a5fa
e2bc0a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82df48f
407a5fa
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import gradio as gr
import numpy as np
import random
import logging
import sys

# 设置日志记录
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                    stream=sys.stdout)
logger = logging.getLogger(__name__)

# 修复 Gradio JSON Schema 错误
try:
    import gradio_client.utils
    # 添加对布尔值的检查
    original_get_type = gradio_client.utils.get_type
    def patched_get_type(schema):
        if isinstance(schema, bool):
            return "bool"
        if not isinstance(schema, dict):
            return "any"
        return original_get_type(schema)
    gradio_client.utils.get_type = patched_get_type
    logger.info("Successfully patched Gradio JSON schema processing")
except Exception as e:
    logger.error(f"Failed to patch Gradio: {str(e)}")

# import spaces #[uncomment to use ZeroGPU]
from diffusers import DiffusionPipeline
import torch

# 使用 try/except 避免在导入模块时出错
try:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model_repo_id = "stabilityai/sdxl-turbo"
    
    logger.info(f"Using device: {device}")
    logger.info(f"Loading model: {model_repo_id}")
    
    if torch.cuda.is_available():
        torch_dtype = torch.float16
    else:
        torch_dtype = torch.float32
    
    pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
    pipe = pipe.to(device)
    logger.info("Model loaded successfully")
    
    MAX_SEED = np.iinfo(np.int32).max
    MAX_IMAGE_SIZE = 1024
except Exception as e:
    logger.error(f"Error during setup: {str(e)}")
    # 不立即抛出异常,让 Gradio 界面可以加载

# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
    prompt,
    negative_prompt,
    seed,
    randomize_seed,
    width,
    height,
    guidance_scale,
    num_inference_steps,
    progress=gr.Progress(track_tqdm=True),
):
    try:
        logger.info(f"Processing prompt: {prompt}")
        
        if randomize_seed:
            seed = random.randint(0, MAX_SEED)
            
        logger.info(f"Using seed: {seed}, width: {width}, height: {height}")
        
        generator = torch.Generator().manual_seed(seed)

        image = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            width=width,
            height=height,
            generator=generator,
        ).images[0]

        logger.info("Image generation successful")
        return image, seed
    except Exception as e:
        logger.error(f"Error in inference: {str(e)}")
        return None, seed  # 返回 None 而不是抛出异常

# 定义示例
examples = [
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "An astronaut riding a green horse",
    "A delicious ceviche cheesecake slice",
]

# 简化 CSS
css = "#col-container { margin: 0 auto; max-width: 640px; }"

# 创建简化版的 Gradio 界面
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# Text-to-Image Generator")
        
        # 主输入区域
        prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt")
        run_button = gr.Button("Generate Image")
        
        # 结果显示
        result = gr.Image(label="Generated Image")
        seed_text = gr.Number(label="Seed Used")
        
        # 高级设置(折叠)
        with gr.Accordion("Advanced Settings", open=False):
            negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="What to exclude from the image")
            
            # 种子设置
            seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
            randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
            
            # 尺寸设置
            with gr.Row():
                width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
                height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
            
            # 生成参数
            with gr.Row():
                guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=0.0)
                num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=2)
        
        # 示例
        gr.Examples(examples, inputs=prompt)
    
    # 绑定事件处理
    run_button.click(
        fn=infer,
        inputs=[
            prompt,
            negative_prompt,
            seed,
            randomize_seed,
            width,
            height,
            guidance_scale,
            num_inference_steps,
        ],
        outputs=[result, seed_text],
    )

# 启动应用
if __name__ == "__main__":
    try:
        logger.info("Starting Gradio app")
        demo.launch()
    except Exception as e:
        logger.error(f"Error launching app: {str(e)}")