Avanish11 commited on
Commit
d20bc22
Β·
verified Β·
1 Parent(s): 9fb8310

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -12
app.py CHANGED
@@ -6,30 +6,52 @@ from PIL import Image
6
  # βœ… Base model (commercial use allowed)
7
  BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
8
 
9
- # LoRAs
10
- LORA_1 = "gh1bli-style.safetensors"
11
- LORA_2 = "ghibli_landscape_lora.safetensors"
12
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
15
 
16
- # 🧠 Load base text-to-image pipeline
17
  print("πŸ”Ή Loading SDXL base model...")
18
  pipe_txt2img = DiffusionPipeline.from_pretrained(
19
  BASE_MODEL,
20
- torch_dtype=dtype,
21
  use_safetensors=True,
22
  ).to(device)
23
 
24
- # 🧩 Apply both LoRAs
 
 
 
 
 
 
 
 
 
25
  print("πŸ”Ή Applying Ghibli-style LoRAs...")
26
- pipe_txt2img.load_lora_weights(LORA_1)
27
- pipe_txt2img.load_lora_weights(LORA_2)
 
 
 
28
 
29
- # πŸ–ΌοΈ Image-to-image pipeline (same model + LoRAs)
30
  print("πŸ”Ή Setting up image-to-image pipeline...")
31
  pipe_img2img = StableDiffusionXLImg2ImgPipeline(**pipe_txt2img.components)
 
 
 
 
 
 
 
 
32
 
 
33
  def generate(prompt, steps=30, guidance=7.5, seed=42, strength=0.6, image=None):
34
  generator = torch.Generator(device=device).manual_seed(int(seed))
35
 
@@ -55,7 +77,7 @@ def generate(prompt, steps=30, guidance=7.5, seed=42, strength=0.6, image=None):
55
 
56
  return result
57
 
58
- # 🎨 Gradio Interface
59
  demo = gr.Interface(
60
  fn=generate,
61
  inputs=[
@@ -67,8 +89,8 @@ demo = gr.Interface(
67
  gr.Image(label="Upload Image (optional)", type="filepath"),
68
  ],
69
  outputs=gr.Image(label="Generated Image"),
70
- title="Ghibli Style Maker – Text & Image to Image",
71
- description="Create Ghibli-style art from a text prompt or transform any photo into Ghibli-inspired scenery using SDXL + LoRAs.",
72
  )
73
 
74
  if __name__ == "__main__":
 
6
  # βœ… Base model (commercial use allowed)
7
  BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
8
 
9
+ # βœ… Local LoRA weights
10
+ LORA_1 = "./gh1bli-style.safetensors"
11
+ LORA_2 = "./ghibli_landscape_lora.safetensors"
12
 
13
+ # βœ… Device setup
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
 
17
+ # 🧠 Load base pipeline
18
  print("πŸ”Ή Loading SDXL base model...")
19
  pipe_txt2img = DiffusionPipeline.from_pretrained(
20
  BASE_MODEL,
21
+ dtype=dtype,
22
  use_safetensors=True,
23
  ).to(device)
24
 
25
+ # βœ… Enable CPU/GPU memory optimization
26
+ if device == "cuda":
27
+ print("πŸš€ Using GPU optimization")
28
+ pipe_txt2img.enable_model_cpu_offload() # For big SDXL weights
29
+ else:
30
+ print("🧩 Using CPU memory optimization")
31
+ pipe_txt2img.enable_attention_slicing() # Reduce RAM spikes
32
+ pipe_txt2img.enable_sequential_cpu_offload()
33
+
34
+ # 🧩 Load both LoRAs (PEFT-compatible)
35
  print("πŸ”Ή Applying Ghibli-style LoRAs...")
36
+ pipe_txt2img.load_lora_weights(LORA_1, adapter_name="ghibli_style")
37
+ pipe_txt2img.load_lora_weights(LORA_2, adapter_name="ghibli_landscape")
38
+
39
+ # 🧠 Merge both styles
40
+ pipe_txt2img.set_adapters(["ghibli_style", "ghibli_landscape"], adapter_weights=[0.7, 0.6])
41
 
42
+ # πŸ–ΌοΈ Image-to-Image pipeline (inherits adapters)
43
  print("πŸ”Ή Setting up image-to-image pipeline...")
44
  pipe_img2img = StableDiffusionXLImg2ImgPipeline(**pipe_txt2img.components)
45
+ pipe_img2img.set_adapters(["ghibli_style", "ghibli_landscape"], adapter_weights=[0.7, 0.6])
46
+
47
+ # βœ… Add same memory optimization for img2img
48
+ if device == "cuda":
49
+ pipe_img2img.enable_model_cpu_offload()
50
+ else:
51
+ pipe_img2img.enable_attention_slicing()
52
+ pipe_img2img.enable_sequential_cpu_offload()
53
 
54
+ # 🎨 Generation function
55
  def generate(prompt, steps=30, guidance=7.5, seed=42, strength=0.6, image=None):
56
  generator = torch.Generator(device=device).manual_seed(int(seed))
57
 
 
77
 
78
  return result
79
 
80
+ # πŸŽ›οΈ Gradio UI
81
  demo = gr.Interface(
82
  fn=generate,
83
  inputs=[
 
89
  gr.Image(label="Upload Image (optional)", type="filepath"),
90
  ],
91
  outputs=gr.Image(label="Generated Image"),
92
+ title="🎨 Ghibli Style Maker – Text & Image to Image",
93
+ description="Generate or transform images in Studio Ghibli-inspired style using SDXL and LoRAs.",
94
  )
95
 
96
  if __name__ == "__main__":