imagetoimage / app.py
Manireddy1508's picture
Update app.py
b2ef34b verified
raw
history blame
3.03 kB
# app.py
import gradio as gr
from PIL import Image
import os
import torch
from diffusers import StableDiffusionInpaintPipeline
from utils.planner import extract_scene_plan # 🧠 Brain Layer
# ----------------------------
# πŸ”§ Device Setup
# ----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# ----------------------------
# πŸ“¦ Load Inpainting Model
# ----------------------------
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to("cuda" if torch.cuda.is_available() else "cpu")
pipe.enable_attention_slicing()
pipe.enable_model_cpu_offload()
# ----------------------------
# 🎨 Image Generation Function
# ----------------------------
def process_image(prompt, image, mask, num_variations):
try:
print("🧠 Prompt received:", prompt)
# 🧠 Step 1: Brain Layer
reasoning_json = extract_scene_plan(prompt)
print("🧠 Scene plan extracted:", reasoning_json)
# Resize inputs to 1024x1024 (required for SDXL)
image = image.resize((1024, 1024)).convert("RGB")
mask = mask.resize((1024, 1024)).convert("L")
results = []
for i in range(num_variations):
print(f"🎨 Generating variation {i + 1}...")
output = pipe(
prompt=prompt,
image=image,
mask_image=mask,
strength=0.98,
guidance_scale=7.5,
num_inference_steps=40
).images[0]
results.append(output)
return results, reasoning_json
except Exception as e:
print("❌ Error during generation:", e)
return ["❌ Generation failed"], {"error": str(e)}
# ----------------------------
# πŸ–ΌοΈ Gradio UI
# ----------------------------
with gr.Blocks() as demo:
gr.Markdown("## 🧠 NewCrux Inpainting Demo (SDXL)\nUpload a product image, a mask, and a prompt to generate realistic content.")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt")
image_input = gr.Image(type="pil", label="Upload Product Image")
mask_input = gr.Image(type="pil", label="Upload Mask (white = keep, black = replace)")
variation_slider = gr.Slider(1, 4, step=1, value=1, label="Number of Variations")
generate_btn = gr.Button("Generate")
with gr.Column():
output_gallery = gr.Gallery(
label="Generated Variations",
columns=2,
rows=2,
height="auto"
)
json_output = gr.JSON(label="🧠 Brain Layer Reasoning")
generate_btn.click(
fn=process_image,
inputs=[prompt_input, image_input, mask_input, variation_slider],
outputs=[output_gallery, json_output]
)
demo.launch()