Spaces:
Paused
Paused
# app.py | |
import gradio as gr | |
from PIL import Image | |
import os | |
import torch | |
import numpy as np | |
import cv2 | |
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel | |
from utils.planner import extract_scene_plan, generate_prompt_variations_from_scene # π§ Brain Layer | |
# ---------------------------- | |
# π§ Device Setup | |
# ---------------------------- | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
# ---------------------------- | |
# π¦ Load ControlNet + SDXL Model | |
# ---------------------------- | |
controlnet = ControlNetModel.from_pretrained( | |
"diffusers/controlnet-canny-sdxl-1.0", | |
torch_dtype=dtype | |
) | |
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
controlnet=controlnet, | |
torch_dtype=dtype, | |
variant="fp16" if dtype == torch.float16 else None | |
).to(device) | |
if device == "cuda": | |
pipe.enable_xformers_memory_efficient_attention() | |
else: | |
pipe.enable_model_cpu_offload() | |
# ---------------------------- | |
# πΌ Canny Edge Generator | |
# ---------------------------- | |
def generate_canny_map(image: Image.Image) -> Image.Image: | |
print("π Generating Canny map...") | |
if image is None: | |
raise ValueError("π« No image passed to Canny generator") | |
image = image.resize((1024, 1024)).convert("RGB") | |
np_image = np.array(image) | |
gray = cv2.cvtColor(np_image, cv2.COLOR_RGB2GRAY) | |
edges = cv2.Canny(gray, 100, 200) | |
if edges is None: | |
raise ValueError("π« OpenCV Canny failed to produce edge map") | |
return Image.fromarray(edges).convert("RGB") | |
# ---------------------------- | |
# π¨ Image Generation Function | |
# ---------------------------- | |
def process_image(prompt, image, num_variations): | |
try: | |
print("π§ Prompt received:", prompt) | |
if image is None: | |
raise ValueError("π« Uploaded image is missing or invalid.") | |
# Step 1: Brain Layer (Scene Plan) | |
scene_plan = extract_scene_plan(prompt) | |
print("π§ Scene plan extracted:", scene_plan) | |
# Step 2: Generate prompt variations from GPT | |
prompt_list = generate_prompt_variations_from_scene(scene_plan, prompt, num_variations) | |
print("π§ Enriched Prompts:") | |
for i, p in enumerate(prompt_list): | |
print(f" {i+1}: {p}") | |
# Step 3: Resize image + generate canny map (1024x1024 for SDXL) | |
image = image.resize((1024, 1024)).convert("RGB") | |
canny_map = generate_canny_map(image) | |
if canny_map is None: | |
raise ValueError("π« Canny map generation failed.") | |
# Step 4: Generate images | |
outputs = [] | |
for i, enriched_prompt in enumerate(prompt_list): | |
print(f"π¨ Generating image {i+1} with enriched prompt") | |
try: | |
result = pipe( | |
prompt=enriched_prompt, | |
image=image, | |
controlnet_conditioning_image=canny_map, | |
num_inference_steps=40, | |
strength=0.5, | |
guidance_scale=7.5 | |
) | |
if result is None or not hasattr(result, "images") or not result.images: | |
raise ValueError("β οΈ No image returned from pipeline") | |
outputs.append(result.images[0]) | |
except Exception as inner: | |
print(f"β Failed to generate image {i+1}:", inner) | |
outputs.append(Image.new("RGB", (512, 512), color="red")) | |
return outputs, scene_plan, canny_map | |
except Exception as e: | |
print("β Generation failed:", e) | |
return ["β Error during generation"], {"error": str(e)}, None | |
# ---------------------------- | |
# πΌ Gradio UI | |
# ---------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("## π§ NewCrux AI β SDXL + Canny Inference\nUpload a product image, enter a prompt, and generate stylized scenes while preserving structure.") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Prompt") | |
image_input = gr.Image(type="pil", label="Upload Product Image") | |
variation_slider = gr.Slider(1, 4, step=1, value=1, label="Number of Variations") | |
generate_btn = gr.Button("Generate") | |
with gr.Column(): | |
output_gallery = gr.Gallery( | |
label="Generated Variations", | |
columns=2, | |
rows=2, | |
height="auto" | |
) | |
json_output = gr.JSON(label="π§ Brain Layer Reasoning") | |
canny_preview = gr.Image(label="π Canny Edge Preview") | |
generate_btn.click( | |
fn=process_image, | |
inputs=[prompt_input, image_input, variation_slider], | |
outputs=[output_gallery, json_output, canny_preview] | |
) | |
demo.launch() | |