AlekseyCalvin commited on
Commit
a09a2cb
·
verified ·
1 Parent(s): 86ac5ff

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +237 -0
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ import json
5
+ import logging
6
+ import subprocess
7
+ import torch
8
+ import transformers
9
+ import diffusers
10
+ from PIL import Image
11
+ from os import path
12
+ from torchvision import transforms
13
+ from dataclasses import dataclass
14
+ import math
15
+ from typing import Callable
16
+ import spaces
17
+ import bitsandbytes
18
+ from diffusers.quantizers import PipelineQuantizationConfig
19
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
20
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
21
+ from transformers import CLIPModel, CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPConfig, T5EncoderModel, T5Tokenizer
22
+ from diffusers.models.transformers import FluxTransformer2DModel
23
+ import copy
24
+ import random
25
+ import time
26
+ import safetensors.torch
27
+ from tqdm import tqdm
28
+ from safetensors.torch import load_file
29
+ from huggingface_hub import HfFileSystem, ModelCard
30
+ from huggingface_hub import login, hf_hub_download
31
+ from huggingface_hub.utils._runtime import dump_environment_info
32
+ hf_token = os.environ.get("HF_TOKEN")
33
+ login(token=hf_token)
34
+
35
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
36
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
37
+ os.environ["HF_HUB_CACHE"] = cache_path
38
+ os.environ["HF_HOME"] = cache_path
39
+
40
+ os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1')
41
+
42
+ dump_environment_info()
43
+ logging.basicConfig(level=logging.DEBUG)
44
+ logger = logging.getLogger(__name__)
45
+
46
+ quant_config = PipelineQuantizationConfig(
47
+ quant_backend="bitsandbytes_4bit",
48
+ quant_kwargs={"load_in_4bit": True, "bnb_4bit_compute_dtype": torch.bfloat16, "bnb_4bit_quant_type": "nf4"},
49
+ components_to_quantize=["transformer"]
50
+ )
51
+
52
+ try:
53
+ # Set max memory usage for ZeroGPU
54
+ torch.cuda.set_per_process_memory_fraction(1.0)
55
+ torch.set_float32_matmul_precision("medium")
56
+ except Exception as e:
57
+ print(f"Error setting memory usage: {e}")
58
+
59
+ dtype = torch.bfloat16
60
+ base_model = "AlekseyCalvin/Flux-Krea-Blaze_byMintLab_fp8_Diffusers"
61
+ pipe = DiffusionPipeline.from_pretrained(
62
+ base_model,
63
+ quantization_config=quant_config,
64
+ torch_dtype=dtype
65
+ ).to("cuda")
66
+ torch.cuda.empty_cache()
67
+
68
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
69
+
70
+ model_id = ("zer0int/LongCLIP-GmP-ViT-L-14")
71
+ config = CLIPConfig.from_pretrained(model_id)
72
+ config.text_config.max_position_embeddings = 248
73
+ clip_model = CLIPModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, config=config, ignore_mismatched_sizes=True)
74
+ clip_processor = CLIPProcessor.from_pretrained(model_id, padding="max_length", max_length=248)
75
+ pipe.tokenizer = clip_processor.tokenizer
76
+ pipe.text_encoder = clip_model.text_model
77
+ pipe.tokenizer_max_length = 248
78
+ pipe.text_encoder.dtype = torch.bfloat16
79
+ #pipe.text_encoder_2 = t5.text_model
80
+
81
+ MAX_SEED = 2**32-1
82
+
83
+ class calculateDuration:
84
+ def __init__(self, activity_name=""):
85
+ self.activity_name = activity_name
86
+
87
+ def __enter__(self):
88
+ self.start_time = time.time()
89
+ return self
90
+
91
+ def __exit__(self, exc_type, exc_value, traceback):
92
+ self.end_time = time.time()
93
+ self.elapsed_time = self.end_time - self.start_time
94
+ if self.activity_name:
95
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
96
+ else:
97
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
98
+
99
+
100
+ def update_selection(evt: gr.SelectData, width, height):
101
+ selected_lora = loras[evt.index]
102
+ new_placeholder = f"Prompt with activator word(s): '{selected_lora['trigger_word']}'! "
103
+ lora_repo = selected_lora["repo"]
104
+ lora_trigger = selected_lora['trigger_word']
105
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}). Prompt using: '{lora_trigger}'!"
106
+ if "aspect" in selected_lora:
107
+ if selected_lora["aspect"] == "portrait":
108
+ width = 768
109
+ height = 1024
110
+ elif selected_lora["aspect"] == "landscape":
111
+ width = 1024
112
+ height = 768
113
+ return (
114
+ gr.update(placeholder=new_placeholder),
115
+ updated_text,
116
+ evt.index,
117
+ width,
118
+ height,
119
+ )
120
+
121
+ @spaces.GPU(duration=80)
122
+ def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress):
123
+ pipe.to("cuda")
124
+ generator = torch.Generator(device="cuda").manual_seed(seed)
125
+
126
+ with calculateDuration("Generating image"):
127
+ # Generate image
128
+ image = pipe(
129
+ prompt=f"{prompt} {trigger_word}",
130
+ num_inference_steps=steps,
131
+ guidance_scale=cfg_scale,
132
+ width=width,
133
+ height=height,
134
+ generator=generator,
135
+ joint_attention_kwargs={"scale": lora_scale},
136
+ ).images[0]
137
+ return image
138
+
139
+ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
140
+ if selected_index is None:
141
+ raise gr.Error("You must select a LoRA before proceeding.")
142
+
143
+ selected_lora = loras[selected_index]
144
+ lora_path = selected_lora["repo"]
145
+ trigger_word = selected_lora['trigger_word']
146
+
147
+ # Load LoRA weights
148
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
149
+ if "weights" in selected_lora:
150
+ pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
151
+ else:
152
+ pipe.load_lora_weights(lora_path)
153
+
154
+ # Set random seed for reproducibility
155
+ with calculateDuration("Randomizing seed"):
156
+ if randomize_seed:
157
+ seed = random.randint(0, MAX_SEED)
158
+
159
+ image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
160
+ pipe.to("cpu")
161
+ pipe.unload_lora_weights()
162
+ return image, seed
163
+
164
+ run_lora.zerogpu = True
165
+
166
+ css = '''
167
+ #gen_btn{height: 100%}
168
+ #title{text-align: center}
169
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
170
+ #title img{width: 100px; margin-right: 0.5em}
171
+ #gallery .grid-wrap{height: 10vh}
172
+ '''
173
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
174
+ title = gr.HTML(
175
+ """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> SOONfactory </h1>""",
176
+ elem_id="title",
177
+ )
178
+ # Info blob stating what the app is running
179
+ info_blob = gr.HTML(
180
+ """<div id="info_blob"> Img. Manufactory Running On: Flux Krea Blaze (a fast modification of Flux Krea). Nearly all of the LoRA adapters accessible via this space were trained by us in an extensive progression of inspired experiments and conceptual mini-projects. Check out our poetry translations at WWW.SILVERagePOETS.com Find our music on SoundCloud @ AlekseyCalvin & YouTube @ SilverAgePoets / AlekseyCalvin! </div>"""
181
+ )
182
+
183
+ # Info blob stating what the app is running
184
+ info_blob = gr.HTML(
185
+ """<div id="info_blob"> To reinforce/focus in selected fine-tuned LoRAs (Low-Rank Adapters), add special “trigger" words/phrases to your prompts. </div>"""
186
+ )
187
+ selected_index = gr.State(None)
188
+ with gr.Row():
189
+ with gr.Column(scale=3):
190
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Select LoRa/Style & type prompt!")
191
+ with gr.Column(scale=1, elem_id="gen_column"):
192
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
193
+ with gr.Row():
194
+ with gr.Column(scale=3):
195
+ selected_info = gr.Markdown("")
196
+ gallery = gr.Gallery(
197
+ [(item["image"], item["title"]) for item in loras],
198
+ label="LoRA Inventory",
199
+ allow_preview=False,
200
+ columns=3,
201
+ elem_id="gallery"
202
+ )
203
+
204
+ with gr.Column(scale=4):
205
+ result = gr.Image(label="Generated Image")
206
+
207
+ with gr.Row():
208
+ with gr.Accordion("Advanced Settings", open=True):
209
+ with gr.Column():
210
+ with gr.Row():
211
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=0, maximum=20, step=.1, value=2.5)
212
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=9)
213
+
214
+ with gr.Row():
215
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
216
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
217
+
218
+ with gr.Row():
219
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
220
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
221
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2.5, step=0.01, value=1.0)
222
+
223
+ gallery.select(
224
+ update_selection,
225
+ inputs=[width, height],
226
+ outputs=[prompt, selected_info, selected_index, width, height]
227
+ )
228
+
229
+ gr.on(
230
+ triggers=[generate_button.click, prompt.submit],
231
+ fn=run_lora,
232
+ inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
233
+ outputs=[result, seed]
234
+ )
235
+
236
+ app.queue(default_concurrency_limit=2).launch(show_error=True)
237
+ app.launch()