alexnasa commited on
Commit
bb076a6
·
verified ·
1 Parent(s): 65bed02

Update pipelines/pipeline_seesr.py

Browse files
Files changed (1) hide show
  1. pipelines/pipeline_seesr.py +27 -1
pipelines/pipeline_seesr.py CHANGED
@@ -95,7 +95,31 @@ EXAMPLE_DOC_STRING = """
95
  ... ).images[0]
96
  ```
97
  """
98
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
101
  r"""
@@ -807,6 +831,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
807
  ram_encoder_hidden_states=None,
808
  latent_tiled_size=320,
809
  latent_tiled_overlap=4,
 
 
810
  use_KDS=True,
811
  args=None,
812
  ):
 
95
  ... ).images[0]
96
  ```
97
  """
98
+ def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
99
+ # x0: (N, C, H, W) in float32
100
+ N, C, H, W = x0.shape
101
+ patches = unfold(
102
+ x0, kernel_size=patch_size, stride=patch_size//2
103
+ ) # (N, C*ps*ps, M)
104
+ P, M = patches.shape[1], patches.shape[2]
105
+ p_i = patches.unsqueeze(1) # (N,1,P,M)
106
+ p_j = patches.unsqueeze(0) # (1,N,P,M)
107
+ diff = p_j - p_i # (N,N,P,M)
108
+ # Gaussian weights
109
+ w = torch.exp((-0.5 / bandwidth**2) *
110
+ (diff.square().sum(dim=2))) # (N,N,M)
111
+ # mean-shift numerator & normalizer
112
+ num = (w.unsqueeze(2) * diff).sum(dim=1) # (N,P,M)
113
+ denom = w.sum(dim=1, keepdim=True) + 1e-8 # (N,1,M)
114
+ mshift = num / denom # (N,P,M)
115
+ # fold back
116
+ grad = fold(
117
+ mshift / bandwidth**2,
118
+ output_size=(H, W),
119
+ kernel_size=patch_size,
120
+ stride=patch_size//2
121
+ ) # (N, C, H, W)
122
+ return grad
123
 
124
  class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
125
  r"""
 
831
  ram_encoder_hidden_states=None,
832
  latent_tiled_size=320,
833
  latent_tiled_overlap=4,
834
+ num_particles: Optional[int] = 4,
835
+ gamma_0: Optional[float] = 0.1, # base steering strength
836
  use_KDS=True,
837
  args=None,
838
  ):