Spaces:
Running
Running
fix bugs
Browse files- apg_guidance.py +5 -6
- pipeline_ace_step.py +0 -1
apg_guidance.py
CHANGED
|
@@ -17,15 +17,14 @@ def project(
|
|
| 17 |
dims=[-1, -2],
|
| 18 |
):
|
| 19 |
dtype = v0.dtype
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
v0, v1 = v0.double(), v1.double()
|
| 25 |
v1 = torch.nn.functional.normalize(v1, dim=dims)
|
| 26 |
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
|
| 27 |
v0_orthogonal = v0 - v0_parallel
|
| 28 |
-
return v0_parallel.to(dtype)
|
| 29 |
|
| 30 |
|
| 31 |
def apg_forward(
|
|
|
|
| 17 |
dims=[-1, -2],
|
| 18 |
):
|
| 19 |
dtype = v0.dtype
|
| 20 |
+
if v0.device.type == "mps":
|
| 21 |
+
v0, v1 = v0.float(), v1.float()
|
| 22 |
+
else:
|
| 23 |
+
v0, v1 = v0.double(), v1.double()
|
|
|
|
| 24 |
v1 = torch.nn.functional.normalize(v1, dim=dims)
|
| 25 |
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
|
| 26 |
v0_orthogonal = v0 - v0_parallel
|
| 27 |
+
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
|
| 28 |
|
| 29 |
|
| 30 |
def apg_forward(
|
pipeline_ace_step.py
CHANGED
|
@@ -955,7 +955,6 @@ class ACEStepPipeline:
|
|
| 955 |
latents, _ = self.music_dcae.encode(input_audio, sr=sr)
|
| 956 |
return latents
|
| 957 |
|
| 958 |
-
@spaces.GPU
|
| 959 |
def __call__(
|
| 960 |
self,
|
| 961 |
audio_duration: float = 60.0,
|
|
|
|
| 955 |
latents, _ = self.music_dcae.encode(input_audio, sr=sr)
|
| 956 |
return latents
|
| 957 |
|
|
|
|
| 958 |
def __call__(
|
| 959 |
self,
|
| 960 |
audio_duration: float = 60.0,
|