erikbeltran commited on
Commit
d06075d
·
verified ·
1 Parent(s): f645c51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -9
app.py CHANGED
@@ -4,6 +4,9 @@ import torch
4
  from diffusers import DiffusionPipeline, AutoencoderTiny
5
  import random
6
  import spaces
 
 
 
7
 
8
  # Initialize the base model
9
  dtype = torch.bfloat16
@@ -13,8 +16,11 @@ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).
13
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
14
  MAX_SEED = 2**32-1
15
 
 
 
 
16
  @spaces.GPU()
17
- def generate_image(prompt, width, height, lora_path, trigger_word):
18
  # Load LoRA weights
19
  pipe.load_lora_weights(lora_path)
20
 
@@ -35,13 +41,21 @@ def generate_image(prompt, width, height, lora_path, trigger_word):
35
  generator=generator,
36
  ).images[0]
37
 
38
- return image
 
 
 
 
 
 
 
39
 
40
- def run_lora(prompt, width, height, lora_path, trigger_word):
41
- return generate_image(prompt, width, height, lora_path, trigger_word)
 
42
 
43
  # Set up the Gradio interface
44
- with gr.Blocks() as app:
45
  gr.Markdown("# LoRA Image Generator")
46
 
47
  with gr.Row():
@@ -54,17 +68,38 @@ with gr.Blocks() as app:
54
  with gr.Row():
55
  lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2")
56
  trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK")
 
57
 
58
  generate_button = gr.Button("Generate Image")
59
 
60
  output_image = gr.Image(label="Generated Image")
 
61
 
62
  generate_button.click(
63
  fn=run_lora,
64
- inputs=[prompt, width, height, lora_path, trigger_word],
65
- outputs=[output_image]
66
  )
67
 
 
 
 
 
 
 
 
 
 
68
  if __name__ == "__main__":
69
- app.queue()
70
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
4
  from diffusers import DiffusionPipeline, AutoencoderTiny
5
  import random
6
  import spaces
7
+ from PIL import Image
8
+ from fastapi import FastAPI
9
+ from fastapi.responses import FileResponse
10
 
11
  # Initialize the base model
12
  dtype = torch.bfloat16
 
16
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
17
  MAX_SEED = 2**32-1
18
 
19
+ # Create a FastAPI app
20
+ app = FastAPI()
21
+
22
  @spaces.GPU()
23
+ def generate_image(prompt, width, height, lora_path, trigger_word, hash_value):
24
  # Load LoRA weights
25
  pipe.load_lora_weights(lora_path)
26
 
 
41
  generator=generator,
42
  ).images[0]
43
 
44
+ # Create ./tmp/ directory if it doesn't exist
45
+ os.makedirs('./tmp', exist_ok=True)
46
+
47
+ # Save the image with the provided hash
48
+ image_path = f'./tmp/{hash_value}.png'
49
+ image.save(image_path)
50
+
51
+ return image, image_path
52
 
53
+ def run_lora(prompt, width, height, lora_path, trigger_word, hash_value):
54
+ image, image_path = generate_image(prompt, width, height, lora_path, trigger_word, hash_value)
55
+ return image, image_path
56
 
57
  # Set up the Gradio interface
58
+ with gr.Blocks() as gradio_app:
59
  gr.Markdown("# LoRA Image Generator")
60
 
61
  with gr.Row():
 
68
  with gr.Row():
69
  lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2")
70
  trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK")
71
+ hash_value = gr.Textbox(label="Hash", placeholder="Enter a unique hash for the image")
72
 
73
  generate_button = gr.Button("Generate Image")
74
 
75
  output_image = gr.Image(label="Generated Image")
76
+ output_path = gr.Textbox(label="Saved Image Path")
77
 
78
  generate_button.click(
79
  fn=run_lora,
80
+ inputs=[prompt, width, height, lora_path, trigger_word, hash_value],
81
+ outputs=[output_image, output_path]
82
  )
83
 
84
+ # FastAPI endpoint for downloading images
85
+ @app.get("/download/{hash_value}")
86
+ async def download_image(hash_value: str):
87
+ image_path = f"./tmp/{hash_value}.png"
88
+ if os.path.exists(image_path):
89
+ return FileResponse(image_path, media_type="image/png", filename=f"{hash_value}.png")
90
+ return {"error": "Image not found"}
91
+
92
+ # Launch both Gradio and FastAPI
93
  if __name__ == "__main__":
94
+ import uvicorn
95
+ from threading import Thread
96
+
97
+ def run_gradio():
98
+ gradio_app.queue()
99
+ gradio_app.launch()
100
+
101
+ # Start Gradio in a separate thread
102
+ Thread(target=run_gradio, daemon=True).start()
103
+
104
+ # Run FastAPI with uvicorn
105
+ uvicorn.run(app, host="0.0.0.0", port=8000)