File size: 4,723 Bytes
82df48f
 
 
407a5fa
 
8718399
65762c4
407a5fa
 
 
 
 
 
 
65762c4
 
a0be010
 
65762c4
 
a0be010
 
 
82df48f
a0be010
1ca21ac
a0be010
 
407a5fa
a0be010
1ca21ac
a0be010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65762c4
a0be010
65762c4
 
 
407a5fa
1ca21ac
 
407a5fa
1ca21ac
a0be010
8718399
a0be010
8718399
 
a0be010
 
 
 
 
8718399
65762c4
a0be010
65762c4
a0be010
 
 
 
65762c4
 
 
 
 
a0be010
 
 
 
65762c4
a0be010
 
 
65762c4
1ca21ac
407a5fa
a0be010
65762c4
82df48f
a0be010
 
 
 
 
 
 
 
 
 
e2bc0a8
 
82df48f
407a5fa
 
8718399
a0be010
 
 
 
8718399
407a5fa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import numpy as np
import random
import logging
import sys
import os
from PIL import Image as PILImage

# 设置日志记录
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                    stream=sys.stdout)
logger = logging.getLogger(__name__)

# 创建一个简单的示例图像,在模型加载失败或生成失败时使用
def create_dummy_image():
    # 创建一个简单的彩色图像
    img = PILImage.new('RGB', (256, 256), color = (255, 100, 100))
    return img

# 全局变量,管理模型状态
pipe = None
MAX_SEED = np.iinfo(np.int32).max

# 简化的推理函数
def generate_image(prompt):
    global pipe
    
    try:
        # 懒加载模型 - 仅在第一次调用时加载
        if pipe is None:
            try:
                logger.info("First request - loading model...")
                from diffusers import DiffusionPipeline
                import torch
                
                device = "cuda" if torch.cuda.is_available() else "cpu"
                model_repo_id = "stabilityai/sdxl-turbo"
                
                logger.info(f"Using device: {device}")
                logger.info(f"Loading model: {model_repo_id}")
                
                if torch.cuda.is_available():
                    torch_dtype = torch.float16
                else:
                    torch_dtype = torch.float32
                
                # 优化内存使用
                pipe = DiffusionPipeline.from_pretrained(
                    model_repo_id, 
                    torch_dtype=torch_dtype,
                    variant="fp16" if torch.cuda.is_available() else None,
                    use_safetensors=True
                )
                pipe = pipe.to(device)
                
                # 优化内存
                if torch.cuda.is_available():
                    pipe.enable_attention_slicing()
                    # 释放不必要的内存
                    torch.cuda.empty_cache()
                
                logger.info("Model loaded successfully")
                
            except Exception as e:
                logger.error(f"Error loading model: {str(e)}")
                return create_dummy_image()
            
        # 处理空提示
        if not prompt or prompt.strip() == "":
            prompt = "A beautiful landscape"
            logger.info(f"Empty prompt, using default: {prompt}")
            
        logger.info(f"Generating image for prompt: {prompt}")
        seed = random.randint(0, MAX_SEED)
        generator = torch.Generator().manual_seed(seed)
        
        # 简化参数和异常处理
        try:
            # 使用最小的推理步骤以减轻资源压力
            image = pipe(
                prompt=prompt,
                guidance_scale=0.0,  # 设为0以获得最快的推理时间
                num_inference_steps=1,  # 减少步骤
                generator=generator,
                height=256,  # 减小图像尺寸
                width=256    # 减小图像尺寸
            ).images[0]
            
            # 确保图像是有效的PIL图像
            if not isinstance(image, PILImage.Image):
                logger.warning("Converting image to PIL format")
                image = PILImage.fromarray(np.array(image))
                
            # 转换图像模式
            if image.mode != 'RGB':
                image = image.convert('RGB')
                
            logger.info("Image generation successful")
            
            # 释放内存
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                
            return image
            
        except Exception as e:
            logger.error(f"Error in image generation: {str(e)}")
            return create_dummy_image()
        
    except Exception as e:
        logger.error(f"Unexpected error: {str(e)}")
        return create_dummy_image()

# 使用最简单的界面
demo = gr.Interface(
    fn=generate_image,
    inputs=gr.Textbox(label="Enter your prompt"),
    outputs=gr.Image(label="Generated Image", type="pil"),
    title="SDXL Turbo Text-to-Image",
    description="Enter a text prompt to generate an image.",
    examples=["A cute cat"],  # 只保留一个简单示例
    cache_examples=False
)

# 启动应用
if __name__ == "__main__":
    try:
        logger.info("Starting Gradio app")
        demo.launch(
            server_name="0.0.0.0",
            server_port=7860,
            show_api=False,  # 禁用API
            share=False
        )
    except Exception as e:
        logger.error(f"Error launching app: {str(e)}")