Spaces:
Sleeping
Sleeping
File size: 6,159 Bytes
82df48f 407a5fa 8718399 65762c4 407a5fa 7a3e379 3b80791 65762c4 3b80791 a0be010 65762c4 4564834 7a3e379 4564834 7a3e379 4564834 7a3e379 82df48f 4564834 e2bc0a8 82df48f 407a5fa 4564834 7a3e379 4564834 407a5fa 4564834 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import gradio as gr
import numpy as np
import random
import logging
import sys
import os
from PIL import Image as PILImage
# 设置日志记录
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout)
logger = logging.getLogger(__name__)
# 修复 Gradio JSON Schema 错误
try:
import gradio_client.utils
# 保存原始函数
original_get_type = gradio_client.utils.get_type
# 创建新的 get_type 函数,处理布尔值
def patched_get_type(schema):
if schema is True or schema is False or schema is None:
return "any"
if not isinstance(schema, dict):
return "any"
return original_get_type(schema)
# 替换原始函数
gradio_client.utils.get_type = patched_get_type
logger.info("Successfully patched gradio_client.utils.get_type")
# 同样修补 _json_schema_to_python_type 函数
original_json_schema_to_python_type = gradio_client.utils._json_schema_to_python_type
def patched_json_schema_to_python_type(schema, defs=None):
if schema is True or schema is False:
return "bool"
if schema is None:
return "None"
if not isinstance(schema, dict):
return "any"
try:
return original_json_schema_to_python_type(schema, defs)
except Exception as e:
logger.warning(f"Error in json_schema_to_python_type: {e}")
return "any"
gradio_client.utils._json_schema_to_python_type = patched_json_schema_to_python_type
logger.info("Successfully patched gradio_client.utils._json_schema_to_python_type")
except Exception as e:
logger.error(f"Failed to patch Gradio utils: {e}")
# 创建一个简单的示例图像
def create_dummy_image():
logger.info("Creating dummy image")
img = PILImage.new('RGB', (256, 256), color = (255, 100, 100))
return img
# 全局变量
pipe = None
# 懒加载AI模型函数
def get_model():
try:
import torch
from diffusers import StableDiffusionPipeline
logger.info("开始加载模型...")
# 使用较小的模型而不是SDXL
model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"使用设备: {device}")
# 优化设置以减少内存使用
if torch.cuda.is_available():
# 使用半精度
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
safety_checker=None, # 禁用安全检查器以节省内存
requires_safety_checker=False,
use_safetensors=True
)
pipe = pipe.to(device)
pipe.enable_attention_slicing() # 减少显存使用
# 释放不必要的内存
torch.cuda.empty_cache()
else:
# CPU版本,占用内存较大,但现在只用处理一个请求
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
safety_checker=None,
requires_safety_checker=False,
use_safetensors=True
)
pipe = pipe.to(device)
logger.info("模型加载成功")
return pipe
except Exception as e:
logger.error(f"模型加载失败: {e}")
return None
# 生成图像函数
def generate_image(prompt):
global pipe
# 如果提示为空,使用默认提示
if not prompt or prompt.strip() == "":
prompt = "a beautiful landscape"
logger.info(f"输入为空,使用默认提示词: {prompt}")
logger.info(f"收到提示词: {prompt}")
# 第一次调用时加载模型
if pipe is None:
pipe = get_model()
if pipe is None:
logger.error("模型加载失败,返回默认图像")
return create_dummy_image()
try:
# 优化生成参数,减少内存需求
logger.info("开始生成图像...")
# 设置随机种子以确保结果一致性
seed = random.randint(0, 2147483647)
generator = torch.Generator(device=pipe.device).manual_seed(seed)
# 使用最轻量级的参数
image = pipe(
prompt=prompt,
num_inference_steps=3, # 极少的步骤
guidance_scale=7.5,
height=256, # 小尺寸
width=256, # 小尺寸
generator=generator
).images[0]
# 释放缓存,避免内存增长
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info(f"图像生成成功,种子: {seed}")
return image
except Exception as e:
logger.error(f"生成过程发生错误: {e}")
return create_dummy_image()
# 创建Gradio界面
def create_demo():
# 使用简单界面,避免复杂组件
demo = gr.Interface(
fn=generate_image,
inputs=gr.Textbox(label="输入提示词"),
outputs=gr.Image(type="pil", label="生成的图像"),
title="文本到图像生成",
description="输入文本描述,AI将生成相应的图像(会加载较长时间,请耐心等待)",
examples=["a cute cat", "mountain landscape"],
cache_examples=False,
allow_flagging="never" # 禁用标记功能以减少复杂性
)
return demo
# 创建演示界面
demo = create_demo()
# 启动应用
if __name__ == "__main__":
try:
logger.info("启动Gradio界面...")
# 使用最小配置
demo.launch(
server_name="0.0.0.0",
show_api=False, # 禁用API
share=False, # 不创建公共链接
debug=False, # 禁用调试模式
quiet=True # 减少日志输出
)
except Exception as e:
logger.error(f"启动失败: {e}")
|