InScoreAPI / api.py
manuel-l01's picture
Initial commit
3c0b02c
from agents.agents import harmonizer, infiller, change_melody
from flask import Flask, request, jsonify
from flask_cors import CORS
import mido
import tempfile
import os
import music21
import traceback
from uuid import uuid4
import threading
from transformers import AutoModelForCausalLM
app = Flask(__name__)
CORS(app)
@app.after_request
def add_cors_headers(response):
# Allow only your domain
response.headers['Access-Control-Allow-Origin'] = 'https://inscoreai.netlify.app/'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type'
return response
def midi_to_musicxml(midi_file_path):
"""Convert MIDI file to MusicXML string with absolute safety"""
try:
midi_path_str = str(midi_file_path)
# Parse and convert to MusicXML
score = music21.converter.parse(midi_path_str)
# Create temporary output file path
temp_output = os.path.join(tempfile.gettempdir(), f"output_{uuid4().hex}.musicxml")
# Write to temporary file
score.write('musicxml', temp_output)
# Read back as string
with open(temp_output, 'r') as f:
musicxml_str = f.read()
# Clean up
os.unlink(temp_output)
return musicxml_str
except Exception as e:
print(f"Conversion error: {str(e)}")
traceback.print_exc()
raise
def load_model():
cache_dir = os.environ.get('HF_HOME', '/home/user/.cache/huggingface')
print(f"Using cache directory: {cache_dir}")
# Verify permissions
try:
test_file = os.path.join(cache_dir, "test.txt")
with open(test_file, "w") as f:
f.write("test")
print("βœ… Cache directory is writable")
except Exception as e:
print(f"❌ Cache directory not writable: {e}")
# Load model
return AutoModelForCausalLM.from_pretrained(
'stanford-crfm/music-small-800k',
cache_dir=cache_dir,
local_files_only=False,
force_download=False
)
# Model loading setup
MODEL = None
MODEL_LOCK = threading.Lock()
# Initialize model when app starts
with app.app_context():
load_model()
@app.route('/upload', methods=['POST'])
def handle_upload():
temp_midi_path = None
top_p = float(request.form.get('top_p', '0.95'))
try:
# Validate input
if 'midi_file' not in request.files:
return jsonify({"status": "error", "message": "No MIDI file provided"}), 400
midi_file = request.files['midi_file']
start_time = request.form.get('start_time', '0')
end_time = request.form.get('end_time', '0')
# Create temporary MIDI file with random name
temp_dir = tempfile.gettempdir()
temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid")
# Save uploaded MIDI to temp file
midi_file.save(temp_midi_path)
# Process MIDI
midi = mido.MidiFile(temp_midi_path)
model = load_model()
harmonized_midi = harmonizer(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p)
# Save harmonized MIDI (overwriting temp file)
harmonized_midi.save(temp_midi_path)
# Convert to MusicXML string
musicxml_str = midi_to_musicxml(temp_midi_path)
# Final type verification
if not isinstance(musicxml_str, str):
raise TypeError(f"Expected string but got {type(musicxml_str)}")
return jsonify({
"status": "success",
"musicxml": musicxml_str
})
except Exception as e:
print(f"Error processing request: {str(e)}")
traceback.print_exc()
return jsonify({
"status": "error",
"message": str(e)
}), 400
finally:
# Clean up temp file
if temp_midi_path and os.path.exists(temp_midi_path):
try:
os.unlink(temp_midi_path)
except Exception as e:
print(f"Warning: Could not remove {temp_midi_path}: {str(e)}")
@app.route('/uploadinfill', methods=['POST'])
def handle_upload_infilling():
temp_midi_path = None
top_p = float(request.form.get('top_p', '0.95'))
try:
# Validate input
if 'midi_file' not in request.files:
return jsonify({"status": "error", "message": "No MIDI file provided"}), 400
midi_file = request.files['midi_file']
start_time = request.form.get('start_time', '0')
end_time = request.form.get('end_time', '0')
# Create temporary MIDI file with random name
temp_dir = tempfile.gettempdir()
temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid")
# Save uploaded MIDI to temp file
midi_file.save(temp_midi_path)
# Process MIDI
midi = mido.MidiFile(temp_midi_path)
model = load_model()
infilled_midi = infiller(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p)
# Save harmonized MIDI (overwriting temp file)
infilled_midi.save(temp_midi_path)
# Convert to MusicXML string
musicxml_str = midi_to_musicxml(temp_midi_path)
# Final type verification
if not isinstance(musicxml_str, str):
raise TypeError(f"Expected string but got {type(musicxml_str)}")
return jsonify({
"status": "success",
"musicxml": musicxml_str
})
except Exception as e:
print(f"Error processing request: {str(e)}")
traceback.print_exc()
return jsonify({
"status": "error",
"message": str(e)
}), 400
finally:
# Clean up temp file
if temp_midi_path and os.path.exists(temp_midi_path):
try:
os.unlink(temp_midi_path)
except Exception as e:
print(f"Warning: Could not remove {temp_midi_path}: {str(e)}")
@app.route('/uploadchangemelody', methods=['POST'])
def handle_upload_changemelody():
temp_midi_path = None
top_p = float(request.form.get('top_p', '0.95'))
try:
# Validate input
if 'midi_file' not in request.files:
return jsonify({"status": "error", "message": "No MIDI file provided"}), 400
midi_file = request.files['midi_file']
start_time = request.form.get('start_time', '0')
end_time = request.form.get('end_time', '0')
# Create temporary MIDI file with random name
temp_dir = tempfile.gettempdir()
temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid")
# Save uploaded MIDI to temp file
midi_file.save(temp_midi_path)
# Process MIDI
midi = mido.MidiFile(temp_midi_path)
model = load_model()
changed_melody_midi = change_melody(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p)
# Save harmonized MIDI (overwriting temp file)
changed_melody_midi.save(temp_midi_path)
# Convert to MusicXML string
musicxml_str = midi_to_musicxml(temp_midi_path)
# Final type verification
if not isinstance(musicxml_str, str):
raise TypeError(f"Expected string but got {type(musicxml_str)}")
return jsonify({
"status": "success",
"musicxml": musicxml_str
})
except Exception as e:
print(f"Error processing request: {str(e)}")
traceback.print_exc()
return jsonify({
"status": "error",
"message": str(e)
}), 400
finally:
# Clean up temp file
if temp_midi_path and os.path.exists(temp_midi_path):
try:
os.unlink(temp_midi_path)
except Exception as e:
print(f"Warning: Could not remove {temp_midi_path}: {str(e)}")
if __name__ == '__main__':
app.run(debug=True, port=5000)