Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -116,25 +116,49 @@ def initialize_models():
|
|
| 116 |
ae = ae.to(device)
|
| 117 |
|
| 118 |
print("Loading Flux model...")
|
| 119 |
-
# Use the standard Flux model instead of quantized version
|
| 120 |
-
# This will use more memory but avoid compatibility issues
|
| 121 |
from huggingface_hub import hf_hub_download
|
| 122 |
from safetensors.torch import load_file
|
| 123 |
|
| 124 |
try:
|
| 125 |
-
# Try to load from
|
| 126 |
-
|
|
|
|
| 127 |
model = Flux()
|
| 128 |
-
model = model.to(dtype=torch.bfloat16, device=device)
|
| 129 |
|
| 130 |
-
#
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
except Exception as e:
|
| 136 |
print(f"Error initializing Flux model: {e}")
|
| 137 |
-
|
|
|
|
| 138 |
|
| 139 |
model_initialized = True
|
| 140 |
print("Models initialized successfully!")
|
|
|
|
| 116 |
ae = ae.to(device)
|
| 117 |
|
| 118 |
print("Loading Flux model...")
|
|
|
|
|
|
|
| 119 |
from huggingface_hub import hf_hub_download
|
| 120 |
from safetensors.torch import load_file
|
| 121 |
|
| 122 |
try:
|
| 123 |
+
# Try to load from a standard Flux checkpoint
|
| 124 |
+
# First, let's try the schnell version which might be smaller
|
| 125 |
+
print("Attempting to load Flux model weights...")
|
| 126 |
model = Flux()
|
|
|
|
| 127 |
|
| 128 |
+
# Try loading from black-forest-labs directly
|
| 129 |
+
try:
|
| 130 |
+
# Note: You might need to authenticate with HuggingFace for this
|
| 131 |
+
sd = load_file(hf_hub_download(repo_id="black-forest-labs/FLUX.1-schnell", filename="flux1-schnell.safetensors"))
|
| 132 |
+
# Adjust state dict keys if needed
|
| 133 |
+
model.load_state_dict(sd, strict=False)
|
| 134 |
+
print("Loaded Flux schnell model successfully!")
|
| 135 |
+
except Exception as e1:
|
| 136 |
+
print(f"Could not load Flux schnell: {e1}")
|
| 137 |
+
|
| 138 |
+
# Try the dev version
|
| 139 |
+
try:
|
| 140 |
+
sd = load_file(hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="flux1-dev.safetensors"))
|
| 141 |
+
model.load_state_dict(sd, strict=False)
|
| 142 |
+
print("Loaded Flux dev model successfully!")
|
| 143 |
+
except Exception as e2:
|
| 144 |
+
print(f"Could not load Flux dev: {e2}")
|
| 145 |
+
|
| 146 |
+
# If no pretrained weights are available, warn the user
|
| 147 |
+
print("\n" + "="*50)
|
| 148 |
+
print("WARNING: Could not load pretrained Flux weights!")
|
| 149 |
+
print("The model will use random initialization.")
|
| 150 |
+
print("For proper results, you need to:")
|
| 151 |
+
print("1. Authenticate with HuggingFace: huggingface-cli login")
|
| 152 |
+
print("2. Accept the Flux model license agreement")
|
| 153 |
+
print("3. Or use a publicly available Flux checkpoint")
|
| 154 |
+
print("="*50 + "\n")
|
| 155 |
+
|
| 156 |
+
model = model.to(dtype=torch.bfloat16, device=device)
|
| 157 |
|
| 158 |
except Exception as e:
|
| 159 |
print(f"Error initializing Flux model: {e}")
|
| 160 |
+
# Continue with random initialization for now
|
| 161 |
+
model = Flux().to(dtype=torch.bfloat16, device=device)
|
| 162 |
|
| 163 |
model_initialized = True
|
| 164 |
print("Models initialized successfully!")
|