Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,992 Bytes
69defc9 016b505 69defc9 8629f1c 69defc9 25f023e 6fc5e6b 69defc9 25f023e 6fc5e6b 69defc9 25f023e 69defc9 8629f1c 6fc5e6b 69defc9 6142e6b 69defc9 6fc5e6b 69defc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
import os
from PIL import Image
import numpy as np
import datetime
import spaces
from .image_encoder import ImageEncoder
# Add MIDI emotion model path to Python path
import sys
MIDI_EMOTION_PATH = os.path.join(os.path.dirname(__file__), "..", "midi_emotion", "src")
sys.path.append(MIDI_EMOTION_PATH)
class ARIA:
"""ARIA model that generates music from images based on emotional content."""
@spaces.GPU(duration=10) # Model loading should be quick
def __init__(
self,
image_model_checkpoint: str,
midi_model_dir: str,
conditioning: str = "continuous_concat",
device: str = None
):
"""Initialize ARIA model.
Args:
image_model_checkpoint: Path to image emotion model checkpoint
midi_model_dir: Path to midi emotion model directory
conditioning: Type of conditioning to use (continuous_concat, continuous_token, discrete_token)
device: Device to run on (default: auto-detect)
"""
# Initialize device - use CPU if CUDA not available
if device is not None:
self.device = torch.device(device)
elif torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
print(f"Using device: {self.device}")
self.conditioning = conditioning
# Load image emotion model
self.image_model = ImageEncoder()
try:
checkpoint = torch.load(image_model_checkpoint, map_location=self.device, weights_only=True)
# Extract only the custom heads from the checkpoint (ignore CLIP model weights)
state_dict = {}
for key, value in checkpoint["model_state_dict"].items():
if key.startswith(('valence_head.', 'arousal_head.')):
state_dict[key] = value
# Initialize the model first so the heads exist
self.image_model._ensure_initialized()
# Load only the custom head weights
self.image_model.load_state_dict(state_dict, strict=False)
print("ImageEncoder custom heads loaded successfully")
except Exception as e:
print(f"Warning: Could not load ImageEncoder checkpoint: {e}")
print("Using randomly initialized heads")
# Initialize anyway with random weights
self.image_model._ensure_initialized()
self.image_model = self.image_model.to(self.device)
self.image_model.eval()
# Import midi generation
from midi_emotion.src.generate import generate
from midi_emotion.src.models.build_model import build_model
self.generate_midi = generate
# Load midi model
model_fp = os.path.join(midi_model_dir, 'model.pt')
mappings_fp = os.path.join(midi_model_dir, 'mappings.pt')
config_fp = os.path.join(midi_model_dir, 'model_config.pt')
self.maps = torch.load(mappings_fp, map_location=self.device, weights_only=True)
config = torch.load(config_fp, map_location=self.device, weights_only=True)
self.midi_model, _ = build_model(None, load_config_dict=config)
self.midi_model = self.midi_model.to(self.device)
self.midi_model.load_state_dict(torch.load(model_fp, map_location=self.device, weights_only=True))
self.midi_model.eval()
@spaces.GPU(duration=60)
@torch.inference_mode() # More efficient than no_grad for inference
def generate(
self,
image_path: str,
out_dir: str = "output",
gen_len: int = 2048,
temperature: list = [1.2, 1.2],
top_k: int = -1,
top_p: float = 0.7,
min_instruments: int = 2
) -> tuple[float, float, str]:
"""Generate music from an image.
Args:
image_path: Path to input image
out_dir: Directory to save generated MIDI
gen_len: Length of generation in tokens
temperature: Temperature for sampling [note_temp, rest_temp]
top_k: Top-k sampling (-1 to disable)
top_p: Top-p sampling threshold
min_instruments: Minimum number of instruments required
Returns:
Tuple of (valence, arousal, midi_path)
"""
print("▶ ARIA.generate entered")
# Get emotion from image
image = Image.open(image_path).convert("RGB")
valence, arousal = self.image_model(image)
valence = valence.squeeze().cpu().item()
arousal = arousal.squeeze().cpu().item()
# Create output directory
os.makedirs(out_dir, exist_ok=True)
# Generate MIDI
continuous_conditions = np.array([[valence, arousal]], dtype=np.float32)
# Generate timestamp for filename (for reference)
now = datetime.datetime.now()
timestamp = now.strftime("%Y_%m_%d_%H_%M_%S")
# Generate the MIDI
self.generate_midi(
model=self.midi_model,
maps=self.maps,
device=self.device,
out_dir=out_dir,
conditioning=self.conditioning,
continuous_conditions=continuous_conditions,
gen_len=gen_len,
temperatures=temperature,
top_k=top_k,
top_p=top_p,
min_n_instruments=min_instruments
)
# Find the most recently generated MIDI file
midi_files = [f for f in os.listdir(out_dir) if f.endswith('.mid')]
if midi_files:
# Sort by creation time and get most recent
midi_path = os.path.join(out_dir, max(midi_files, key=lambda f: os.path.getctime(os.path.join(out_dir, f))))
return valence, arousal, midi_path
raise RuntimeError("Failed to generate MIDI file")
|