Spaces:
Paused
Paused
import os | |
import json | |
from dotenv import load_dotenv | |
from openai import OpenAI | |
from PIL import Image | |
import torch | |
from transformers import ( | |
BlipProcessor, | |
BlipForConditionalGeneration, | |
CLIPTokenizer | |
) | |
# ---------------------------- | |
# π Load API Keys & Setup | |
# ---------------------------- | |
load_dotenv() | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# ---------------------------- | |
# πΈ Load BLIP Captioning Model | |
# ---------------------------- | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device) | |
# ---------------------------- | |
# π§ Load CLIP Tokenizer (for token check) | |
# ---------------------------- | |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
# ---------------------------- | |
# πΈ Generate Caption from Product Image | |
# ---------------------------- | |
def generate_blip_caption(image: Image.Image) -> str: | |
try: | |
inputs = processor(images=image, return_tensors="pt").to(device) | |
out = blip_model.generate(**inputs, max_length=50) | |
caption = processor.decode(out[0], skip_special_tokens=True) | |
# Clean duplicate tokens | |
caption = " ".join(dict.fromkeys(caption.split())) | |
print(f"πΌοΈ BLIP Caption: {caption}") | |
return caption | |
except Exception as e: | |
print("β BLIP Captioning Error:", e) | |
return "a product image" | |
# ---------------------------- | |
# π§ GPT Scene Planning with Caption + Visual Style | |
# ---------------------------- | |
SCENE_SYSTEM_INSTRUCTIONS = """ | |
You are a scene planning assistant for an AI image generation system. | |
Your job is to take a caption from a product image, a visual style hint, and a user prompt, then return a structured JSON with: | |
- scene (environment, setting) | |
- subject (main_actor) | |
- objects (main_product or items) | |
- layout (foreground/background elements and their placement) | |
- rules (validation rules to ensure visual correctness) | |
Respond ONLY in raw JSON format. Do NOT include explanations. | |
""" | |
def extract_scene_plan(prompt: str, image: Image.Image) -> dict: | |
try: | |
caption = generate_blip_caption(image) | |
visual_hint = caption if "shoe" in caption or "product" in caption else "low-top product photo on white background" | |
merged_prompt = ( | |
f"Image Caption: {caption}\n" | |
f"Image Visual Style: {visual_hint}\n" | |
f"User Prompt: {prompt}" | |
) | |
response = client.chat.completions.create( | |
model="gpt-4o-mini-2024-07-18", | |
messages=[ | |
{"role": "system", "content": SCENE_SYSTEM_INSTRUCTIONS}, | |
{"role": "user", "content": merged_prompt} | |
], | |
temperature=0.3, | |
max_tokens=500 | |
) | |
content = response.choices[0].message.content | |
print("π§ Scene Plan (Raw):", content) | |
# Logging | |
os.makedirs("logs", exist_ok=True) | |
with open("logs/scene_plans.jsonl", "a") as f: | |
f.write(json.dumps({ | |
"caption": caption, | |
"visual_hint": visual_hint, | |
"prompt": prompt, | |
"scene_plan": content | |
}) + "\n") | |
return json.loads(content) | |
except Exception as e: | |
print("β extract_scene_plan() Error:", e) | |
return { | |
"scene": {"environment": "studio", "setting": "plain white background"}, | |
"subject": {"main_actor": "a product"}, | |
"objects": {"main_product": "product"}, | |
"layout": {}, | |
"rules": {} | |
} | |
# ---------------------------- | |
# β¨ Enriched Prompt Generation (GPT, 77-token safe) | |
# ---------------------------- | |
ENRICHED_PROMPT_INSTRUCTIONS = """ | |
You are a prompt engineer for an AI image generation model. | |
Given a structured scene plan and a user prompt, generate a single natural-language enriched prompt that: | |
1. Describes the subject, product, setting, and layout clearly | |
2. Uses natural, photo-realistic language | |
3. Stays strictly under 77 tokens (CLIP token limit) | |
Return ONLY the enriched prompt string. No explanations. | |
""" | |
def generate_prompt_variations_from_scene(scene_plan: dict, base_prompt: str, n: int = 3) -> list: | |
prompts = [] | |
for _ in range(n): | |
try: | |
user_input = f"Scene Plan:\n{json.dumps(scene_plan)}\n\nUser Prompt:\n{base_prompt}" | |
response = client.chat.completions.create( | |
model="gpt-4o-mini-2024-07-18", | |
messages=[ | |
{"role": "system", "content": ENRICHED_PROMPT_INSTRUCTIONS}, | |
{"role": "user", "content": user_input} | |
], | |
temperature=0.4, | |
max_tokens=100 | |
) | |
enriched = response.choices[0].message.content.strip() | |
token_count = len(tokenizer(enriched)["input_ids"]) | |
print(f"π Enriched Prompt ({token_count} tokens): {enriched}") | |
prompts.append(enriched) | |
except Exception as e: | |
print("β οΈ Prompt fallback:", e) | |
prompts.append(base_prompt) | |
return prompts | |
# ---------------------------- | |
# β Negative Prompt Generator | |
# ---------------------------- | |
NEGATIVE_SYSTEM_PROMPT = """ | |
You are a prompt engineer. Given a structured scene plan, generate a short negative prompt | |
to suppress unwanted visual elements such as: distortion, blurriness, poor anatomy, | |
logo errors, background noise, or low realism. | |
Return a single comma-separated list. No intro text. | |
""" | |
def generate_negative_prompt_from_scene(scene_plan: dict) -> str: | |
try: | |
response = client.chat.completions.create( | |
model="gpt-4o-mini-2024-07-18", | |
messages=[ | |
{"role": "system", "content": NEGATIVE_SYSTEM_PROMPT}, | |
{"role": "user", "content": json.dumps(scene_plan)} | |
], | |
temperature=0.2, | |
max_tokens=100 | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
print("β Negative Prompt Error:", e) | |
return "blurry, distorted, low quality, deformed, watermark" | |