Spaces:
Runtime error
Runtime error
fix bugs
Browse files- apg_guidance.py +5 -6
- pipeline_ace_step.py +0 -1
apg_guidance.py
CHANGED
@@ -17,15 +17,14 @@ def project(
|
|
17 |
dims=[-1, -2],
|
18 |
):
|
19 |
dtype = v0.dtype
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
v0, v1 = v0.double(), v1.double()
|
25 |
v1 = torch.nn.functional.normalize(v1, dim=dims)
|
26 |
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
|
27 |
v0_orthogonal = v0 - v0_parallel
|
28 |
-
return v0_parallel.to(dtype)
|
29 |
|
30 |
|
31 |
def apg_forward(
|
|
|
17 |
dims=[-1, -2],
|
18 |
):
|
19 |
dtype = v0.dtype
|
20 |
+
if v0.device.type == "mps":
|
21 |
+
v0, v1 = v0.float(), v1.float()
|
22 |
+
else:
|
23 |
+
v0, v1 = v0.double(), v1.double()
|
|
|
24 |
v1 = torch.nn.functional.normalize(v1, dim=dims)
|
25 |
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
|
26 |
v0_orthogonal = v0 - v0_parallel
|
27 |
+
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
|
28 |
|
29 |
|
30 |
def apg_forward(
|
pipeline_ace_step.py
CHANGED
@@ -955,7 +955,6 @@ class ACEStepPipeline:
|
|
955 |
latents, _ = self.music_dcae.encode(input_audio, sr=sr)
|
956 |
return latents
|
957 |
|
958 |
-
@spaces.GPU
|
959 |
def __call__(
|
960 |
self,
|
961 |
audio_duration: float = 60.0,
|
|
|
955 |
latents, _ = self.music_dcae.encode(input_audio, sr=sr)
|
956 |
return latents
|
957 |
|
|
|
958 |
def __call__(
|
959 |
self,
|
960 |
audio_duration: float = 60.0,
|