imagetoimage / app.py
Manireddy1508's picture
Update app.py
b574e01 verified
raw
history blame
3.04 kB
# app.py
import gradio as gr
from PIL import Image
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
from utils.planner import (
extract_scene_plan,
generate_prompt_variations_from_scene,
generate_negative_prompt_from_scene
)
# ----------------------------
# πŸ”§ Device Setup
# ----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# ----------------------------
# βœ… Load SDXL Only Pipeline
# ----------------------------
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=dtype,
variant="fp16" if device == "cuda" else None,
use_safetensors=True,
)
pipe.to(device)
pipe.enable_model_cpu_offload()
pipe.enable_attention_slicing()
# ----------------------------
# 🎨 Image Generation Function
# ----------------------------
def process_image(prompt, image, num_variations):
try:
print("🧠 User Prompt:", prompt)
if image is None:
raise ValueError("🚫 Uploaded image is missing.")
# Step 1: Extract scene plan
scene_plan = extract_scene_plan(prompt, image)
print("πŸ“‹ Scene Plan:", scene_plan)
# Step 2: Generate enriched prompts
prompt_list = generate_prompt_variations_from_scene(scene_plan, prompt, num_variations)
print("βœ… Enriched Prompts:", prompt_list)
# Step 3: Generate negative prompt
negative_prompt = generate_negative_prompt_from_scene(scene_plan)
print("🚫 Negative Prompt:", negative_prompt)
# Step 4: Resize image to SDXL resolution
image = image.resize((1024, 1024)).convert("RGB")
# Step 5: Generate outputs with SDXL only
outputs = []
for i, enriched_prompt in enumerate(prompt_list):
print(f"🎨 Generating variation {i+1}...")
result = pipe(
prompt=enriched_prompt,
negative_prompt=negative_prompt,
image=image,
strength=0.7, # ← You can fine-tune this
guidance_scale=7.5,
num_inference_steps=30,
)
outputs.append(result.images[0])
return outputs
except Exception as e:
print("❌ Generation Error:", e)
return [Image.new("RGB", (512, 512), color="red")]
# ----------------------------
# πŸ–ΌοΈ Gradio Interface
# ----------------------------
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Textbox(label="Prompt"),
gr.Image(type="pil", label="Product Image"),
gr.Slider(1, 5, value=3, step=1, label="Number of Variations")
],
outputs=gr.Gallery(label="Generated Images").style(grid=[2], height="auto"),
title="NewCrux Product Image Generator (SDXL Only)",
description="Upload a product image and enter a prompt. SDXL will generate enriched variations using AI."
)
if __name__ == "__main__":
demo.launch()