vincentamato commited on
Commit
25f023e
·
1 Parent(s): 015fa2c
Files changed (4) hide show
  1. app.py +176 -45
  2. aria/aria.py +30 -6
  3. aria/image_encoder.py +30 -2
  4. requirements.txt +4 -2
app.py CHANGED
@@ -13,6 +13,7 @@ import librosa
13
  import soundfile as sf
14
  from midi2audio import FluidSynth
15
  import spaces
 
16
 
17
  # Remove CPU forcing since we'll use ZeroGPU
18
  # os.environ["CUDA_VISIBLE_DEVICES"] = ""
@@ -21,7 +22,12 @@ import spaces
21
  from aria.image_encoder import ImageEncoder
22
  from aria.aria import ARIA
23
 
24
- print("Checking model files...")
 
 
 
 
 
25
  # Pre-download all model files at startup
26
  MODEL_FILES = {
27
  "image_encoder": "image_encoder.pt",
@@ -33,34 +39,71 @@ MODEL_FILES = {
33
  # Create cache directory
34
  CACHE_DIR = os.path.join(os.path.dirname(__file__), "model_cache")
35
  os.makedirs(CACHE_DIR, exist_ok=True)
 
 
36
 
37
  # Download and cache all files
38
  cached_files = {}
 
 
 
39
  for model_type, files in MODEL_FILES.items():
 
 
 
40
  if isinstance(files, str):
41
  files = [files]
42
 
43
  cached_files[model_type] = []
44
  for file in files:
 
 
 
 
45
  try:
46
  # Check if file already exists in cache
47
  repo_id = "vincentamato/aria"
48
  cached_path = os.path.join(CACHE_DIR, repo_id, file)
 
49
  if os.path.exists(cached_path):
50
- print(f"Using cached file: {file}")
 
51
  cached_files[model_type].append(cached_path)
52
  else:
53
- print(f"Downloading file: {file}")
 
 
 
 
54
  cached_path = hf_hub_download(
55
  repo_id=repo_id,
56
  filename=file,
57
- cache_dir=CACHE_DIR
 
58
  )
59
- cached_files[model_type].append(cached_path)
 
 
 
 
 
 
 
60
  except Exception as e:
61
- print(f"Error with file {file}: {str(e)}")
 
 
 
 
 
 
62
 
63
- print("Model files ready.")
 
 
 
 
 
64
 
65
  # Global model cache
66
  models = {}
@@ -151,30 +194,64 @@ def convert_midi_to_wav(midi_path):
151
  return wav_path
152
 
153
  try:
154
- # Check common soundfont locations
155
- soundfont_paths = [
156
- '/usr/share/sounds/sf2/FluidR3_GM.sf2', # Linux
157
- '/usr/share/soundfonts/default.sf2', # Linux alternative
158
- '/usr/local/share/fluidsynth/generaluser.sf2', # macOS
159
- 'C:\\soundfonts\\generaluser.sf2' # Windows
 
 
 
160
  ]
161
 
162
  soundfont = None
163
- for sf_path in soundfont_paths:
164
- if os.path.exists(sf_path):
165
- soundfont = sf_path
166
- break
 
 
 
 
 
 
167
 
168
  if soundfont is None:
169
- raise RuntimeError("No SoundFont file found. Please install fluid-soundfont-gm package.")
 
 
 
170
 
171
- # Convert MIDI to WAV using FluidSynth with explicit soundfont
172
- fs = FluidSynth(sound_font=soundfont)
173
- fs.midi_to_audio(midi_path, wav_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
- return wav_path
176
  except Exception as e:
177
  print(f"Error converting MIDI to WAV: {str(e)}")
 
178
  return None
179
 
180
  @spaces.GPU(duration=120)
@@ -186,7 +263,7 @@ def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_in
186
  return (
187
  None, # For emotion_chart
188
  None, # For midi_output
189
- f"⚠️ Error: Failed to initialize {conditioning_type} model. Please check the logs."
190
  )
191
 
192
  try:
@@ -205,19 +282,41 @@ def generate_music(image, conditioning_type, gen_len, temperature, top_p, min_in
205
  min_instruments=int(min_instruments)
206
  )
207
 
 
 
 
208
  # Convert MIDI to WAV
209
  wav_path = convert_midi_to_wav(midi_path)
210
  if wav_path is None:
211
- return (
212
- None,
213
- None,
214
- "⚠️ Error: Failed to convert MIDI to WAV for playback"
215
- )
216
-
217
- # Create emotion plot
218
- plot_path = create_emotion_plot(valence, arousal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- # Build a nice Markdown result string
221
  result_text = f"""
222
  **Model Type:** {conditioning_type}
223
 
@@ -240,7 +339,7 @@ Your music has been generated! Click the play button above to listen.
240
  return (
241
  None,
242
  None,
243
- f"⚠️ Error generating music: {str(e)}"
244
  )
245
 
246
  def generate_music_wrapper(image, conditioning_type, gen_len, note_temp, rest_temp, top_p, min_instruments):
@@ -261,16 +360,31 @@ with gr.Blocks(title="ARIA - Art to Music Generator", theme=gr.themes.Soft(
261
  font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"]
262
  )) as demo:
263
  gr.Markdown("""
264
- # 🎨 ARIA: Artistic Rendering of Images into Audio
265
 
266
  Upload an image and ARIA will analyze its emotional content to generate matching music!
267
 
268
- ### How it works:
269
  1. ARIA first analyzes the emotional content of your image along two dimensions:
270
  - **Valence**: How positive or negative the emotion is (-1 to 1)
271
  - **Arousal**: How calm or excited the emotion is (-1 to 1)
272
  2. These emotions are then used to generate music that matches the mood
273
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
  with gr.Row():
276
  with gr.Column(scale=3):
@@ -278,8 +392,17 @@ with gr.Blocks(title="ARIA - Art to Music Generator", theme=gr.themes.Soft(
278
  type="filepath",
279
  label="Upload Image"
280
  )
281
-
282
- with gr.Group():
 
 
 
 
 
 
 
 
 
283
  gr.Markdown("### Generation Settings")
284
 
285
  with gr.Row():
@@ -340,16 +463,18 @@ with gr.Blocks(title="ARIA - Art to Music Generator", theme=gr.themes.Soft(
340
  info="Minimum number of instruments in the generated music"
341
  )
342
 
343
- generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg")
344
 
345
  # Add examples
 
 
346
  gr.Examples(
347
  examples=[
348
- ["examples/happy.jpg", "continuous_concat", 1024, 1.2, 1.2, 0.6, 2],
349
- ["examples/sad.jpeg", "continuous_concat", 1024, 1.2, 1.2, 0.6, 2],
350
  ],
351
  inputs=[image_input, conditioning_type, gen_len, note_temperature, rest_temperature, top_p, min_instruments],
352
- label="Try these examples"
353
  )
354
 
355
  with gr.Column(scale=2):
@@ -367,7 +492,7 @@ with gr.Blocks(title="ARIA - Art to Music Generator", theme=gr.themes.Soft(
367
  ### About ARIA
368
 
369
  ARIA is a deep learning system that generates music from artwork by:
370
- 1. Using a image emotion model to extract emotional content from images
371
  2. Generating matching music using an emotion-conditioned music generation model
372
 
373
  The emotion-conditioned MIDI generation model is based on the work by Serkan Sulun et al. in their paper
@@ -375,9 +500,15 @@ with gr.Blocks(title="ARIA - Art to Music Generator", theme=gr.themes.Soft(
375
  Original implementation: [github.com/serkansulun/midi-emotion](https://github.com/serkansulun/midi-emotion)
376
 
377
  ### Conditioning Types
378
- - **continuous_concat**: Emotions are concatenated with music features (recommended)
379
- - **continuous_token**: Emotions are added as special tokens
380
- - **discrete_token**: Emotions are discretized into tokens
 
 
 
 
 
 
381
  """)
382
 
383
  def generate_music_wrapper(image, conditioning_type, gen_len, note_temp, rest_temp, top_p, min_instruments):
 
13
  import soundfile as sf
14
  from midi2audio import FluidSynth
15
  import spaces
16
+ from tqdm import tqdm
17
 
18
  # Remove CPU forcing since we'll use ZeroGPU
19
  # os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
22
  from aria.image_encoder import ImageEncoder
23
  from aria.aria import ARIA
24
 
25
+ print("=" * 60)
26
+ print("ARIA - Art to Music Generator")
27
+ print("=" * 60)
28
+ print("Initializing model downloads...")
29
+ sys.stdout.flush()
30
+
31
  # Pre-download all model files at startup
32
  MODEL_FILES = {
33
  "image_encoder": "image_encoder.pt",
 
39
  # Create cache directory
40
  CACHE_DIR = os.path.join(os.path.dirname(__file__), "model_cache")
41
  os.makedirs(CACHE_DIR, exist_ok=True)
42
+ print(f"Cache directory: {CACHE_DIR}")
43
+ sys.stdout.flush()
44
 
45
  # Download and cache all files
46
  cached_files = {}
47
+ total_files = sum(len(files) if isinstance(files, list) else 1 for files in MODEL_FILES.values())
48
+ current_file = 0
49
+
50
  for model_type, files in MODEL_FILES.items():
51
+ print(f"\nProcessing {model_type} model files...")
52
+ sys.stdout.flush()
53
+
54
  if isinstance(files, str):
55
  files = [files]
56
 
57
  cached_files[model_type] = []
58
  for file in files:
59
+ current_file += 1
60
+ print(f"[{current_file}/{total_files}] {file}")
61
+ sys.stdout.flush()
62
+
63
  try:
64
  # Check if file already exists in cache
65
  repo_id = "vincentamato/aria"
66
  cached_path = os.path.join(CACHE_DIR, repo_id, file)
67
+
68
  if os.path.exists(cached_path):
69
+ file_size = os.path.getsize(cached_path) / (1024 * 1024) # MB
70
+ print(f" Found cached file ({file_size:.1f} MB)")
71
  cached_files[model_type].append(cached_path)
72
  else:
73
+ print(f" Downloading from HuggingFace Hub...")
74
+ print(f" Repository: {repo_id}")
75
+ sys.stdout.flush()
76
+
77
+ # Download with progress
78
  cached_path = hf_hub_download(
79
  repo_id=repo_id,
80
  filename=file,
81
+ cache_dir=CACHE_DIR,
82
+ # resume_download=True # Enable resume if connection drops
83
  )
84
+
85
+ if os.path.exists(cached_path):
86
+ file_size = os.path.getsize(cached_path) / (1024 * 1024) # MB
87
+ print(f"Download complete ({file_size:.1f} MB)")
88
+ cached_files[model_type].append(cached_path)
89
+ else:
90
+ print(f"Download failed - file not found")
91
+
92
  except Exception as e:
93
+ print(f" Error with file {file}: {str(e)}")
94
+ sys.stdout.flush()
95
+
96
+ print("\n" + "=" * 60)
97
+ print("Model file preparation complete!")
98
+ print("=" * 60)
99
+ sys.stdout.flush()
100
 
101
+ # Check what we actually got
102
+ for model_type, paths in cached_files.items():
103
+ print(f"{model_type}: {len(paths)} files ready")
104
+
105
+ print(f"\nStarting Gradio application...")
106
+ sys.stdout.flush()
107
 
108
  # Global model cache
109
  models = {}
 
194
  return wav_path
195
 
196
  try:
197
+ # Search common soundfont directories for any .sf2 or .sf3 files
198
+ import glob
199
+ soundfont_search_dirs = [
200
+ 'C:\\soundfonts\\', # Windows user soundfonts
201
+ 'C:\\Program Files\\FluidSynth\\sf2\\', # Windows FluidSynth installation
202
+ '/usr/share/sounds/sf2/', # Linux system soundfonts
203
+ '/usr/share/soundfonts/', # Linux alternative
204
+ '/usr/local/share/fluidsynth/', # macOS homebrew
205
+ '/System/Library/Audio/Sounds/Banks/', # macOS system
206
  ]
207
 
208
  soundfont = None
209
+ for search_dir in soundfont_search_dirs:
210
+ if os.path.exists(search_dir):
211
+ # Look for .sf2 and .sf3 files in this directory
212
+ for extension in ['*.sf2', '*.sf3']:
213
+ matches = glob.glob(os.path.join(search_dir, extension))
214
+ if matches:
215
+ soundfont = matches[0] # Use first soundfont found
216
+ break
217
+ if soundfont:
218
+ break
219
 
220
  if soundfont is None:
221
+ print(f"No SoundFont found. Audio playback not available.")
222
+ print(f"MIDI file saved: {midi_path}")
223
+ print(f"To enable audio: Install FluidSynth and place a .sf2 file in C:\\soundfonts\\")
224
+ return None
225
 
226
+ # Convert MIDI to WAV using direct FluidSynth command
227
+ print(f"Converting MIDI to WAV using SoundFont: {soundfont}")
228
+
229
+ # Use subprocess to call fluidsynth directly with proper arguments
230
+ import subprocess
231
+ cmd = [
232
+ 'fluidsynth',
233
+ '-ni', # No interactive mode
234
+ '-g', '0.5', # Gain
235
+ '-r', '44100', # Sample rate
236
+ '-F', wav_path, # Output WAV file
237
+ soundfont, # SoundFont file
238
+ midi_path # Input MIDI file
239
+ ]
240
+
241
+ print(f"FluidSynth command: {' '.join(cmd)}")
242
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
243
+
244
+ if result.returncode == 0 and os.path.exists(wav_path):
245
+ print(f"WAV file created: {wav_path}")
246
+ return wav_path
247
+ else:
248
+ print(f"FluidSynth failed with return code: {result.returncode}")
249
+ print(f"Error output: {result.stderr}")
250
+ return None
251
 
 
252
  except Exception as e:
253
  print(f"Error converting MIDI to WAV: {str(e)}")
254
+ print(f"MIDI file still available: {midi_path}")
255
  return None
256
 
257
  @spaces.GPU(duration=120)
 
263
  return (
264
  None, # For emotion_chart
265
  None, # For midi_output
266
+ f"Error: Failed to initialize {conditioning_type} model. Please check the logs."
267
  )
268
 
269
  try:
 
282
  min_instruments=int(min_instruments)
283
  )
284
 
285
+ # Create emotion plot first (needed for both success and failure cases)
286
+ plot_path = create_emotion_plot(valence, arousal)
287
+
288
  # Convert MIDI to WAV
289
  wav_path = convert_midi_to_wav(midi_path)
290
  if wav_path is None:
291
+ # WAV conversion failed, but we still have MIDI
292
+ result_text = f"""
293
+ **Model Type:** {conditioning_type}
294
+
295
+ **Predicted Emotions:**
296
+ - Valence: {valence:.3f} (negative → positive)
297
+ - Arousal: {arousal:.3f} (calm → excited)
298
+
299
+ **Generation Parameters:**
300
+ - Temperature: {temperature}
301
+ - Top-p: {top_p}
302
+ - Min Instruments: {min_instruments}
303
+
304
+ **Audio Playback Unavailable**
305
+ Your music has been generated as a MIDI file, but audio conversion failed.
306
+
307
+ **MIDI File:** `{os.path.basename(midi_path)}`
308
+
309
+ **To Enable Audio Playback:**
310
+ 1. Install FluidSynth: `choco install fluidsynth` (or download from GitHub)
311
+ 2. Download a SoundFont file (e.g., GeneralUser GS)
312
+ 3. Place it at: `C:\\soundfonts\\generaluser.sf2`
313
+
314
+ You can still download and play the MIDI file in any MIDI player!
315
+ """
316
+ # Return MIDI file for download instead of WAV
317
+ return (plot_path, midi_path, result_text)
318
 
319
+ # Build a nice Markdown result string for successful WAV conversion
320
  result_text = f"""
321
  **Model Type:** {conditioning_type}
322
 
 
339
  return (
340
  None,
341
  None,
342
+ f"Error generating music: {str(e)}"
343
  )
344
 
345
  def generate_music_wrapper(image, conditioning_type, gen_len, note_temp, rest_temp, top_p, min_instruments):
 
360
  font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"]
361
  )) as demo:
362
  gr.Markdown("""
363
+ # ARIA: Artistic Rendering of Images into Audio
364
 
365
  Upload an image and ARIA will analyze its emotional content to generate matching music!
366
 
367
+ ## How it works:
368
  1. ARIA first analyzes the emotional content of your image along two dimensions:
369
  - **Valence**: How positive or negative the emotion is (-1 to 1)
370
  - **Arousal**: How calm or excited the emotion is (-1 to 1)
371
  2. These emotions are then used to generate music that matches the mood
372
  """)
373
+
374
+ # Subtle gradient background for a more modern look
375
+ gr.HTML(
376
+ """
377
+ <style>
378
+ body {
379
+ background: radial-gradient(circle at top left, #0d1117 0%, #06080d 100%);
380
+ }
381
+ /* Elevate accordion header visibility */
382
+ .gr-accordion-summary {
383
+ font-weight: 600;
384
+ }
385
+ </style>
386
+ """
387
+ )
388
 
389
  with gr.Row():
390
  with gr.Column(scale=3):
 
392
  type="filepath",
393
  label="Upload Image"
394
  )
395
+
396
+ # Quick-start guidance so first-time users immediately know what to do
397
+ gr.Markdown(
398
+ "## Quick Start\n"
399
+ "1. **Click an example artwork below** *or* **upload your own image** above.\n"
400
+ "2. (Optional) Open **Advanced Settings** to fine-tune the generation.\n"
401
+ "3. Hit **Generate Music** to inference the model!"
402
+ )
403
+
404
+ # Advanced controls are tucked away inside a collapsible panel to keep the UI clean
405
+ with gr.Accordion("Advanced Settings", open=False):
406
  gr.Markdown("### Generation Settings")
407
 
408
  with gr.Row():
 
463
  info="Minimum number of instruments in the generated music"
464
  )
465
 
466
+ generate_btn = gr.Button("Generate Music", variant="primary", size="lg")
467
 
468
  # Add examples
469
+ # Dynamic path resolution for local vs HF Spaces deployment
470
+ examples_dir = "examples" if os.path.exists("examples") else "ARIA/examples"
471
  gr.Examples(
472
  examples=[
473
+ [f"{examples_dir}/happy.jpg", "continuous_concat", 1024, 1.2, 1.2, 0.6, 2],
474
+ [f"{examples_dir}/sad.jpeg", "continuous_concat", 1024, 1.2, 1.2, 0.6, 2],
475
  ],
476
  inputs=[image_input, conditioning_type, gen_len, note_temperature, rest_temperature, top_p, min_instruments],
477
+ label="Example Artworks (click to load)"
478
  )
479
 
480
  with gr.Column(scale=2):
 
492
  ### About ARIA
493
 
494
  ARIA is a deep learning system that generates music from artwork by:
495
+ 1. Using a image-emotion model to extract emotional content from images
496
  2. Generating matching music using an emotion-conditioned music generation model
497
 
498
  The emotion-conditioned MIDI generation model is based on the work by Serkan Sulun et al. in their paper
 
500
  Original implementation: [github.com/serkansulun/midi-emotion](https://github.com/serkansulun/midi-emotion)
501
 
502
  ### Conditioning Types
503
+
504
+ **continuous_concat (Recommended)**
505
+ Creates a single vector from valence and arousal values, repeats it across the sequence, and concatenates it with every music token embedding. This approach gives the emotion information *global influence* throughout the entire generation process, allowing the transformer to access emotional context at every timestep. Research shows this method achieves the best performance in both note prediction accuracy and emotional coherence.
506
+
507
+ **continuous_token**
508
+ Converts each emotion value (valence and arousal) into separate condition vectors with the same length as music token embeddings, then concatenates them in the sequence dimension. The emotion vectors are inserted at the beginning of the input sequence during generation. This treats emotions similarly to music tokens but can lose influence as the sequence grows longer.
509
+
510
+ **discrete_token**
511
+ Quantizes continuous emotion values into 5 discrete bins (very low, low, moderate, high, very high) and converts them into control tokens. These tokens are placed before the music tokens in the sequence. While this represents the current state-of-the-art approach in conditional text generation, it suffers from information loss due to binning and can lose emotional context during longer generations when tokens are truncated.
512
  """)
513
 
514
  def generate_music_wrapper(image, conditioning_type, gen_len, note_temp, rest_temp, top_p, min_instruments):
aria/aria.py CHANGED
@@ -31,15 +31,39 @@ class ARIA:
31
  conditioning: Type of conditioning to use (continuous_concat, continuous_token, discrete_token)
32
  device: Device to run on (default: auto-detect)
33
  """
34
- # Initialize device
35
- self.device = torch.device("cuda") # Always use CUDA with ZeroGPU
 
 
 
 
 
 
36
  print(f"Using device: {self.device}")
37
  self.conditioning = conditioning
38
 
39
  # Load image emotion model
40
  self.image_model = ImageEncoder()
41
- checkpoint = torch.load(image_model_checkpoint, map_location=self.device, weights_only=True)
42
- self.image_model.load_state_dict(checkpoint["model_state_dict"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  self.image_model = self.image_model.to(self.device)
44
  self.image_model.eval()
45
 
@@ -53,8 +77,8 @@ class ARIA:
53
  mappings_fp = os.path.join(midi_model_dir, 'mappings.pt')
54
  config_fp = os.path.join(midi_model_dir, 'model_config.pt')
55
 
56
- self.maps = torch.load(mappings_fp, weights_only=True)
57
- config = torch.load(config_fp, weights_only=True)
58
  self.midi_model, _ = build_model(None, load_config_dict=config)
59
  self.midi_model = self.midi_model.to(self.device)
60
  self.midi_model.load_state_dict(torch.load(model_fp, map_location=self.device, weights_only=True))
 
31
  conditioning: Type of conditioning to use (continuous_concat, continuous_token, discrete_token)
32
  device: Device to run on (default: auto-detect)
33
  """
34
+ # Initialize device - use CPU if CUDA not available
35
+ if device is not None:
36
+ self.device = torch.device(device)
37
+ elif torch.cuda.is_available():
38
+ self.device = torch.device("cuda")
39
+ else:
40
+ self.device = torch.device("cpu")
41
+
42
  print(f"Using device: {self.device}")
43
  self.conditioning = conditioning
44
 
45
  # Load image emotion model
46
  self.image_model = ImageEncoder()
47
+ try:
48
+ checkpoint = torch.load(image_model_checkpoint, map_location=self.device, weights_only=True)
49
+ # Extract only the custom heads from the checkpoint (ignore CLIP model weights)
50
+ state_dict = {}
51
+ for key, value in checkpoint["model_state_dict"].items():
52
+ if key.startswith(('valence_head.', 'arousal_head.')):
53
+ state_dict[key] = value
54
+
55
+ # Initialize the model first so the heads exist
56
+ self.image_model._ensure_initialized()
57
+
58
+ # Load only the custom head weights
59
+ self.image_model.load_state_dict(state_dict, strict=False)
60
+ print("ImageEncoder custom heads loaded successfully")
61
+ except Exception as e:
62
+ print(f"Warning: Could not load ImageEncoder checkpoint: {e}")
63
+ print("Using randomly initialized heads")
64
+ # Initialize anyway with random weights
65
+ self.image_model._ensure_initialized()
66
+
67
  self.image_model = self.image_model.to(self.device)
68
  self.image_model.eval()
69
 
 
77
  mappings_fp = os.path.join(midi_model_dir, 'mappings.pt')
78
  config_fp = os.path.join(midi_model_dir, 'model_config.pt')
79
 
80
+ self.maps = torch.load(mappings_fp, map_location=self.device, weights_only=True)
81
+ config = torch.load(config_fp, map_location=self.device, weights_only=True)
82
  self.midi_model, _ = build_model(None, load_config_dict=config)
83
  self.midi_model = self.midi_model.to(self.device)
84
  self.midi_model.load_state_dict(torch.load(model_fp, map_location=self.device, weights_only=True))
aria/image_encoder.py CHANGED
@@ -13,9 +13,28 @@ class ImageEncoder(nn.Module):
13
  """
14
  super().__init__()
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # Load CLIP model and processor
17
- self.clip_model = CLIPModel.from_pretrained(clip_model_name)
18
- self.processor = CLIPProcessor.from_pretrained(clip_model_name)
 
 
19
 
20
  # Freeze CLIP parameters
21
  for param in self.clip_model.parameters():
@@ -50,6 +69,9 @@ class ImageEncoder(nn.Module):
50
  # Move model to GPU if available
51
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
  self.to(self.device)
 
 
 
53
 
54
  def forward(self, images: Union[Image.Image, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
55
  """Forward pass to get valence and arousal predictions.
@@ -60,6 +82,9 @@ class ImageEncoder(nn.Module):
60
  Returns:
61
  Tuple of predicted valence and arousal scores
62
  """
 
 
 
63
  # Process images if they're PIL images
64
  if isinstance(images, Image.Image):
65
  inputs = self.processor(images=images, return_tensors="pt")
@@ -85,6 +110,9 @@ class ImageEncoder(nn.Module):
85
  Returns:
86
  Image embedding tensor
87
  """
 
 
 
88
  inputs = self.processor(images=image, return_tensors="pt")
89
  with torch.no_grad():
90
  image_features = self.clip_model.get_image_features(inputs.pixel_values.to(self.device))
 
13
  """
14
  super().__init__()
15
 
16
+ # Store model name for lazy loading
17
+ self.clip_model_name = clip_model_name
18
+ self.clip_model = None
19
+ self.processor = None
20
+ self.valence_head = None
21
+ self.arousal_head = None
22
+ self.device = None
23
+ self._initialized = False
24
+
25
+ def _ensure_initialized(self):
26
+ """Lazy initialization of the model components."""
27
+ if self._initialized:
28
+ return
29
+
30
+ print(f"Initializing ImageEncoder with {self.clip_model_name}...")
31
+ print("Downloading CLIP model (this may take a moment)...")
32
+
33
  # Load CLIP model and processor
34
+ self.clip_model = CLIPModel.from_pretrained(self.clip_model_name)
35
+ self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
36
+
37
+ print("CLIP model loaded successfully")
38
 
39
  # Freeze CLIP parameters
40
  for param in self.clip_model.parameters():
 
69
  # Move model to GPU if available
70
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
  self.to(self.device)
72
+
73
+ print(f"Model moved to device: {self.device}")
74
+ self._initialized = True
75
 
76
  def forward(self, images: Union[Image.Image, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
77
  """Forward pass to get valence and arousal predictions.
 
82
  Returns:
83
  Tuple of predicted valence and arousal scores
84
  """
85
+ # Ensure model is initialized
86
+ self._ensure_initialized()
87
+
88
  # Process images if they're PIL images
89
  if isinstance(images, Image.Image):
90
  inputs = self.processor(images=images, return_tensors="pt")
 
110
  Returns:
111
  Image embedding tensor
112
  """
113
+ # Ensure model is initialized
114
+ self._ensure_initialized()
115
+
116
  inputs = self.processor(images=image, return_tensors="pt")
117
  with torch.no_grad():
118
  image_features = self.clip_model.get_image_features(inputs.pixel_values.to(self.device))
requirements.txt CHANGED
@@ -6,8 +6,10 @@ gradio>=4.0.0
6
  matplotlib>=3.7.0
7
  huggingface_hub>=0.19.0
8
  pretty-midi>=0.2.9
9
- librosa>=0.10.0
10
  soundfile>=0.12.0
11
  midi2audio>=0.1.1
12
  transformers>=4.35.0
13
- spaces>=0.32.0
 
 
 
6
  matplotlib>=3.7.0
7
  huggingface_hub>=0.19.0
8
  pretty-midi>=0.2.9
9
+ librosa>=0.10.1
10
  soundfile>=0.12.0
11
  midi2audio>=0.1.1
12
  transformers>=4.35.0
13
+ spaces>=0.32.0
14
+ numba>=0.60.0
15
+ llvmlite>=0.43.0