File size: 6,159 Bytes
82df48f
 
 
407a5fa
 
8718399
65762c4
407a5fa
 
 
 
 
 
 
7a3e379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b80791
65762c4
3b80791
a0be010
65762c4
 
4564834
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a3e379
4564834
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a3e379
4564834
 
 
 
 
 
 
 
7a3e379
 
82df48f
4564834
 
e2bc0a8
 
82df48f
407a5fa
4564834
7a3e379
4564834
 
 
 
 
 
 
407a5fa
4564834
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import gradio as gr
import numpy as np
import random
import logging
import sys
import os
from PIL import Image as PILImage

# 设置日志记录
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                    stream=sys.stdout)
logger = logging.getLogger(__name__)

# 修复 Gradio JSON Schema 错误
try:
    import gradio_client.utils
    
    # 保存原始函数
    original_get_type = gradio_client.utils.get_type
    
    # 创建新的 get_type 函数,处理布尔值
    def patched_get_type(schema):
        if schema is True or schema is False or schema is None:
            return "any"
        if not isinstance(schema, dict):
            return "any"
        return original_get_type(schema)
    
    # 替换原始函数
    gradio_client.utils.get_type = patched_get_type
    logger.info("Successfully patched gradio_client.utils.get_type")
    
    # 同样修补 _json_schema_to_python_type 函数
    original_json_schema_to_python_type = gradio_client.utils._json_schema_to_python_type
    
    def patched_json_schema_to_python_type(schema, defs=None):
        if schema is True or schema is False:
            return "bool"
        if schema is None:
            return "None"
        if not isinstance(schema, dict):
            return "any"
        
        try:
            return original_json_schema_to_python_type(schema, defs)
        except Exception as e:
            logger.warning(f"Error in json_schema_to_python_type: {e}")
            return "any"
    
    gradio_client.utils._json_schema_to_python_type = patched_json_schema_to_python_type
    logger.info("Successfully patched gradio_client.utils._json_schema_to_python_type")
except Exception as e:
    logger.error(f"Failed to patch Gradio utils: {e}")

# 创建一个简单的示例图像
def create_dummy_image():
    logger.info("Creating dummy image")
    img = PILImage.new('RGB', (256, 256), color = (255, 100, 100))
    return img

# 全局变量
pipe = None

# 懒加载AI模型函数
def get_model():
    try:
        import torch
        from diffusers import StableDiffusionPipeline
        
        logger.info("开始加载模型...")
        
        # 使用较小的模型而不是SDXL
        model_id = "runwayml/stable-diffusion-v1-5"
        device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"使用设备: {device}")
        
        # 优化设置以减少内存使用
        if torch.cuda.is_available():
            # 使用半精度
            pipe = StableDiffusionPipeline.from_pretrained(
                model_id, 
                torch_dtype=torch.float16,
                safety_checker=None,  # 禁用安全检查器以节省内存
                requires_safety_checker=False,
                use_safetensors=True
            )
            pipe = pipe.to(device)
            pipe.enable_attention_slicing()  # 减少显存使用
            
            # 释放不必要的内存
            torch.cuda.empty_cache()
        else:
            # CPU版本,占用内存较大,但现在只用处理一个请求
            pipe = StableDiffusionPipeline.from_pretrained(
                model_id,
                safety_checker=None,
                requires_safety_checker=False,
                use_safetensors=True
            )
            pipe = pipe.to(device)
        
        logger.info("模型加载成功")
        return pipe
    except Exception as e:
        logger.error(f"模型加载失败: {e}")
        return None

# 生成图像函数
def generate_image(prompt):
    global pipe
    
    # 如果提示为空,使用默认提示
    if not prompt or prompt.strip() == "":
        prompt = "a beautiful landscape"
        logger.info(f"输入为空,使用默认提示词: {prompt}")
    
    logger.info(f"收到提示词: {prompt}")
    
    # 第一次调用时加载模型
    if pipe is None:
        pipe = get_model()
        if pipe is None:
            logger.error("模型加载失败,返回默认图像")
            return create_dummy_image()
    
    try:
        # 优化生成参数,减少内存需求
        logger.info("开始生成图像...")
        
        # 设置随机种子以确保结果一致性
        seed = random.randint(0, 2147483647)
        generator = torch.Generator(device=pipe.device).manual_seed(seed)
        
        # 使用最轻量级的参数
        image = pipe(
            prompt=prompt,
            num_inference_steps=3,  # 极少的步骤
            guidance_scale=7.5,
            height=256,  # 小尺寸
            width=256,   # 小尺寸
            generator=generator
        ).images[0]
        
        # 释放缓存,避免内存增长
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
        logger.info(f"图像生成成功,种子: {seed}")
        return image
    except Exception as e:
        logger.error(f"生成过程发生错误: {e}")
        return create_dummy_image()

# 创建Gradio界面
def create_demo():
    # 使用简单界面,避免复杂组件
    demo = gr.Interface(
        fn=generate_image,
        inputs=gr.Textbox(label="输入提示词"),
        outputs=gr.Image(type="pil", label="生成的图像"),
        title="文本到图像生成",
        description="输入文本描述,AI将生成相应的图像(会加载较长时间,请耐心等待)",
        examples=["a cute cat", "mountain landscape"],
        cache_examples=False,
        allow_flagging="never"  # 禁用标记功能以减少复杂性
    )
    return demo

# 创建演示界面
demo = create_demo()

# 启动应用
if __name__ == "__main__":
    try:
        logger.info("启动Gradio界面...")
        # 使用最小配置
        demo.launch(
            server_name="0.0.0.0", 
            show_api=False,      # 禁用API
            share=False,         # 不创建公共链接
            debug=False,         # 禁用调试模式
            quiet=True           # 减少日志输出
        )
    except Exception as e:
        logger.error(f"启动失败: {e}")