lisonallen commited on
Commit
4564834
·
1 Parent(s): 7a3e379

恢复AI图像生成功能,使用更轻量的方式

Browse files
Files changed (1) hide show
  1. app.py +114 -22
app.py CHANGED
@@ -59,36 +59,128 @@ def create_dummy_image():
59
  img = PILImage.new('RGB', (256, 256), color = (255, 100, 100))
60
  return img
61
 
62
- # 使用极简方法
63
- def simple_demo():
64
- # 定义一个极简单的生成函数
65
- def generate(text):
66
- logger.info(f"Received text: {text}")
67
- # 创建一个简单的图像作为响应
68
- image = PILImage.new('RGB', (256, 256), color=(100, 200, 100))
69
- logger.info("Created simple green image")
70
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- logger.info("Setting up simple demo without model loading")
73
- # 创建最小界面
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  demo = gr.Interface(
75
- fn=generate,
76
- inputs=gr.Textbox(label="Prompt"),
77
- outputs=gr.Image(type="pil", label="Generated Image"),
78
- title="Simple Text-to-Image",
79
- examples=["a cat"],
80
- cache_examples=False
 
 
81
  )
82
  return demo
83
 
84
- # 只使用极简应用,避免加载任何模型
85
- demo = simple_demo()
86
 
87
  # 启动应用
88
  if __name__ == "__main__":
89
  try:
90
- logger.info("Starting Gradio interface")
91
  # 使用最小配置
92
- demo.launch(server_name="0.0.0.0", show_api=False)
 
 
 
 
 
 
93
  except Exception as e:
94
- logger.error(f"Failed to launch: {e}")
 
59
  img = PILImage.new('RGB', (256, 256), color = (255, 100, 100))
60
  return img
61
 
62
+ # 全局变量
63
+ pipe = None
64
+
65
+ # 懒加载AI模型函数
66
+ def get_model():
67
+ try:
68
+ import torch
69
+ from diffusers import StableDiffusionPipeline
70
+
71
+ logger.info("开始加载模型...")
72
+
73
+ # 使用较小的模型而不是SDXL
74
+ model_id = "runwayml/stable-diffusion-v1-5"
75
+ device = "cuda" if torch.cuda.is_available() else "cpu"
76
+ logger.info(f"使用设备: {device}")
77
+
78
+ # 优化设置以减少内存使用
79
+ if torch.cuda.is_available():
80
+ # 使用半精度
81
+ pipe = StableDiffusionPipeline.from_pretrained(
82
+ model_id,
83
+ torch_dtype=torch.float16,
84
+ safety_checker=None, # 禁用安全检查器以节省内存
85
+ requires_safety_checker=False,
86
+ use_safetensors=True
87
+ )
88
+ pipe = pipe.to(device)
89
+ pipe.enable_attention_slicing() # 减少显存使用
90
+
91
+ # 释放不必要的内存
92
+ torch.cuda.empty_cache()
93
+ else:
94
+ # CPU版本,占用内存较大,但现在只用处理一个请求
95
+ pipe = StableDiffusionPipeline.from_pretrained(
96
+ model_id,
97
+ safety_checker=None,
98
+ requires_safety_checker=False,
99
+ use_safetensors=True
100
+ )
101
+ pipe = pipe.to(device)
102
+
103
+ logger.info("模型加载成功")
104
+ return pipe
105
+ except Exception as e:
106
+ logger.error(f"模型加载失败: {e}")
107
+ return None
108
+
109
+ # 生成图像函数
110
+ def generate_image(prompt):
111
+ global pipe
112
+
113
+ # 如果提示为空,使用默认提示
114
+ if not prompt or prompt.strip() == "":
115
+ prompt = "a beautiful landscape"
116
+ logger.info(f"输入为空,使用默认提示词: {prompt}")
117
 
118
+ logger.info(f"收到提示词: {prompt}")
119
+
120
+ # 第一次调用时加载模型
121
+ if pipe is None:
122
+ pipe = get_model()
123
+ if pipe is None:
124
+ logger.error("模型加载失败,返回默认图像")
125
+ return create_dummy_image()
126
+
127
+ try:
128
+ # 优化生成参数,减少内存需求
129
+ logger.info("开始生成图像...")
130
+
131
+ # 设置随机种子以确保结果一致性
132
+ seed = random.randint(0, 2147483647)
133
+ generator = torch.Generator(device=pipe.device).manual_seed(seed)
134
+
135
+ # 使用最轻量级的参数
136
+ image = pipe(
137
+ prompt=prompt,
138
+ num_inference_steps=3, # 极少的步骤
139
+ guidance_scale=7.5,
140
+ height=256, # 小尺寸
141
+ width=256, # 小尺寸
142
+ generator=generator
143
+ ).images[0]
144
+
145
+ # 释放缓存,避免内存增长
146
+ if torch.cuda.is_available():
147
+ torch.cuda.empty_cache()
148
+
149
+ logger.info(f"图像生成成功,种子: {seed}")
150
+ return image
151
+ except Exception as e:
152
+ logger.error(f"生成过程发生错误: {e}")
153
+ return create_dummy_image()
154
+
155
+ # 创建Gradio界面
156
+ def create_demo():
157
+ # 使用简单界面,避免复杂组件
158
  demo = gr.Interface(
159
+ fn=generate_image,
160
+ inputs=gr.Textbox(label="输入提示词"),
161
+ outputs=gr.Image(type="pil", label="生成的图像"),
162
+ title="文本到图像生成",
163
+ description="输入文本描述,AI将生成相应的图像(会加载较长时间,请耐心等待)",
164
+ examples=["a cute cat", "mountain landscape"],
165
+ cache_examples=False,
166
+ allow_flagging="never" # 禁用标记功能以减少复杂性
167
  )
168
  return demo
169
 
170
+ # 创建演示界面
171
+ demo = create_demo()
172
 
173
  # 启动应用
174
  if __name__ == "__main__":
175
  try:
176
+ logger.info("启动Gradio界面...")
177
  # 使用最小配置
178
+ demo.launch(
179
+ server_name="0.0.0.0",
180
+ show_api=False, # 禁用API
181
+ share=False, # 不创建公共链接
182
+ debug=False, # 禁用调试��式
183
+ quiet=True # 减少日志输出
184
+ )
185
  except Exception as e:
186
+ logger.error(f"启动失败: {e}")