Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,14 @@ CUSTOM_LORA_DIR = "./Custom_LoRAs"
|
|
14 |
os.makedirs(LOCAL_LORA_DIR, exist_ok=True)
|
15 |
os.makedirs(CUSTOM_LORA_DIR, exist_ok=True)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
print("downloading OmniConsistency base LoRA …")
|
18 |
omni_consistency_path = hf_hub_download(
|
19 |
repo_id="showlab/OmniConsistency",
|
@@ -23,11 +31,13 @@ omni_consistency_path = hf_hub_download(
|
|
23 |
|
24 |
print("loading base pipeline …")
|
25 |
pipe = FluxPipeline.from_pretrained(
|
26 |
-
BASE_PATH, torch_dtype=
|
27 |
-
).to(
|
|
|
28 |
set_single_lora(pipe.transformer, omni_consistency_path,
|
29 |
lora_weights=[1], cond_size=512)
|
30 |
|
|
|
31 |
def download_all_loras():
|
32 |
lora_names = [
|
33 |
"3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy",
|
@@ -48,7 +58,8 @@ def clear_cache(transformer):
|
|
48 |
for _, attn_processor in transformer.attn_processors.items():
|
49 |
attn_processor.bank_kv.clear()
|
50 |
|
51 |
-
|
|
|
52 |
def generate_image(
|
53 |
lora_name,
|
54 |
custom_repo_id,
|
@@ -113,7 +124,7 @@ def generate_image(
|
|
113 |
clear_cache(pipe.transformer)
|
114 |
return uploaded_image, out_img
|
115 |
|
116 |
-
#
|
117 |
def create_interface():
|
118 |
demo_lora_names = [
|
119 |
"3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy",
|
@@ -130,7 +141,6 @@ def create_interface():
|
|
130 |
new_trigger = " ".join(lora_name.split("_"))+ " style,"
|
131 |
return new_trigger + prompt
|
132 |
|
133 |
-
# Example data
|
134 |
examples = [
|
135 |
["3D_Chibi", "", "3D Chibi style, Two smiling colleagues enthusiastically high-five in front of a whiteboard filled with technical notes about multimodal learning, reflecting a moment of success and collaboration at OpenAI.",
|
136 |
Image.open("./test_imgs/00.png"), 680, 1024, 3.5, 24, 42],
|
|
|
14 |
os.makedirs(LOCAL_LORA_DIR, exist_ok=True)
|
15 |
os.makedirs(CUSTOM_LORA_DIR, exist_ok=True)
|
16 |
|
17 |
+
# ------------------ DEVICE SETUP (✅ supports CPU-only Spaces) ------------------ #
|
18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
20 |
+
print(f"🚀 Running on device: {device}")
|
21 |
+
if device.type == "cpu":
|
22 |
+
print("⚠️ WARNING: No CUDA detected. Running on CPU. Generation may be slow.")
|
23 |
+
|
24 |
+
# ------------------ Load Base LoRA ------------------ #
|
25 |
print("downloading OmniConsistency base LoRA …")
|
26 |
omni_consistency_path = hf_hub_download(
|
27 |
repo_id="showlab/OmniConsistency",
|
|
|
31 |
|
32 |
print("loading base pipeline …")
|
33 |
pipe = FluxPipeline.from_pretrained(
|
34 |
+
BASE_PATH, torch_dtype=dtype
|
35 |
+
).to(device)
|
36 |
+
|
37 |
set_single_lora(pipe.transformer, omni_consistency_path,
|
38 |
lora_weights=[1], cond_size=512)
|
39 |
|
40 |
+
# ------------------ Download LoRA Styles ------------------ #
|
41 |
def download_all_loras():
|
42 |
lora_names = [
|
43 |
"3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy",
|
|
|
58 |
for _, attn_processor in transformer.attn_processors.items():
|
59 |
attn_processor.bank_kv.clear()
|
60 |
|
61 |
+
# ------------------ Generation Function ------------------ #
|
62 |
+
@spaces.GPU() # Will fallback silently if GPU not available
|
63 |
def generate_image(
|
64 |
lora_name,
|
65 |
custom_repo_id,
|
|
|
124 |
clear_cache(pipe.transformer)
|
125 |
return uploaded_image, out_img
|
126 |
|
127 |
+
# ------------------ UI Interface ------------------ #
|
128 |
def create_interface():
|
129 |
demo_lora_names = [
|
130 |
"3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy",
|
|
|
141 |
new_trigger = " ".join(lora_name.split("_"))+ " style,"
|
142 |
return new_trigger + prompt
|
143 |
|
|
|
144 |
examples = [
|
145 |
["3D_Chibi", "", "3D Chibi style, Two smiling colleagues enthusiastically high-five in front of a whiteboard filled with technical notes about multimodal learning, reflecting a moment of success and collaboration at OpenAI.",
|
146 |
Image.open("./test_imgs/00.png"), 680, 1024, 3.5, 24, 42],
|