Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,9 @@ import torch
|
|
4 |
from diffusers import DiffusionPipeline, AutoencoderTiny
|
5 |
import random
|
6 |
import spaces
|
|
|
|
|
|
|
7 |
|
8 |
# Initialize the base model
|
9 |
dtype = torch.bfloat16
|
@@ -13,8 +16,11 @@ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).
|
|
13 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
14 |
MAX_SEED = 2**32-1
|
15 |
|
|
|
|
|
|
|
16 |
@spaces.GPU()
|
17 |
-
def generate_image(prompt, width, height, lora_path, trigger_word):
|
18 |
# Load LoRA weights
|
19 |
pipe.load_lora_weights(lora_path)
|
20 |
|
@@ -35,13 +41,21 @@ def generate_image(prompt, width, height, lora_path, trigger_word):
|
|
35 |
generator=generator,
|
36 |
).images[0]
|
37 |
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
def run_lora(prompt, width, height, lora_path, trigger_word):
|
41 |
-
|
|
|
42 |
|
43 |
# Set up the Gradio interface
|
44 |
-
with gr.Blocks() as
|
45 |
gr.Markdown("# LoRA Image Generator")
|
46 |
|
47 |
with gr.Row():
|
@@ -54,17 +68,38 @@ with gr.Blocks() as app:
|
|
54 |
with gr.Row():
|
55 |
lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2")
|
56 |
trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK")
|
|
|
57 |
|
58 |
generate_button = gr.Button("Generate Image")
|
59 |
|
60 |
output_image = gr.Image(label="Generated Image")
|
|
|
61 |
|
62 |
generate_button.click(
|
63 |
fn=run_lora,
|
64 |
-
inputs=[prompt, width, height, lora_path, trigger_word],
|
65 |
-
outputs=[output_image]
|
66 |
)
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
if __name__ == "__main__":
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from diffusers import DiffusionPipeline, AutoencoderTiny
|
5 |
import random
|
6 |
import spaces
|
7 |
+
from PIL import Image
|
8 |
+
from fastapi import FastAPI
|
9 |
+
from fastapi.responses import FileResponse
|
10 |
|
11 |
# Initialize the base model
|
12 |
dtype = torch.bfloat16
|
|
|
16 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
17 |
MAX_SEED = 2**32-1
|
18 |
|
19 |
+
# Create a FastAPI app
|
20 |
+
app = FastAPI()
|
21 |
+
|
22 |
@spaces.GPU()
|
23 |
+
def generate_image(prompt, width, height, lora_path, trigger_word, hash_value):
|
24 |
# Load LoRA weights
|
25 |
pipe.load_lora_weights(lora_path)
|
26 |
|
|
|
41 |
generator=generator,
|
42 |
).images[0]
|
43 |
|
44 |
+
# Create ./tmp/ directory if it doesn't exist
|
45 |
+
os.makedirs('./tmp', exist_ok=True)
|
46 |
+
|
47 |
+
# Save the image with the provided hash
|
48 |
+
image_path = f'./tmp/{hash_value}.png'
|
49 |
+
image.save(image_path)
|
50 |
+
|
51 |
+
return image, image_path
|
52 |
|
53 |
+
def run_lora(prompt, width, height, lora_path, trigger_word, hash_value):
|
54 |
+
image, image_path = generate_image(prompt, width, height, lora_path, trigger_word, hash_value)
|
55 |
+
return image, image_path
|
56 |
|
57 |
# Set up the Gradio interface
|
58 |
+
with gr.Blocks() as gradio_app:
|
59 |
gr.Markdown("# LoRA Image Generator")
|
60 |
|
61 |
with gr.Row():
|
|
|
68 |
with gr.Row():
|
69 |
lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2")
|
70 |
trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK")
|
71 |
+
hash_value = gr.Textbox(label="Hash", placeholder="Enter a unique hash for the image")
|
72 |
|
73 |
generate_button = gr.Button("Generate Image")
|
74 |
|
75 |
output_image = gr.Image(label="Generated Image")
|
76 |
+
output_path = gr.Textbox(label="Saved Image Path")
|
77 |
|
78 |
generate_button.click(
|
79 |
fn=run_lora,
|
80 |
+
inputs=[prompt, width, height, lora_path, trigger_word, hash_value],
|
81 |
+
outputs=[output_image, output_path]
|
82 |
)
|
83 |
|
84 |
+
# FastAPI endpoint for downloading images
|
85 |
+
@app.get("/download/{hash_value}")
|
86 |
+
async def download_image(hash_value: str):
|
87 |
+
image_path = f"./tmp/{hash_value}.png"
|
88 |
+
if os.path.exists(image_path):
|
89 |
+
return FileResponse(image_path, media_type="image/png", filename=f"{hash_value}.png")
|
90 |
+
return {"error": "Image not found"}
|
91 |
+
|
92 |
+
# Launch both Gradio and FastAPI
|
93 |
if __name__ == "__main__":
|
94 |
+
import uvicorn
|
95 |
+
from threading import Thread
|
96 |
+
|
97 |
+
def run_gradio():
|
98 |
+
gradio_app.queue()
|
99 |
+
gradio_app.launch()
|
100 |
+
|
101 |
+
# Start Gradio in a separate thread
|
102 |
+
Thread(target=run_gradio, daemon=True).start()
|
103 |
+
|
104 |
+
# Run FastAPI with uvicorn
|
105 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|