ma4389 commited on
Commit
f97ce46
·
verified ·
1 Parent(s): d636062

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -2,26 +2,33 @@ import torch
2
  from diffusers import DiffusionPipeline
3
  import gradio as gr
4
 
5
- # Load model with float16 for GPU or float32 for CPU
6
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
8
 
 
9
  pipe = DiffusionPipeline.from_pretrained(
10
- "CompVis/stable-diffusion-v1-4", torch_dtype=dtype
 
11
  )
12
- pipe.load_lora_weights("EliKet/train_text_to_img")
13
  pipe.to(device)
14
 
 
 
 
 
15
  def generate_image(prompt):
16
  image = pipe(prompt).images[0]
17
  return image
18
 
 
19
  demo = gr.Interface(
20
  fn=generate_image,
21
- inputs=gr.Textbox(placeholder="Enter your image prompt here..."),
22
  outputs="image",
23
- title="Text-to-Image Generator (Lynx)",
24
- description="Type a prompt like 'a lynx in the snowy forest, ultra-detailed'."
25
  )
26
 
 
27
  demo.launch()
 
2
  from diffusers import DiffusionPipeline
3
  import gradio as gr
4
 
5
+ # Detect device
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ dtype = torch.float16 if device == "cuda" else torch.float32
8
 
9
+ # Load pipeline
10
  pipe = DiffusionPipeline.from_pretrained(
11
+ "CompVis/stable-diffusion-v1-4",
12
+ torch_dtype=dtype
13
  )
 
14
  pipe.to(device)
15
 
16
+ # Load LoRA weights (requires `peft` installed)
17
+ pipe.load_lora_weights("EliKet/train_text_to_img")
18
+
19
+ # Inference function
20
  def generate_image(prompt):
21
  image = pipe(prompt).images[0]
22
  return image
23
 
24
+ # Gradio Interface
25
  demo = gr.Interface(
26
  fn=generate_image,
27
+ inputs=gr.Textbox(lines=2, placeholder="Describe the image you want..."),
28
  outputs="image",
29
+ title="🖼️ LoRA Text-to-Image Generator",
30
+ description="Enter a prompt to generate an image using Stable Diffusion with LoRA (EliKet/train_text_to_img)."
31
  )
32
 
33
+ # Launch app
34
  demo.launch()