Spaces:
Paused
Paused
Update utils/planner.py
Browse files- utils/planner.py +33 -8
utils/planner.py
CHANGED
@@ -1,22 +1,44 @@
|
|
1 |
-
# utils/planner.py
|
2 |
-
|
3 |
import os
|
4 |
import json
|
5 |
from dotenv import load_dotenv
|
6 |
from openai import OpenAI
|
|
|
|
|
|
|
7 |
|
8 |
# ----------------------------
|
9 |
-
# 🔐 Load Environment & Client
|
10 |
# ----------------------------
|
11 |
load_dotenv()
|
12 |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
13 |
|
14 |
# ----------------------------
|
15 |
-
# 🧠
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
# ----------------------------
|
17 |
SCENE_SYSTEM_INSTRUCTIONS = """
|
18 |
You are a scene planning assistant for an AI image generation system.
|
19 |
-
Your job is to take
|
20 |
- scene (environment, setting)
|
21 |
- subject (main actor)
|
22 |
- objects (main product or items)
|
@@ -25,13 +47,16 @@ Your job is to take the user's prompt and return a structured JSON with:
|
|
25 |
Respond ONLY in raw JSON format. Do NOT include explanations.
|
26 |
"""
|
27 |
|
28 |
-
def extract_scene_plan(prompt: str) -> dict:
|
29 |
try:
|
|
|
|
|
|
|
30 |
response = client.chat.completions.create(
|
31 |
model="gpt-4o-mini-2024-07-18",
|
32 |
messages=[
|
33 |
{"role": "system", "content": SCENE_SYSTEM_INSTRUCTIONS},
|
34 |
-
{"role": "user", "content":
|
35 |
],
|
36 |
temperature=0.3,
|
37 |
max_tokens=500
|
@@ -51,7 +76,7 @@ def extract_scene_plan(prompt: str) -> dict:
|
|
51 |
}
|
52 |
|
53 |
# ----------------------------
|
54 |
-
# 🧠 Prompt Variation Generator
|
55 |
# ----------------------------
|
56 |
def generate_prompt_variations_from_scene(scene_plan: dict, base_prompt: str, n: int = 3) -> list:
|
57 |
try:
|
|
|
|
|
|
|
1 |
import os
|
2 |
import json
|
3 |
from dotenv import load_dotenv
|
4 |
from openai import OpenAI
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
8 |
|
9 |
# ----------------------------
|
10 |
+
# 🔐 Load Environment & GPT Client
|
11 |
# ----------------------------
|
12 |
load_dotenv()
|
13 |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
14 |
|
15 |
# ----------------------------
|
16 |
+
# 🧠 Load BLIP Captioning Model (once globally)
|
17 |
+
# ----------------------------
|
18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
20 |
+
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
|
21 |
+
|
22 |
+
# ----------------------------
|
23 |
+
# 📸 Generate Caption from Uploaded Product Image
|
24 |
+
# ----------------------------
|
25 |
+
def generate_blip_caption(image: Image.Image) -> str:
|
26 |
+
try:
|
27 |
+
inputs = processor(images=image, return_tensors="pt").to(device)
|
28 |
+
out = blip_model.generate(**inputs, max_length=50)
|
29 |
+
caption = processor.decode(out[0], skip_special_tokens=True)
|
30 |
+
print(f"🖼️ BLIP Caption: {caption}")
|
31 |
+
return caption
|
32 |
+
except Exception as e:
|
33 |
+
print("❌ BLIP Captioning Error:", e)
|
34 |
+
return "a product image"
|
35 |
+
|
36 |
+
# ----------------------------
|
37 |
+
# 🧠 Scene Plan Extractor (GPT-4o)
|
38 |
# ----------------------------
|
39 |
SCENE_SYSTEM_INSTRUCTIONS = """
|
40 |
You are a scene planning assistant for an AI image generation system.
|
41 |
+
Your job is to take a caption from a product image and a user prompt, then return a structured JSON with:
|
42 |
- scene (environment, setting)
|
43 |
- subject (main actor)
|
44 |
- objects (main product or items)
|
|
|
47 |
Respond ONLY in raw JSON format. Do NOT include explanations.
|
48 |
"""
|
49 |
|
50 |
+
def extract_scene_plan(prompt: str, image: Image.Image) -> dict:
|
51 |
try:
|
52 |
+
caption = generate_blip_caption(image)
|
53 |
+
merged_prompt = f"Image Caption: {caption}\nUser Prompt: {prompt}"
|
54 |
+
|
55 |
response = client.chat.completions.create(
|
56 |
model="gpt-4o-mini-2024-07-18",
|
57 |
messages=[
|
58 |
{"role": "system", "content": SCENE_SYSTEM_INSTRUCTIONS},
|
59 |
+
{"role": "user", "content": merged_prompt}
|
60 |
],
|
61 |
temperature=0.3,
|
62 |
max_tokens=500
|
|
|
76 |
}
|
77 |
|
78 |
# ----------------------------
|
79 |
+
# 🧠 Prompt Variation Generator (GPT-4o)
|
80 |
# ----------------------------
|
81 |
def generate_prompt_variations_from_scene(scene_plan: dict, base_prompt: str, n: int = 3) -> list:
|
82 |
try:
|