ARIA / aria /aria.py
vincentamato's picture
precache clip
6142e6b
import torch
import os
from PIL import Image
import numpy as np
import datetime
import spaces
from .image_encoder import ImageEncoder
# Add MIDI emotion model path to Python path
import sys
MIDI_EMOTION_PATH = os.path.join(os.path.dirname(__file__), "..", "midi_emotion", "src")
sys.path.append(MIDI_EMOTION_PATH)
class ARIA:
"""ARIA model that generates music from images based on emotional content."""
@spaces.GPU(duration=10) # Model loading should be quick
def __init__(
self,
image_model_checkpoint: str,
midi_model_dir: str,
conditioning: str = "continuous_concat",
device: str = None
):
"""Initialize ARIA model.
Args:
image_model_checkpoint: Path to image emotion model checkpoint
midi_model_dir: Path to midi emotion model directory
conditioning: Type of conditioning to use (continuous_concat, continuous_token, discrete_token)
device: Device to run on (default: auto-detect)
"""
# Initialize device - use CPU if CUDA not available
if device is not None:
self.device = torch.device(device)
elif torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
print(f"Using device: {self.device}")
self.conditioning = conditioning
# Load image emotion model
self.image_model = ImageEncoder()
try:
checkpoint = torch.load(image_model_checkpoint, map_location=self.device, weights_only=True)
# Extract only the custom heads from the checkpoint (ignore CLIP model weights)
state_dict = {}
for key, value in checkpoint["model_state_dict"].items():
if key.startswith(('valence_head.', 'arousal_head.')):
state_dict[key] = value
# Initialize the model first so the heads exist
self.image_model._ensure_initialized()
# Load only the custom head weights
self.image_model.load_state_dict(state_dict, strict=False)
print("ImageEncoder custom heads loaded successfully")
except Exception as e:
print(f"Warning: Could not load ImageEncoder checkpoint: {e}")
print("Using randomly initialized heads")
# Initialize anyway with random weights
self.image_model._ensure_initialized()
self.image_model = self.image_model.to(self.device)
self.image_model.eval()
# Import midi generation
from midi_emotion.src.generate import generate
from midi_emotion.src.models.build_model import build_model
self.generate_midi = generate
# Load midi model
model_fp = os.path.join(midi_model_dir, 'model.pt')
mappings_fp = os.path.join(midi_model_dir, 'mappings.pt')
config_fp = os.path.join(midi_model_dir, 'model_config.pt')
self.maps = torch.load(mappings_fp, map_location=self.device, weights_only=True)
config = torch.load(config_fp, map_location=self.device, weights_only=True)
self.midi_model, _ = build_model(None, load_config_dict=config)
self.midi_model = self.midi_model.to(self.device)
self.midi_model.load_state_dict(torch.load(model_fp, map_location=self.device, weights_only=True))
self.midi_model.eval()
@spaces.GPU(duration=60)
@torch.inference_mode() # More efficient than no_grad for inference
def generate(
self,
image_path: str,
out_dir: str = "output",
gen_len: int = 2048,
temperature: list = [1.2, 1.2],
top_k: int = -1,
top_p: float = 0.7,
min_instruments: int = 2
) -> tuple[float, float, str]:
"""Generate music from an image.
Args:
image_path: Path to input image
out_dir: Directory to save generated MIDI
gen_len: Length of generation in tokens
temperature: Temperature for sampling [note_temp, rest_temp]
top_k: Top-k sampling (-1 to disable)
top_p: Top-p sampling threshold
min_instruments: Minimum number of instruments required
Returns:
Tuple of (valence, arousal, midi_path)
"""
print("▶ ARIA.generate entered")
# Get emotion from image
image = Image.open(image_path).convert("RGB")
valence, arousal = self.image_model(image)
valence = valence.squeeze().cpu().item()
arousal = arousal.squeeze().cpu().item()
# Create output directory
os.makedirs(out_dir, exist_ok=True)
# Generate MIDI
continuous_conditions = np.array([[valence, arousal]], dtype=np.float32)
# Generate timestamp for filename (for reference)
now = datetime.datetime.now()
timestamp = now.strftime("%Y_%m_%d_%H_%M_%S")
# Generate the MIDI
self.generate_midi(
model=self.midi_model,
maps=self.maps,
device=self.device,
out_dir=out_dir,
conditioning=self.conditioning,
continuous_conditions=continuous_conditions,
gen_len=gen_len,
temperatures=temperature,
top_k=top_k,
top_p=top_p,
min_n_instruments=min_instruments
)
# Find the most recently generated MIDI file
midi_files = [f for f in os.listdir(out_dir) if f.endswith('.mid')]
if midi_files:
# Sort by creation time and get most recent
midi_path = os.path.join(out_dir, max(midi_files, key=lambda f: os.path.getctime(os.path.join(out_dir, f))))
return valence, arousal, midi_path
raise RuntimeError("Failed to generate MIDI file")