Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import random | |
import logging | |
import sys | |
import os | |
from PIL import Image as PILImage | |
# 设置日志记录 | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
stream=sys.stdout) | |
logger = logging.getLogger(__name__) | |
# 创建一个简单的示例图像,在模型加载失败或生成失败时使用 | |
def create_dummy_image(): | |
# 创建一个简单的彩色图像 | |
img = PILImage.new('RGB', (256, 256), color = (255, 100, 100)) | |
return img | |
# 全局变量,管理模型状态 | |
pipe = None | |
MAX_SEED = np.iinfo(np.int32).max | |
# 简化的推理函数 | |
def generate_image(prompt): | |
global pipe | |
try: | |
# 懒加载模型 - 仅在第一次调用时加载 | |
if pipe is None: | |
try: | |
logger.info("First request - loading model...") | |
from diffusers import DiffusionPipeline | |
import torch | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_repo_id = "stabilityai/sdxl-turbo" | |
logger.info(f"Using device: {device}") | |
logger.info(f"Loading model: {model_repo_id}") | |
if torch.cuda.is_available(): | |
torch_dtype = torch.float16 | |
else: | |
torch_dtype = torch.float32 | |
# 优化内存使用 | |
pipe = DiffusionPipeline.from_pretrained( | |
model_repo_id, | |
torch_dtype=torch_dtype, | |
variant="fp16" if torch.cuda.is_available() else None, | |
use_safetensors=True | |
) | |
pipe = pipe.to(device) | |
# 优化内存 | |
if torch.cuda.is_available(): | |
pipe.enable_attention_slicing() | |
# 释放不必要的内存 | |
torch.cuda.empty_cache() | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
return create_dummy_image() | |
# 处理空提示 | |
if not prompt or prompt.strip() == "": | |
prompt = "A beautiful landscape" | |
logger.info(f"Empty prompt, using default: {prompt}") | |
logger.info(f"Generating image for prompt: {prompt}") | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
# 简化参数和异常处理 | |
try: | |
# 使用最小的推理步骤以减轻资源压力 | |
image = pipe( | |
prompt=prompt, | |
guidance_scale=0.0, # 设为0以获得最快的推理时间 | |
num_inference_steps=1, # 减少步骤 | |
generator=generator, | |
height=256, # 减小图像尺寸 | |
width=256 # 减小图像尺寸 | |
).images[0] | |
# 确保图像是有效的PIL图像 | |
if not isinstance(image, PILImage.Image): | |
logger.warning("Converting image to PIL format") | |
image = PILImage.fromarray(np.array(image)) | |
# 转换图像模式 | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
logger.info("Image generation successful") | |
# 释放内存 | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return image | |
except Exception as e: | |
logger.error(f"Error in image generation: {str(e)}") | |
return create_dummy_image() | |
except Exception as e: | |
logger.error(f"Unexpected error: {str(e)}") | |
return create_dummy_image() | |
# 使用最简单的界面 | |
demo = gr.Interface( | |
fn=generate_image, | |
inputs=gr.Textbox(label="Enter your prompt"), | |
outputs=gr.Image(label="Generated Image", type="pil"), | |
title="SDXL Turbo Text-to-Image", | |
description="Enter a text prompt to generate an image.", | |
examples=["A cute cat"], # 只保留一个简单示例 | |
cache_examples=False | |
) | |
# 启动应用 | |
if __name__ == "__main__": | |
try: | |
logger.info("Starting Gradio app") | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_api=False, # 禁用API | |
share=False | |
) | |
except Exception as e: | |
logger.error(f"Error launching app: {str(e)}") | |