Manireddy1508 commited on
Commit
8cafce9
Β·
verified Β·
1 Parent(s): 34cc19a

Update utils/planner.py

Browse files
Files changed (1) hide show
  1. utils/planner.py +35 -4
utils/planner.py CHANGED
@@ -5,6 +5,7 @@ from openai import OpenAI
5
  from PIL import Image
6
  import torch
7
  from transformers import BlipProcessor, BlipForConditionalGeneration
 
8
 
9
  # ----------------------------
10
  # πŸ” Load Environment & GPT Client
@@ -18,6 +19,14 @@ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
20
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
 
 
 
 
 
 
 
 
21
 
22
  # ----------------------------
23
  # πŸ“Έ Generate Caption from Image
@@ -76,7 +85,7 @@ def extract_scene_plan(prompt: str, image: Image.Image) -> dict:
76
  }
77
 
78
  # ----------------------------
79
- # 🧠 Generate Positive Prompt Variations
80
  # ----------------------------
81
  def generate_prompt_variations_from_scene(scene_plan: dict, base_prompt: str, n: int = 3) -> list:
82
  try:
@@ -88,8 +97,10 @@ Each prompt should:
88
  - Include stylistic or contextual variation
89
  - Reference the same product and environment
90
  - Stay faithful to the base prompt and extracted plan
 
91
  Respond ONLY with a JSON array of strings. No explanations.
92
  """
 
93
  response = client.chat.completions.create(
94
  model="gpt-4o-mini-2024-07-18",
95
  messages=[
@@ -104,8 +115,17 @@ Respond ONLY with a JSON array of strings. No explanations.
104
  )
105
 
106
  content = response.choices[0].message.content
107
- print("🧠 Prompt Variations (Raw):", content)
108
- return json.loads(content)
 
 
 
 
 
 
 
 
 
109
 
110
  except Exception as e:
111
  print("❌ generate_prompt_variations_from_scene() Error:", e)
@@ -143,4 +163,15 @@ No explanations.
143
  return "deformed hands, extra limbs, text, watermark, signature"
144
 
145
 
146
-
 
 
 
 
 
 
 
 
 
 
 
 
5
  from PIL import Image
6
  import torch
7
  from transformers import BlipProcessor, BlipForConditionalGeneration
8
+ from transformers import CLIPTokenizer
9
 
10
  # ----------------------------
11
  # πŸ” Load Environment & GPT Client
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
21
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
22
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
23
+
24
+ # ----------------------------
25
+ # πŸ“ Log File
26
+ # ----------------------------
27
+
28
+ LOG_PATH = "logs/prompt_log.jsonl"
29
+ os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True)
30
 
31
  # ----------------------------
32
  # πŸ“Έ Generate Caption from Image
 
85
  }
86
 
87
  # ----------------------------
88
+ # 🧠 Generate Positive Prompt Variations (CLIP-safe)
89
  # ----------------------------
90
  def generate_prompt_variations_from_scene(scene_plan: dict, base_prompt: str, n: int = 3) -> list:
91
  try:
 
97
  - Include stylistic or contextual variation
98
  - Reference the same product and environment
99
  - Stay faithful to the base prompt and extracted plan
100
+ - Be under 77 tokens when tokenized using a CLIP tokenizer
101
  Respond ONLY with a JSON array of strings. No explanations.
102
  """
103
+
104
  response = client.chat.completions.create(
105
  model="gpt-4o-mini-2024-07-18",
106
  messages=[
 
115
  )
116
 
117
  content = response.choices[0].message.content
118
+ all_prompts = json.loads(content)
119
+
120
+ # Enforce token limit using CLIP tokenizer
121
+ filtered = []
122
+ for p in all_prompts:
123
+ tokens = clip_tokenizer(p)["input_ids"]
124
+ if len(tokens) <= 77:
125
+ filtered.append(p)
126
+ print("🧠 Filtered Prompts (<=77 tokens):", filtered)
127
+
128
+ return filtered or [base_prompt]
129
 
130
  except Exception as e:
131
  print("❌ generate_prompt_variations_from_scene() Error:", e)
 
163
  return "deformed hands, extra limbs, text, watermark, signature"
164
 
165
 
166
+ # ----------------------------
167
+ # πŸ“ Save Logs for Analysis
168
+ # ----------------------------
169
+ def save_generation_log(caption, scene_plan, prompts, negative_prompt):
170
+ log = {
171
+ "blip_caption": caption,
172
+ "scene_plan": scene_plan,
173
+ "enriched_prompts": prompts,
174
+ "negative_prompt": negative_prompt
175
+ }
176
+ with open(LOG_PATH, "a") as f:
177
+ f.write(json.dumps(log, indent=2) + "\n")