SmilingTree commited on
Commit
2486a6e
·
verified ·
1 Parent(s): 85f4321

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ from diffusers import DiffusionPipeline
5
+ from rembg import remove
6
+ import torch
7
+
8
+ # ===== 初始化模型 =====
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ MAX_SEED = np.iinfo(np.int32).max
11
+ MAX_IMAGE_SIZE = 1024
12
+
13
+ if torch.cuda.is_available():
14
+ pipe = DiffusionPipeline.from_pretrained(
15
+ "stabilityai/sdxl-turbo",
16
+ torch_dtype=torch.float16,
17
+ variant="fp16",
18
+ use_safetensors=True
19
+ )
20
+ pipe.enable_xformers_memory_efficient_attention()
21
+ pipe = pipe.to(device)
22
+ else:
23
+ pipe = DiffusionPipeline.from_pretrained(
24
+ "stabilityai/sdxl-turbo",
25
+ use_safetensors=True
26
+ )
27
+ pipe = pipe.to(device)
28
+
29
+ # ===== 功能函數 =====
30
+ def generate_anime(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
31
+ if randomize_seed:
32
+ seed = random.randint(0, MAX_SEED)
33
+ generator = torch.Generator().manual_seed(seed)
34
+
35
+ image = pipe(
36
+ prompt=f"{prompt}, Anime",
37
+ negative_prompt=negative_prompt,
38
+ guidance_scale=guidance_scale,
39
+ num_inference_steps=num_inference_steps,
40
+ width=width,
41
+ height=height,
42
+ generator=generator
43
+ ).images[0]
44
+
45
+ return image
46
+
47
+ def remove_background(input_img):
48
+ if input_img is None:
49
+ return None
50
+ return remove(input_img)
51
+
52
+ # ===== Gradio 介面設計 =====
53
+ examples = [
54
+ "A well-behaved schoolgirl with glasses",
55
+ "Astronaut in a jungle, cold color palette, 8k",
56
+ "An astronaut riding a green horse",
57
+ ]
58
+
59
+ css = """
60
+ #col-container {
61
+ margin: 0 auto;
62
+ max-width: 520px;
63
+ }
64
+ """
65
+
66
+ with gr.Blocks(css=css) as demo:
67
+ with gr.Column(elem_id="col-container"):
68
+ gr.Markdown("## 🧠 Anime Character Generator + Background Remover")
69
+
70
+ # Prompt row
71
+ with gr.Row():
72
+ prompt = gr.Text(
73
+ label="Prompt",
74
+ show_label=False,
75
+ max_lines=1,
76
+ placeholder="Describe your anime character...",
77
+ container=False,
78
+ )
79
+ run_button = gr.Button("🎨 Generate Anime")
80
+
81
+ # Output image (before and after remove background)
82
+ with gr.Row():
83
+ result_img = gr.Image(label="Generated Image")
84
+ removed_img = gr.Image(label="Background Removed")
85
+
86
+ # Advanced settings
87
+ with gr.Accordion("Advanced Settings", open=False):
88
+ negative_prompt = gr.Text(
89
+ label="Negative prompt",
90
+ max_lines=1,
91
+ placeholder="Enter a negative prompt",
92
+ visible=True
93
+ )
94
+
95
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
96
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
97
+
98
+ with gr.Row():
99
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
100
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
101
+
102
+ with gr.Row():
103
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=0.0)
104
+ num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=12, step=1, value=2)
105
+
106
+ # 範例按鈕區
107
+ gr.Markdown("#### ✨ Prompt Examples")
108
+ with gr.Row():
109
+ for example in examples:
110
+ gr.Button(example).click(lambda x=example: x, outputs=prompt)
111
+
112
+ # 主按鈕 callback
113
+ run_button.click(
114
+ fn=generate_anime,
115
+ inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
116
+ outputs=[result_img]
117
+ ).then(
118
+ fn=remove_background,
119
+ inputs=[result_img],
120
+ outputs=[removed_img]
121
+ )
122
+
123
+ demo.queue().launch(share=True)