hellohf / app.py
lisonallen's picture
恢复AI图像生成功能,降级依赖以解决兼容性问题
0cfce88
raw
history blame
7.33 kB
import gradio as gr
import numpy as np
import random
import logging
import sys
import os
from PIL import Image as PILImage
# 设置日志记录
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout)
logger = logging.getLogger(__name__)
# 补丁修复 Gradio JSON Schema 错误
try:
import gradio_client.utils
# 保存原始函数
original_get_type = gradio_client.utils.get_type
# 创建新的 get_type 函数,处理布尔值
def patched_get_type(schema):
if schema is True or schema is False or schema is None:
return "any"
if not isinstance(schema, dict):
return "any"
return original_get_type(schema)
# 替换原始函数
gradio_client.utils.get_type = patched_get_type
logger.info("Successfully patched gradio_client.utils.get_type")
# 同样修补 _json_schema_to_python_type 函数
original_json_schema_to_python_type = gradio_client.utils._json_schema_to_python_type
def patched_json_schema_to_python_type(schema, defs=None):
if schema is True or schema is False:
return "bool"
if schema is None:
return "None"
if not isinstance(schema, dict):
return "any"
try:
return original_json_schema_to_python_type(schema, defs)
except Exception as e:
logger.warning(f"Error in json_schema_to_python_type: {e}")
return "any"
gradio_client.utils._json_schema_to_python_type = patched_json_schema_to_python_type
logger.info("Successfully patched gradio_client.utils._json_schema_to_python_type")
except Exception as e:
logger.error(f"Failed to patch Gradio utils: {e}")
# 创建一个备用图像
def create_backup_image(prompt=""):
logger.info(f"创建备用图像: {prompt}")
img = PILImage.new('RGB', (512, 512), color=(240, 240, 250))
try:
from PIL import ImageDraw, ImageFont
draw = ImageDraw.Draw(img)
font = ImageFont.load_default()
draw.text((20, 20), f"提示词: {prompt}", fill=(0, 0, 0), font=font)
draw.text((20, 60), "模型加载失败,无法生成图像", fill=(255, 0, 0), font=font)
except Exception as e:
logger.error(f"创建备用图像时出错: {e}")
return img
# 预加载 AI 模型
model = None
def load_model():
global model
if model is not None:
return model
try:
logger.info("开始加载AI模型...")
# 延迟导入,确保所有依赖都已正确安装
import torch
from diffusers import StableDiffusionPipeline
# 使用较低版本的模型
model_id = "CompVis/stable-diffusion-v1-4"
# 设置加载参数
load_options = {
"revision": "fp16" if torch.cuda.is_available() else None,
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
"safety_checker": None
}
logger.info(f"使用模型: {model_id}")
pipe = StableDiffusionPipeline.from_pretrained(model_id, **load_options)
# 转移到适当的设备
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipe.to(device)
# 优化
if torch.cuda.is_available():
pipe.enable_attention_slicing()
logger.info("AI模型加载成功")
model = pipe
return model
except Exception as e:
logger.error(f"AI模型加载失败: {e}")
return None
# AI 图像生成函数
def generate_ai_image(prompt, seed=None):
# 尝试加载模型
pipe = load_model()
if pipe is None:
logger.error("AI模型不可用")
return None
try:
logger.info(f"使用AI生成图像: {prompt}")
# 设置生成参数
if seed is None:
seed = random.randint(0, 2147483647)
# 确定正确的设备
generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed)
# 生成图像
image = pipe(
prompt=prompt,
guidance_scale=7.5,
num_inference_steps=5, # 降低步数以加快速度
generator=generator,
height=512,
width=512
).images[0]
# 清理缓存
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info(f"AI图像生成成功,种子: {seed}")
return image
except Exception as e:
logger.error(f"AI图像生成失败: {e}")
return None
# 入口点函数 - 处理请求并生成图像
def generate_image(prompt):
# 处理空提示
if not prompt or prompt.strip() == "":
prompt = "a beautiful landscape"
logger.info(f"输入为空,使用默认提示词: {prompt}")
logger.info(f"收到提示词: {prompt}")
# 尝试使用AI生成
image = generate_ai_image(prompt)
# 检查结果
if image is not None:
return image
else:
logger.warning("使用备用生成器")
return create_backup_image(prompt)
# 创建Gradio界面
def create_demo():
with gr.Blocks(title="AI 文本到图像生成器") as demo:
gr.Markdown("# AI 文本到图像生成器")
gr.Markdown("输入文本描述,AI将为你生成相应的图像。")
with gr.Row():
with gr.Column(scale=3):
# 输入区域
prompt_input = gr.Textbox(
label="输入提示词",
placeholder="描述你想要的图像,例如:一只可爱的猫,日落下的山脉...",
lines=2
)
generate_button = gr.Button("生成图像", variant="primary")
# 示例
gr.Examples(
examples=[
"a cute cat sitting on a windowsill",
"beautiful sunset over mountains",
"an astronaut riding a horse in space",
"a fantasy castle on a floating island"
],
inputs=prompt_input
)
# 输出区域
with gr.Column(scale=5):
output_image = gr.Image(label="生成的图像", type="pil")
# 绑定按钮事件
generate_button.click(
fn=generate_image,
inputs=prompt_input,
outputs=output_image
)
# 也绑定Enter键提交
prompt_input.submit(
fn=generate_image,
inputs=prompt_input,
outputs=output_image
)
return demo
# 创建演示界面
demo = create_demo()
# 启动应用
if __name__ == "__main__":
try:
logger.info("启动Gradio界面...")
demo.launch(
server_name="0.0.0.0",
show_api=False,
share=False
)
except Exception as e:
logger.error(f"启动失败: {e}")