lisonallen commited on
Commit
a0be010
·
1 Parent(s): 65762c4

Extreme simplification for stability

Browse files
Files changed (2) hide show
  1. app.py +80 -124
  2. requirements.txt +6 -7
app.py CHANGED
@@ -5,7 +5,6 @@ import logging
5
  import sys
6
  import os
7
  from PIL import Image as PILImage
8
- import io
9
 
10
  # 设置日志记录
11
  logging.basicConfig(level=logging.INFO,
@@ -13,88 +12,61 @@ logging.basicConfig(level=logging.INFO,
13
  stream=sys.stdout)
14
  logger = logging.getLogger(__name__)
15
 
16
- # 确保缓存目录存在
17
- os.makedirs(".gradio/cached_examples", exist_ok=True)
18
-
19
- # 修复 Gradio JSON Schema 错误
20
- try:
21
- import gradio_client.utils
22
-
23
- # 保存原始函数
24
- original_get_type = gradio_client.utils.get_type
25
- original_json_schema_to_python_type = gradio_client.utils._json_schema_to_python_type
26
-
27
- # 修复 get_type 函数
28
- def patched_get_type(schema):
29
- if isinstance(schema, bool):
30
- return "bool"
31
- if not isinstance(schema, dict):
32
- return "any"
33
- return original_get_type(schema)
34
-
35
- # 修复 _json_schema_to_python_type 函数
36
- def patched_json_schema_to_python_type(schema, defs=None):
37
- # 处理基本类型
38
- if schema is True or schema is False:
39
- return "bool"
40
- if schema is None:
41
- return "None"
42
- if not isinstance(schema, dict):
43
- return "any"
44
-
45
- try:
46
- return original_json_schema_to_python_type(schema, defs)
47
- except Exception as e:
48
- logger.warning(f"Error in JSON schema parsing: {str(e)}")
49
- return "any" # 作为备用类型返回
50
-
51
- # 应用补丁
52
- gradio_client.utils.get_type = patched_get_type
53
- gradio_client.utils._json_schema_to_python_type = patched_json_schema_to_python_type
54
-
55
- logger.info("Successfully patched Gradio JSON schema processing")
56
- except Exception as e:
57
- logger.error(f"Failed to patch Gradio: {str(e)}")
58
-
59
  # 创建一个简单的示例图像,在模型加载失败或生成失败时使用
60
  def create_dummy_image():
61
- # 创建一个256x256的红色图像
62
- img = PILImage.new('RGB', (256, 256), color = (255, 0, 0))
63
  return img
64
 
65
- # 加载模型
66
- try:
67
- from diffusers import DiffusionPipeline
68
- import torch
69
-
70
- device = "cuda" if torch.cuda.is_available() else "cpu"
71
- model_repo_id = "stabilityai/sdxl-turbo"
72
-
73
- logger.info(f"Using device: {device}")
74
- logger.info(f"Loading model: {model_repo_id}")
75
-
76
- if torch.cuda.is_available():
77
- torch_dtype = torch.float16
78
- else:
79
- torch_dtype = torch.float32
80
-
81
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
82
- pipe = pipe.to(device)
83
- logger.info("Model loaded successfully")
84
-
85
- MAX_SEED = np.iinfo(np.int32).max
86
- except Exception as e:
87
- logger.error(f"Error loading model: {str(e)}")
88
- # 创建一个空的函数以避免错误
89
- pipe = None
90
 
91
- # 简单的推理函数
92
  def generate_image(prompt):
 
 
93
  try:
 
94
  if pipe is None:
95
- logger.error("Model not loaded, cannot generate image")
96
- return create_dummy_image()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
98
  if not prompt or prompt.strip() == "":
99
  prompt = "A beautiful landscape"
100
  logger.info(f"Empty prompt, using default: {prompt}")
@@ -103,79 +75,63 @@ def generate_image(prompt):
103
  seed = random.randint(0, MAX_SEED)
104
  generator = torch.Generator().manual_seed(seed)
105
 
106
- # 添加异常处理,防止数组索引越界
107
  try:
 
108
  image = pipe(
109
  prompt=prompt,
110
- guidance_scale=0.0,
111
- num_inference_steps=2,
112
- generator=generator
 
 
113
  ).images[0]
114
 
115
- # 确保图像是 PIL.Image 类型
116
  if not isinstance(image, PILImage.Image):
117
- logger.warning(f"Converting image from {type(image)} to PIL.Image")
118
- if hasattr(image, 'numpy'):
119
- image = PILImage.fromarray(image.numpy())
120
- else:
121
- image = PILImage.fromarray(np.array(image))
122
-
123
- # 转换为 RGB 模式,确保兼容性
124
  if image.mode != 'RGB':
125
  image = image.convert('RGB')
126
 
127
  logger.info("Image generation successful")
128
- # 保存图像以供调试
129
- debug_path = "debug_image.jpg"
130
- image.save(debug_path)
131
- logger.info(f"Debug image saved to {debug_path}")
132
 
 
 
 
 
133
  return image
134
- except IndexError as e:
135
- logger.error(f"Index error in pipe: {str(e)}")
 
136
  return create_dummy_image()
137
 
138
  except Exception as e:
139
- logger.error(f"Error generating image: {str(e)}")
140
  return create_dummy_image()
141
 
142
- # 创建简单的 Gradio 界面,禁用示例缓存
143
- with gr.Blocks(title="SDXL Turbo Text-to-Image") as demo:
144
- gr.Markdown("# SDXL Turbo Text-to-Image Generator")
145
-
146
- with gr.Row():
147
- with gr.Column():
148
- prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your text prompt here...")
149
- generate_button = gr.Button("Generate Image")
150
-
151
- with gr.Column():
152
- image_output = gr.Image(label="Generated Image", type="pil")
153
-
154
- # 示例
155
- gr.Examples(
156
- examples=["A cute cat", "Sunset over mountains"],
157
- inputs=prompt_input,
158
- outputs=image_output,
159
- fn=generate_image,
160
- cache_examples=False
161
- )
162
-
163
- # 绑定生成按钮
164
- generate_button.click(fn=generate_image, inputs=prompt_input, outputs=image_output)
165
-
166
- # 直接绑定文本框的提交事件
167
- prompt_input.submit(fn=generate_image, inputs=prompt_input, outputs=image_output)
168
 
169
  # 启动应用
170
  if __name__ == "__main__":
171
  try:
172
  logger.info("Starting Gradio app")
173
- # 添加更多启动选项,以提高稳定性
174
  demo.launch(
175
- debug=True,
176
- show_error=True,
177
- server_name="0.0.0.0", # 确保可以从外部访问
178
- max_threads=1 # 降低并发以避免竞态条件
179
  )
180
  except Exception as e:
181
  logger.error(f"Error launching app: {str(e)}")
 
5
  import sys
6
  import os
7
  from PIL import Image as PILImage
 
8
 
9
  # 设置日志记录
10
  logging.basicConfig(level=logging.INFO,
 
12
  stream=sys.stdout)
13
  logger = logging.getLogger(__name__)
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # 创建一个简单的示例图像,在模型加载失败或生成失败时使用
16
  def create_dummy_image():
17
+ # 创建一个简单的彩色图像
18
+ img = PILImage.new('RGB', (256, 256), color = (255, 100, 100))
19
  return img
20
 
21
+ # 全局变量,管理模型状态
22
+ pipe = None
23
+ MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # 简化的推理函数
26
  def generate_image(prompt):
27
+ global pipe
28
+
29
  try:
30
+ # 懒加载模型 - 仅在第一次调用时加载
31
  if pipe is None:
32
+ try:
33
+ logger.info("First request - loading model...")
34
+ from diffusers import DiffusionPipeline
35
+ import torch
36
+
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ model_repo_id = "stabilityai/sdxl-turbo"
39
+
40
+ logger.info(f"Using device: {device}")
41
+ logger.info(f"Loading model: {model_repo_id}")
42
+
43
+ if torch.cuda.is_available():
44
+ torch_dtype = torch.float16
45
+ else:
46
+ torch_dtype = torch.float32
47
+
48
+ # 优化内存使用
49
+ pipe = DiffusionPipeline.from_pretrained(
50
+ model_repo_id,
51
+ torch_dtype=torch_dtype,
52
+ variant="fp16" if torch.cuda.is_available() else None,
53
+ use_safetensors=True
54
+ )
55
+ pipe = pipe.to(device)
56
+
57
+ # 优化内存
58
+ if torch.cuda.is_available():
59
+ pipe.enable_attention_slicing()
60
+ # 释放不必要的内存
61
+ torch.cuda.empty_cache()
62
+
63
+ logger.info("Model loaded successfully")
64
+
65
+ except Exception as e:
66
+ logger.error(f"Error loading model: {str(e)}")
67
+ return create_dummy_image()
68
 
69
+ # 处理空提示
70
  if not prompt or prompt.strip() == "":
71
  prompt = "A beautiful landscape"
72
  logger.info(f"Empty prompt, using default: {prompt}")
 
75
  seed = random.randint(0, MAX_SEED)
76
  generator = torch.Generator().manual_seed(seed)
77
 
78
+ # 简化参数和异常处理
79
  try:
80
+ # 使用最小的推理步骤以减轻资源压力
81
  image = pipe(
82
  prompt=prompt,
83
+ guidance_scale=0.0, # 设为0以获得最快的推理时间
84
+ num_inference_steps=1, # 减少步骤
85
+ generator=generator,
86
+ height=256, # 减小图像尺寸
87
+ width=256 # 减小图像尺寸
88
  ).images[0]
89
 
90
+ # 确保图像是有效的PIL图像
91
  if not isinstance(image, PILImage.Image):
92
+ logger.warning("Converting image to PIL format")
93
+ image = PILImage.fromarray(np.array(image))
94
+
95
+ # 转换图像模式
 
 
 
96
  if image.mode != 'RGB':
97
  image = image.convert('RGB')
98
 
99
  logger.info("Image generation successful")
 
 
 
 
100
 
101
+ # 释放内存
102
+ if torch.cuda.is_available():
103
+ torch.cuda.empty_cache()
104
+
105
  return image
106
+
107
+ except Exception as e:
108
+ logger.error(f"Error in image generation: {str(e)}")
109
  return create_dummy_image()
110
 
111
  except Exception as e:
112
+ logger.error(f"Unexpected error: {str(e)}")
113
  return create_dummy_image()
114
 
115
+ # 使用最简单的界面
116
+ demo = gr.Interface(
117
+ fn=generate_image,
118
+ inputs=gr.Textbox(label="Enter your prompt"),
119
+ outputs=gr.Image(label="Generated Image", type="pil"),
120
+ title="SDXL Turbo Text-to-Image",
121
+ description="Enter a text prompt to generate an image.",
122
+ examples=["A cute cat"], # 只保留一个简单示例
123
+ cache_examples=False
124
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  # 启动应用
127
  if __name__ == "__main__":
128
  try:
129
  logger.info("Starting Gradio app")
 
130
  demo.launch(
131
+ server_name="0.0.0.0",
132
+ server_port=7860,
133
+ show_api=False, # 禁用API
134
+ share=False
135
  )
136
  except Exception as e:
137
  logger.error(f"Error launching app: {str(e)}")
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
7
- gradio==3.34.0
 
1
+ accelerate==0.21.0
2
+ diffusers==0.20.0
3
+ torch==2.0.1
4
+ transformers==4.34.0
5
+ gradio==3.32.0
6
+ Pillow==10.0.0