# app.py import gradio as gr from PIL import Image import os import torch from diffusers import StableDiffusionInpaintPipeline from utils.planner import extract_scene_plan # 🧠 Brain Layer # ---------------------------- # 🔧 Device Setup # ---------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 # ---------------------------- # 📦 Load Inpainting Model # ---------------------------- pipe = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to("cuda" if torch.cuda.is_available() else "cpu") pipe.enable_attention_slicing() pipe.enable_model_cpu_offload() # ---------------------------- # 🎨 Image Generation Function # ---------------------------- def process_image(prompt, image, mask, num_variations): try: print("🧠 Prompt received:", prompt) # 🧠 Step 1: Brain Layer reasoning_json = extract_scene_plan(prompt) print("🧠 Scene plan extracted:", reasoning_json) # Resize inputs to 1024x1024 (required for SDXL) image = image.resize((1024, 1024)).convert("RGB") mask = mask.resize((1024, 1024)).convert("L") results = [] for i in range(num_variations): print(f"🎨 Generating variation {i + 1}...") output = pipe( prompt=prompt, image=image, mask_image=mask, strength=0.98, guidance_scale=7.5, num_inference_steps=40 ).images[0] results.append(output) return results, reasoning_json except Exception as e: print("❌ Error during generation:", e) return ["❌ Generation failed"], {"error": str(e)} # ---------------------------- # 🖼️ Gradio UI # ---------------------------- with gr.Blocks() as demo: gr.Markdown("## 🧠 NewCrux Inpainting Demo (SDXL)\nUpload a product image, a mask, and a prompt to generate realistic content.") with gr.Row(): with gr.Column(): prompt_input = gr.Textbox(label="Prompt") image_input = gr.Image(type="pil", label="Upload Product Image") mask_input = gr.Image(type="pil", label="Upload Mask (white = keep, black = replace)") variation_slider = gr.Slider(1, 4, step=1, value=1, label="Number of Variations") generate_btn = gr.Button("Generate") with gr.Column(): output_gallery = gr.Gallery( label="Generated Variations", columns=2, rows=2, height="auto" ) json_output = gr.JSON(label="🧠 Brain Layer Reasoning") generate_btn.click( fn=process_image, inputs=[prompt_input, image_input, mask_input, variation_slider], outputs=[output_gallery, json_output] ) demo.launch()