Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -14,17 +14,16 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
14 |
dtype = torch.float16 if device == "cuda" else torch.float32
|
15 |
|
16 |
# ----------------------------
|
17 |
-
# 📦 Load Inpainting Model
|
18 |
# ----------------------------
|
19 |
-
|
20 |
-
|
21 |
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
22 |
"runwayml/stable-diffusion-inpainting",
|
23 |
-
torch_dtype=
|
24 |
-
).to(
|
25 |
|
26 |
-
|
27 |
-
pipe.
|
|
|
28 |
|
29 |
# ----------------------------
|
30 |
# 🎨 Image Generation Function
|
@@ -37,9 +36,9 @@ def process_image(prompt, image, mask, num_variations):
|
|
37 |
reasoning_json = extract_scene_plan(prompt)
|
38 |
print("🧠 Scene plan extracted:", reasoning_json)
|
39 |
|
40 |
-
# Resize inputs
|
41 |
-
image = image.resize((
|
42 |
-
mask = mask.resize((
|
43 |
|
44 |
results = []
|
45 |
for i in range(num_variations):
|
@@ -64,7 +63,7 @@ def process_image(prompt, image, mask, num_variations):
|
|
64 |
# 🖼️ Gradio UI
|
65 |
# ----------------------------
|
66 |
with gr.Blocks() as demo:
|
67 |
-
gr.Markdown("## 🧠 NewCrux Inpainting Demo (
|
68 |
|
69 |
with gr.Row():
|
70 |
with gr.Column():
|
@@ -93,4 +92,3 @@ demo.launch()
|
|
93 |
|
94 |
|
95 |
|
96 |
-
|
|
|
14 |
dtype = torch.float16 if device == "cuda" else torch.float32
|
15 |
|
16 |
# ----------------------------
|
17 |
+
# 📦 Load Inpainting Model (SD 1.5-based for now)
|
18 |
# ----------------------------
|
|
|
|
|
19 |
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
20 |
"runwayml/stable-diffusion-inpainting",
|
21 |
+
torch_dtype=dtype
|
22 |
+
).to(device)
|
23 |
|
24 |
+
if device == "cuda":
|
25 |
+
pipe.enable_attention_slicing()
|
26 |
+
pipe.enable_model_cpu_offload() # ✅ Only on GPU!
|
27 |
|
28 |
# ----------------------------
|
29 |
# 🎨 Image Generation Function
|
|
|
36 |
reasoning_json = extract_scene_plan(prompt)
|
37 |
print("🧠 Scene plan extracted:", reasoning_json)
|
38 |
|
39 |
+
# Step 2: Resize inputs (SD 1.5 = 512x512, SDXL = 1024x1024)
|
40 |
+
image = image.resize((512, 512)).convert("RGB")
|
41 |
+
mask = mask.resize((512, 512)).convert("L")
|
42 |
|
43 |
results = []
|
44 |
for i in range(num_variations):
|
|
|
63 |
# 🖼️ Gradio UI
|
64 |
# ----------------------------
|
65 |
with gr.Blocks() as demo:
|
66 |
+
gr.Markdown("## 🧠 NewCrux Inpainting Demo (SD 1.5)\nUpload a product image, a mask, and a prompt to generate realistic content.")
|
67 |
|
68 |
with gr.Row():
|
69 |
with gr.Column():
|
|
|
92 |
|
93 |
|
94 |
|
|