import diffusers import torch import random from tqdm import tqdm from constants import SUBJECTS, MEDIUMS from PIL import Image import math # For acos, sin # Slerp (Spherical Linear Interpolation) function def slerp(v0, v1, t, DOT_THRESHOLD=0.9995): """ Spherical linear interpolation. v0, v1: Tensors to interpolate between. t: Interpolation factor (scalar or tensor). DOT_THRESHOLD: Threshold for considering vectors collinear. """ if not isinstance(t, torch.Tensor): t = torch.tensor(t, device=v0.device, dtype=v0.dtype) # Dot product dot = torch.sum(v0 * v1 / (torch.norm(v0, dim=-1, keepdim=True) * torch.norm(v1, dim=-1, keepdim=True) + 1e-8), dim=-1, keepdim=True) # If vectors are too close, use linear interpolation (LERP) # This also handles t=0 and t=1 correctly if dot is 1. # Also, if dot is -1 (opposite), omega is pi. if torch.any(torch.abs(dot) > DOT_THRESHOLD): # For Slerp, if they are too close, omega is small, sin(omega) is small. # Fallback to LERP for stability and when vectors are nearly collinear. # However, the general Slerp formula handles this if dot is clamped. # Let's use the standard formula but ensure stability. pass # Continue to Slerp formula with clamping # Clamp dot to prevent NaN from acos due to floating point errors. dot = torch.clamp(dot, -1.0, 1.0) omega = torch.acos(dot) # Angle between vectors # Get magnitudes for later linear interpolation of magnitude mag_v0 = torch.norm(v0, dim=-1, keepdim=True) mag_v1 = torch.norm(v1, dim=-1, keepdim=True) interpolated_mag = (1 - t) * mag_v0 + t * mag_v1 # Normalize v0 and v1 for pure Slerp on direction v0_norm = v0 / (mag_v0 + 1e-8) v1_norm = v1 / (mag_v1 + 1e-8) # If sin_omega is very small, vectors are nearly collinear. # LERP on normalized vectors is a good approximation. # Then re-apply interpolated magnitude. sin_omega = torch.sin(omega) # Condition for LERP fallback (nearly collinear) # Using a small epsilon for sin_omega use_lerp_fallback = sin_omega.abs() < 1e-5 s0 = torch.sin((1 - t) * omega) / (sin_omega + 1e-8) # Add epsilon to sin_omega for stability s1 = torch.sin(t * omega) / (sin_omega + 1e-8) # Add epsilon to sin_omega for stability # For elements where LERP fallback is needed s0[use_lerp_fallback] = 1.0 - t s1[use_lerp_fallback] = t result_norm = s0 * v0_norm + s1 * v1_norm result = result_norm * interpolated_mag # Re-apply interpolated magnitude return result.to(v0.dtype) class CLIPSlider: def __init__( self, sd_pipe, device: torch.device, target_word: str = "", opposite: str = "", target_word_2nd: str = "", opposite_2nd: str = "", iterations: int = 300, ): self.device = device self.pipe = sd_pipe.to(self.device, torch.float16) self.iterations = iterations if target_word != "" or opposite != "": self.avg_diff = self.find_latent_direction(target_word, opposite) else: self.avg_diff = None if target_word_2nd != "" or opposite_2nd != "": self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd) else: self.avg_diff_2nd = None def find_latent_direction(self, target_word:str, opposite:str): # lets identify a latent direction by taking differences between opposites # target_word = "happy" # opposite = "sad" with torch.no_grad(): positives = [] negatives = [] for i in tqdm(range(self.iterations)): medium = random.choice(MEDIUMS) subject = random.choice(SUBJECTS) pos_prompt = f"a {medium} of a {target_word} {subject}" neg_prompt = f"a {medium} of a {opposite} {subject}" pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) pos = self.pipe.text_encoder(pos_toks).pooler_output neg = self.pipe.text_encoder(neg_toks).pooler_output positives.append(pos) negatives.append(neg) positives = torch.cat(positives, dim=0) negatives = torch.cat(negatives, dim=0) diffs = positives - negatives avg_diff = diffs.mean(0, keepdim=True) return avg_diff def generate(self, prompt = "a photo of a house", scale = 2., scale_2nd = 0., # scale for the 2nd dim directions when avg_diff_2nd is not None seed = 15, only_pooler = False, normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None correlation_weight_factor = 1.0, **pipeline_kwargs ): # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true # if pooler token only [-4,4] work well with torch.no_grad(): toks = self.pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state if self.avg_diff_2nd and normalize_scales: denominator = abs(scale) + abs(scale_2nd) scale = scale / denominator scale_2nd = scale_2nd / denominator if only_pooler: prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale if self.avg_diff_2nd: prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd else: normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768) standard_weights = torch.ones_like(weights) weights = standard_weights + (weights - standard_weights) * correlation_weight_factor # weights = torch.sigmoid((weights-0.5)*7) prompt_embeds = prompt_embeds + ( weights * self.avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale) if self.avg_diff_2nd: prompt_embeds += weights * self.avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd torch.manual_seed(seed) images = self.pipe(prompt_embeds=prompt_embeds, **pipeline_kwargs).images return images def spectrum(self, prompt="a photo of a house", low_scale=-2, low_scale_2nd=-2, high_scale=2, high_scale_2nd=2, steps=5, seed=15, only_pooler=False, normalize_scales=False, correlation_weight_factor=1.0, **pipeline_kwargs ): images = [] for i in range(steps): scale = low_scale + (high_scale - low_scale) * i / (steps - 1) scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1) image = self.generate(prompt, scale, scale_2nd, seed, only_pooler, normalize_scales, correlation_weight_factor, **pipeline_kwargs) images.append(image[0]) canvas = Image.new('RGB', (640 * steps, 640)) for i, im in enumerate(images): canvas.paste(im, (640 * i, 0)) return canvas class CLIPSliderXL(CLIPSlider): def find_latent_direction(self, target_word:str, opposite:str): # lets identify a latent direction by taking differences between opposites # target_word = "happy" # opposite = "sad" with torch.no_grad(): positives = [] negatives = [] positives2 = [] negatives2 = [] for i in tqdm(range(self.iterations)): medium = random.choice(MEDIUMS) subject = random.choice(SUBJECTS) pos_prompt = f"a {medium} of a {target_word} {subject}" neg_prompt = f"a {medium} of a {opposite} {subject}" pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) pos = self.pipe.text_encoder(pos_toks).pooler_output neg = self.pipe.text_encoder(neg_toks).pooler_output positives.append(pos) negatives.append(neg) pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds positives2.append(pos2) negatives2.append(neg2) positives = torch.cat(positives, dim=0) negatives = torch.cat(negatives, dim=0) diffs = positives - negatives avg_diff = diffs.mean(0, keepdim=True) positives2 = torch.cat(positives2, dim=0) negatives2 = torch.cat(negatives2, dim=0) diffs2 = positives2 - negatives2 avg_diff2 = diffs2.mean(0, keepdim=True) return (avg_diff, avg_diff2) def generate(self, prompt = "a photo of a house", scale = 2, scale_2nd = 2, seed = 15, only_pooler = False, normalize_scales = False, correlation_weight_factor = 1.0, **pipeline_kwargs ): # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true # if pooler token only [-4,4] work well text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2] tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2] with torch.no_grad(): # toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.to(self.device) # prompt_embeds = pipe.text_encoder(toks).last_hidden_state prompt_embeds_list = [] for i, text_encoder in enumerate(text_encoders): tokenizer = tokenizers[i] text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) toks = text_inputs.input_ids prompt_embeds = text_encoder( toks.to(text_encoder.device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] if self.avg_diff_2nd and normalize_scales: denominator = abs(scale) + abs(scale_2nd) scale = scale / denominator scale_2nd = scale_2nd / denominator if only_pooler: prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale if self.avg_diff_2nd: prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd[0] * scale_2nd else: normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T if i == 0: weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768) standard_weights = torch.ones_like(weights) weights = standard_weights + (weights - standard_weights) * correlation_weight_factor prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale) if self.avg_diff_2nd: prompt_embeds += (weights * self.avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd) else: weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280) standard_weights = torch.ones_like(weights) weights = standard_weights + (weights - standard_weights) * correlation_weight_factor prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale) if self.avg_diff_2nd: prompt_embeds += (weights * self.avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd) bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) torch.manual_seed(seed) images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, **pipeline_kwargs).images return images class CLIPSliderXL_inv(CLIPSlider): def find_latent_direction(self, target_word:str, opposite:str): # lets identify a latent direction by taking differences between opposites # target_word = "happy" # opposite = "sad" with torch.no_grad(): positives = [] negatives = [] positives2 = [] negatives2 = [] for i in tqdm(range(self.iterations)): medium = random.choice(MEDIUMS) subject = random.choice(SUBJECTS) pos_prompt = f"a {medium} of a {target_word} {subject}" neg_prompt = f"a {medium} of a {opposite} {subject}" pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) pos = self.pipe.text_encoder(pos_toks).pooler_output neg = self.pipe.text_encoder(neg_toks).pooler_output positives.append(pos) negatives.append(neg) pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds positives2.append(pos2) negatives2.append(neg2) positives = torch.cat(positives, dim=0) negatives = torch.cat(negatives, dim=0) diffs = positives - negatives avg_diff = diffs.mean(0, keepdim=True) positives2 = torch.cat(positives2, dim=0) negatives2 = torch.cat(negatives2, dim=0) diffs2 = positives2 - negatives2 avg_diff2 = diffs2.mean(0, keepdim=True) return (avg_diff, avg_diff2) def generate(self, prompt = "a photo of a house", scale = 2, scale_2nd = 2, seed = 15, only_pooler = False, normalize_scales = False, correlation_weight_factor = 1.0, **pipeline_kwargs ): with torch.no_grad(): torch.manual_seed(seed) images = self.pipe(editing_prompt=prompt, avg_diff=self.avg_diff, avg_diff_2nd=self.avg_diff_2nd, scale=scale, scale_2nd=scale_2nd, **pipeline_kwargs).images return images class CLIPSliderFlux(CLIPSlider): def find_latent_direction(self, target_word:str, opposite:str, num_iterations: int = None): # lets identify a latent direction by taking differences between opposites # target_word = "happy" # opposite = "sad" if num_iterations is not None: iterations = num_iterations else: iterations = self.iterations with torch.no_grad(): positives = [] negatives = [] for i in tqdm(range(iterations)): medium = random.choice(MEDIUMS) subject = random.choice(SUBJECTS) pos_prompt = f"a {medium} of a {target_word} {subject}" neg_prompt = f"a {medium} of a {opposite} {subject}" pos_toks = self.pipe.tokenizer(pos_prompt, padding="max_length", max_length=self.pipe.tokenizer_max_length, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt",).input_ids.to(self.device) neg_toks = self.pipe.tokenizer(neg_prompt, padding="max_length", max_length=self.pipe.tokenizer_max_length, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt",).input_ids.to(self.device) pos = self.pipe.text_encoder(pos_toks).pooler_output neg = self.pipe.text_encoder(neg_toks).pooler_output positives.append(pos) negatives.append(neg) positives = torch.cat(positives, dim=0) negatives = torch.cat(negatives, dim=0) diffs = positives - negatives avg_diff = diffs.mean(0, keepdim=True) return avg_diff def generate(self, prompt = "a photo of a house", scale = 2.0, seed = 15, normalize_scales = False, avg_diff = None, avg_diff_2nd = None, use_slerp: bool = False, max_strength_for_slerp_endpoint: float = 0.0, **pipeline_kwargs ): # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true # if pooler token only [-4,4] work well # Remove slider-specific kwargs before passing to the pipeline pipeline_kwargs.pop('use_slerp', None) pipeline_kwargs.pop('max_strength_for_slerp_endpoint', None) with torch.no_grad(): text_inputs = self.pipe.tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds_out = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False) original_pooled_prompt_embeds = prompt_embeds_out.pooler_output.to(dtype=self.pipe.text_encoder.dtype, device=self.device) # For the second text encoder (T5-like for FLUX) text_inputs_2 = self.pipe.tokenizer_2( prompt, padding="max_length", max_length=512, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) toks_2 = text_inputs_2.input_ids # This is the non-pooled, sequence output for the second encoder prompt_embeds_seq_2 = self.pipe.text_encoder_2(toks_2.to(self.device), output_hidden_states=False)[0] prompt_embeds_seq_2 = prompt_embeds_seq_2.to(dtype=self.pipe.text_encoder_2.dtype, device=self.device) modified_pooled_embeds = original_pooled_prompt_embeds.clone() if avg_diff is not None: if use_slerp and max_strength_for_slerp_endpoint != 0.0: # Slerp logic slerp_t_val = 0.0 if max_strength_for_slerp_endpoint != 0: slerp_t_val = abs(scale) / max_strength_for_slerp_endpoint slerp_t_val = min(slerp_t_val, 1.0) if scale == 0: pass else: v0 = original_pooled_prompt_embeds.float() if scale > 0: v_end_target = original_pooled_prompt_embeds + max_strength_for_slerp_endpoint * avg_diff else: v_end_target = original_pooled_prompt_embeds - max_strength_for_slerp_endpoint * avg_diff modified_pooled_embeds = slerp(v0, v_end_target.float(), slerp_t_val).to(original_pooled_prompt_embeds.dtype) else: modified_pooled_embeds = modified_pooled_embeds + avg_diff * scale if avg_diff_2nd is not None: scale_2nd_val = pipeline_kwargs.get("scale_2nd", 0.0) modified_pooled_embeds += avg_diff_2nd * scale_2nd_val torch.manual_seed(seed) images = self.pipe(prompt_embeds=prompt_embeds_seq_2, pooled_prompt_embeds=modified_pooled_embeds, **pipeline_kwargs).images return images[0] def spectrum(self, prompt="a photo of a house", low_scale=-2, low_scale_2nd=-2, high_scale=2, high_scale_2nd=2, steps=5, seed=15, normalize_scales=False, **pipeline_kwargs ): images = [] for i in range(steps): scale = low_scale + (high_scale - low_scale) * i / (steps - 1) scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1) image = self.generate(prompt, scale, scale_2nd, seed, normalize_scales, **pipeline_kwargs) images.append(image[0].resize((512,512))) canvas = Image.new('RGB', (640 * steps, 640)) for i, im in enumerate(images): canvas.paste(im, (640 * i, 0)) return canvas class T5SliderFlux(CLIPSlider): def find_latent_direction(self, target_word:str, opposite:str): # lets identify a latent direction by taking differences between opposites # target_word = "happy" # opposite = "sad" with torch.no_grad(): positives = [] negatives = [] for i in tqdm(range(self.iterations)): medium = random.choice(MEDIUMS) subject = random.choice(SUBJECTS) pos_prompt = f"a {medium} of a {target_word} {subject}" neg_prompt = f"a {medium} of a {opposite} {subject}" pos_toks = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, return_length=False, return_overflowing_tokens=False, max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) neg_toks = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, return_length=False, return_overflowing_tokens=False, max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) pos = self.pipe.text_encoder_2(pos_toks, output_hidden_states=False)[0] neg = self.pipe.text_encoder_2(neg_toks, output_hidden_states=False)[0] positives.append(pos) negatives.append(neg) positives = torch.cat(positives, dim=0) negatives = torch.cat(negatives, dim=0) diffs = positives - negatives avg_diff = diffs.mean(0, keepdim=True) return avg_diff def generate(self, prompt = "a photo of a house", scale = 2, scale_2nd = 2, seed = 15, only_pooler = False, normalize_scales = False, correlation_weight_factor = 1.0, **pipeline_kwargs ): # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true # if pooler token only [-4,4] work well with torch.no_grad(): text_inputs = self.pipe.tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output pooled_prompt_embeds = prompt_embeds.to(dtype=self.pipe.text_encoder.dtype, device=self.device) # Use pooled output of CLIPTextModel text_inputs = self.pipe.tokenizer_2( prompt, padding="max_length", max_length=512, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) toks = text_inputs.input_ids prompt_embeds = self.pipe.text_encoder_2(toks.to(self.device), output_hidden_states=False)[0] dtype = self.pipe.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device) if self.avg_diff_2nd and normalize_scales: denominator = abs(scale) + abs(scale_2nd) scale = scale / denominator scale_2nd = scale_2nd / denominator if only_pooler: prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale if self.avg_diff_2nd: prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd else: normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, prompt_embeds.shape[2]) standard_weights = torch.ones_like(weights) weights = standard_weights + (weights - standard_weights) * correlation_weight_factor prompt_embeds = prompt_embeds + ( weights * self.avg_diff * scale) if self.avg_diff_2nd: prompt_embeds += ( weights * self.avg_diff_2nd * scale_2nd) torch.manual_seed(seed) images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, **pipeline_kwargs).images return images