Spaces:
Sleeping
Sleeping
from agents.agents import harmonizer, infiller, change_melody | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
import mido | |
import tempfile | |
import os | |
import music21 | |
import traceback | |
from uuid import uuid4 | |
import threading | |
from transformers import AutoModelForCausalLM | |
app = Flask(__name__) | |
CORS(app) | |
def add_cors_headers(response): | |
# Allow only your domain | |
response.headers['Access-Control-Allow-Origin'] = 'https://inscoreai.netlify.app/' | |
response.headers['Access-Control-Allow-Methods'] = 'GET, POST' | |
response.headers['Access-Control-Allow-Headers'] = 'Content-Type' | |
return response | |
def midi_to_musicxml(midi_file_path): | |
"""Convert MIDI file to MusicXML string with absolute safety""" | |
try: | |
midi_path_str = str(midi_file_path) | |
# Parse and convert to MusicXML | |
score = music21.converter.parse(midi_path_str) | |
# Create temporary output file path | |
temp_output = os.path.join(tempfile.gettempdir(), f"output_{uuid4().hex}.musicxml") | |
# Write to temporary file | |
score.write('musicxml', temp_output) | |
# Read back as string | |
with open(temp_output, 'r') as f: | |
musicxml_str = f.read() | |
# Clean up | |
os.unlink(temp_output) | |
return musicxml_str | |
except Exception as e: | |
print(f"Conversion error: {str(e)}") | |
traceback.print_exc() | |
raise | |
def load_model(): | |
cache_dir = os.environ.get('HF_HOME', '/home/user/.cache/huggingface') | |
print(f"Using cache directory: {cache_dir}") | |
# Verify permissions | |
try: | |
test_file = os.path.join(cache_dir, "test.txt") | |
with open(test_file, "w") as f: | |
f.write("test") | |
print("β Cache directory is writable") | |
except Exception as e: | |
print(f"β Cache directory not writable: {e}") | |
# Load model | |
return AutoModelForCausalLM.from_pretrained( | |
'stanford-crfm/music-small-800k', | |
cache_dir=cache_dir, | |
local_files_only=False, | |
force_download=False | |
) | |
# Model loading setup | |
MODEL = None | |
MODEL_LOCK = threading.Lock() | |
# Initialize model when app starts | |
with app.app_context(): | |
load_model() | |
def handle_upload(): | |
temp_midi_path = None | |
top_p = float(request.form.get('top_p', '0.95')) | |
try: | |
# Validate input | |
if 'midi_file' not in request.files: | |
return jsonify({"status": "error", "message": "No MIDI file provided"}), 400 | |
midi_file = request.files['midi_file'] | |
start_time = request.form.get('start_time', '0') | |
end_time = request.form.get('end_time', '0') | |
# Create temporary MIDI file with random name | |
temp_dir = tempfile.gettempdir() | |
temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid") | |
# Save uploaded MIDI to temp file | |
midi_file.save(temp_midi_path) | |
# Process MIDI | |
midi = mido.MidiFile(temp_midi_path) | |
model = load_model() | |
harmonized_midi = harmonizer(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p) | |
# Save harmonized MIDI (overwriting temp file) | |
harmonized_midi.save(temp_midi_path) | |
# Convert to MusicXML string | |
musicxml_str = midi_to_musicxml(temp_midi_path) | |
# Final type verification | |
if not isinstance(musicxml_str, str): | |
raise TypeError(f"Expected string but got {type(musicxml_str)}") | |
return jsonify({ | |
"status": "success", | |
"musicxml": musicxml_str | |
}) | |
except Exception as e: | |
print(f"Error processing request: {str(e)}") | |
traceback.print_exc() | |
return jsonify({ | |
"status": "error", | |
"message": str(e) | |
}), 400 | |
finally: | |
# Clean up temp file | |
if temp_midi_path and os.path.exists(temp_midi_path): | |
try: | |
os.unlink(temp_midi_path) | |
except Exception as e: | |
print(f"Warning: Could not remove {temp_midi_path}: {str(e)}") | |
def handle_upload_infilling(): | |
temp_midi_path = None | |
top_p = float(request.form.get('top_p', '0.95')) | |
try: | |
# Validate input | |
if 'midi_file' not in request.files: | |
return jsonify({"status": "error", "message": "No MIDI file provided"}), 400 | |
midi_file = request.files['midi_file'] | |
start_time = request.form.get('start_time', '0') | |
end_time = request.form.get('end_time', '0') | |
# Create temporary MIDI file with random name | |
temp_dir = tempfile.gettempdir() | |
temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid") | |
# Save uploaded MIDI to temp file | |
midi_file.save(temp_midi_path) | |
# Process MIDI | |
midi = mido.MidiFile(temp_midi_path) | |
model = load_model() | |
infilled_midi = infiller(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p) | |
# Save harmonized MIDI (overwriting temp file) | |
infilled_midi.save(temp_midi_path) | |
# Convert to MusicXML string | |
musicxml_str = midi_to_musicxml(temp_midi_path) | |
# Final type verification | |
if not isinstance(musicxml_str, str): | |
raise TypeError(f"Expected string but got {type(musicxml_str)}") | |
return jsonify({ | |
"status": "success", | |
"musicxml": musicxml_str | |
}) | |
except Exception as e: | |
print(f"Error processing request: {str(e)}") | |
traceback.print_exc() | |
return jsonify({ | |
"status": "error", | |
"message": str(e) | |
}), 400 | |
finally: | |
# Clean up temp file | |
if temp_midi_path and os.path.exists(temp_midi_path): | |
try: | |
os.unlink(temp_midi_path) | |
except Exception as e: | |
print(f"Warning: Could not remove {temp_midi_path}: {str(e)}") | |
def handle_upload_changemelody(): | |
temp_midi_path = None | |
top_p = float(request.form.get('top_p', '0.95')) | |
try: | |
# Validate input | |
if 'midi_file' not in request.files: | |
return jsonify({"status": "error", "message": "No MIDI file provided"}), 400 | |
midi_file = request.files['midi_file'] | |
start_time = request.form.get('start_time', '0') | |
end_time = request.form.get('end_time', '0') | |
# Create temporary MIDI file with random name | |
temp_dir = tempfile.gettempdir() | |
temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid") | |
# Save uploaded MIDI to temp file | |
midi_file.save(temp_midi_path) | |
# Process MIDI | |
midi = mido.MidiFile(temp_midi_path) | |
model = load_model() | |
changed_melody_midi = change_melody(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p) | |
# Save harmonized MIDI (overwriting temp file) | |
changed_melody_midi.save(temp_midi_path) | |
# Convert to MusicXML string | |
musicxml_str = midi_to_musicxml(temp_midi_path) | |
# Final type verification | |
if not isinstance(musicxml_str, str): | |
raise TypeError(f"Expected string but got {type(musicxml_str)}") | |
return jsonify({ | |
"status": "success", | |
"musicxml": musicxml_str | |
}) | |
except Exception as e: | |
print(f"Error processing request: {str(e)}") | |
traceback.print_exc() | |
return jsonify({ | |
"status": "error", | |
"message": str(e) | |
}), 400 | |
finally: | |
# Clean up temp file | |
if temp_midi_path and os.path.exists(temp_midi_path): | |
try: | |
os.unlink(temp_midi_path) | |
except Exception as e: | |
print(f"Warning: Could not remove {temp_midi_path}: {str(e)}") | |
if __name__ == '__main__': | |
app.run(debug=True, port=5000) |