File size: 3,453 Bytes
2caf84c
0e0ee20
 
4989e93
607d766
4989e93
d06075d
 
 
0e0ee20
 
c724573
 
463aefd
c724573
 
f3e96f9
c59400c
d06075d
 
 
f645c51
d06075d
f645c51
 
 
4989e93
f645c51
4989e93
 
 
5b82e60
fd8e800
4989e93
 
 
 
 
 
 
 
 
 
d06075d
 
 
 
 
 
 
 
2caf84c
d06075d
 
 
0b93385
4989e93
d06075d
4989e93
 
0e0ee20
4989e93
 
0e0ee20
4989e93
 
 
f645c51
 
 
d06075d
f645c51
4989e93
 
 
d06075d
4989e93
 
0e0ee20
d06075d
 
0e0ee20
 
d06075d
 
 
 
 
 
 
 
 
4989e93
d06075d
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import gradio as gr
import torch
from diffusers import DiffusionPipeline, AutoencoderTiny
import random
import spaces
from PIL import Image
from fastapi import FastAPI
from fastapi.responses import FileResponse

# Initialize the base model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
MAX_SEED = 2**32-1

# Create a FastAPI app
app = FastAPI()

@spaces.GPU()
def generate_image(prompt, width, height, lora_path, trigger_word, hash_value):
    # Load LoRA weights
    pipe.load_lora_weights(lora_path)
    
    # Combine prompt with trigger word
    full_prompt = f"{trigger_word} {prompt}"
    
    # Set up generation parameters
    seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device="cuda").manual_seed(seed)
    
    # Generate image
    image = pipe(
        prompt=full_prompt,
        num_inference_steps=28,
        guidance_scale=3.5,
        width=width,
        height=height,
        generator=generator,
    ).images[0]
    
    # Create ./tmp/ directory if it doesn't exist
    os.makedirs('./tmp', exist_ok=True)
    
    # Save the image with the provided hash
    image_path = f'./tmp/{hash_value}.png'
    image.save(image_path)
    
    return image, image_path

def run_lora(prompt, width, height, lora_path, trigger_word, hash_value):
    image, image_path = generate_image(prompt, width, height, lora_path, trigger_word, hash_value)
    return image, image_path

# Set up the Gradio interface
with gr.Blocks() as gradio_app:
    gr.Markdown("# LoRA Image Generator")
    
    with gr.Row():
        prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here")
    
    with gr.Row():
        width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512)
        height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
    
    with gr.Row():
        lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2")
        trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK")
        hash_value = gr.Textbox(label="Hash", placeholder="Enter a unique hash for the image")
    
    generate_button = gr.Button("Generate Image")
    
    output_image = gr.Image(label="Generated Image")
    output_path = gr.Textbox(label="Saved Image Path")
    
    generate_button.click(
        fn=run_lora,
        inputs=[prompt, width, height, lora_path, trigger_word, hash_value],
        outputs=[output_image, output_path]
    )

# FastAPI endpoint for downloading images
@app.get("/download/{hash_value}")
async def download_image(hash_value: str):
    image_path = f"./tmp/{hash_value}.png"
    if os.path.exists(image_path):
        return FileResponse(image_path, media_type="image/png", filename=f"{hash_value}.png")
    return {"error": "Image not found"}

# Launch both Gradio and FastAPI
if __name__ == "__main__":
    import uvicorn
    from threading import Thread
    
    def run_gradio():
        gradio_app.queue()
        gradio_app.launch()
    
    # Start Gradio in a separate thread
    Thread(target=run_gradio, daemon=True).start()
    
    # Run FastAPI with uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)