Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import random | |
import logging | |
import sys | |
import os | |
from PIL import Image as PILImage | |
# 设置日志记录 | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
stream=sys.stdout) | |
logger = logging.getLogger(__name__) | |
# 补丁修复 Gradio JSON Schema 错误 | |
try: | |
import gradio_client.utils | |
# 保存原始函数 | |
original_get_type = gradio_client.utils.get_type | |
# 创建新的 get_type 函数,处理布尔值 | |
def patched_get_type(schema): | |
if schema is True or schema is False or schema is None: | |
return "any" | |
if not isinstance(schema, dict): | |
return "any" | |
return original_get_type(schema) | |
# 替换原始函数 | |
gradio_client.utils.get_type = patched_get_type | |
logger.info("Successfully patched gradio_client.utils.get_type") | |
# 同样修补 _json_schema_to_python_type 函数 | |
original_json_schema_to_python_type = gradio_client.utils._json_schema_to_python_type | |
def patched_json_schema_to_python_type(schema, defs=None): | |
if schema is True or schema is False: | |
return "bool" | |
if schema is None: | |
return "None" | |
if not isinstance(schema, dict): | |
return "any" | |
try: | |
return original_json_schema_to_python_type(schema, defs) | |
except Exception as e: | |
logger.warning(f"Error in json_schema_to_python_type: {e}") | |
return "any" | |
gradio_client.utils._json_schema_to_python_type = patched_json_schema_to_python_type | |
logger.info("Successfully patched gradio_client.utils._json_schema_to_python_type") | |
except Exception as e: | |
logger.error(f"Failed to patch Gradio utils: {e}") | |
# 创建一个备用图像 | |
def create_backup_image(prompt=""): | |
logger.info(f"创建备用图像: {prompt}") | |
img = PILImage.new('RGB', (512, 512), color=(240, 240, 250)) | |
try: | |
from PIL import ImageDraw, ImageFont | |
draw = ImageDraw.Draw(img) | |
font = ImageFont.load_default() | |
draw.text((20, 20), f"提示词: {prompt}", fill=(0, 0, 0), font=font) | |
draw.text((20, 60), "模型加载失败,无法生成图像", fill=(255, 0, 0), font=font) | |
except Exception as e: | |
logger.error(f"创建备用图像时出错: {e}") | |
return img | |
# 预加载 AI 模型 | |
model = None | |
def load_model(): | |
global model | |
if model is not None: | |
return model | |
try: | |
logger.info("开始加载AI模型...") | |
# 延迟导入,确保所有依赖都已正确安装 | |
import torch | |
from diffusers import StableDiffusionPipeline | |
# 使用较低版本的模型 | |
model_id = "CompVis/stable-diffusion-v1-4" | |
# 设置加载参数 | |
load_options = { | |
"revision": "fp16" if torch.cuda.is_available() else None, | |
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32, | |
"safety_checker": None | |
} | |
logger.info(f"使用模型: {model_id}") | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, **load_options) | |
# 转移到适当的设备 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = pipe.to(device) | |
# 优化 | |
if torch.cuda.is_available(): | |
pipe.enable_attention_slicing() | |
logger.info("AI模型加载成功") | |
model = pipe | |
return model | |
except Exception as e: | |
logger.error(f"AI模型加载失败: {e}") | |
return None | |
# AI 图像生成函数 | |
def generate_ai_image(prompt, seed=None): | |
# 尝试加载模型 | |
pipe = load_model() | |
if pipe is None: | |
logger.error("AI模型不可用") | |
return None | |
try: | |
logger.info(f"使用AI生成图像: {prompt}") | |
# 设置生成参数 | |
if seed is None: | |
seed = random.randint(0, 2147483647) | |
# 确定正确的设备 | |
generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed) | |
# 生成图像 | |
image = pipe( | |
prompt=prompt, | |
guidance_scale=7.5, | |
num_inference_steps=5, # 降低步数以加快速度 | |
generator=generator, | |
height=512, | |
width=512 | |
).images[0] | |
# 清理缓存 | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
logger.info(f"AI图像生成成功,种子: {seed}") | |
return image | |
except Exception as e: | |
logger.error(f"AI图像生成失败: {e}") | |
return None | |
# 入口点函数 - 处理请求并生成图像 | |
def generate_image(prompt): | |
# 处理空提示 | |
if not prompt or prompt.strip() == "": | |
prompt = "a beautiful landscape" | |
logger.info(f"输入为空,使用默认提示词: {prompt}") | |
logger.info(f"收到提示词: {prompt}") | |
# 尝试使用AI生成 | |
image = generate_ai_image(prompt) | |
# 检查结果 | |
if image is not None: | |
return image | |
else: | |
logger.warning("使用备用生成器") | |
return create_backup_image(prompt) | |
# 创建Gradio界面 | |
def create_demo(): | |
with gr.Blocks(title="AI 文本到图像生成器") as demo: | |
gr.Markdown("# AI 文本到图像生成器") | |
gr.Markdown("输入文本描述,AI将为你生成相应的图像。") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
# 输入区域 | |
prompt_input = gr.Textbox( | |
label="输入提示词", | |
placeholder="描述你想要的图像,例如:一只可爱的猫,日落下的山脉...", | |
lines=2 | |
) | |
generate_button = gr.Button("生成图像", variant="primary") | |
# 示例 | |
gr.Examples( | |
examples=[ | |
"a cute cat sitting on a windowsill", | |
"beautiful sunset over mountains", | |
"an astronaut riding a horse in space", | |
"a fantasy castle on a floating island" | |
], | |
inputs=prompt_input | |
) | |
# 输出区域 | |
with gr.Column(scale=5): | |
output_image = gr.Image(label="生成的图像", type="pil") | |
# 绑定按钮事件 | |
generate_button.click( | |
fn=generate_image, | |
inputs=prompt_input, | |
outputs=output_image | |
) | |
# 也绑定Enter键提交 | |
prompt_input.submit( | |
fn=generate_image, | |
inputs=prompt_input, | |
outputs=output_image | |
) | |
return demo | |
# 创建演示界面 | |
demo = create_demo() | |
# 启动应用 | |
if __name__ == "__main__": | |
try: | |
logger.info("启动Gradio界面...") | |
demo.launch( | |
server_name="0.0.0.0", | |
show_api=False, | |
share=False | |
) | |
except Exception as e: | |
logger.error(f"启动失败: {e}") | |