hellohf / app.py
lisonallen's picture
恢复AI图像生成功能,使用更轻量的方式
4564834
raw
history blame
6.16 kB
import gradio as gr
import numpy as np
import random
import logging
import sys
import os
from PIL import Image as PILImage
# 设置日志记录
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout)
logger = logging.getLogger(__name__)
# 修复 Gradio JSON Schema 错误
try:
import gradio_client.utils
# 保存原始函数
original_get_type = gradio_client.utils.get_type
# 创建新的 get_type 函数,处理布尔值
def patched_get_type(schema):
if schema is True or schema is False or schema is None:
return "any"
if not isinstance(schema, dict):
return "any"
return original_get_type(schema)
# 替换原始函数
gradio_client.utils.get_type = patched_get_type
logger.info("Successfully patched gradio_client.utils.get_type")
# 同样修补 _json_schema_to_python_type 函数
original_json_schema_to_python_type = gradio_client.utils._json_schema_to_python_type
def patched_json_schema_to_python_type(schema, defs=None):
if schema is True or schema is False:
return "bool"
if schema is None:
return "None"
if not isinstance(schema, dict):
return "any"
try:
return original_json_schema_to_python_type(schema, defs)
except Exception as e:
logger.warning(f"Error in json_schema_to_python_type: {e}")
return "any"
gradio_client.utils._json_schema_to_python_type = patched_json_schema_to_python_type
logger.info("Successfully patched gradio_client.utils._json_schema_to_python_type")
except Exception as e:
logger.error(f"Failed to patch Gradio utils: {e}")
# 创建一个简单的示例图像
def create_dummy_image():
logger.info("Creating dummy image")
img = PILImage.new('RGB', (256, 256), color = (255, 100, 100))
return img
# 全局变量
pipe = None
# 懒加载AI模型函数
def get_model():
try:
import torch
from diffusers import StableDiffusionPipeline
logger.info("开始加载模型...")
# 使用较小的模型而不是SDXL
model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"使用设备: {device}")
# 优化设置以减少内存使用
if torch.cuda.is_available():
# 使用半精度
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
safety_checker=None, # 禁用安全检查器以节省内存
requires_safety_checker=False,
use_safetensors=True
)
pipe = pipe.to(device)
pipe.enable_attention_slicing() # 减少显存使用
# 释放不必要的内存
torch.cuda.empty_cache()
else:
# CPU版本,占用内存较大,但现在只用处理一个请求
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
safety_checker=None,
requires_safety_checker=False,
use_safetensors=True
)
pipe = pipe.to(device)
logger.info("模型加载成功")
return pipe
except Exception as e:
logger.error(f"模型加载失败: {e}")
return None
# 生成图像函数
def generate_image(prompt):
global pipe
# 如果提示为空,使用默认提示
if not prompt or prompt.strip() == "":
prompt = "a beautiful landscape"
logger.info(f"输入为空,使用默认提示词: {prompt}")
logger.info(f"收到提示词: {prompt}")
# 第一次调用时加载模型
if pipe is None:
pipe = get_model()
if pipe is None:
logger.error("模型加载失败,返回默认图像")
return create_dummy_image()
try:
# 优化生成参数,减少内存需求
logger.info("开始生成图像...")
# 设置随机种子以确保结果一致性
seed = random.randint(0, 2147483647)
generator = torch.Generator(device=pipe.device).manual_seed(seed)
# 使用最轻量级的参数
image = pipe(
prompt=prompt,
num_inference_steps=3, # 极少的步骤
guidance_scale=7.5,
height=256, # 小尺寸
width=256, # 小尺寸
generator=generator
).images[0]
# 释放缓存,避免内存增长
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info(f"图像生成成功,种子: {seed}")
return image
except Exception as e:
logger.error(f"生成过程发生错误: {e}")
return create_dummy_image()
# 创建Gradio界面
def create_demo():
# 使用简单界面,避免复杂组件
demo = gr.Interface(
fn=generate_image,
inputs=gr.Textbox(label="输入提示词"),
outputs=gr.Image(type="pil", label="生成的图像"),
title="文本到图像生成",
description="输入文本描述,AI将生成相应的图像(会加载较长时间,请耐心等待)",
examples=["a cute cat", "mountain landscape"],
cache_examples=False,
allow_flagging="never" # 禁用标记功能以减少复杂性
)
return demo
# 创建演示界面
demo = create_demo()
# 启动应用
if __name__ == "__main__":
try:
logger.info("启动Gradio界面...")
# 使用最小配置
demo.launch(
server_name="0.0.0.0",
show_api=False, # 禁用API
share=False, # 不创建公共链接
debug=False, # 禁用调试模式
quiet=True # 减少日志输出
)
except Exception as e:
logger.error(f"启动失败: {e}")