Manireddy1508 commited on
Commit
4895c45
·
verified ·
1 Parent(s): b2ef34b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -14,17 +14,16 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
14
  dtype = torch.float16 if device == "cuda" else torch.float32
15
 
16
  # ----------------------------
17
- # 📦 Load Inpainting Model
18
  # ----------------------------
19
-
20
-
21
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
22
  "runwayml/stable-diffusion-inpainting",
23
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
24
- ).to("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
- pipe.enable_attention_slicing()
27
- pipe.enable_model_cpu_offload()
 
28
 
29
  # ----------------------------
30
  # 🎨 Image Generation Function
@@ -37,9 +36,9 @@ def process_image(prompt, image, mask, num_variations):
37
  reasoning_json = extract_scene_plan(prompt)
38
  print("🧠 Scene plan extracted:", reasoning_json)
39
 
40
- # Resize inputs to 1024x1024 (required for SDXL)
41
- image = image.resize((1024, 1024)).convert("RGB")
42
- mask = mask.resize((1024, 1024)).convert("L")
43
 
44
  results = []
45
  for i in range(num_variations):
@@ -64,7 +63,7 @@ def process_image(prompt, image, mask, num_variations):
64
  # 🖼️ Gradio UI
65
  # ----------------------------
66
  with gr.Blocks() as demo:
67
- gr.Markdown("## 🧠 NewCrux Inpainting Demo (SDXL)\nUpload a product image, a mask, and a prompt to generate realistic content.")
68
 
69
  with gr.Row():
70
  with gr.Column():
@@ -93,4 +92,3 @@ demo.launch()
93
 
94
 
95
 
96
-
 
14
  dtype = torch.float16 if device == "cuda" else torch.float32
15
 
16
  # ----------------------------
17
+ # 📦 Load Inpainting Model (SD 1.5-based for now)
18
  # ----------------------------
 
 
19
  pipe = StableDiffusionInpaintPipeline.from_pretrained(
20
  "runwayml/stable-diffusion-inpainting",
21
+ torch_dtype=dtype
22
+ ).to(device)
23
 
24
+ if device == "cuda":
25
+ pipe.enable_attention_slicing()
26
+ pipe.enable_model_cpu_offload() # ✅ Only on GPU!
27
 
28
  # ----------------------------
29
  # 🎨 Image Generation Function
 
36
  reasoning_json = extract_scene_plan(prompt)
37
  print("🧠 Scene plan extracted:", reasoning_json)
38
 
39
+ # Step 2: Resize inputs (SD 1.5 = 512x512, SDXL = 1024x1024)
40
+ image = image.resize((512, 512)).convert("RGB")
41
+ mask = mask.resize((512, 512)).convert("L")
42
 
43
  results = []
44
  for i in range(num_variations):
 
63
  # 🖼️ Gradio UI
64
  # ----------------------------
65
  with gr.Blocks() as demo:
66
+ gr.Markdown("## 🧠 NewCrux Inpainting Demo (SD 1.5)\nUpload a product image, a mask, and a prompt to generate realistic content.")
67
 
68
  with gr.Row():
69
  with gr.Column():
 
92
 
93
 
94