Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -9,23 +9,17 @@ import spaces
|
|
9 |
dtype = torch.bfloat16
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
base_model = "black-forest-labs/FLUX.1-dev"
|
12 |
-
|
13 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
14 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
15 |
-
|
16 |
MAX_SEED = 2**32-1
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
pipe.load_lora_weights(LORA_PATH)
|
24 |
-
|
25 |
-
@spaces.GPU(duration=70)
|
26 |
-
def generate_image(prompt, width, height):
|
27 |
# Combine prompt with trigger word
|
28 |
-
full_prompt = f"{
|
29 |
|
30 |
# Set up generation parameters
|
31 |
seed = random.randint(0, MAX_SEED)
|
@@ -43,8 +37,8 @@ def generate_image(prompt, width, height):
|
|
43 |
|
44 |
return image
|
45 |
|
46 |
-
def run_lora(prompt, width, height):
|
47 |
-
return generate_image(prompt, width, height)
|
48 |
|
49 |
# Set up the Gradio interface
|
50 |
with gr.Blocks() as app:
|
@@ -57,13 +51,17 @@ with gr.Blocks() as app:
|
|
57 |
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512)
|
58 |
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
|
59 |
|
|
|
|
|
|
|
|
|
60 |
generate_button = gr.Button("Generate Image")
|
61 |
|
62 |
output_image = gr.Image(label="Generated Image")
|
63 |
|
64 |
generate_button.click(
|
65 |
fn=run_lora,
|
66 |
-
inputs=[prompt, width, height],
|
67 |
outputs=[output_image]
|
68 |
)
|
69 |
|
|
|
9 |
dtype = torch.bfloat16
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
base_model = "black-forest-labs/FLUX.1-dev"
|
|
|
12 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
13 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
|
|
14 |
MAX_SEED = 2**32-1
|
15 |
|
16 |
+
@spaces.GPU()
|
17 |
+
def generate_image(prompt, width, height, lora_path, trigger_word):
|
18 |
+
# Load LoRA weights
|
19 |
+
pipe.load_lora_weights(lora_path)
|
20 |
+
|
|
|
|
|
|
|
|
|
21 |
# Combine prompt with trigger word
|
22 |
+
full_prompt = f"{trigger_word} {prompt}"
|
23 |
|
24 |
# Set up generation parameters
|
25 |
seed = random.randint(0, MAX_SEED)
|
|
|
37 |
|
38 |
return image
|
39 |
|
40 |
+
def run_lora(prompt, width, height, lora_path, trigger_word):
|
41 |
+
return generate_image(prompt, width, height, lora_path, trigger_word)
|
42 |
|
43 |
# Set up the Gradio interface
|
44 |
with gr.Blocks() as app:
|
|
|
51 |
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512)
|
52 |
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
|
53 |
|
54 |
+
with gr.Row():
|
55 |
+
lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2")
|
56 |
+
trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK")
|
57 |
+
|
58 |
generate_button = gr.Button("Generate Image")
|
59 |
|
60 |
output_image = gr.Image(label="Generated Image")
|
61 |
|
62 |
generate_button.click(
|
63 |
fn=run_lora,
|
64 |
+
inputs=[prompt, width, height, lora_path, trigger_word],
|
65 |
outputs=[output_image]
|
66 |
)
|
67 |
|