File size: 5,992 Bytes
69defc9
 
 
 
 
016b505
69defc9
 
 
 
 
 
 
 
 
 
 
8629f1c
69defc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25f023e
 
 
 
 
 
 
 
6fc5e6b
69defc9
 
 
 
25f023e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fc5e6b
69defc9
 
 
 
 
 
 
 
 
 
 
 
25f023e
 
69defc9
 
 
 
 
8629f1c
6fc5e6b
69defc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6142e6b
 
 
69defc9
 
6fc5e6b
 
 
69defc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import os
from PIL import Image
import numpy as np
import datetime
import spaces

from .image_encoder import ImageEncoder

# Add MIDI emotion model path to Python path
import sys
MIDI_EMOTION_PATH = os.path.join(os.path.dirname(__file__), "..", "midi_emotion", "src")
sys.path.append(MIDI_EMOTION_PATH)

class ARIA:
    """ARIA model that generates music from images based on emotional content."""
    
    @spaces.GPU(duration=10) # Model loading should be quick
    def __init__(
        self,
        image_model_checkpoint: str,
        midi_model_dir: str,
        conditioning: str = "continuous_concat",
        device: str = None
    ):
        """Initialize ARIA model.
        
        Args:
            image_model_checkpoint: Path to image emotion model checkpoint
            midi_model_dir: Path to midi emotion model directory
            conditioning: Type of conditioning to use (continuous_concat, continuous_token, discrete_token)
            device: Device to run on (default: auto-detect)
        """
        # Initialize device - use CPU if CUDA not available
        if device is not None:
            self.device = torch.device(device)
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
            
        print(f"Using device: {self.device}")
        self.conditioning = conditioning
        
        # Load image emotion model
        self.image_model = ImageEncoder()
        try:
            checkpoint = torch.load(image_model_checkpoint, map_location=self.device, weights_only=True)
            # Extract only the custom heads from the checkpoint (ignore CLIP model weights)
            state_dict = {}
            for key, value in checkpoint["model_state_dict"].items():
                if key.startswith(('valence_head.', 'arousal_head.')):
                    state_dict[key] = value
            
            # Initialize the model first so the heads exist
            self.image_model._ensure_initialized()
            
            # Load only the custom head weights
            self.image_model.load_state_dict(state_dict, strict=False)
            print("ImageEncoder custom heads loaded successfully")
        except Exception as e:
            print(f"Warning: Could not load ImageEncoder checkpoint: {e}")
            print("Using randomly initialized heads")
            # Initialize anyway with random weights
            self.image_model._ensure_initialized()
            
        self.image_model = self.image_model.to(self.device)
        self.image_model.eval()
        
        # Import midi generation
        from midi_emotion.src.generate import generate
        from midi_emotion.src.models.build_model import build_model
        self.generate_midi = generate
        
        # Load midi model
        model_fp = os.path.join(midi_model_dir, 'model.pt')
        mappings_fp = os.path.join(midi_model_dir, 'mappings.pt')
        config_fp = os.path.join(midi_model_dir, 'model_config.pt')
        
        self.maps = torch.load(mappings_fp, map_location=self.device, weights_only=True)
        config = torch.load(config_fp, map_location=self.device, weights_only=True)
        self.midi_model, _ = build_model(None, load_config_dict=config)
        self.midi_model = self.midi_model.to(self.device)
        self.midi_model.load_state_dict(torch.load(model_fp, map_location=self.device, weights_only=True))
        self.midi_model.eval()
    
    @spaces.GPU(duration=60)
    @torch.inference_mode()  # More efficient than no_grad for inference
    def generate(
        self,
        image_path: str,
        out_dir: str = "output",
        gen_len: int = 2048,
        temperature: list = [1.2, 1.2],
        top_k: int = -1,
        top_p: float = 0.7,
        min_instruments: int = 2
    ) -> tuple[float, float, str]:
        """Generate music from an image.
        
        Args:
            image_path: Path to input image
            out_dir: Directory to save generated MIDI
            gen_len: Length of generation in tokens
            temperature: Temperature for sampling [note_temp, rest_temp]
            top_k: Top-k sampling (-1 to disable)
            top_p: Top-p sampling threshold
            min_instruments: Minimum number of instruments required
            
        Returns:
            Tuple of (valence, arousal, midi_path)
        """

        print("▶ ARIA.generate entered")

        # Get emotion from image
        image = Image.open(image_path).convert("RGB")
        valence, arousal = self.image_model(image)
        valence = valence.squeeze().cpu().item()
        arousal = arousal.squeeze().cpu().item()
        
        # Create output directory
        os.makedirs(out_dir, exist_ok=True)
        
        # Generate MIDI
        continuous_conditions = np.array([[valence, arousal]], dtype=np.float32)
        
        # Generate timestamp for filename (for reference)
        now = datetime.datetime.now()
        timestamp = now.strftime("%Y_%m_%d_%H_%M_%S")
        
        # Generate the MIDI
        self.generate_midi(
            model=self.midi_model,
            maps=self.maps,
            device=self.device,
            out_dir=out_dir,
            conditioning=self.conditioning,
            continuous_conditions=continuous_conditions,
            gen_len=gen_len,
            temperatures=temperature,
            top_k=top_k,
            top_p=top_p,
            min_n_instruments=min_instruments
        )
        
        # Find the most recently generated MIDI file
        midi_files = [f for f in os.listdir(out_dir) if f.endswith('.mid')]
        if midi_files:
            # Sort by creation time and get most recent
            midi_path = os.path.join(out_dir, max(midi_files, key=lambda f: os.path.getctime(os.path.join(out_dir, f))))
            return valence, arousal, midi_path
            
        raise RuntimeError("Failed to generate MIDI file")