File size: 3,043 Bytes
fdd3761
 
8124057
 
37abfc2
ec84a8b
b574e01
e074d8e
 
 
b574e01
e074d8e
38d816c
37abfc2
 
 
b574e01
 
37abfc2
 
b574e01
37abfc2
b574e01
6e52c0e
b574e01
 
 
 
 
 
 
0a10f4f
37abfc2
 
 
a218b7f
d3a3bf1
b574e01
6e52c0e
b574e01
6e52c0e
b574e01
e074d8e
b574e01
d3a3bf1
b574e01
bb3cc4e
b574e01
129d6d6
b574e01
e074d8e
 
 
b574e01
0d5ecb1
d3a3bf1
b574e01
52b4e6d
129d6d6
b574e01
 
 
 
 
 
 
 
 
 
 
 
38d816c
d3a3bf1
b574e01
 
8124057
37abfc2
b574e01
37abfc2
b574e01
 
 
 
 
 
 
 
 
 
 
8124057
b574e01
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# app.py

import gradio as gr
from PIL import Image
import torch

from diffusers import StableDiffusionXLImg2ImgPipeline
from utils.planner import (
    extract_scene_plan,
    generate_prompt_variations_from_scene,
    generate_negative_prompt_from_scene
)

# ----------------------------
# πŸ”§ Device Setup
# ----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

# ----------------------------
# βœ… Load SDXL Only Pipeline
# ----------------------------
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=dtype,
    variant="fp16" if device == "cuda" else None,
    use_safetensors=True,
)
pipe.to(device)
pipe.enable_model_cpu_offload()
pipe.enable_attention_slicing()

# ----------------------------
# 🎨 Image Generation Function
# ----------------------------
def process_image(prompt, image, num_variations):
    try:
        print("🧠 User Prompt:", prompt)
        if image is None:
            raise ValueError("🚫 Uploaded image is missing.")

        # Step 1: Extract scene plan
        scene_plan = extract_scene_plan(prompt, image)
        print("πŸ“‹ Scene Plan:", scene_plan)

        # Step 2: Generate enriched prompts
        prompt_list = generate_prompt_variations_from_scene(scene_plan, prompt, num_variations)
        print("βœ… Enriched Prompts:", prompt_list)

        # Step 3: Generate negative prompt
        negative_prompt = generate_negative_prompt_from_scene(scene_plan)
        print("🚫 Negative Prompt:", negative_prompt)

        # Step 4: Resize image to SDXL resolution
        image = image.resize((1024, 1024)).convert("RGB")

        # Step 5: Generate outputs with SDXL only
        outputs = []
        for i, enriched_prompt in enumerate(prompt_list):
            print(f"🎨 Generating variation {i+1}...")
            result = pipe(
                prompt=enriched_prompt,
                negative_prompt=negative_prompt,
                image=image,
                strength=0.7,                # ← You can fine-tune this
                guidance_scale=7.5,
                num_inference_steps=30,
            )
            outputs.append(result.images[0])

        return outputs

    except Exception as e:
        print("❌ Generation Error:", e)
        return [Image.new("RGB", (512, 512), color="red")]

# ----------------------------
# πŸ–ΌοΈ Gradio Interface
# ----------------------------
demo = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Textbox(label="Prompt"),
        gr.Image(type="pil", label="Product Image"),
        gr.Slider(1, 5, value=3, step=1, label="Number of Variations")
    ],
    outputs=gr.Gallery(label="Generated Images").style(grid=[2], height="auto"),
    title="NewCrux Product Image Generator (SDXL Only)",
    description="Upload a product image and enter a prompt. SDXL will generate enriched variations using AI."
)

if __name__ == "__main__":
    demo.launch()