Boese0601 commited on
Commit
b215808
·
verified ·
1 Parent(s): d574ec9

Delete app_old.py

Browse files
Files changed (1) hide show
  1. app_old.py +0 -374
app_old.py DELETED
@@ -1,374 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- import spaces
4
- from PIL import Image, ImageDraw, ImageFont
5
- # from src.condition import Condition
6
- from diffusers.pipelines import FluxPipeline
7
- import numpy as np
8
- import requests
9
- from huggingface_hub import hf_hub_download
10
- from safetensors.torch import load_file
11
- import torch.multiprocessing as mp
12
- ###
13
- import argparse
14
- import logging
15
- import math
16
- import os
17
- import re
18
- import random
19
- import shutil
20
- from contextlib import nullcontext
21
- from pathlib import Path
22
- from PIL import Image
23
- import accelerate
24
- import datasets
25
- import numpy as np
26
- import torch
27
- import torch.nn.functional as F
28
- from torch import Tensor, nn
29
- import torch.utils.checkpoint
30
- import transformers
31
- from accelerate import Accelerator
32
- from accelerate.logging import get_logger
33
- from accelerate.state import AcceleratorState
34
- from accelerate.utils import ProjectConfiguration, set_seed
35
- from huggingface_hub import create_repo, upload_folder
36
- from packaging import version
37
- from tqdm.auto import tqdm
38
- from transformers import CLIPTextModel, CLIPTokenizer
39
- from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
40
- from transformers.utils import ContextManagers
41
- from omegaconf import OmegaConf
42
- from copy import deepcopy
43
- import diffusers
44
- from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline
45
- from diffusers.optimization import get_scheduler
46
- from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr
47
- from diffusers.utils import check_min_version, deprecate, make_image_grid
48
- from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
49
- from diffusers.utils.import_utils import is_xformers_available
50
- from diffusers.utils.torch_utils import is_compiled_module
51
- from einops import rearrange
52
- from src.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
53
- from src.flux.util import (configs, load_ae, load_clip,
54
- load_flow_model2, load_t5, save_image, tensor_to_pil_image, load_checkpoint)
55
- from src.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor, IPDoubleStreamBlockProcessor, IPSingleStreamBlockProcessor, ImageProjModel
56
- from src.flux.xflux_pipeline import XFluxSampler
57
-
58
- from image_datasets.dataset import loader, eval_image_pair_loader, image_resize
59
-
60
- from safetensors.torch import load_file
61
- import json
62
-
63
-
64
- # logger = get_logger(__name__, log_level="INFO")
65
-
66
-
67
- def get_models(name: str, device, offload: bool, is_schnell: bool):
68
- t5 = load_t5(device, max_length=256 if is_schnell else 512)
69
- clip = load_clip(device)
70
- clip.requires_grad_(False)
71
- model = load_flow_model2(name, device="cpu")
72
- vae = load_ae(name, device="cpu" if offload else device)
73
- return model, vae, t5, clip
74
-
75
- args = OmegaConf.load("inference_configs/inference.yaml") #OmegaConf.load(parse_args())
76
- is_schnell = args.model_name == "flux-schnell"
77
- set_seed(args.seed)
78
- # logging_dir = os.path.join(args.output_dir, args.logging_dir)
79
- device = "cuda"
80
- dit, vae, t5, clip = get_models(name=args.model_name, device=device, offload=False, is_schnell=is_schnell)
81
-
82
- # # load image encoder
83
- # ip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.getenv("CLIP_VIT")).to(
84
- # # accelerator.device, dtype=torch.bfloat16
85
- # device, dtype=torch.bfloat16
86
- # )
87
- # ip_clip_image_processor = CLIPImageProcessor()
88
-
89
- if args.use_ip:
90
- sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=True, spatial_condition=False, clip_image_processor=ip_clip_image_processor, image_encoder=ip_image_encoder, improj=ip_improj)
91
- elif args.use_spatial_condition:
92
- sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=False, spatial_condition=True, clip_image_processor=None, image_encoder=None, improj=None,share_position_embedding=args.share_position_embedding)
93
- else:
94
- sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=device, ip_loaded=False, spatial_condition=False, clip_image_processor=None, image_encoder=None, improj=None)
95
-
96
-
97
- # @spaces.GPU
98
- def generate(image, edit_prompt):
99
- print("hello?????????!!!!!")
100
-
101
- # accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
102
-
103
- # accelerator = Accelerator(
104
- # gradient_accumulation_steps=1,
105
- # mixed_precision=args.mixed_precision,
106
- # log_with=args.report_to,
107
- # project_config=accelerator_project_config,
108
- # )
109
-
110
- # Make one log on every process with the configuration for debugging.
111
- # logging.basicConfig(
112
- # format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
113
- # datefmt="%m/%d/%Y %H:%M:%S",
114
- # level=logging.INFO,
115
- # )
116
- # logger.info(accelerator.state, main_process_only=False)
117
- # if accelerator.is_local_main_process:
118
- # datasets.utils.logging.set_verbosity_warning()
119
- # transformers.utils.logging.set_verbosity_warning()
120
- # diffusers.utils.logging.set_verbosity_info()
121
- # else:
122
- # datasets.utils.logging.set_verbosity_error()
123
- # transformers.utils.logging.set_verbosity_error()
124
- # diffusers.utils.logging.set_verbosity_error()
125
-
126
-
127
- # if accelerator.is_main_process:
128
- # if args.output_dir is not None:
129
- # os.makedirs(args.output_dir, exist_ok=True)
130
- # gpt_eval_path = os.path.join(args.output_dir,"Eval")
131
- # os.makedirs(gpt_eval_path, exist_ok=True)
132
-
133
- # dit, vae, t5, clip = get_models(name=args.model_name, device=accelerator.device, offload=False, is_schnell=is_schnell)
134
- # dit, vae, t5, clip = get_models(name=args.model_name, device=device, offload=False, is_schnell=is_schnell)
135
-
136
- if args.use_lora:
137
- lora_attn_procs = {}
138
- if args.use_ip:
139
- ip_attn_procs = {}
140
- if args.double_blocks is None:
141
- double_blocks_idx = list(range(19))
142
- else:
143
- double_blocks_idx = [int(idx) for idx in args.double_blocks.split(",")]
144
-
145
- if args.single_blocks is None:
146
- single_blocks_idx = list(range(38))
147
- elif args.single_blocks is not None:
148
- single_blocks_idx = [int(idx) for idx in args.single_blocks.split(",")]
149
-
150
- if args.use_lora:
151
- for name, attn_processor in dit.attn_processors.items():
152
- match = re.search(r'\.(\d+)\.', name)
153
- if match:
154
- layer_index = int(match.group(1))
155
-
156
- if name.startswith("double_blocks") and layer_index in double_blocks_idx:
157
- # if accelerator.is_main_process:
158
- # print("setting LoRA Processor for", name)
159
- lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(
160
- dim=3072, rank=args.rank
161
- )
162
- elif name.startswith("single_blocks") and layer_index in single_blocks_idx:
163
- # if accelerator.is_main_process:
164
- # print("setting LoRA Processor for", name)
165
- lora_attn_procs[name] = SingleStreamBlockLoraProcessor(
166
- dim=3072, rank=args.rank
167
- )
168
- else:
169
- lora_attn_procs[name] = attn_processor
170
-
171
- dit.set_attn_processor(lora_attn_procs)
172
-
173
- # if args.use_ip:
174
- # # unpack checkpoint
175
- # checkpoint = load_checkpoint(args.ip_local_path, args.ip_repo_id, args.ip_name)
176
- # prefix = "double_blocks."
177
- # # blocks = {}
178
- # proj = {}
179
-
180
- # for key, value in checkpoint.items():
181
- # # if key.startswith(prefix):
182
- # # blocks[key[len(prefix):].replace('.processor.', '.')] = value
183
- # if key.startswith("ip_adapter_proj_model"):
184
- # proj[key[len("ip_adapter_proj_model."):]] = value
185
-
186
- # # # load image encoder
187
- # # ip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(os.getenv("CLIP_VIT")).to(
188
- # # # accelerator.device, dtype=torch.bfloat16
189
- # # device, dtype=torch.bfloat16
190
- # # )
191
- # # ip_clip_image_processor = CLIPImageProcessor()
192
-
193
- # # setup image embedding projection model
194
- # ip_improj = ImageProjModel(4096, 768, 4)
195
- # ip_improj.load_state_dict(proj)
196
- # # ip_improj = ip_improj.to(accelerator.device, dtype=torch.bfloat16)
197
- # ip_improj = ip_improj.to(device, dtype=torch.bfloat16)
198
-
199
- # ip_attn_procs = {}
200
-
201
- # for name, _ in dit.attn_processors.items():
202
- # ip_state_dict = {}
203
- # for k in checkpoint.keys():
204
- # if name in k:
205
- # ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k]
206
- # if ip_state_dict:
207
- # ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072)
208
- # ip_attn_procs[name].load_state_dict(ip_state_dict)
209
- # ip_attn_procs[name].to(accelerator.device, dtype=torch.bfloat16)
210
- # else:
211
- # ip_attn_procs[name] = dit.attn_processors[name]
212
- # dit.set_attn_processor(ip_attn_procs)
213
-
214
-
215
- vae.requires_grad_(False)
216
- t5.requires_grad_(False)
217
- clip.requires_grad_(False)
218
-
219
-
220
-
221
- # weight_dtype = torch.float32
222
- # if accelerator.mixed_precision == "fp16":
223
- # weight_dtype = torch.float16
224
- # args.mixed_precision = accelerator.mixed_precision
225
- # elif accelerator.mixed_precision == "bf16":
226
- # weight_dtype = torch.bfloat16
227
- # args.mixed_precision = accelerator.mixed_precision
228
-
229
-
230
- # print(f"Resuming from checkpoint {args.ckpt_dir}")
231
- # dit_stat_dict = load_file(args.ckpt_dir)
232
- # Get path from Hub
233
- model_path = hf_hub_download(
234
- repo_id="Boese0601/ByteMorpher",
235
- filename="dit.safetensors"
236
- )
237
- state_dict = load_file(model_path)
238
- dit.load_state_dict(state_dict)
239
- dit = dit.to(weight_dtype)
240
- dit.eval()
241
-
242
- # test_dataloader = loader(**args.data_config)
243
- test_dataloader = eval_image_pair_loader(**args.data_config)
244
-
245
-
246
-
247
- # from deepspeed import initialize
248
- dit = accelerator.prepare(dit)
249
-
250
- # if accelerator.is_main_process:
251
- # accelerator.init_trackers(args.tracker_project_name, {"test": None})
252
-
253
- # logger.info("***** Running Evaluation *****")
254
- # logger.info(f" Instantaneous batch size = {args.eval_batch_size}")
255
-
256
-
257
-
258
- # progress_bar = tqdm(
259
- # range(0, len(test_dataloader)),
260
- # initial=0,
261
- # desc="Steps",
262
- # disable=not accelerator.is_local_main_process,
263
- # )
264
-
265
- # for step, batch in enumerate(test_dataloader):
266
- # with accelerator.accumulate(dit):
267
- # img, tgt_image, prompt, edit_prompt, img_name, edit_name = batch
268
- img = image_resize(image, 512)
269
- w, h = img.size
270
- new_w = (w // 32) * 32
271
- new_h = (h // 32) * 32
272
- img = img.resize((new_w, new_h))
273
- img = torch.from_numpy((np.array(img) / 127.5) - 1)
274
- img = img.permute(2, 0, 1).unsqueeze(0)
275
-
276
- edit_prompt = edit_prompt
277
-
278
- # if args.use_ip:
279
- # sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=True, spatial_condition=False, clip_image_processor=ip_clip_image_processor, image_encoder=ip_image_encoder, improj=ip_improj)
280
- # elif args.use_spatial_condition:
281
- # sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=False, spatial_condition=True, clip_image_processor=None, image_encoder=None, improj=None,share_position_embedding=args.share_position_embedding)
282
- # else:
283
- # sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device, ip_loaded=False, spatial_condition=False, clip_image_processor=None, image_encoder=None, improj=None)
284
- with torch.no_grad():
285
- result = sampler(prompt=edit_prompt,
286
- width=args.sample_width,
287
- height=args.sample_height,
288
- num_steps=args.sample_steps,
289
- image_prompt=None, # ip_adapter
290
- true_gs=args.cfg_scale,
291
- seed=args.seed,
292
- ip_scale=args.ip_scale if args.use_ip else 1.0,
293
- source_image=img if args.use_spatial_condition else None,
294
- )
295
- gen_img = result
296
-
297
-
298
-
299
- # progress_bar.update(1)
300
-
301
- # accelerator.wait_for_everyone()
302
- # accelerator.end_training()
303
- return gen_img
304
-
305
-
306
- def get_samples():
307
- sample_list = [
308
- {
309
- "image": "assets/0_camera_zoom/20486354.png",
310
- "edit_prompt": "Zoom in on the coral and add a small blue fish in the background.",
311
- },
312
- ]
313
- return [
314
- [
315
- Image.open(sample["image"]).resize((512, 512)),
316
- sample["edit_prompt"],
317
- ]
318
- for sample in sample_list
319
- ]
320
-
321
-
322
- header = """
323
- # ByteMoprh
324
-
325
- <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
326
- <a href=""><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
327
- <a href="https://huggingface.co/datasets/Boese0601/ByteMorph-Bench"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
328
- <a href="https://github.com/Boese0601/ByteMorph"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
329
- </div>
330
- """
331
-
332
-
333
- def create_app():
334
- with gr.Blocks() as app:
335
- gr.Markdown(header, elem_id="header")
336
- with gr.Row(equal_height=False):
337
- with gr.Column(variant="panel", elem_classes="inputPanel"):
338
- original_image = gr.Image(
339
- type="pil", label="Condition Image", width=300, elem_id="input"
340
- )
341
- edit_prompt = gr.Textbox(lines=2, label="Edit Prompt", elem_id="edit_prompt")
342
- submit_btn = gr.Button("Run", elem_id="submit_btn")
343
-
344
- with gr.Column(variant="panel", elem_classes="outputPanel"):
345
- output_image = gr.Image(type="pil", elem_id="output")
346
-
347
- with gr.Row():
348
- examples = gr.Examples(
349
- examples=get_samples(),
350
- inputs=[original_image, edit_prompt],
351
- label="Examples",
352
- )
353
-
354
- submit_btn.click(
355
- fn=generate,
356
- inputs=[original_image, edit_prompt],
357
- outputs=output_image,
358
- )
359
- gr.HTML(
360
- """
361
- <div style="text-align: center;">
362
- * This demo's template was modified from <a href="https://arxiv.org/abs/2411.15098" target="_blank">OminiControl</a>.
363
- </div>
364
- """
365
- )
366
- return app
367
-
368
-
369
- if __name__ == "__main__":
370
- print("CUDA available:", torch.cuda.is_available())
371
- print("CUDA version:", torch.version.cuda)
372
- print("GPU device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")
373
- # mp.set_start_method("spawn", force=True)
374
- create_app().launch(debug=False, share=True, ssr_mode=False)