|
import gradio as gr |
|
import sys |
|
import threading |
|
import queue |
|
from io import TextIOBase |
|
import datetime |
|
import subprocess |
|
import os |
|
from inference import postprocess_inst_names |
|
|
|
|
|
from inference import inference_patch |
|
from convert import abc2xml, xml2, pdf2img |
|
|
|
|
|
|
|
with open('prompts.txt', 'r') as f: |
|
prompts = f.readlines() |
|
|
|
valid_combinations = set() |
|
for prompt in prompts: |
|
prompt = prompt.strip() |
|
parts = prompt.split('_') |
|
valid_combinations.add((parts[0], parts[1], parts[2])) |
|
|
|
|
|
periods = sorted({p for p, _, _ in valid_combinations}) |
|
composers = sorted({c for _, c, _ in valid_combinations}) |
|
instruments = sorted({i for _, _, i in valid_combinations}) |
|
|
|
|
|
def update_components(period, composer): |
|
if not period: |
|
return [ |
|
gr.update(choices=[], value=None, interactive=False), |
|
gr.update(choices=[], value=None, interactive=False) |
|
] |
|
|
|
valid_composers = sorted({c for p, c, _ in valid_combinations if p == period}) |
|
valid_instruments = sorted({i for p, c, i in valid_combinations if p == period and c == composer}) if composer else [] |
|
|
|
return [ |
|
gr.update( |
|
choices=valid_composers, |
|
value=composer if composer in valid_composers else None, |
|
interactive=True |
|
), |
|
gr.update( |
|
choices=valid_instruments, |
|
value=None, |
|
interactive=bool(valid_instruments) |
|
) |
|
] |
|
|
|
|
|
class RealtimeStream(TextIOBase): |
|
def __init__(self, queue): |
|
self.queue = queue |
|
|
|
def write(self, text): |
|
self.queue.put(text) |
|
return len(text) |
|
|
|
def convert_files(abc_content, period, composer, instrumentation): |
|
if not all([period, composer, instrumentation]): |
|
raise gr.Error("Please complete a valid generation first before saving") |
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
prompt_str = f"{period}_{composer}_{instrumentation}" |
|
filename_base = f"{timestamp}_{prompt_str}" |
|
|
|
abc_filename = f"{filename_base}.abc" |
|
with open(abc_filename, "w", encoding="utf-8") as f: |
|
f.write(abc_content) |
|
|
|
|
|
postprocessed_inst_abc = postprocess_inst_names(abc_content) |
|
filename_base_postinst = f"{filename_base}_postinst" |
|
with open(filename_base_postinst + ".abc", "w", encoding="utf-8") as f: |
|
f.write(postprocessed_inst_abc) |
|
|
|
|
|
file_paths = {'abc': abc_filename} |
|
try: |
|
|
|
abc2xml(filename_base) |
|
abc2xml(filename_base_postinst) |
|
|
|
|
|
xml2(filename_base, 'pdf') |
|
|
|
|
|
xml2(filename_base, 'mid') |
|
xml2(filename_base_postinst, 'mid') |
|
|
|
|
|
xml2(filename_base, 'wav') |
|
xml2(filename_base_postinst, 'wav') |
|
|
|
|
|
images = pdf2img(filename_base) |
|
for i, image in enumerate(images): |
|
image.save(f"{filename_base}_page_{i+1}.png", "PNG") |
|
|
|
file_paths.update({ |
|
'xml': f"{filename_base_postinst}.xml", |
|
'pdf': f"{filename_base}.pdf", |
|
'mid': f"{filename_base_postinst}.mid", |
|
'wav': f"{filename_base_postinst}.wav", |
|
'pages': len(images), |
|
'current_page': 0, |
|
'base': filename_base |
|
}) |
|
|
|
except Exception as e: |
|
raise gr.Error(f"文件处理失败: {str(e)}") |
|
|
|
return file_paths |
|
|
|
|
|
|
|
def update_page(direction, data): |
|
""" |
|
data 里面包含了 'pages','current_page','base' 三个关键信息 |
|
""" |
|
if not data: |
|
return None, gr.update(interactive=False), gr.update(interactive=False), data |
|
|
|
if direction == "prev" and data['current_page'] > 0: |
|
data['current_page'] -= 1 |
|
elif direction == "next" and data['current_page'] < data['pages'] - 1: |
|
data['current_page'] += 1 |
|
|
|
current_page_index = data['current_page'] |
|
|
|
new_image = f"{data['base']}_page_{current_page_index+1}.png" |
|
|
|
prev_btn_state = gr.update(interactive=(current_page_index > 0)) |
|
next_btn_state = gr.update(interactive=(current_page_index < data['pages'] - 1)) |
|
|
|
return new_image, prev_btn_state, next_btn_state, data |
|
|
|
|
|
def generate_music(period, composer, instrumentation): |
|
""" |
|
需要保证每次 yield 的返回值数量一致。 |
|
我们这里准备返回 5 个值,对应: |
|
1) process_output (中间推理信息) |
|
2) final_output (最终 ABC) |
|
3) pdf_image (PDF 第一页对应的 png 路径) |
|
4) audio_player (WAV 路径) |
|
5) pdf_state (翻页用的 state) |
|
""" |
|
if (period, composer, instrumentation) not in valid_combinations: |
|
|
|
raise gr.Error("Invalid prompt combination! Please re-select from the period options") |
|
|
|
|
|
|
|
|
|
|
|
output_queue = queue.Queue() |
|
original_stdout = sys.stdout |
|
sys.stdout = RealtimeStream(output_queue) |
|
|
|
result_container = [] |
|
|
|
def run_inference(): |
|
try: |
|
|
|
result = inference_patch(period, composer, instrumentation) |
|
result_container.append(result) |
|
finally: |
|
sys.stdout = original_stdout |
|
|
|
thread = threading.Thread(target=run_inference) |
|
thread.start() |
|
|
|
process_output = "" |
|
final_output_abc = "" |
|
pdf_image = None |
|
audio_file = None |
|
pdf_state = None |
|
|
|
|
|
while thread.is_alive(): |
|
try: |
|
text = output_queue.get(timeout=0.1) |
|
process_output += text |
|
|
|
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state |
|
except queue.Empty: |
|
continue |
|
|
|
|
|
while not output_queue.empty(): |
|
text = output_queue.get() |
|
process_output += text |
|
|
|
|
|
final_result = result_container[0] if result_container else "" |
|
|
|
|
|
final_output_abc = "Converting files..." |
|
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state |
|
|
|
|
|
try: |
|
file_paths = convert_files(final_result, period, composer, instrumentation) |
|
final_output_abc = final_result |
|
|
|
if file_paths['pages'] > 0: |
|
pdf_image = f"{file_paths['base']}_page_1.png" |
|
audio_file = file_paths['wav'] |
|
pdf_state = file_paths |
|
except Exception as e: |
|
|
|
yield process_output, f"Error converting files: {str(e)}", None, None, None |
|
return |
|
|
|
|
|
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state |
|
|
|
|
|
def get_file(file_type, period, composer, instrumentation): |
|
""" |
|
返回本地的指定类型文件,用于 Gradio 下载 |
|
""" |
|
|
|
|
|
|
|
possible_files = [f for f in os.listdir('.') if f.endswith(f'.{file_type}')] |
|
if not possible_files: |
|
return None |
|
|
|
possible_files.sort(key=os.path.getmtime) |
|
return possible_files[-1] |
|
|
|
|
|
css = """ |
|
/* 紧凑按钮样式 */ |
|
button[size="sm"] { |
|
padding: 4px 8px !important; |
|
margin: 2px !important; |
|
min-width: 60px; |
|
} |
|
|
|
/* PDF预览区 */ |
|
#pdf-preview { |
|
border-radius: 8px; /* 圆角 */ |
|
box-shadow: 0 2px 8px rgba(0,0,0,0.1); /* 阴影 */ |
|
} |
|
|
|
.page-btn { |
|
padding: 12px !important; /* 增大点击区域 */ |
|
margin: auto !important; /* 垂直居中 */ |
|
} |
|
|
|
/* 按钮悬停效果 */ |
|
.page-btn:hover { |
|
background: #f0f0f0 !important; |
|
transform: scale(1.05); |
|
} |
|
|
|
/* 布局调整 */ |
|
.gr-row { |
|
gap: 10px !important; /* 元素间距 */ |
|
} |
|
|
|
/* 音频播放器 */ |
|
.audio-panel { |
|
margin-top: 15px !important; |
|
max-width: 400px; |
|
} |
|
|
|
#audio-preview audio { |
|
height: 200px !important; |
|
} |
|
|
|
/* 保存功能区 */ |
|
.save-as-row { |
|
margin-top: 15px; |
|
padding: 10px; |
|
border-top: 1px solid #eee; |
|
} |
|
|
|
.save-as-label { |
|
font-weight: bold; |
|
margin-right: 10px; |
|
align-self: center; |
|
} |
|
|
|
.save-buttons { |
|
gap: 5px; /* 按钮间距 */ |
|
} |
|
|
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown("## NotaGen") |
|
|
|
|
|
pdf_state = gr.State() |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
period_dd = gr.Dropdown( |
|
choices=periods, |
|
value=None, |
|
label="Period", |
|
interactive=True |
|
) |
|
composer_dd = gr.Dropdown( |
|
choices=[], |
|
value=None, |
|
label="Composer", |
|
interactive=False |
|
) |
|
instrument_dd = gr.Dropdown( |
|
choices=[], |
|
value=None, |
|
label="Instrumentation", |
|
interactive=False |
|
) |
|
|
|
generate_btn = gr.Button("Generate!", variant="primary") |
|
|
|
process_output = gr.Textbox( |
|
label="Generation process", |
|
interactive=False, |
|
lines=2, |
|
max_lines=2, |
|
placeholder="Generation progress will be shown here..." |
|
) |
|
|
|
final_output = gr.Textbox( |
|
label="Post-processed ABC notation scores", |
|
interactive=True, |
|
lines=8, |
|
max_lines=8, |
|
placeholder="Post-processed ABC scores will be shown here..." |
|
) |
|
|
|
|
|
audio_player = gr.Audio( |
|
label="Audio Preview", |
|
format="wav", |
|
interactive=False, |
|
|
|
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
|
pdf_image = gr.Image( |
|
label="Sheet Music Preview", |
|
show_label=False, |
|
height=650, |
|
type="filepath", |
|
elem_id="pdf-preview", |
|
interactive=False, |
|
show_download_button=False |
|
) |
|
|
|
|
|
with gr.Row(): |
|
prev_btn = gr.Button( |
|
"⬅️ Last Page", |
|
variant="secondary", |
|
size="sm", |
|
elem_classes="page-btn" |
|
) |
|
next_btn = gr.Button( |
|
"Next Page ➡️", |
|
variant="secondary", |
|
size="sm", |
|
elem_classes="page-btn" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
gr.Markdown("**Save As: (Scroll down to get the link)**") |
|
save_abc = gr.Button("🅰️ ABC", variant="secondary", size="sm") |
|
save_xml = gr.Button("🎼 XML", variant="secondary", size="sm") |
|
save_pdf = gr.Button("📑 PDF", variant="secondary", size="sm") |
|
save_mid = gr.Button("🎹 MIDI", variant="secondary", size="sm") |
|
save_wav = gr.Button("🎧 WAV", variant="secondary", size="sm") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
period_dd.change( |
|
update_components, |
|
inputs=[period_dd, composer_dd], |
|
outputs=[composer_dd, instrument_dd] |
|
) |
|
composer_dd.change( |
|
update_components, |
|
inputs=[period_dd, composer_dd], |
|
outputs=[composer_dd, instrument_dd] |
|
) |
|
|
|
|
|
generate_btn.click( |
|
generate_music, |
|
inputs=[period_dd, composer_dd, instrument_dd], |
|
outputs=[process_output, final_output, pdf_image, audio_player, pdf_state] |
|
) |
|
|
|
|
|
prev_signal = gr.Textbox(value="prev", visible=False) |
|
next_signal = gr.Textbox(value="next", visible=False) |
|
|
|
prev_btn.click( |
|
update_page, |
|
inputs=[prev_signal, pdf_state], |
|
outputs=[pdf_image, prev_btn, next_btn, pdf_state] |
|
) |
|
|
|
next_btn.click( |
|
update_page, |
|
inputs=[next_signal, pdf_state], |
|
outputs=[pdf_image, prev_btn, next_btn, pdf_state] |
|
) |
|
|
|
|
|
save_abc.click( |
|
lambda state: state.get('abc') if state else None, |
|
inputs=[pdf_state], |
|
outputs=gr.File(label="abc", visible=True) |
|
) |
|
save_xml.click( |
|
lambda state: state.get('xml') if state else None, |
|
inputs=[pdf_state], |
|
outputs=gr.File(label="xml", visible=True) |
|
) |
|
save_pdf.click( |
|
lambda state: state.get('pdf') if state else None, |
|
inputs=[pdf_state], |
|
outputs=gr.File(label="pdf", visible=True) |
|
) |
|
save_mid.click( |
|
lambda state: state.get('mid') if state else None, |
|
inputs=[pdf_state], |
|
outputs=gr.File(label="midi", visible=True) |
|
) |
|
save_wav.click( |
|
lambda state: state.get('wav') if state else None, |
|
inputs=[pdf_state], |
|
outputs=gr.File(label="wav", visible=True) |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
is_spaces = os.environ.get('SPACE_ID') is not None |
|
|
|
|
|
if is_spaces: |
|
port = int(os.environ.get('PORT', 7860)) |
|
demo.launch(server_name="0.0.0.0", server_port=port) |
|
else: |
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |