Spaces:
Runtime error
Runtime error
main
Browse files- gradio_app.py +32 -61
gradio_app.py
CHANGED
|
@@ -94,35 +94,33 @@ def load_target_model(selected_model):
|
|
| 94 |
AE_PATH = download_file(ae_repo_id, ae_file)
|
| 95 |
LORA_WEIGHTS_PATH = download_file(lora_repo, lora_file)
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
# logger.error(f"Error loading models: {e}")
|
| 125 |
-
# return f"Error loading models: {e}"
|
| 126 |
|
| 127 |
# Image pre-processing (resize and padding)
|
| 128 |
class ResizeWithPadding:
|
|
@@ -156,37 +154,10 @@ class ResizeWithPadding:
|
|
| 156 |
# The function to generate image from a prompt and conditional image
|
| 157 |
@spaces.GPU(duration=180)
|
| 158 |
def infer(prompt, sample_image, recraft_model, seed=0):
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
logger.info("Loading models...")
|
| 164 |
-
try:
|
| 165 |
-
if model is None is None or clip_l is None or t5xxl is None or ae is None:
|
| 166 |
-
_, model = flux_utils.load_flow_model(
|
| 167 |
-
BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cuda", disable_mmap=False
|
| 168 |
-
)
|
| 169 |
-
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cuda", disable_mmap=False)
|
| 170 |
-
clip_l.eval()
|
| 171 |
-
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cuda", disable_mmap=False)
|
| 172 |
-
t5xxl.eval()
|
| 173 |
-
ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cuda", disable_mmap=False)
|
| 174 |
-
|
| 175 |
-
# Load LoRA weights
|
| 176 |
-
multiplier = 1.0
|
| 177 |
-
weights_sd = load_file(LORA_WEIGHTS_PATH)
|
| 178 |
-
lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
|
| 179 |
-
lora_model.apply_to([clip_l, t5xxl], model)
|
| 180 |
-
info = lora_model.load_state_dict(weights_sd, strict=True)
|
| 181 |
-
logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
|
| 182 |
-
lora_model.eval()
|
| 183 |
-
|
| 184 |
-
logger.info("Models loaded successfully.")
|
| 185 |
-
# return "Models loaded successfully. Using Recraft: {}".format(selected_model)
|
| 186 |
-
|
| 187 |
-
except Exception as e:
|
| 188 |
-
logger.error(f"Error loading models: {e}")
|
| 189 |
-
return f"Error loading models: {e}"
|
| 190 |
|
| 191 |
model_path = model_paths[recraft_model]
|
| 192 |
frame_num = model_path['Frame']
|
|
@@ -317,7 +288,7 @@ def infer(prompt, sample_image, recraft_model, seed=0):
|
|
| 317 |
|
| 318 |
# Gradio interface
|
| 319 |
with gr.Blocks() as demo:
|
| 320 |
-
gr.Markdown("##
|
| 321 |
|
| 322 |
with gr.Row():
|
| 323 |
with gr.Column(scale=1):
|
|
|
|
| 94 |
AE_PATH = download_file(ae_repo_id, ae_file)
|
| 95 |
LORA_WEIGHTS_PATH = download_file(lora_repo, lora_file)
|
| 96 |
|
| 97 |
+
logger.info("Loading models...")
|
| 98 |
+
try:
|
| 99 |
+
if model is None is None or clip_l is None or t5xxl is None or ae is None:
|
| 100 |
+
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
| 101 |
+
clip_l.eval()
|
| 102 |
+
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
| 103 |
+
t5xxl.eval()
|
| 104 |
+
ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
| 105 |
+
|
| 106 |
+
# Load flux & LoRA weights
|
| 107 |
+
_, model = flux_utils.load_flow_model(
|
| 108 |
+
BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
|
| 109 |
+
)
|
| 110 |
+
multiplier = 1.0
|
| 111 |
+
weights_sd = load_file(LORA_WEIGHTS_PATH)
|
| 112 |
+
lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
|
| 113 |
+
lora_model.apply_to([clip_l, t5xxl], model)
|
| 114 |
+
info = lora_model.load_state_dict(weights_sd, strict=True)
|
| 115 |
+
logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
|
| 116 |
+
lora_model.eval()
|
| 117 |
+
|
| 118 |
+
logger.info("Models loaded successfully.")
|
| 119 |
+
return "Models loaded successfully. Using Recraft: {}".format(selected_model)
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.error(f"Error loading models: {e}")
|
| 123 |
+
return f"Error loading models: {e}"
|
|
|
|
|
|
|
| 124 |
|
| 125 |
# Image pre-processing (resize and padding)
|
| 126 |
class ResizeWithPadding:
|
|
|
|
| 154 |
# The function to generate image from a prompt and conditional image
|
| 155 |
@spaces.GPU(duration=180)
|
| 156 |
def infer(prompt, sample_image, recraft_model, seed=0):
|
| 157 |
+
global model, clip_l, t5xxl, ae, lora_model
|
| 158 |
+
if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
|
| 159 |
+
logger.error("Models not loaded. Please load the models first.")
|
| 160 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
model_path = model_paths[recraft_model]
|
| 163 |
frame_num = model_path['Frame']
|
|
|
|
| 288 |
|
| 289 |
# Gradio interface
|
| 290 |
with gr.Blocks() as demo:
|
| 291 |
+
gr.Markdown("## Recraft Generation")
|
| 292 |
|
| 293 |
with gr.Row():
|
| 294 |
with gr.Column(scale=1):
|