Spaces:
Runtime error
Runtime error
| from typing import Dict, List, TypedDict | |
| import numpy as np | |
| import torch | |
| import math | |
| from ..Misc import Logger as log | |
| from .BaseProc import CrossAttnProcessorBase | |
| from .BaseProc import BundleType | |
| from ..Misc.BBox import BoundingBox | |
| class InjecterProcessor(CrossAttnProcessorBase): | |
| def __init__( | |
| self, | |
| bundle: BundleType, | |
| bbox_per_frame: List[BoundingBox], | |
| name: str, | |
| strengthen_scale: float = 0.0, | |
| weaken_scale: float = 1.0, | |
| is_text2vidzero: bool = False, | |
| ): | |
| super().__init__(bundle, is_text2vidzero=is_text2vidzero) | |
| self.strengthen_scale = strengthen_scale | |
| self.weaken_scale = weaken_scale | |
| self.bundle = bundle | |
| self.num_frames = len(bbox_per_frame) | |
| self.bbox_per_frame = bbox_per_frame | |
| self.use_weaken = True | |
| self.name = name | |
| def dd_core(self, attention_probs: torch.Tensor): | |
| """ """ | |
| frame_size = attention_probs.shape[0] // self.num_frames | |
| num_affected_frames = self.num_frames | |
| attention_probs_copied = attention_probs.detach().clone() | |
| token_inds = self.bundle.get("token_inds") | |
| trailing_length = self.bundle.get("trailing_length") | |
| trailing_inds = list( | |
| range(self.len_prompt + 1, self.len_prompt + trailing_length + 1) | |
| ) | |
| # NOTE: Spatial cross attention editing | |
| if len(attention_probs.size()) == 4: | |
| all_tokens_inds = list(set(token_inds).union(set(trailing_inds))) | |
| strengthen_map = self.localized_weight_map( | |
| attention_probs_copied, | |
| token_inds=all_tokens_inds, | |
| bbox_per_frame=self.bbox_per_frame, | |
| ) | |
| weaken_map = torch.ones_like(strengthen_map) | |
| zero_indices = torch.where(strengthen_map == 0) | |
| weaken_map[zero_indices] = self.weaken_scale | |
| # weakening | |
| attention_probs_copied[..., all_tokens_inds] *= weaken_map[ | |
| ..., all_tokens_inds | |
| ] | |
| # strengthen | |
| attention_probs_copied[..., all_tokens_inds] += ( | |
| self.strengthen_scale * strengthen_map[..., all_tokens_inds] | |
| ) | |
| # NOTE: Temporal cross attention editing | |
| elif len(attention_probs.size()) == 5: | |
| strengthen_map = self.localized_temporal_weight_map( | |
| attention_probs_copied, | |
| bbox_per_frame=self.bbox_per_frame, | |
| ) | |
| weaken_map = torch.ones_like(strengthen_map) | |
| zero_indices = torch.where(strengthen_map == 0) | |
| weaken_map[zero_indices] = self.weaken_scale | |
| # weakening | |
| attention_probs_copied *= weaken_map | |
| # strengthen | |
| attention_probs_copied += self.strengthen_scale * strengthen_map | |
| return attention_probs_copied | |