alexnasa commited on
Commit
3e4a17e
·
verified ·
1 Parent(s): 897276f

Update pipelines/pipeline_seesr.py

Browse files
Files changed (1) hide show
  1. pipelines/pipeline_seesr.py +8 -8
pipelines/pipeline_seesr.py CHANGED
@@ -99,9 +99,7 @@ EXAMPLE_DOC_STRING = """
99
  def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
100
  # x0: (N, C, H, W) in float32
101
  N, C, H, W = x0.shape
102
- patches = unfold(
103
- x0, kernel_size=patch_size, stride=patch_size//2
104
- ) # (N, C*ps*ps, M)
105
  P, M = patches.shape[1], patches.shape[2]
106
  p_i = patches.unsqueeze(1) # (N,1,P,M)
107
  p_j = patches.unsqueeze(0) # (1,N,P,M)
@@ -113,13 +111,15 @@ def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
113
  num = (w.unsqueeze(2) * diff).sum(dim=1) # (N,P,M)
114
  denom = w.sum(dim=1, keepdim=True) + 1e-8 # (N,1,M)
115
  mshift = num / denom # (N,P,M)
 
116
  # fold back
117
  grad = fold(
118
- mshift / bandwidth**2,
119
- output_size=(H, W),
120
- kernel_size=patch_size,
121
- stride=patch_size//2
122
- ) # (N, C, H, W)
 
123
  return grad
124
 
125
  class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
 
99
  def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
100
  # x0: (N, C, H, W) in float32
101
  N, C, H, W = x0.shape
102
+ patches = unfold(x0, kernel_size=patch_size, stride=patch_size) # (N, C*ps*ps, M)
 
 
103
  P, M = patches.shape[1], patches.shape[2]
104
  p_i = patches.unsqueeze(1) # (N,1,P,M)
105
  p_j = patches.unsqueeze(0) # (1,N,P,M)
 
111
  num = (w.unsqueeze(2) * diff).sum(dim=1) # (N,P,M)
112
  denom = w.sum(dim=1, keepdim=True) + 1e-8 # (N,1,M)
113
  mshift = num / denom # (N,P,M)
114
+
115
  # fold back
116
  grad = fold(
117
+ mshift / bandwidth**2,
118
+ output_size=(H, W),
119
+ kernel_size=patch_size,
120
+ stride=patch_size
121
+ ) # (N, C, H, W)
122
+
123
  return grad
124
 
125
  class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):