Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -20,18 +20,18 @@ from io import BytesIO
|
|
| 20 |
import re
|
| 21 |
import json
|
| 22 |
|
| 23 |
-
#
|
| 24 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 25 |
login(token=HF_TOKEN)
|
| 26 |
import diffusers
|
| 27 |
print(diffusers.__version__)
|
| 28 |
|
| 29 |
-
#
|
| 30 |
-
dtype = torch.float16 #
|
| 31 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
-
base_model = "black-forest-labs/FLUX.1-dev"
|
| 33 |
|
| 34 |
-
#
|
| 35 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
|
| 36 |
|
| 37 |
MAX_SEED = 2**32 - 1
|
|
@@ -56,10 +56,7 @@ class calculateDuration:
|
|
| 56 |
else:
|
| 57 |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
# 生成图像的函数
|
| 62 |
-
@spaces.GPU
|
| 63 |
@torch.inference_mode()
|
| 64 |
def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
|
| 65 |
pipe.to(device)
|
|
@@ -121,10 +118,11 @@ def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed,
|
|
| 121 |
adapter_name = lora_info.get("adapter_name")
|
| 122 |
adapter_weight = lora_info.get("adapter_weight")
|
| 123 |
if lora_repo and weights and adapter_name:
|
| 124 |
-
#
|
| 125 |
pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
|
| 126 |
adapter_names.append(adapter_name)
|
| 127 |
adapter_weights.append(adapter_weight)
|
|
|
|
| 128 |
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
|
| 129 |
|
| 130 |
# Set random seed for reproducibility
|
|
@@ -133,8 +131,11 @@ def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed,
|
|
| 133 |
seed = random.randint(0, MAX_SEED)
|
| 134 |
|
| 135 |
# Generate image
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
| 138 |
if final_image:
|
| 139 |
if upload_to_r2:
|
| 140 |
with calculateDuration("Upload image"):
|
|
@@ -142,12 +143,15 @@ def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed,
|
|
| 142 |
result = {"status": "success", "message": "upload image success", "url": url}
|
| 143 |
else:
|
| 144 |
result = {"status": "success", "message": "Image generated but not uploaded"}
|
| 145 |
-
|
|
|
|
|
|
|
| 146 |
progress(100, "Completed!")
|
| 147 |
|
| 148 |
return final_image, seed, json.dumps(result)
|
| 149 |
|
| 150 |
-
# Gradio
|
|
|
|
| 151 |
css="""
|
| 152 |
#col-container {
|
| 153 |
margin: 0 auto;
|
|
@@ -156,7 +160,7 @@ css="""
|
|
| 156 |
"""
|
| 157 |
|
| 158 |
with gr.Blocks(css=css) as demo:
|
| 159 |
-
gr.Markdown("
|
| 160 |
with gr.Row():
|
| 161 |
|
| 162 |
with gr.Column():
|
|
|
|
| 20 |
import re
|
| 21 |
import json
|
| 22 |
|
| 23 |
+
# Login Hugging Face Hub
|
| 24 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 25 |
login(token=HF_TOKEN)
|
| 26 |
import diffusers
|
| 27 |
print(diffusers.__version__)
|
| 28 |
|
| 29 |
+
# init
|
| 30 |
+
dtype = torch.float16 # use float16 for fast generate
|
| 31 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
+
base_model = "black-forest-labs/FLUX.1-dev"
|
| 33 |
|
| 34 |
+
# load pipe
|
| 35 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
|
| 36 |
|
| 37 |
MAX_SEED = 2**32 - 1
|
|
|
|
| 56 |
else:
|
| 57 |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
|
| 58 |
|
| 59 |
+
@spaces.GPU(duration=120)
|
|
|
|
|
|
|
|
|
|
| 60 |
@torch.inference_mode()
|
| 61 |
def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
|
| 62 |
pipe.to(device)
|
|
|
|
| 118 |
adapter_name = lora_info.get("adapter_name")
|
| 119 |
adapter_weight = lora_info.get("adapter_weight")
|
| 120 |
if lora_repo and weights and adapter_name:
|
| 121 |
+
# load lora
|
| 122 |
pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
|
| 123 |
adapter_names.append(adapter_name)
|
| 124 |
adapter_weights.append(adapter_weight)
|
| 125 |
+
# set lora weights
|
| 126 |
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
|
| 127 |
|
| 128 |
# Set random seed for reproducibility
|
|
|
|
| 131 |
seed = random.randint(0, MAX_SEED)
|
| 132 |
|
| 133 |
# Generate image
|
| 134 |
+
try:
|
| 135 |
+
final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress)
|
| 136 |
+
except:
|
| 137 |
+
final_image = None
|
| 138 |
+
|
| 139 |
if final_image:
|
| 140 |
if upload_to_r2:
|
| 141 |
with calculateDuration("Upload image"):
|
|
|
|
| 143 |
result = {"status": "success", "message": "upload image success", "url": url}
|
| 144 |
else:
|
| 145 |
result = {"status": "success", "message": "Image generated but not uploaded"}
|
| 146 |
+
else:
|
| 147 |
+
result = {"status": "failed", "message": "Image generate failed"}
|
| 148 |
+
|
| 149 |
progress(100, "Completed!")
|
| 150 |
|
| 151 |
return final_image, seed, json.dumps(result)
|
| 152 |
|
| 153 |
+
# Gradio interface
|
| 154 |
+
|
| 155 |
css="""
|
| 156 |
#col-container {
|
| 157 |
margin: 0 auto;
|
|
|
|
| 160 |
"""
|
| 161 |
|
| 162 |
with gr.Blocks(css=css) as demo:
|
| 163 |
+
gr.Markdown("flux-dev-multi-lora")
|
| 164 |
with gr.Row():
|
| 165 |
|
| 166 |
with gr.Column():
|