Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import random | |
import logging | |
import sys | |
import os | |
from PIL import Image as PILImage | |
# 设置日志记录 | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
stream=sys.stdout) | |
logger = logging.getLogger(__name__) | |
# 修复 Gradio JSON Schema 错误 | |
try: | |
import gradio_client.utils | |
# 保存原始函数 | |
original_get_type = gradio_client.utils.get_type | |
# 创建新的 get_type 函数,处理布尔值 | |
def patched_get_type(schema): | |
if schema is True or schema is False or schema is None: | |
return "any" | |
if not isinstance(schema, dict): | |
return "any" | |
return original_get_type(schema) | |
# 替换原始函数 | |
gradio_client.utils.get_type = patched_get_type | |
logger.info("Successfully patched gradio_client.utils.get_type") | |
# 同样修补 _json_schema_to_python_type 函数 | |
original_json_schema_to_python_type = gradio_client.utils._json_schema_to_python_type | |
def patched_json_schema_to_python_type(schema, defs=None): | |
if schema is True or schema is False: | |
return "bool" | |
if schema is None: | |
return "None" | |
if not isinstance(schema, dict): | |
return "any" | |
try: | |
return original_json_schema_to_python_type(schema, defs) | |
except Exception as e: | |
logger.warning(f"Error in json_schema_to_python_type: {e}") | |
return "any" | |
gradio_client.utils._json_schema_to_python_type = patched_json_schema_to_python_type | |
logger.info("Successfully patched gradio_client.utils._json_schema_to_python_type") | |
except Exception as e: | |
logger.error(f"Failed to patch Gradio utils: {e}") | |
# 创建一个简单的示例图像 | |
def create_dummy_image(): | |
logger.info("Creating dummy image") | |
img = PILImage.new('RGB', (256, 256), color = (255, 100, 100)) | |
return img | |
# 全局变量 | |
pipe = None | |
# 懒加载AI模型函数 | |
def get_model(): | |
try: | |
import torch | |
from diffusers import StableDiffusionPipeline | |
logger.info("开始加载模型...") | |
# 使用较小的模型而不是SDXL | |
model_id = "runwayml/stable-diffusion-v1-5" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"使用设备: {device}") | |
# 优化设置以减少内存使用 | |
if torch.cuda.is_available(): | |
# 使用半精度 | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
safety_checker=None, # 禁用安全检查器以节省内存 | |
requires_safety_checker=False, | |
use_safetensors=True | |
) | |
pipe = pipe.to(device) | |
pipe.enable_attention_slicing() # 减少显存使用 | |
# 释放不必要的内存 | |
torch.cuda.empty_cache() | |
else: | |
# CPU版本,占用内存较大,但现在只用处理一个请求 | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, | |
safety_checker=None, | |
requires_safety_checker=False, | |
use_safetensors=True | |
) | |
pipe = pipe.to(device) | |
logger.info("模型加载成功") | |
return pipe | |
except Exception as e: | |
logger.error(f"模型加载失败: {e}") | |
return None | |
# 生成图像函数 | |
def generate_image(prompt): | |
global pipe | |
# 如果提示为空,使用默认提示 | |
if not prompt or prompt.strip() == "": | |
prompt = "a beautiful landscape" | |
logger.info(f"输入为空,使用默认提示词: {prompt}") | |
logger.info(f"收到提示词: {prompt}") | |
# 第一次调用时加载模型 | |
if pipe is None: | |
pipe = get_model() | |
if pipe is None: | |
logger.error("模型加载失败,返回默认图像") | |
return create_dummy_image() | |
try: | |
# 优化生成参数,减少内存需求 | |
logger.info("开始生成图像...") | |
# 设置随机种子以确保结果一致性 | |
seed = random.randint(0, 2147483647) | |
generator = torch.Generator(device=pipe.device).manual_seed(seed) | |
# 使用最轻量级的参数 | |
image = pipe( | |
prompt=prompt, | |
num_inference_steps=3, # 极少的步骤 | |
guidance_scale=7.5, | |
height=256, # 小尺寸 | |
width=256, # 小尺寸 | |
generator=generator | |
).images[0] | |
# 释放缓存,避免内存增长 | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
logger.info(f"图像生成成功,种子: {seed}") | |
return image | |
except Exception as e: | |
logger.error(f"生成过程发生错误: {e}") | |
return create_dummy_image() | |
# 创建Gradio界面 | |
def create_demo(): | |
# 使用简单界面,避免复杂组件 | |
demo = gr.Interface( | |
fn=generate_image, | |
inputs=gr.Textbox(label="输入提示词"), | |
outputs=gr.Image(type="pil", label="生成的图像"), | |
title="文本到图像生成", | |
description="输入文本描述,AI将生成相应的图像(会加载较长时间,请耐心等待)", | |
examples=["a cute cat", "mountain landscape"], | |
cache_examples=False, | |
allow_flagging="never" # 禁用标记功能以减少复杂性 | |
) | |
return demo | |
# 创建演示界面 | |
demo = create_demo() | |
# 启动应用 | |
if __name__ == "__main__": | |
try: | |
logger.info("启动Gradio界面...") | |
# 使用最小配置 | |
demo.launch( | |
server_name="0.0.0.0", | |
show_api=False, # 禁用API | |
share=False, # 不创建公共链接 | |
debug=False, # 禁用调试模式 | |
quiet=True # 减少日志输出 | |
) | |
except Exception as e: | |
logger.error(f"启动失败: {e}") | |