danhtran2mind commited on
Commit
08653ff
·
verified ·
1 Parent(s): 57ea55b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -191
app.py CHANGED
@@ -1,204 +1,33 @@
1
- import dataclasses
2
- import json
3
- from pathlib import Path
4
-
5
- import gradio as gr
6
- import torch
7
- from PIL import Image
8
- import numpy as np
9
- from transformers import CLIPTextModel, CLIPTokenizer
10
- from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
11
- from tqdm import tqdm
12
-
13
- def get_examples(examples_dir: str = "assets/examples") -> list:
14
- """
15
- Load example data from the assets/examples directory.
16
- Each example is a subdirectory containing a config.json and an image file.
17
- Returns a list of [prompt, height, width, num_inference_steps, guidance_scale, seed, image_path].
18
- """
19
- examples = Path(examples_dir)
20
- ans = []
21
- for example in examples.iterdir():
22
- if not example.is_dir():
23
- continue
24
- with open(example / "config.json") as f:
25
- example_dict = json.load(f)
26
-
27
- required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
28
- if not all(key in example_dict for key in required_keys):
29
- continue
30
-
31
- example_list = [
32
- example_dict["prompt"],
33
- example_dict["height"],
34
- example_dict["width"],
35
- example_dict["num_inference_steps"],
36
- example_dict["guidance_scale"],
37
- example_dict["seed"],
38
- str(example / example_dict["image"]) # Path to the image file
39
- ]
40
- ans.append(example_list)
41
-
42
- if not ans:
43
- ans = [
44
- ["a serene landscape in Ghibli style", 64, 64, 50, 3.5, 42, None]
45
- ]
46
- return ans
47
-
48
- def create_demo(
49
- model_name: str = "danhtran2mind/ghibli-fine-tuned-sd-2.1",
50
- device: str = "cuda" if torch.cuda.is_available() else "cpu",
51
- ):
52
- # Convert device string to torch.device
53
- device = torch.device(device)
54
- dtype = torch.float16 if device.type == "cuda" else torch.float32
55
-
56
- # Load models with consistent dtype
57
- vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", torch_dtype=dtype).to(device)
58
- tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
59
- text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder", torch_dtype=dtype).to(device)
60
- unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=dtype).to(device)
61
- scheduler = PNDMScheduler.from_pretrained(model_name, subfolder="scheduler")
62
-
63
- def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed):
64
- if not prompt:
65
- return None, "Prompt cannot be empty."
66
- if height % 8 != 0 or width % 8 != 0:
67
- return None, "Height and width must be divisible by 8 (e.g., 256, 512, 1024)."
68
- if num_inference_steps < 1 or num_inference_steps > 100:
69
- return None, "Number of inference steps must be between 1 and 100."
70
- if guidance_scale < 1.0 or guidance_scale > 20.0:
71
- return None, "Guidance scale must be between 1.0 and 20.0."
72
- if seed < 0 or seed > 4294967295:
73
- return None, "Seed must be between 0 and 4294967295."
74
-
75
- batch_size = 1
76
- if random_seed:
77
- seed = torch.randint(0, 4294967295, (1,)).item()
78
- generator = torch.Generator(device=device).manual_seed(int(seed))
79
 
80
- text_input = tokenizer(
81
- [prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
 
 
82
  )
83
- with torch.no_grad():
84
- text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
85
-
86
- max_length = text_input.input_ids.shape[-1]
87
- uncond_input = tokenizer(
88
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
89
  )
90
- with torch.no_grad():
91
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
92
-
93
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
94
-
95
- latents = torch.randn(
96
- (batch_size, unet.config.in_channels, height // 8, width // 8),
97
- generator=generator,
98
- dtype=dtype,
99
- device=device
100
  )
101
-
102
- scheduler.set_timesteps(num_inference_steps)
103
- latents = latents * scheduler.init_noise_sigma
104
-
105
- for t in tqdm(scheduler.timesteps, desc="Generating image"):
106
- latent_model_input = torch.cat([latents] * 2)
107
- latent_model_input = scheduler.scale_model_input(latent_model_input, t)
108
-
109
- with torch.no_grad():
110
- if device.type == "cuda":
111
- with torch.autocast(device_type="cuda", dtype=torch.float16):
112
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
113
- else:
114
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
115
-
116
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
117
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
118
- latents = scheduler.step(noise_pred, t, latents).prev_sample
119
-
120
- with torch.no_grad():
121
- latents = latents / vae.config.scaling_factor
122
- image = vae.decode(latents).sample
123
-
124
- image = (image / 2 + 0.5).clamp(0, 1)
125
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
126
- image = (image * 255).round().astype("uint8")
127
- pil_image = Image.fromarray(image[0])
128
-
129
- return pil_image, f"Image generated successfully! Seed used: {seed}"
130
-
131
- def load_example_image(prompt, height, width, num_inference_steps, guidance_scale, seed, image_path):
132
- """
133
- Load the image for the selected example and update input fields.
134
- """
135
- if image_path and Path(image_path).exists():
136
- try:
137
- image = Image.open(image_path)
138
- return prompt, height, width, num_inference_steps, guidance_scale, seed, image, f"Loaded image: {image_path}"
139
- except Exception as e:
140
- return prompt, height, width, num_inference_steps, guidance_scale, seed, None, f"Error loading image: {e}"
141
- return prompt, height, width, num_inference_steps, guidance_scale, seed, None, "No image available"
142
-
143
- badges_text = r"""
144
- <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
145
- <a href="https://huggingface.co/spaces/danhtran2mind/ghibli-fine-tuned-sd-2.1"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Space&color=orange"></a>
146
- </div>
147
- """.strip()
148
-
149
- with gr.Blocks() as demo:
150
- gr.Markdown("# Ghibli-Style Image Generator")
151
- gr.Markdown(badges_text)
152
- gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model. Select an example below to load a pre-generated image or enter a prompt to generate a new one.")
153
- gr.Markdown("""**Note:** For CPU inference, execution time is long (e.g., for resolution 512 × 512) with 50 inference steps, time is approximately 1700 seconds).""")
154
-
155
- with gr.Row():
156
- with gr.Column():
157
- prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'")
158
- with gr.Row():
159
- width = gr.Slider(32, 4096, 512, step=8, label="Generation Width")
160
- height = gr.Slider(32, 4096, 512, step=8, label="Generation Height")
161
- with gr.Accordion("Advanced Options", open=False):
162
- num_inference_steps = gr.Slider(1, 100, 50, step=1, label="Number of Inference Steps")
163
- guidance_scale = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
164
- seed = gr.Number(42, label="Seed (0 to 4294967295)")
165
- random_seed = gr.Checkbox(label="Use Random Seed", value=False)
166
- generate_btn = gr.Button("Generate Image")
167
-
168
- with gr.Column():
169
- output_image = gr.Image(label="Generated Image")
170
- output_text = gr.Textbox(label="Status")
171
-
172
- examples = get_examples("assets/examples")
173
- gr.Examples(
174
- examples=examples,
175
- inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, output_image],
176
- outputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, output_image, output_text],
177
- fn=load_example_image,
178
- cache_examples=False
179
  )
180
-
181
- generate_btn.click(
182
- fn=generate_image,
183
- inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed],
184
- outputs=[output_image, output_text]
185
  )
186
 
187
- return demo
188
-
189
- if __name__ == "__main__":
190
- from transformers import HfArgumentParser
191
-
192
- @dataclasses.dataclass
193
- class AppArgs:
194
- model_name: str = "danhtran2mind/ghibli-fine-tuned-sd-2.1"
195
- device: str = "cuda" if torch.cuda.is_available() else "cpu"
196
- port: int = 7860
197
- share: bool = False # Set to True for public sharing (Hugging Face Spaces)
198
-
199
  parser = HfArgumentParser([AppArgs])
200
  args_tuple = parser.parse_args_into_dataclasses()
201
  args = args_tuple[0]
202
 
 
 
 
 
203
  demo = create_demo(args.model_name, args.device)
204
  demo.launch(server_port=args.port, share=args.share)
 
1
+ if __name__ == "__main__":
2
+ from transformers import HfArgumentParser
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ @dataclasses.dataclass
5
+ class AppArgs:
6
+ local_model: bool = dataclasses.field(
7
+ default=True, metadata={"help": "Use local model path instead of Hugging Face model."}
8
  )
9
+ model_name: str = dataclasses.field(
10
+ default="danhtran2mind/ghibli-fine-tuned-sd-2.1",
11
+ metadata={"help": "Model name or path for the fine-tuned Stable Diffusion model."}
 
 
 
12
  )
13
+ device: str = dataclasses.field(
14
+ default="cuda" if torch.cuda.is_available() else "cpu",
15
+ metadata={"help": "Device to run the model on (e.g., 'cuda', 'cpu')."}
 
 
 
 
 
 
 
16
  )
17
+ port: int = dataclasses.field(
18
+ default=7860, metadata={"help": "Port to run the Gradio app on."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  )
20
+ share: bool = dataclasses.field(
21
+ default=False, metadata={"help": "Set to True for public sharing (Hugging Face Spaces)."}
 
 
 
22
  )
23
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  parser = HfArgumentParser([AppArgs])
25
  args_tuple = parser.parse_args_into_dataclasses()
26
  args = args_tuple[0]
27
 
28
+ # Set model_name based on local_model flag
29
+ if args.local_model:
30
+ args.model_name = "ghibli-fine-tuned-sd-2.1"
31
+
32
  demo = create_demo(args.model_name, args.device)
33
  demo.launch(server_port=args.port, share=args.share)