Manireddy1508 commited on
Commit
433cace
·
verified ·
1 Parent(s): ec84a8b

Update utils/planner.py

Browse files
Files changed (1) hide show
  1. utils/planner.py +33 -8
utils/planner.py CHANGED
@@ -1,22 +1,44 @@
1
- # utils/planner.py
2
-
3
  import os
4
  import json
5
  from dotenv import load_dotenv
6
  from openai import OpenAI
 
 
 
7
 
8
  # ----------------------------
9
- # 🔐 Load Environment & Client
10
  # ----------------------------
11
  load_dotenv()
12
  client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
13
 
14
  # ----------------------------
15
- # 🧠 Scene Plan Extractor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # ----------------------------
17
  SCENE_SYSTEM_INSTRUCTIONS = """
18
  You are a scene planning assistant for an AI image generation system.
19
- Your job is to take the user's prompt and return a structured JSON with:
20
  - scene (environment, setting)
21
  - subject (main actor)
22
  - objects (main product or items)
@@ -25,13 +47,16 @@ Your job is to take the user's prompt and return a structured JSON with:
25
  Respond ONLY in raw JSON format. Do NOT include explanations.
26
  """
27
 
28
- def extract_scene_plan(prompt: str) -> dict:
29
  try:
 
 
 
30
  response = client.chat.completions.create(
31
  model="gpt-4o-mini-2024-07-18",
32
  messages=[
33
  {"role": "system", "content": SCENE_SYSTEM_INSTRUCTIONS},
34
- {"role": "user", "content": prompt}
35
  ],
36
  temperature=0.3,
37
  max_tokens=500
@@ -51,7 +76,7 @@ def extract_scene_plan(prompt: str) -> dict:
51
  }
52
 
53
  # ----------------------------
54
- # 🧠 Prompt Variation Generator
55
  # ----------------------------
56
  def generate_prompt_variations_from_scene(scene_plan: dict, base_prompt: str, n: int = 3) -> list:
57
  try:
 
 
 
1
  import os
2
  import json
3
  from dotenv import load_dotenv
4
  from openai import OpenAI
5
+ from PIL import Image
6
+ import torch
7
+ from transformers import BlipProcessor, BlipForConditionalGeneration
8
 
9
  # ----------------------------
10
+ # 🔐 Load Environment & GPT Client
11
  # ----------------------------
12
  load_dotenv()
13
  client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
14
 
15
  # ----------------------------
16
+ # 🧠 Load BLIP Captioning Model (once globally)
17
+ # ----------------------------
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
20
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
21
+
22
+ # ----------------------------
23
+ # 📸 Generate Caption from Uploaded Product Image
24
+ # ----------------------------
25
+ def generate_blip_caption(image: Image.Image) -> str:
26
+ try:
27
+ inputs = processor(images=image, return_tensors="pt").to(device)
28
+ out = blip_model.generate(**inputs, max_length=50)
29
+ caption = processor.decode(out[0], skip_special_tokens=True)
30
+ print(f"🖼️ BLIP Caption: {caption}")
31
+ return caption
32
+ except Exception as e:
33
+ print("❌ BLIP Captioning Error:", e)
34
+ return "a product image"
35
+
36
+ # ----------------------------
37
+ # 🧠 Scene Plan Extractor (GPT-4o)
38
  # ----------------------------
39
  SCENE_SYSTEM_INSTRUCTIONS = """
40
  You are a scene planning assistant for an AI image generation system.
41
+ Your job is to take a caption from a product image and a user prompt, then return a structured JSON with:
42
  - scene (environment, setting)
43
  - subject (main actor)
44
  - objects (main product or items)
 
47
  Respond ONLY in raw JSON format. Do NOT include explanations.
48
  """
49
 
50
+ def extract_scene_plan(prompt: str, image: Image.Image) -> dict:
51
  try:
52
+ caption = generate_blip_caption(image)
53
+ merged_prompt = f"Image Caption: {caption}\nUser Prompt: {prompt}"
54
+
55
  response = client.chat.completions.create(
56
  model="gpt-4o-mini-2024-07-18",
57
  messages=[
58
  {"role": "system", "content": SCENE_SYSTEM_INSTRUCTIONS},
59
+ {"role": "user", "content": merged_prompt}
60
  ],
61
  temperature=0.3,
62
  max_tokens=500
 
76
  }
77
 
78
  # ----------------------------
79
+ # 🧠 Prompt Variation Generator (GPT-4o)
80
  # ----------------------------
81
  def generate_prompt_variations_from_scene(scene_plan: dict, base_prompt: str, n: int = 3) -> list:
82
  try: