hellohf / app.py
lisonallen's picture
Extreme simplification for stability
a0be010
raw
history blame
4.72 kB
import gradio as gr
import numpy as np
import random
import logging
import sys
import os
from PIL import Image as PILImage
# 设置日志记录
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout)
logger = logging.getLogger(__name__)
# 创建一个简单的示例图像,在模型加载失败或生成失败时使用
def create_dummy_image():
# 创建一个简单的彩色图像
img = PILImage.new('RGB', (256, 256), color = (255, 100, 100))
return img
# 全局变量,管理模型状态
pipe = None
MAX_SEED = np.iinfo(np.int32).max
# 简化的推理函数
def generate_image(prompt):
global pipe
try:
# 懒加载模型 - 仅在第一次调用时加载
if pipe is None:
try:
logger.info("First request - loading model...")
from diffusers import DiffusionPipeline
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "stabilityai/sdxl-turbo"
logger.info(f"Using device: {device}")
logger.info(f"Loading model: {model_repo_id}")
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# 优化内存使用
pipe = DiffusionPipeline.from_pretrained(
model_repo_id,
torch_dtype=torch_dtype,
variant="fp16" if torch.cuda.is_available() else None,
use_safetensors=True
)
pipe = pipe.to(device)
# 优化内存
if torch.cuda.is_available():
pipe.enable_attention_slicing()
# 释放不必要的内存
torch.cuda.empty_cache()
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return create_dummy_image()
# 处理空提示
if not prompt or prompt.strip() == "":
prompt = "A beautiful landscape"
logger.info(f"Empty prompt, using default: {prompt}")
logger.info(f"Generating image for prompt: {prompt}")
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
# 简化参数和异常处理
try:
# 使用最小的推理步骤以减轻资源压力
image = pipe(
prompt=prompt,
guidance_scale=0.0, # 设为0以获得最快的推理时间
num_inference_steps=1, # 减少步骤
generator=generator,
height=256, # 减小图像尺寸
width=256 # 减小图像尺寸
).images[0]
# 确保图像是有效的PIL图像
if not isinstance(image, PILImage.Image):
logger.warning("Converting image to PIL format")
image = PILImage.fromarray(np.array(image))
# 转换图像模式
if image.mode != 'RGB':
image = image.convert('RGB')
logger.info("Image generation successful")
# 释放内存
if torch.cuda.is_available():
torch.cuda.empty_cache()
return image
except Exception as e:
logger.error(f"Error in image generation: {str(e)}")
return create_dummy_image()
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
return create_dummy_image()
# 使用最简单的界面
demo = gr.Interface(
fn=generate_image,
inputs=gr.Textbox(label="Enter your prompt"),
outputs=gr.Image(label="Generated Image", type="pil"),
title="SDXL Turbo Text-to-Image",
description="Enter a text prompt to generate an image.",
examples=["A cute cat"], # 只保留一个简单示例
cache_examples=False
)
# 启动应用
if __name__ == "__main__":
try:
logger.info("Starting Gradio app")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_api=False, # 禁用API
share=False
)
except Exception as e:
logger.error(f"Error launching app: {str(e)}")