File size: 7,329 Bytes
82df48f
 
 
407a5fa
 
8718399
65762c4
407a5fa
 
 
 
 
 
 
0cfce88
7a3e379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cfce88
 
 
5f71263
 
4564834
5f71263
 
0cfce88
4564834
0cfce88
 
4564834
5f71263
0cfce88
5f71263
 
 
0cfce88
 
 
 
 
 
 
5f71263
 
0cfce88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4564834
0cfce88
 
 
 
 
 
 
 
 
 
5f71263
0cfce88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4564834
0cfce88
4564834
0cfce88
4564834
 
 
7a3e379
4564834
 
0cfce88
 
 
 
 
 
 
 
 
4564834
 
 
0cfce88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a3e379
82df48f
4564834
 
e2bc0a8
 
82df48f
407a5fa
4564834
 
0cfce88
 
 
4564834
407a5fa
4564834
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import gradio as gr
import numpy as np
import random
import logging
import sys
import os
from PIL import Image as PILImage

# 设置日志记录
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                    stream=sys.stdout)
logger = logging.getLogger(__name__)

# 补丁修复 Gradio JSON Schema 错误
try:
    import gradio_client.utils
    
    # 保存原始函数
    original_get_type = gradio_client.utils.get_type
    
    # 创建新的 get_type 函数,处理布尔值
    def patched_get_type(schema):
        if schema is True or schema is False or schema is None:
            return "any"
        if not isinstance(schema, dict):
            return "any"
        return original_get_type(schema)
    
    # 替换原始函数
    gradio_client.utils.get_type = patched_get_type
    logger.info("Successfully patched gradio_client.utils.get_type")
    
    # 同样修补 _json_schema_to_python_type 函数
    original_json_schema_to_python_type = gradio_client.utils._json_schema_to_python_type
    
    def patched_json_schema_to_python_type(schema, defs=None):
        if schema is True or schema is False:
            return "bool"
        if schema is None:
            return "None"
        if not isinstance(schema, dict):
            return "any"
        
        try:
            return original_json_schema_to_python_type(schema, defs)
        except Exception as e:
            logger.warning(f"Error in json_schema_to_python_type: {e}")
            return "any"
    
    gradio_client.utils._json_schema_to_python_type = patched_json_schema_to_python_type
    logger.info("Successfully patched gradio_client.utils._json_schema_to_python_type")
except Exception as e:
    logger.error(f"Failed to patch Gradio utils: {e}")

# 创建一个备用图像
def create_backup_image(prompt=""):
    logger.info(f"创建备用图像: {prompt}")
    img = PILImage.new('RGB', (512, 512), color=(240, 240, 250))
    
    try:
        from PIL import ImageDraw, ImageFont
        draw = ImageDraw.Draw(img)
        font = ImageFont.load_default()
        
        draw.text((20, 20), f"提示词: {prompt}", fill=(0, 0, 0), font=font)
        draw.text((20, 60), "模型加载失败,无法生成图像", fill=(255, 0, 0), font=font)
        
    except Exception as e:
        logger.error(f"创建备用图像时出错: {e}")
    
    return img

# 预加载 AI 模型
model = None

def load_model():
    global model
    if model is not None:
        return model
    
    try:
        logger.info("开始加载AI模型...")
        
        # 延迟导入,确保所有依赖都已正确安装
        import torch
        from diffusers import StableDiffusionPipeline
        
        # 使用较低版本的模型
        model_id = "CompVis/stable-diffusion-v1-4"
        
        # 设置加载参数
        load_options = {
            "revision": "fp16" if torch.cuda.is_available() else None,
            "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
            "safety_checker": None
        }
        
        logger.info(f"使用模型: {model_id}")
        pipe = StableDiffusionPipeline.from_pretrained(model_id, **load_options)
        
        # 转移到适当的设备
        device = "cuda" if torch.cuda.is_available() else "cpu"
        pipe = pipe.to(device)
        
        # 优化
        if torch.cuda.is_available():
            pipe.enable_attention_slicing()
        
        logger.info("AI模型加载成功")
        model = pipe
        return model
    except Exception as e:
        logger.error(f"AI模型加载失败: {e}")
        return None

# AI 图像生成函数
def generate_ai_image(prompt, seed=None):
    # 尝试加载模型
    pipe = load_model()
    if pipe is None:
        logger.error("AI模型不可用")
        return None
    
    try:
        logger.info(f"使用AI生成图像: {prompt}")
        
        # 设置生成参数
        if seed is None:
            seed = random.randint(0, 2147483647)
        
        # 确定正确的设备
        generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed)
        
        # 生成图像
        image = pipe(
            prompt=prompt,
            guidance_scale=7.5,
            num_inference_steps=5,  # 降低步数以加快速度
            generator=generator,
            height=512,
            width=512
        ).images[0]
        
        # 清理缓存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
        logger.info(f"AI图像生成成功,种子: {seed}")
        return image
    
    except Exception as e:
        logger.error(f"AI图像生成失败: {e}")
        return None

# 入口点函数 - 处理请求并生成图像
def generate_image(prompt):
    # 处理空提示
    if not prompt or prompt.strip() == "":
        prompt = "a beautiful landscape"
        logger.info(f"输入为空,使用默认提示词: {prompt}")
    
    logger.info(f"收到提示词: {prompt}")
    
    # 尝试使用AI生成
    image = generate_ai_image(prompt)
    
    # 检查结果
    if image is not None:
        return image
    else:
        logger.warning("使用备用生成器")
        return create_backup_image(prompt)

# 创建Gradio界面
def create_demo():
    with gr.Blocks(title="AI 文本到图像生成器") as demo:
        gr.Markdown("# AI 文本到图像生成器")
        gr.Markdown("输入文本描述,AI将为你生成相应的图像。")
        
        with gr.Row():
            with gr.Column(scale=3):
                # 输入区域
                prompt_input = gr.Textbox(
                    label="输入提示词", 
                    placeholder="描述你想要的图像,例如:一只可爱的猫,日落下的山脉...",
                    lines=2
                )
                generate_button = gr.Button("生成图像", variant="primary")
                
                # 示例
                gr.Examples(
                    examples=[
                        "a cute cat sitting on a windowsill",
                        "beautiful sunset over mountains",
                        "an astronaut riding a horse in space",
                        "a fantasy castle on a floating island"
                    ],
                    inputs=prompt_input
                )
            
            # 输出区域
            with gr.Column(scale=5):
                output_image = gr.Image(label="生成的图像", type="pil")
        
        # 绑定按钮事件
        generate_button.click(
            fn=generate_image,
            inputs=prompt_input,
            outputs=output_image
        )
        
        # 也绑定Enter键提交
        prompt_input.submit(
            fn=generate_image,
            inputs=prompt_input,
            outputs=output_image
        )
    
    return demo

# 创建演示界面
demo = create_demo()

# 启动应用
if __name__ == "__main__":
    try:
        logger.info("启动Gradio界面...")
        demo.launch(
            server_name="0.0.0.0",
            show_api=False,
            share=False
        )
    except Exception as e:
        logger.error(f"启动失败: {e}")