Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import random | |
import logging | |
import sys | |
import os | |
from PIL import Image as PILImage | |
# 设置日志记录 | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
stream=sys.stdout) | |
logger = logging.getLogger(__name__) | |
# 补丁修复 Gradio JSON Schema 错误 | |
try: | |
import gradio_client.utils | |
# 保存原始函数 | |
original_get_type = gradio_client.utils.get_type | |
# 创建新的 get_type 函数,处理布尔值 | |
def patched_get_type(schema): | |
if schema is True or schema is False or schema is None: | |
return "any" | |
if not isinstance(schema, dict): | |
return "any" | |
return original_get_type(schema) | |
# 替换原始函数 | |
gradio_client.utils.get_type = patched_get_type | |
logger.info("Successfully patched gradio_client.utils.get_type") | |
# 同样修补 _json_schema_to_python_type 函数 | |
original_json_schema_to_python_type = gradio_client.utils._json_schema_to_python_type | |
def patched_json_schema_to_python_type(schema, defs=None): | |
if schema is True or schema is False: | |
return "bool" | |
if schema is None: | |
return "None" | |
if not isinstance(schema, dict): | |
return "any" | |
try: | |
return original_json_schema_to_python_type(schema, defs) | |
except Exception as e: | |
logger.warning(f"Error in json_schema_to_python_type: {e}") | |
return "any" | |
gradio_client.utils._json_schema_to_python_type = patched_json_schema_to_python_type | |
logger.info("Successfully patched gradio_client.utils._json_schema_to_python_type") | |
except Exception as e: | |
logger.error(f"Failed to patch Gradio utils: {e}") | |
# 创建一个备用图像 | |
def create_backup_image(prompt=""): | |
logger.info(f"Creating backup image for: {prompt}") | |
img = PILImage.new('RGB', (512, 512), color=(240, 240, 250)) | |
try: | |
from PIL import ImageDraw, ImageFont | |
draw = ImageDraw.Draw(img) | |
font = ImageFont.load_default() | |
# 使用英文消息避免编码问题 | |
draw.text((20, 20), f"Prompt: {prompt}", fill=(0, 0, 0), font=font) | |
draw.text((20, 60), "Model loading failed. Showing placeholder image.", fill=(255, 0, 0), font=font) | |
except Exception as e: | |
logger.error(f"Error creating backup image: {e}") | |
return img | |
# 预加载图像用于快速响应 | |
PLACEHOLDER_IMAGE = create_backup_image("placeholder") | |
# 尝试导入必要的AI库 | |
try: | |
import torch | |
from diffusers import StableDiffusionPipeline | |
HAS_AI_LIBS = True | |
logger.info("Successfully imported AI libraries") | |
except ImportError as e: | |
logger.error(f"Failed to import AI libraries: {e}") | |
HAS_AI_LIBS = False | |
# AI 模型加载和图像生成 | |
def generate_ai_image(prompt, seed=None): | |
if not HAS_AI_LIBS: | |
logger.error("AI libraries not available") | |
return PLACEHOLDER_IMAGE | |
# 设置随机种子 | |
if seed is None: | |
seed = random.randint(0, 2147483647) | |
try: | |
logger.info(f"Generating image for: {prompt}") | |
# 使用兼容的旧版本API加载模型 | |
model_id = "runwayml/stable-diffusion-v1-5" | |
logger.info(f"Loading model: {model_id}") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# 加载模型 - 使用兼容的低级API | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
use_auth_token=False, # 明确不使用认证 | |
revision="main", # 使用主分支 | |
safety_checker=None, # 禁用安全检查器 | |
) | |
pipe = pipe.to(device) | |
# 优化内存 | |
if torch.cuda.is_available(): | |
pipe.enable_attention_slicing() | |
torch.cuda.empty_cache() | |
logger.info("Model loaded, generating image...") | |
# 生成图像 | |
generator = torch.Generator(device).manual_seed(seed) | |
image = pipe( | |
prompt=prompt, | |
guidance_scale=7.5, | |
num_inference_steps=4, # 最小步数 | |
generator=generator, | |
height=512, | |
width=512 | |
).images[0] | |
# 清理缓存 | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
logger.info(f"Image generation successful with seed: {seed}") | |
return image | |
except Exception as e: | |
logger.error(f"AI image generation failed: {e}") | |
return create_backup_image(prompt) | |
# 使用简单的规则生成图像作为备用方案 | |
def generate_rule_based_image(prompt): | |
"""当AI模型不可用时使用规则生成图像""" | |
logger.info(f"Using rule-based generator for: {prompt}") | |
# 创建基础图像 | |
img = PILImage.new('RGB', (512, 512), color=(240, 240, 250)) | |
try: | |
from PIL import ImageDraw, ImageFont | |
draw = ImageDraw.Draw(img) | |
# 提取关键词 | |
prompt_lower = prompt.lower() | |
# 设置默认颜色和形状 | |
bg_color = (240, 240, 250) # 浅蓝背景 | |
shape_color = (64, 64, 128) # 深蓝形状 | |
# 基于关键词调整颜色 | |
if "red" in prompt_lower: | |
shape_color = (200, 50, 50) | |
elif "blue" in prompt_lower: | |
shape_color = (50, 50, 200) | |
elif "green" in prompt_lower: | |
shape_color = (50, 200, 50) | |
elif "yellow" in prompt_lower: | |
shape_color = (200, 200, 50) | |
# 画一个基本形状 | |
if "cat" in prompt_lower or "kitten" in prompt_lower: | |
# 猫头 | |
draw.ellipse((156, 156, 356, 356), fill=shape_color) | |
# 猫眼睛 | |
draw.ellipse((206, 206, 236, 236), fill=(255, 255, 255)) | |
draw.ellipse((276, 206, 306, 236), fill=(255, 255, 255)) | |
# 猫瞳孔 | |
draw.ellipse((216, 216, 226, 226), fill=(0, 0, 0)) | |
draw.ellipse((286, 216, 296, 226), fill=(0, 0, 0)) | |
# 猫鼻子 | |
draw.polygon([(256, 256), (246, 276), (266, 276)], fill=(255, 150, 150)) | |
# 猫耳朵 | |
draw.polygon([(156, 156), (176, 96), (216, 156)], fill=shape_color) | |
draw.polygon([(356, 156), (336, 96), (296, 156)], fill=shape_color) | |
elif "landscape" in prompt_lower or "mountain" in prompt_lower: | |
# 天空 | |
draw.rectangle([(0, 0), (512, 300)], fill=(100, 150, 250)) | |
# 山脉 | |
draw.polygon([(0, 300), (150, 100), (300, 300)], fill=(100, 100, 100)) | |
draw.polygon([(200, 300), (400, 150), (512, 300)], fill=(80, 80, 80)) | |
# 地面 | |
draw.rectangle([(0, 300), (512, 512)], fill=(100, 200, 100)) | |
elif "castle" in prompt_lower or "building" in prompt_lower: | |
# 天空 | |
draw.rectangle([(0, 0), (512, 200)], fill=(150, 200, 250)) | |
# 主塔 | |
draw.rectangle([(200, 200), (312, 400)], fill=shape_color) | |
# 塔顶 | |
draw.polygon([(180, 200), (256, 100), (332, 200)], fill=(180, 0, 0)) | |
# 小塔 | |
draw.rectangle([(150, 300), (200, 400)], fill=shape_color) | |
draw.rectangle([(312, 300), (362, 400)], fill=shape_color) | |
# 城墙 | |
draw.rectangle([(100, 400), (412, 450)], fill=shape_color) | |
# 地面 | |
draw.rectangle([(0, 450), (512, 512)], fill=(100, 150, 100)) | |
else: | |
# 默认绘制几何形状 | |
draw.rectangle([(100, 100), (412, 412)], outline=(0, 0, 0), width=2) | |
draw.ellipse((150, 150, 362, 362), fill=shape_color) | |
draw.polygon([(256, 100), (412, 412), (100, 412)], fill=(shape_color[0]//2, shape_color[1]//2, shape_color[2]//2)) | |
# 添加提示词和说明 | |
font = ImageFont.load_default() | |
draw.text((10, 10), f"Prompt: {prompt}", fill=(0, 0, 0), font=font) | |
draw.text((10, 30), "Generated with rules (AI model unavailable)", fill=(100, 100, 100), font=font) | |
except Exception as e: | |
logger.error(f"Error in rule-based image generation: {e}") | |
return img | |
# 入口点函数 - 处理请求并生成图像 | |
def generate_image(prompt): | |
# 处理空提示 | |
if not prompt or prompt.strip() == "": | |
prompt = "a beautiful landscape" | |
logger.info(f"Empty prompt, using default: {prompt}") | |
logger.info(f"Received prompt: {prompt}") | |
# 尝试使用AI生成 | |
if HAS_AI_LIBS: | |
try: | |
image = generate_ai_image(prompt) | |
if image is not None: | |
return image | |
except Exception as e: | |
logger.error(f"Error using AI generation: {e}") | |
# 如果AI不可用或失败,使用规则生成 | |
logger.warning("Using rule-based image generation") | |
return generate_rule_based_image(prompt) | |
# 创建Gradio界面 | |
def create_demo(): | |
with gr.Blocks(title="Text to Image Generator") as demo: | |
gr.Markdown("# Text to Image Generator") | |
gr.Markdown("Enter a text description to generate an image.") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
# 输入区域 | |
prompt_input = gr.Textbox( | |
label="Prompt", | |
placeholder="Describe the image you want, e.g.: a cute cat, sunset over mountains...", | |
lines=2 | |
) | |
generate_button = gr.Button("Generate Image", variant="primary") | |
# 示例 | |
gr.Examples( | |
examples=[ | |
"a cute cat sitting on a windowsill", | |
"beautiful sunset over mountains", | |
"an astronaut riding a horse in space", | |
"a fantasy castle on a floating island" | |
], | |
inputs=prompt_input | |
) | |
# 输出区域 | |
with gr.Column(scale=5): | |
output_image = gr.Image(label="Generated Image", type="pil") | |
# 绑定按钮事件 | |
generate_button.click( | |
fn=generate_image, | |
inputs=prompt_input, | |
outputs=output_image | |
) | |
# 也绑定Enter键提交 | |
prompt_input.submit( | |
fn=generate_image, | |
inputs=prompt_input, | |
outputs=output_image | |
) | |
return demo | |
# 创建演示界面 | |
demo = create_demo() | |
# 启动应用 | |
if __name__ == "__main__": | |
try: | |
logger.info("Starting Gradio interface...") | |
demo.launch( | |
server_name="0.0.0.0", | |
show_api=False, | |
share=False | |
) | |
except Exception as e: | |
logger.error(f"Failed to launch: {e}") | |