Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
016b505
1
Parent(s):
14c6e42
Added GPU decorators
Browse files- app.py +6 -0
- aria/aria.py +5 -15
- requirements.txt +8 -7
app.py
CHANGED
@@ -10,6 +10,11 @@ import pretty_midi
|
|
10 |
import librosa
|
11 |
import soundfile as sf
|
12 |
from midi2audio import FluidSynth
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
from aria.image_encoder import ImageEncoder
|
15 |
from aria.aria import ARIA
|
@@ -163,6 +168,7 @@ def convert_midi_to_wav(midi_path):
|
|
163 |
print(f"Error converting MIDI to WAV: {str(e)}")
|
164 |
return None
|
165 |
|
|
|
166 |
def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_instruments):
|
167 |
"""Generate music from input image"""
|
168 |
model = get_model(conditioning_type)
|
|
|
10 |
import librosa
|
11 |
import soundfile as sf
|
12 |
from midi2audio import FluidSynth
|
13 |
+
import spaces
|
14 |
+
|
15 |
+
# Remove CPU forcing since we'll use ZeroGPU
|
16 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
17 |
+
# torch.set_num_threads(4)
|
18 |
|
19 |
from aria.image_encoder import ImageEncoder
|
20 |
from aria.aria import ARIA
|
|
|
168 |
print(f"Error converting MIDI to WAV: {str(e)}")
|
169 |
return None
|
170 |
|
171 |
+
@spaces.GPU # Set duration to 60 seconds for music generation
|
172 |
def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_instruments):
|
173 |
"""Generate music from input image"""
|
174 |
model = get_model(conditioning_type)
|
aria/aria.py
CHANGED
@@ -3,6 +3,7 @@ import os
|
|
3 |
from PIL import Image
|
4 |
import numpy as np
|
5 |
import datetime
|
|
|
6 |
|
7 |
from .image_encoder import ImageEncoder
|
8 |
|
@@ -14,6 +15,7 @@ sys.path.append(MIDI_EMOTION_PATH)
|
|
14 |
class ARIA:
|
15 |
"""ARIA model that generates music from images based on emotional content."""
|
16 |
|
|
|
17 |
def __init__(
|
18 |
self,
|
19 |
image_model_checkpoint: str,
|
@@ -29,21 +31,8 @@ class ARIA:
|
|
29 |
conditioning: Type of conditioning to use (continuous_concat, continuous_token, discrete_token)
|
30 |
device: Device to run on (default: auto-detect)
|
31 |
"""
|
32 |
-
# Initialize
|
33 |
-
|
34 |
-
if not torch.cuda.is_available():
|
35 |
-
self.device = torch.device("cpu")
|
36 |
-
else:
|
37 |
-
try:
|
38 |
-
# Test CUDA initialization
|
39 |
-
torch.zeros(1).cuda()
|
40 |
-
self.device = torch.device("cuda")
|
41 |
-
except RuntimeError:
|
42 |
-
print("CUDA initialization failed, falling back to CPU")
|
43 |
-
self.device = torch.device("cpu")
|
44 |
-
else:
|
45 |
-
self.device = torch.device(device)
|
46 |
-
|
47 |
print(f"Using device: {self.device}")
|
48 |
self.conditioning = conditioning
|
49 |
|
@@ -71,6 +60,7 @@ class ARIA:
|
|
71 |
self.midi_model.load_state_dict(torch.load(model_fp, map_location=self.device, weights_only=True))
|
72 |
self.midi_model.eval()
|
73 |
|
|
|
74 |
@torch.inference_mode() # More efficient than no_grad for inference
|
75 |
def generate(
|
76 |
self,
|
|
|
3 |
from PIL import Image
|
4 |
import numpy as np
|
5 |
import datetime
|
6 |
+
import spaces
|
7 |
|
8 |
from .image_encoder import ImageEncoder
|
9 |
|
|
|
15 |
class ARIA:
|
16 |
"""ARIA model that generates music from images based on emotional content."""
|
17 |
|
18 |
+
@spaces.GPU # Model loading should be quick
|
19 |
def __init__(
|
20 |
self,
|
21 |
image_model_checkpoint: str,
|
|
|
31 |
conditioning: Type of conditioning to use (continuous_concat, continuous_token, discrete_token)
|
32 |
device: Device to run on (default: auto-detect)
|
33 |
"""
|
34 |
+
# Initialize device
|
35 |
+
self.device = torch.device("cuda") # Always use CUDA with ZeroGPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
print(f"Using device: {self.device}")
|
37 |
self.conditioning = conditioning
|
38 |
|
|
|
60 |
self.midi_model.load_state_dict(torch.load(model_fp, map_location=self.device, weights_only=True))
|
61 |
self.midi_model.eval()
|
62 |
|
63 |
+
@spaces.GPU
|
64 |
@torch.inference_mode() # More efficient than no_grad for inference
|
65 |
def generate(
|
66 |
self,
|
requirements.txt
CHANGED
@@ -1,11 +1,12 @@
|
|
|
|
|
|
|
|
|
|
1 |
gradio>=4.0.0
|
2 |
-
torch>=2.0.0
|
3 |
-
numpy>=1.24.0
|
4 |
matplotlib>=3.7.0
|
5 |
-
|
6 |
-
|
7 |
-
pretty-midi>=0.2.10
|
8 |
librosa>=0.10.0
|
9 |
-
soundfile>=0.12.0
|
10 |
midi2audio>=0.1.1
|
11 |
-
transformers>=4.
|
|
|
1 |
+
torch>=2.1.0
|
2 |
+
torchvision>=0.16.0
|
3 |
+
numpy>=1.21.0
|
4 |
+
Pillow>=10.0.0
|
5 |
gradio>=4.0.0
|
|
|
|
|
6 |
matplotlib>=3.7.0
|
7 |
+
huggingface_hub>=0.19.0
|
8 |
+
pretty-midi>=0.2.9
|
|
|
9 |
librosa>=0.10.0
|
10 |
+
soundfile>=0.12.0
|
11 |
midi2audio>=0.1.1
|
12 |
+
transformers>=4.35.0
|