Spaces:
Running
on
Zero
Running
on
Zero
Update clip_slider_pipeline.py
Browse files- clip_slider_pipeline.py +80 -8
clip_slider_pipeline.py
CHANGED
|
@@ -210,8 +210,6 @@ class CLIPSliderXL(CLIPSlider):
|
|
| 210 |
correlation_weight_factor = 1.0,
|
| 211 |
avg_diff = None,
|
| 212 |
avg_diff_2nd = None,
|
| 213 |
-
init_latents = None, # inversion
|
| 214 |
-
zs = None, # inversion
|
| 215 |
**pipeline_kwargs
|
| 216 |
):
|
| 217 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
@@ -289,14 +287,88 @@ class CLIPSliderXL(CLIPSlider):
|
|
| 289 |
print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
|
| 290 |
torch.manual_seed(seed)
|
| 291 |
start_time = time.time()
|
| 292 |
-
|
| 293 |
-
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
| 294 |
-
avg_diff=avg_diff, avg_diff_2=avg_diff2, scale=scale,
|
| 295 |
-
**pipeline_kwargs).images[0]
|
| 296 |
-
else:
|
| 297 |
-
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
| 298 |
**pipeline_kwargs).images[0]
|
| 299 |
end_time = time.time()
|
| 300 |
print(f"generation time - pipe: {end_time - start_time:.2f} ms")
|
| 301 |
|
| 302 |
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
correlation_weight_factor = 1.0,
|
| 211 |
avg_diff = None,
|
| 212 |
avg_diff_2nd = None,
|
|
|
|
|
|
|
| 213 |
**pipeline_kwargs
|
| 214 |
):
|
| 215 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
|
|
| 287 |
print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
|
| 288 |
torch.manual_seed(seed)
|
| 289 |
start_time = time.time()
|
| 290 |
+
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
**pipeline_kwargs).images[0]
|
| 292 |
end_time = time.time()
|
| 293 |
print(f"generation time - pipe: {end_time - start_time:.2f} ms")
|
| 294 |
|
| 295 |
return image
|
| 296 |
+
|
| 297 |
+
class CLIPSliderXL_inv(CLIPSlider):
|
| 298 |
+
|
| 299 |
+
def find_latent_direction(self,
|
| 300 |
+
target_word:str,
|
| 301 |
+
opposite:str,
|
| 302 |
+
num_iterations: int = None):
|
| 303 |
+
|
| 304 |
+
# lets identify a latent direction by taking differences between opposites
|
| 305 |
+
# target_word = "happy"
|
| 306 |
+
# opposite = "sad"
|
| 307 |
+
if num_iterations is not None:
|
| 308 |
+
iterations = num_iterations
|
| 309 |
+
else:
|
| 310 |
+
iterations = self.iterations
|
| 311 |
+
|
| 312 |
+
with torch.no_grad():
|
| 313 |
+
positives = []
|
| 314 |
+
negatives = []
|
| 315 |
+
positives2 = []
|
| 316 |
+
negatives2 = []
|
| 317 |
+
for i in tqdm(range(iterations)):
|
| 318 |
+
medium = random.choice(MEDIUMS)
|
| 319 |
+
subject = random.choice(SUBJECTS)
|
| 320 |
+
pos_prompt = f"a {medium} of a {target_word} {subject}"
|
| 321 |
+
neg_prompt = f"a {medium} of a {opposite} {subject}"
|
| 322 |
+
|
| 323 |
+
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
| 324 |
+
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
| 325 |
+
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
| 326 |
+
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
| 327 |
+
pos = self.pipe.text_encoder(pos_toks).pooler_output
|
| 328 |
+
neg = self.pipe.text_encoder(neg_toks).pooler_output
|
| 329 |
+
positives.append(pos)
|
| 330 |
+
negatives.append(neg)
|
| 331 |
+
|
| 332 |
+
pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
| 333 |
+
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
|
| 334 |
+
neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
| 335 |
+
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
|
| 336 |
+
pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
|
| 337 |
+
neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
|
| 338 |
+
positives2.append(pos2)
|
| 339 |
+
negatives2.append(neg2)
|
| 340 |
+
|
| 341 |
+
positives = torch.cat(positives, dim=0)
|
| 342 |
+
negatives = torch.cat(negatives, dim=0)
|
| 343 |
+
diffs = positives - negatives
|
| 344 |
+
avg_diff = diffs.mean(0, keepdim=True)
|
| 345 |
+
|
| 346 |
+
positives2 = torch.cat(positives2, dim=0)
|
| 347 |
+
negatives2 = torch.cat(negatives2, dim=0)
|
| 348 |
+
diffs2 = positives2 - negatives2
|
| 349 |
+
avg_diff2 = diffs2.mean(0, keepdim=True)
|
| 350 |
+
return (avg_diff, avg_diff2)
|
| 351 |
+
|
| 352 |
+
def generate(self,
|
| 353 |
+
prompt = "a photo of a house",
|
| 354 |
+
scale = 2,
|
| 355 |
+
scale_2nd = 2,
|
| 356 |
+
seed = 15,
|
| 357 |
+
only_pooler = False,
|
| 358 |
+
normalize_scales = False,
|
| 359 |
+
correlation_weight_factor = 1.0,
|
| 360 |
+
avg_diff=None,
|
| 361 |
+
avg_diff_2nd=None,
|
| 362 |
+
init_latents=None,
|
| 363 |
+
zs=None,
|
| 364 |
+
**pipeline_kwargs
|
| 365 |
+
):
|
| 366 |
+
|
| 367 |
+
with torch.no_grad():
|
| 368 |
+
torch.manual_seed(seed)
|
| 369 |
+
images = self.pipe(editing_prompt=prompt, init_latents=init_latents, zs=zs,
|
| 370 |
+
avg_diff=avg_diff[0], avg_diff_2=avg_diff[1],
|
| 371 |
+
scale=scale,
|
| 372 |
+
**pipeline_kwargs).images
|
| 373 |
+
|
| 374 |
+
return images
|