Manireddy1508 commited on
Commit
acc4043
Β·
verified Β·
1 Parent(s): b574e01

Update utils/planner.py

Browse files
Files changed (1) hide show
  1. utils/planner.py +54 -86
utils/planner.py CHANGED
@@ -1,35 +1,35 @@
1
- # utils/planner.py
2
-
3
  import os
4
  import json
5
  from dotenv import load_dotenv
6
  from openai import OpenAI
7
  from PIL import Image
8
  import torch
9
- from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPTokenizer
 
 
 
 
10
 
11
  # ----------------------------
12
- # πŸ” Load Environment & GPT Client
13
  # ----------------------------
14
  load_dotenv()
15
  client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 
16
 
17
  # ----------------------------
18
- # 🧠 Load BLIP & CLIP Tokenizer
19
  # ----------------------------
20
- device = "cuda" if torch.cuda.is_available() else "cpu"
21
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
22
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
23
- clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
24
 
25
  # ----------------------------
26
- # πŸ“ Log Path
27
  # ----------------------------
28
- LOG_PATH = "logs/prompt_log.jsonl"
29
- os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True)
30
 
31
  # ----------------------------
32
- # πŸ“Έ Generate Caption from Image
33
  # ----------------------------
34
  def generate_blip_caption(image: Image.Image) -> str:
35
  try:
@@ -43,7 +43,7 @@ def generate_blip_caption(image: Image.Image) -> str:
43
  return "a product image"
44
 
45
  # ----------------------------
46
- # 🧠 Extract Scene Plan from GPT
47
  # ----------------------------
48
  SCENE_SYSTEM_INSTRUCTIONS = """
49
  You are a scene planning assistant for an AI image generation system.
@@ -70,106 +70,74 @@ def extract_scene_plan(prompt: str, image: Image.Image) -> dict:
70
  temperature=0.3,
71
  max_tokens=500
72
  )
73
- json_output = response.choices[0].message.content
74
- print("🧠 Scene Plan (Raw):", json_output)
75
- return json.loads(json_output)
 
 
 
 
 
 
76
 
77
  except Exception as e:
78
  print("❌ extract_scene_plan() Error:", e)
79
  return {
80
- "scene": None,
81
- "subject": None,
82
  "objects": [],
83
  "layout": {},
84
  "rules": {}
85
  }
86
 
87
  # ----------------------------
88
- # 🧠 Generate Positive Prompt Variations (CLIP-safe)
89
  # ----------------------------
90
  def generate_prompt_variations_from_scene(scene_plan: dict, base_prompt: str, n: int = 3) -> list:
91
- try:
92
- system_msg = f"""
93
- You are a creative prompt variation generator for an AI image generation system.
94
- Given a base user prompt and its structured scene plan, generate {n} diverse image generation prompts.
95
- Each prompt must:
96
- - Be visually rich and descriptive
97
- - Include stylistic or contextual variation
98
- - Reference the same product and environment
99
- - Stay faithful to the base prompt and extracted plan
100
- - Be under 77 tokens when tokenized using a CLIP tokenizer
101
- Respond ONLY with a JSON array of strings. No explanations.
102
- """
103
-
104
- response = client.chat.completions.create(
105
- model="gpt-4o-mini-2024-07-18",
106
- messages=[
107
- {"role": "system", "content": system_msg},
108
- {"role": "user", "content": json.dumps({
109
- "base_prompt": base_prompt,
110
- "scene_plan": scene_plan
111
- })}
112
- ],
113
- temperature=0.7,
114
- max_tokens=600
115
  )
116
 
117
- content = response.choices[0].message.content
118
- all_prompts = json.loads(content)
119
-
120
- filtered = []
121
- for p in all_prompts:
122
- token_count = len(clip_tokenizer(p)["input_ids"])
123
- if token_count <= 77:
124
- filtered.append(p)
125
 
126
- print("🧠 Filtered Prompts (<=77 tokens):", filtered)
127
- return filtered or [base_prompt]
128
 
129
- except Exception as e:
130
- print("❌ generate_prompt_variations_from_scene() Error:", e)
131
- return [base_prompt]
132
 
133
  # ----------------------------
134
- # 🧠 Generate Negative Prompt Automatically
135
  # ----------------------------
136
- def generate_negative_prompt_from_scene(scene_plan: dict) -> str:
137
- try:
138
- system_msg = """
139
- You are an assistant that generates negative prompts for an image generation model.
140
- Based on the structured scene plan, return a list of things that should NOT appear in the image,
141
- such as incorrect objects, extra limbs, distorted hands, text, watermark, etc.
142
- Return a single negative prompt string (comma-separated values). No explanations.
143
  """
144
 
 
 
145
  response = client.chat.completions.create(
146
  model="gpt-4o-mini-2024-07-18",
147
  messages=[
148
- {"role": "system", "content": system_msg},
149
  {"role": "user", "content": json.dumps(scene_plan)}
150
  ],
151
- temperature=0.3,
152
- max_tokens=150
153
  )
154
-
155
- negative_prompt = response.choices[0].message.content.strip()
156
- print("🚫 Negative Prompt (GPT):", negative_prompt)
157
- return negative_prompt
158
-
159
  except Exception as e:
160
- print("❌ generate_negative_prompt_from_scene() Error:", e)
161
- return "deformed hands, extra limbs, text, watermark, signature"
162
-
163
- # ----------------------------
164
- # πŸ“ Save Logs
165
- # ----------------------------
166
- def save_generation_log(caption, scene_plan, prompts, negative_prompt):
167
- log = {
168
- "blip_caption": caption,
169
- "scene_plan": scene_plan,
170
- "enriched_prompts": prompts,
171
- "negative_prompt": negative_prompt
172
- }
173
- with open(LOG_PATH, "a") as f:
174
- f.write(json.dumps(log, indent=2) + "\n")
175
 
 
 
 
1
  import os
2
  import json
3
  from dotenv import load_dotenv
4
  from openai import OpenAI
5
  from PIL import Image
6
  import torch
7
+ from transformers import (
8
+ BlipProcessor,
9
+ BlipForConditionalGeneration,
10
+ CLIPTokenizer
11
+ )
12
 
13
  # ----------------------------
14
+ # πŸ” Load API Keys & Setup
15
  # ----------------------------
16
  load_dotenv()
17
  client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  # ----------------------------
21
+ # πŸ“Έ Load BLIP Captioning Model
22
  # ----------------------------
 
23
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
24
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
 
25
 
26
  # ----------------------------
27
+ # 🧠 Load CLIP Tokenizer (for token limit check)
28
  # ----------------------------
29
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
 
30
 
31
  # ----------------------------
32
+ # πŸ“Έ Generate Caption from Product Image
33
  # ----------------------------
34
  def generate_blip_caption(image: Image.Image) -> str:
35
  try:
 
43
  return "a product image"
44
 
45
  # ----------------------------
46
+ # 🧠 GPT Scene Planning
47
  # ----------------------------
48
  SCENE_SYSTEM_INSTRUCTIONS = """
49
  You are a scene planning assistant for an AI image generation system.
 
70
  temperature=0.3,
71
  max_tokens=500
72
  )
73
+ content = response.choices[0].message.content
74
+ print("🧠 Scene Plan (Raw):", content)
75
+
76
+ # Optional logging
77
+ os.makedirs("logs", exist_ok=True)
78
+ with open("logs/scene_plans.jsonl", "a") as f:
79
+ f.write(json.dumps({"caption": caption, "prompt": prompt, "scene_plan": content}) + "\n")
80
+
81
+ return json.loads(content)
82
 
83
  except Exception as e:
84
  print("❌ extract_scene_plan() Error:", e)
85
  return {
86
+ "scene": "studio",
87
+ "subject": "product",
88
  "objects": [],
89
  "layout": {},
90
  "rules": {}
91
  }
92
 
93
  # ----------------------------
94
+ # ✨ Generate Prompt Variations
95
  # ----------------------------
96
  def generate_prompt_variations_from_scene(scene_plan: dict, base_prompt: str, n: int = 3) -> list:
97
+ variations = []
98
+
99
+ for i in range(n):
100
+ enriched_prompt = (
101
+ f"{scene_plan.get('subject', 'a product')} "
102
+ f"in a {scene_plan.get('scene', 'studio setting')} "
103
+ f"with {', '.join(scene_plan.get('objects', []))} "
104
+ f"and layout details like {scene_plan.get('layout', {})}. "
105
+ f"{scene_plan.get('rules', '')}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  )
107
 
108
+ # Enforce 77-token limit for SDXL
109
+ tokens = tokenizer(enriched_prompt)["input_ids"]
110
+ if len(tokens) > 77:
111
+ enriched_prompt = tokenizer.decode(tokens[:77], skip_special_tokens=True)
 
 
 
 
112
 
113
+ variations.append(enriched_prompt.strip())
 
114
 
115
+ return variations
 
 
116
 
117
  # ----------------------------
118
+ # ❌ Generate Negative Prompt
119
  # ----------------------------
120
+ NEGATIVE_SYSTEM_PROMPT = """
121
+ You are a prompt engineer. Given a structured scene plan, generate a short negative prompt
122
+ to suppress unwanted visual elements such as: distortion, blurriness, poor anatomy,
123
+ logo errors, background noise, or low realism.
124
+ Return a single comma-separated list. No intro text.
 
 
125
  """
126
 
127
+ def generate_negative_prompt_from_scene(scene_plan: dict) -> str:
128
+ try:
129
  response = client.chat.completions.create(
130
  model="gpt-4o-mini-2024-07-18",
131
  messages=[
132
+ {"role": "system", "content": NEGATIVE_SYSTEM_PROMPT},
133
  {"role": "user", "content": json.dumps(scene_plan)}
134
  ],
135
+ temperature=0.2,
136
+ max_tokens=100
137
  )
138
+ negative = response.choices[0].message.content.strip()
139
+ return negative
 
 
 
140
  except Exception as e:
141
+ print("❌ Negative Prompt Error:", e)
142
+ return "blurry, distorted, low quality, deformed, watermark"
 
 
 
 
 
 
 
 
 
 
 
 
 
143