File size: 1,419 Bytes
e0336bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 16 17:23:28 2025
CFGZero* implementation for Blissful Tuner extension based on https://github.com/WeichenFan/CFG-Zero-star/blob/main/models/wan/wan_pipeline.py
License: Apache 2.0
@author: blyss
"""
import torch


def apply_zerostar(cond: torch.Tensor, uncond: torch.Tensor, current_step: int, guidance_scale: float, use_scaling: bool = True, zero_init_steps: int = -1) -> torch.Tensor:

    if (current_step <= zero_init_steps):
        return cond * 0
    if not use_scaling:
        # CFG formula
        noise_pred = uncond + guidance_scale * (cond - uncond)
    else:
        batch_size = cond.shape[0]
        positive_flat = cond.view(batch_size, -1)
        negative_flat = uncond.view(batch_size, -1)
        alpha = optimized_scale(positive_flat, negative_flat)
        alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1)))
        alpha = alpha.to(cond.dtype)
        # CFG formula modified with alpha
        noise_pred = uncond * alpha + guidance_scale * (cond - uncond * alpha)
    return noise_pred


def optimized_scale(positive_flat, negative_flat):

    dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
    squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8

    # st_star = v_cond^T * v_uncond / ||v_uncond||^2
    st_star = dot_product / squared_norm

    return st_star