Spaces:
Paused
Paused
# app.py | |
import gradio as gr | |
from PIL import Image | |
import torch | |
from diffusers import StableDiffusionXLImg2ImgPipeline | |
from utils.planner import ( | |
extract_scene_plan, | |
generate_prompt_variations_from_scene, | |
generate_negative_prompt_from_scene | |
) | |
# ---------------------------- | |
# π§ Device Setup | |
# ---------------------------- | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
# ---------------------------- | |
# β Load SDXL Only Pipeline | |
# ---------------------------- | |
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=dtype, | |
variant="fp16" if device == "cuda" else None, | |
use_safetensors=True, | |
) | |
pipe.to(device) | |
pipe.enable_model_cpu_offload() | |
pipe.enable_attention_slicing() | |
# ---------------------------- | |
# π¨ Image Generation Function | |
# ---------------------------- | |
def process_image(prompt, image, num_variations): | |
try: | |
print("π§ User Prompt:", prompt) | |
if image is None: | |
raise ValueError("π« Uploaded image is missing.") | |
# Step 1: Extract scene plan | |
scene_plan = extract_scene_plan(prompt, image) | |
print("π Scene Plan:", scene_plan) | |
# Step 2: Generate enriched prompts | |
prompt_list = generate_prompt_variations_from_scene(scene_plan, prompt, num_variations) | |
print("β Enriched Prompts:", prompt_list) | |
# Step 3: Generate negative prompt | |
negative_prompt = generate_negative_prompt_from_scene(scene_plan) | |
print("π« Negative Prompt:", negative_prompt) | |
# Step 4: Resize image to SDXL resolution | |
image = image.resize((1024, 1024)).convert("RGB") | |
# Step 5: Generate outputs with SDXL only | |
outputs = [] | |
for i, enriched_prompt in enumerate(prompt_list): | |
print(f"π¨ Generating variation {i+1}...") | |
result = pipe( | |
prompt=enriched_prompt, | |
negative_prompt=negative_prompt, | |
image=image, | |
strength=0.7, # β You can fine-tune this | |
guidance_scale=7.5, | |
num_inference_steps=30, | |
) | |
outputs.append(result.images[0]) | |
return outputs | |
except Exception as e: | |
print("β Generation Error:", e) | |
return [Image.new("RGB", (512, 512), color="red")] | |
# ---------------------------- | |
# πΌοΈ Gradio Interface | |
# ---------------------------- | |
demo = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
gr.Textbox(label="Prompt"), | |
gr.Image(type="pil", label="Product Image"), | |
gr.Slider(1, 5, value=3, step=1, label="Number of Variations") | |
], | |
outputs=gr.Gallery(label="Generated Images").style(grid=[2], height="auto"), | |
title="NewCrux Product Image Generator (SDXL Only)", | |
description="Upload a product image and enter a prompt. SDXL will generate enriched variations using AI." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |