vincentamato commited on
Commit
016b505
·
1 Parent(s): 14c6e42

Added GPU decorators

Browse files
Files changed (3) hide show
  1. app.py +6 -0
  2. aria/aria.py +5 -15
  3. requirements.txt +8 -7
app.py CHANGED
@@ -10,6 +10,11 @@ import pretty_midi
10
  import librosa
11
  import soundfile as sf
12
  from midi2audio import FluidSynth
 
 
 
 
 
13
 
14
  from aria.image_encoder import ImageEncoder
15
  from aria.aria import ARIA
@@ -163,6 +168,7 @@ def convert_midi_to_wav(midi_path):
163
  print(f"Error converting MIDI to WAV: {str(e)}")
164
  return None
165
 
 
166
  def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_instruments):
167
  """Generate music from input image"""
168
  model = get_model(conditioning_type)
 
10
  import librosa
11
  import soundfile as sf
12
  from midi2audio import FluidSynth
13
+ import spaces
14
+
15
+ # Remove CPU forcing since we'll use ZeroGPU
16
+ # os.environ["CUDA_VISIBLE_DEVICES"] = ""
17
+ # torch.set_num_threads(4)
18
 
19
  from aria.image_encoder import ImageEncoder
20
  from aria.aria import ARIA
 
168
  print(f"Error converting MIDI to WAV: {str(e)}")
169
  return None
170
 
171
+ @spaces.GPU # Set duration to 60 seconds for music generation
172
  def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_instruments):
173
  """Generate music from input image"""
174
  model = get_model(conditioning_type)
aria/aria.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  from PIL import Image
4
  import numpy as np
5
  import datetime
 
6
 
7
  from .image_encoder import ImageEncoder
8
 
@@ -14,6 +15,7 @@ sys.path.append(MIDI_EMOTION_PATH)
14
  class ARIA:
15
  """ARIA model that generates music from images based on emotional content."""
16
 
 
17
  def __init__(
18
  self,
19
  image_model_checkpoint: str,
@@ -29,21 +31,8 @@ class ARIA:
29
  conditioning: Type of conditioning to use (continuous_concat, continuous_token, discrete_token)
30
  device: Device to run on (default: auto-detect)
31
  """
32
- # Initialize CUDA if available
33
- if device is None:
34
- if not torch.cuda.is_available():
35
- self.device = torch.device("cpu")
36
- else:
37
- try:
38
- # Test CUDA initialization
39
- torch.zeros(1).cuda()
40
- self.device = torch.device("cuda")
41
- except RuntimeError:
42
- print("CUDA initialization failed, falling back to CPU")
43
- self.device = torch.device("cpu")
44
- else:
45
- self.device = torch.device(device)
46
-
47
  print(f"Using device: {self.device}")
48
  self.conditioning = conditioning
49
 
@@ -71,6 +60,7 @@ class ARIA:
71
  self.midi_model.load_state_dict(torch.load(model_fp, map_location=self.device, weights_only=True))
72
  self.midi_model.eval()
73
 
 
74
  @torch.inference_mode() # More efficient than no_grad for inference
75
  def generate(
76
  self,
 
3
  from PIL import Image
4
  import numpy as np
5
  import datetime
6
+ import spaces
7
 
8
  from .image_encoder import ImageEncoder
9
 
 
15
  class ARIA:
16
  """ARIA model that generates music from images based on emotional content."""
17
 
18
+ @spaces.GPU # Model loading should be quick
19
  def __init__(
20
  self,
21
  image_model_checkpoint: str,
 
31
  conditioning: Type of conditioning to use (continuous_concat, continuous_token, discrete_token)
32
  device: Device to run on (default: auto-detect)
33
  """
34
+ # Initialize device
35
+ self.device = torch.device("cuda") # Always use CUDA with ZeroGPU
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  print(f"Using device: {self.device}")
37
  self.conditioning = conditioning
38
 
 
60
  self.midi_model.load_state_dict(torch.load(model_fp, map_location=self.device, weights_only=True))
61
  self.midi_model.eval()
62
 
63
+ @spaces.GPU
64
  @torch.inference_mode() # More efficient than no_grad for inference
65
  def generate(
66
  self,
requirements.txt CHANGED
@@ -1,11 +1,12 @@
 
 
 
 
1
  gradio>=4.0.0
2
- torch>=2.0.0
3
- numpy>=1.24.0
4
  matplotlib>=3.7.0
5
- Pillow>=10.0.0
6
- huggingface-hub>=0.19.0
7
- pretty-midi>=0.2.10
8
  librosa>=0.10.0
9
- soundfile>=0.12.0
10
  midi2audio>=0.1.1
11
- transformers>=4.30.0
 
1
+ torch>=2.1.0
2
+ torchvision>=0.16.0
3
+ numpy>=1.21.0
4
+ Pillow>=10.0.0
5
  gradio>=4.0.0
 
 
6
  matplotlib>=3.7.0
7
+ huggingface_hub>=0.19.0
8
+ pretty-midi>=0.2.9
 
9
  librosa>=0.10.0
10
+ soundfile>=0.12.0
11
  midi2audio>=0.1.1
12
+ transformers>=4.35.0