imagetoimage / utils /planner.py
Manireddy1508's picture
Update utils/planner.py
9362fe6 verified
import os
import json
from dotenv import load_dotenv
from openai import OpenAI
from PIL import Image
import torch
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
CLIPTokenizer
)
# ----------------------------
# πŸ” Load API Keys & Setup
# ----------------------------
load_dotenv()
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
device = "cuda" if torch.cuda.is_available() else "cpu"
# ----------------------------
# πŸ“Έ Load BLIP Captioning Model
# ----------------------------
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
# ----------------------------
# 🧠 Load CLIP Tokenizer (for token check)
# ----------------------------
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# ----------------------------
# πŸ“Έ Generate Caption from Product Image
# ----------------------------
def generate_blip_caption(image: Image.Image) -> str:
try:
inputs = processor(images=image, return_tensors="pt").to(device)
out = blip_model.generate(**inputs, max_length=50)
caption = processor.decode(out[0], skip_special_tokens=True)
# Clean duplicate tokens
caption = " ".join(dict.fromkeys(caption.split()))
print(f"πŸ–ΌοΈ BLIP Caption: {caption}")
return caption
except Exception as e:
print("❌ BLIP Captioning Error:", e)
return "a product image"
# ----------------------------
# 🧠 GPT Scene Planning with Caption + Visual Style
# ----------------------------
SCENE_SYSTEM_INSTRUCTIONS = """
You are a scene planning assistant for an AI image generation system.
Your job is to take a caption from a product image, a visual style hint, and a user prompt, then return a structured JSON with:
- scene (environment, setting)
- subject (main_actor)
- objects (main_product or items)
- layout (foreground/background elements and their placement)
- rules (validation rules to ensure visual correctness)
Respond ONLY in raw JSON format. Do NOT include explanations.
"""
def extract_scene_plan(prompt: str, image: Image.Image) -> dict:
try:
caption = generate_blip_caption(image)
visual_hint = caption if "shoe" in caption or "product" in caption else "low-top product photo on white background"
merged_prompt = (
f"Image Caption: {caption}\n"
f"Image Visual Style: {visual_hint}\n"
f"User Prompt: {prompt}"
)
response = client.chat.completions.create(
model="gpt-4o-mini-2024-07-18",
messages=[
{"role": "system", "content": SCENE_SYSTEM_INSTRUCTIONS},
{"role": "user", "content": merged_prompt}
],
temperature=0.3,
max_tokens=500
)
content = response.choices[0].message.content
print("🧠 Scene Plan (Raw):", content)
# Logging
os.makedirs("logs", exist_ok=True)
with open("logs/scene_plans.jsonl", "a") as f:
f.write(json.dumps({
"caption": caption,
"visual_hint": visual_hint,
"prompt": prompt,
"scene_plan": content
}) + "\n")
return json.loads(content)
except Exception as e:
print("❌ extract_scene_plan() Error:", e)
return {
"scene": {"environment": "studio", "setting": "plain white background"},
"subject": {"main_actor": "a product"},
"objects": {"main_product": "product"},
"layout": {},
"rules": {}
}
# ----------------------------
# ✨ Enriched Prompt Generation (GPT, 77-token safe)
# ----------------------------
ENRICHED_PROMPT_INSTRUCTIONS = """
You are a prompt engineer for an AI image generation model.
Given a structured scene plan and a user prompt, generate a single natural-language enriched prompt that:
1. Describes the subject, product, setting, and layout clearly
2. Uses natural, photo-realistic language
3. Stays strictly under 77 tokens (CLIP token limit)
Return ONLY the enriched prompt string. No explanations.
"""
def generate_prompt_variations_from_scene(scene_plan: dict, base_prompt: str, n: int = 3) -> list:
prompts = []
for _ in range(n):
try:
user_input = f"Scene Plan:\n{json.dumps(scene_plan)}\n\nUser Prompt:\n{base_prompt}"
response = client.chat.completions.create(
model="gpt-4o-mini-2024-07-18",
messages=[
{"role": "system", "content": ENRICHED_PROMPT_INSTRUCTIONS},
{"role": "user", "content": user_input}
],
temperature=0.4,
max_tokens=100
)
enriched = response.choices[0].message.content.strip()
token_count = len(tokenizer(enriched)["input_ids"])
print(f"πŸ“ Enriched Prompt ({token_count} tokens): {enriched}")
prompts.append(enriched)
except Exception as e:
print("⚠️ Prompt fallback:", e)
prompts.append(base_prompt)
return prompts
# ----------------------------
# ❌ Negative Prompt Generator
# ----------------------------
NEGATIVE_SYSTEM_PROMPT = """
You are a prompt engineer. Given a structured scene plan, generate a short negative prompt
to suppress unwanted visual elements such as: distortion, blurriness, poor anatomy,
logo errors, background noise, or low realism.
Return a single comma-separated list. No intro text.
"""
def generate_negative_prompt_from_scene(scene_plan: dict) -> str:
try:
response = client.chat.completions.create(
model="gpt-4o-mini-2024-07-18",
messages=[
{"role": "system", "content": NEGATIVE_SYSTEM_PROMPT},
{"role": "user", "content": json.dumps(scene_plan)}
],
temperature=0.2,
max_tokens=100
)
return response.choices[0].message.content.strip()
except Exception as e:
print("❌ Negative Prompt Error:", e)
return "blurry, distorted, low quality, deformed, watermark"