imagetoimage / app.py
Manireddy1508's picture
Update app.py
6e52c0e verified
raw
history blame
4.89 kB
# app.py
import gradio as gr
from PIL import Image
import os
import torch
import numpy as np
import cv2
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel
from utils.planner import extract_scene_plan, generate_prompt_variations_from_scene # 🧠 Brain Layer
# ----------------------------
# πŸ”§ Device Setup
# ----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# ----------------------------
# πŸ“¦ Load ControlNet + SDXL Model
# ----------------------------
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-canny-sdxl-1.0",
torch_dtype=dtype
)
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
torch_dtype=dtype,
variant="fp16" if dtype == torch.float16 else None
).to(device)
if device == "cuda":
pipe.enable_xformers_memory_efficient_attention()
else:
pipe.enable_model_cpu_offload()
# ----------------------------
# πŸ–Ό Canny Edge Generator
# ----------------------------
def generate_canny_map(image: Image.Image) -> Image.Image:
print("πŸ” Generating Canny map...")
if image is None:
raise ValueError("🚫 No image passed to Canny generator")
image = image.resize((1024, 1024)).convert("RGB")
np_image = np.array(image)
gray = cv2.cvtColor(np_image, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 100, 200)
if edges is None:
raise ValueError("🚫 OpenCV Canny failed to produce edge map")
return Image.fromarray(edges).convert("RGB")
# ----------------------------
# 🎨 Image Generation Function
# ----------------------------
def process_image(prompt, image, num_variations):
try:
print("🧠 Prompt received:", prompt)
if image is None:
raise ValueError("🚫 Uploaded image is missing or invalid.")
# Step 1: Brain Layer (Scene Plan)
scene_plan = extract_scene_plan(prompt)
print("🧠 Scene plan extracted:", scene_plan)
# Step 2: Generate prompt variations from GPT
prompt_list = generate_prompt_variations_from_scene(scene_plan, prompt, num_variations)
print("🧠 Enriched Prompts:")
for i, p in enumerate(prompt_list):
print(f" {i+1}: {p}")
# Step 3: Resize image + generate canny map (1024x1024 for SDXL)
image = image.resize((1024, 1024)).convert("RGB")
canny_map = generate_canny_map(image)
if canny_map is None:
raise ValueError("🚫 Canny map generation failed.")
# Step 4: Generate images
outputs = []
for i, enriched_prompt in enumerate(prompt_list):
print(f"🎨 Generating image {i+1} with enriched prompt")
try:
result = pipe(
prompt=enriched_prompt,
image=image,
controlnet_conditioning_image=canny_map,
num_inference_steps=40,
strength=0.5,
guidance_scale=7.5
)
if result is None or not hasattr(result, "images") or not result.images:
raise ValueError("⚠️ No image returned from pipeline")
outputs.append(result.images[0])
except Exception as inner:
print(f"❌ Failed to generate image {i+1}:", inner)
outputs.append(Image.new("RGB", (512, 512), color="red"))
return outputs, scene_plan, canny_map
except Exception as e:
print("❌ Generation failed:", e)
return ["❌ Error during generation"], {"error": str(e)}, None
# ----------------------------
# πŸ–Ό Gradio UI
# ----------------------------
with gr.Blocks() as demo:
gr.Markdown("## 🧠 NewCrux AI β€” SDXL + Canny Inference\nUpload a product image, enter a prompt, and generate stylized scenes while preserving structure.")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt")
image_input = gr.Image(type="pil", label="Upload Product Image")
variation_slider = gr.Slider(1, 4, step=1, value=1, label="Number of Variations")
generate_btn = gr.Button("Generate")
with gr.Column():
output_gallery = gr.Gallery(
label="Generated Variations",
columns=2,
rows=2,
height="auto"
)
json_output = gr.JSON(label="🧠 Brain Layer Reasoning")
canny_preview = gr.Image(label="πŸ” Canny Edge Preview")
generate_btn.click(
fn=process_image,
inputs=[prompt_input, image_input, variation_slider],
outputs=[output_gallery, json_output, canny_preview]
)
demo.launch()