Spaces:
Paused
Paused
# app.py | |
import gradio as gr | |
from PIL import Image | |
import os | |
import torch | |
from diffusers import StableDiffusionInpaintPipeline | |
from utils.planner import extract_scene_plan # π§ Brain Layer | |
# ---------------------------- | |
# π§ Device Setup | |
# ---------------------------- | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
# ---------------------------- | |
# π¦ Load Inpainting Model | |
# ---------------------------- | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
).to("cuda" if torch.cuda.is_available() else "cpu") | |
pipe.enable_attention_slicing() | |
pipe.enable_model_cpu_offload() | |
# ---------------------------- | |
# π¨ Image Generation Function | |
# ---------------------------- | |
def process_image(prompt, image, mask, num_variations): | |
try: | |
print("π§ Prompt received:", prompt) | |
# π§ Step 1: Brain Layer | |
reasoning_json = extract_scene_plan(prompt) | |
print("π§ Scene plan extracted:", reasoning_json) | |
# Resize inputs to 1024x1024 (required for SDXL) | |
image = image.resize((1024, 1024)).convert("RGB") | |
mask = mask.resize((1024, 1024)).convert("L") | |
results = [] | |
for i in range(num_variations): | |
print(f"π¨ Generating variation {i + 1}...") | |
output = pipe( | |
prompt=prompt, | |
image=image, | |
mask_image=mask, | |
strength=0.98, | |
guidance_scale=7.5, | |
num_inference_steps=40 | |
).images[0] | |
results.append(output) | |
return results, reasoning_json | |
except Exception as e: | |
print("β Error during generation:", e) | |
return ["β Generation failed"], {"error": str(e)} | |
# ---------------------------- | |
# πΌοΈ Gradio UI | |
# ---------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("## π§ NewCrux Inpainting Demo (SDXL)\nUpload a product image, a mask, and a prompt to generate realistic content.") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Prompt") | |
image_input = gr.Image(type="pil", label="Upload Product Image") | |
mask_input = gr.Image(type="pil", label="Upload Mask (white = keep, black = replace)") | |
variation_slider = gr.Slider(1, 4, step=1, value=1, label="Number of Variations") | |
generate_btn = gr.Button("Generate") | |
with gr.Column(): | |
output_gallery = gr.Gallery( | |
label="Generated Variations", | |
columns=2, | |
rows=2, | |
height="auto" | |
) | |
json_output = gr.JSON(label="π§ Brain Layer Reasoning") | |
generate_btn.click( | |
fn=process_image, | |
inputs=[prompt_input, image_input, mask_input, variation_slider], | |
outputs=[output_gallery, json_output] | |
) | |
demo.launch() | |