NotaGen / app.py
ElectricAlexis's picture
Upload 2 files
7d848a1 verified
raw
history blame
16.2 kB
try:
import spaces
USING_SPACES = True
except ImportError:
USING_SPACES = False
import zero
import gradio as gr
import sys
import threading
import queue
from io import TextIOBase
import datetime
import subprocess
import os
from inference import postprocess_inst_names
from inference import inference_patch
from convert import abc2xml, xml2, pdf2img
def gpu_decorator(func):
if USING_SPACES:
return spaces.GPU(func)
else:
return func
# 读取 prompt 组合
with open('prompts.txt', 'r') as f:
prompts = f.readlines()
valid_combinations = set()
for prompt in prompts:
prompt = prompt.strip()
parts = prompt.split('_')
valid_combinations.add((parts[0], parts[1], parts[2]))
# 准备下拉框选项
periods = sorted({p for p, _, _ in valid_combinations})
composers = sorted({c for _, c, _ in valid_combinations})
instruments = sorted({i for _, _, i in valid_combinations})
# 动态更新作曲家、乐器下拉选项
def update_components(period, composer):
if not period:
return [
gr.update(choices=[], value=None, interactive=False),
gr.update(choices=[], value=None, interactive=False)
]
valid_composers = sorted({c for p, c, _ in valid_combinations if p == period})
valid_instruments = sorted({i for p, c, i in valid_combinations if p == period and c == composer}) if composer else []
return [
gr.update(
choices=valid_composers,
value=composer if composer in valid_composers else None,
interactive=True
),
gr.update(
choices=valid_instruments,
value=None,
interactive=bool(valid_instruments)
)
]
# 自定义实时流,用于把模型推理过程输出到前端
class RealtimeStream(TextIOBase):
def __init__(self, queue):
self.queue = queue
def write(self, text):
self.queue.put(text)
return len(text)
def convert_files(abc_content, period, composer, instrumentation):
if not all([period, composer, instrumentation]):
raise gr.Error("Please complete a valid generation first before saving")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
prompt_str = f"{period}_{composer}_{instrumentation}"
filename_base = f"{timestamp}_{prompt_str}"
abc_filename = f"{filename_base}.abc"
with open(abc_filename, "w", encoding="utf-8") as f:
f.write(abc_content)
# instrumentation replacement
postprocessed_inst_abc = postprocess_inst_names(abc_content)
filename_base_postinst = f"{filename_base}_postinst"
with open(filename_base_postinst + ".abc", "w", encoding="utf-8") as f:
f.write(postprocessed_inst_abc)
# 转换文件
file_paths = {'abc': abc_filename}
try:
# abc2xml
abc2xml(filename_base)
abc2xml(filename_base_postinst)
# xml2pdf
xml2(filename_base, 'pdf')
# xml2mid
xml2(filename_base, 'mid')
xml2(filename_base_postinst, 'mid')
# xml2wav
xml2(filename_base, 'wav')
xml2(filename_base_postinst, 'wav')
# 将PDF转为图片
images = pdf2img(filename_base)
for i, image in enumerate(images):
image.save(f"{filename_base}_page_{i+1}.png", "PNG")
file_paths.update({
'xml': f"{filename_base_postinst}.xml",
'pdf': f"{filename_base}.pdf",
'mid': f"{filename_base_postinst}.mid",
'wav': f"{filename_base_postinst}.wav",
'pages': len(images),
'current_page': 0,
'base': filename_base
})
except Exception as e:
raise gr.Error(f"文件处理失败: {str(e)}")
return file_paths
# 翻页控制函数
def update_page(direction, data):
"""
data 里面包含了 'pages','current_page','base' 三个关键信息
"""
if not data:
return None, gr.update(interactive=False), gr.update(interactive=False), data
if direction == "prev" and data['current_page'] > 0:
data['current_page'] -= 1
elif direction == "next" and data['current_page'] < data['pages'] - 1:
data['current_page'] += 1
current_page_index = data['current_page']
# 更新图片路径
new_image = f"{data['base']}_page_{current_page_index+1}.png"
# 当 current_page==0 时,prev_btn 不可用;当 current_page==pages-1 时,next_btn 不可用
prev_btn_state = gr.update(interactive=(current_page_index > 0))
next_btn_state = gr.update(interactive=(current_page_index < data['pages'] - 1))
return new_image, prev_btn_state, next_btn_state, data
@gpu_decorator
def generate_music(period, composer, instrumentation):
"""
需要保证每次 yield 的返回值数量一致。
我们这里准备返回 5 个值,对应:
1) process_output (中间推理信息)
2) final_output (最终 ABC)
3) pdf_image (PDF 第一页对应的 png 路径)
4) audio_player (WAV 路径)
5) pdf_state (翻页用的 state)
"""
if (period, composer, instrumentation) not in valid_combinations:
# 如果组合非法,直接抛出错误
raise gr.Error("Invalid prompt combination! Please re-select from the period options")
# # Ensure model weights were downloaded successfully
# if not os.path.exists(model_weights_path):
# raise gr.Error(f"Model weights not available at {model_weights_path}")
output_queue = queue.Queue()
original_stdout = sys.stdout
sys.stdout = RealtimeStream(output_queue)
result_container = []
def run_inference():
try:
# 使用下载的模型权重路径进行推理
result = inference_patch(period, composer, instrumentation)
result_container.append(result)
finally:
sys.stdout = original_stdout
thread = threading.Thread(target=run_inference)
thread.start()
process_output = ""
final_output_abc = ""
pdf_image = None
audio_file = None
pdf_state = None
# 先持续读中间输出
while thread.is_alive():
try:
text = output_queue.get(timeout=0.1)
process_output += text
# 暂时没有最终 ABC,还没有转文件
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state
except queue.Empty:
continue
# 线程结束后,把剩余的队列都拿出来
while not output_queue.empty():
text = output_queue.get()
process_output += text
# 最终推理结果
final_result = result_container[0] if result_container else ""
# 显示转换文件的提示
final_output_abc = "Converting files..."
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state
# 做文件转换
try:
file_paths = convert_files(final_result, period, composer, instrumentation)
final_output_abc = final_result
# 拿到第一张图片和 wav 文件
if file_paths['pages'] > 0:
pdf_image = f"{file_paths['base']}_page_1.png"
audio_file = file_paths['wav']
pdf_state = file_paths # 直接把转换后的信息字典拿来存到 state
except Exception as e:
# 如果失败了,把错误信息返回到输出框
yield process_output, f"Error converting files: {str(e)}", None, None, None
return
# 最后一次 yield,带上所有信息
yield process_output, final_output_abc, pdf_image, audio_file, pdf_state
def get_file(file_type, period, composer, instrumentation):
"""
返回本地的指定类型文件,用于 Gradio 下载
"""
# 这里其实需要你根据先前保存下来的具体文件路径来返回,演示时可以简化
# 如果是按 timestamp 去匹配,可以把转换的文件都存在某个目录下再拿最新的
# 这里仅做示例:
possible_files = [f for f in os.listdir('.') if f.endswith(f'.{file_type}')]
if not possible_files:
return None
# 简单返回最新的
possible_files.sort(key=os.path.getmtime)
return possible_files[-1]
css = """
/* 紧凑按钮样式 */
button[size="sm"] {
padding: 4px 8px !important;
margin: 2px !important;
min-width: 60px;
}
/* PDF预览区 */
#pdf-preview {
border-radius: 8px; /* 圆角 */
box-shadow: 0 2px 8px rgba(0,0,0,0.1); /* 阴影 */
}
.page-btn {
padding: 12px !important; /* 增大点击区域 */
margin: auto !important; /* 垂直居中 */
}
/* 按钮悬停效果 */
.page-btn:hover {
background: #f0f0f0 !important;
transform: scale(1.05);
}
/* 布局调整 */
.gr-row {
gap: 10px !important; /* 元素间距 */
}
/* 音频播放器 */
.audio-panel {
margin-top: 15px !important;
max-width: 400px;
}
#audio-preview audio {
height: 200px !important;
}
/* 保存功能区 */
.save-as-row {
margin-top: 15px;
padding: 10px;
border-top: 1px solid #eee;
}
.save-as-label {
font-weight: bold;
margin-right: 10px;
align-self: center;
}
.save-buttons {
gap: 5px; /* 按钮间距 */
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("## NotaGen")
# 用于保存 PDF 页数、当前页等信息
pdf_state = gr.State()
with gr.Column():
with gr.Row():
# 左侧栏
with gr.Column():
with gr.Row():
period_dd = gr.Dropdown(
choices=periods,
value=None,
label="Period",
interactive=True
)
composer_dd = gr.Dropdown(
choices=[],
value=None,
label="Composer",
interactive=False
)
instrument_dd = gr.Dropdown(
choices=[],
value=None,
label="Instrumentation",
interactive=False
)
generate_btn = gr.Button("Generate!", variant="primary")
process_output = gr.Textbox(
label="Generation process",
interactive=False,
lines=2,
max_lines=2,
placeholder="Generation progress will be shown here..."
)
final_output = gr.Textbox(
label="Post-processed ABC notation scores",
interactive=True,
lines=8,
max_lines=8,
placeholder="Post-processed ABC scores will be shown here..."
)
# 音频播放
audio_player = gr.Audio(
label="Audio Preview",
format="wav",
interactive=False,
# container=False,
# elem_id="audio-preview"
)
# 右侧栏
with gr.Column():
# 图片容器
pdf_image = gr.Image(
label="Sheet Music Preview",
show_label=False,
height=650,
type="filepath",
elem_id="pdf-preview",
interactive=False,
show_download_button=False
)
# 翻页按钮
with gr.Row():
prev_btn = gr.Button(
"⬅️ Last Page",
variant="secondary",
size="sm",
elem_classes="page-btn"
)
next_btn = gr.Button(
"Next Page ➡️",
variant="secondary",
size="sm",
elem_classes="page-btn"
)
# 按钮组
with gr.Row():
gr.Markdown("**Save As: (Scroll down to get the link)**")
save_abc = gr.Button("🅰️ ABC", variant="secondary", size="sm")
save_xml = gr.Button("🎼 XML", variant="secondary", size="sm")
save_pdf = gr.Button("📑 PDF", variant="secondary", size="sm")
save_mid = gr.Button("🎹 MIDI", variant="secondary", size="sm")
save_wav = gr.Button("🎧 WAV", variant="secondary", size="sm")
# save_status = gr.Textbox(
# label="Save Status",
# interactive=False,
# visible=True,
# max_lines=1
# )
# 下拉框联动
period_dd.change(
update_components,
inputs=[period_dd, composer_dd],
outputs=[composer_dd, instrument_dd]
)
composer_dd.change(
update_components,
inputs=[period_dd, composer_dd],
outputs=[composer_dd, instrument_dd]
)
# 点击生成按钮,注意 outputs 要和 generate_music 里每次 yield 保持一致
generate_btn.click(
generate_music,
inputs=[period_dd, composer_dd, instrument_dd],
outputs=[process_output, final_output, pdf_image, audio_player, pdf_state]
)
# 翻页
prev_signal = gr.Textbox(value="prev", visible=False)
next_signal = gr.Textbox(value="next", visible=False)
prev_btn.click(
update_page,
inputs=[prev_signal, pdf_state], # ✅ 使用组件
outputs=[pdf_image, prev_btn, next_btn, pdf_state]
)
next_btn.click(
update_page,
inputs=[next_signal, pdf_state], # ✅ 使用组件
outputs=[pdf_image, prev_btn, next_btn, pdf_state]
)
# 文件保存按钮
save_abc.click(
lambda state: state.get('abc') if state else None,
inputs=[pdf_state],
outputs=gr.File(label="abc", visible=True)
)
save_xml.click(
lambda state: state.get('xml') if state else None,
inputs=[pdf_state],
outputs=gr.File(label="xml", visible=True)
)
save_pdf.click(
lambda state: state.get('pdf') if state else None,
inputs=[pdf_state],
outputs=gr.File(label="pdf", visible=True)
)
save_mid.click(
lambda state: state.get('mid') if state else None,
inputs=[pdf_state],
outputs=gr.File(label="midi", visible=True)
)
save_wav.click(
lambda state: state.get('wav') if state else None,
inputs=[pdf_state],
outputs=gr.File(label="wav", visible=True)
)
if __name__ == "__main__":
# Configure GPU/CPU handling
import torch
# Function to initialize CUDA safely and verify it's working
def is_cuda_working():
try:
if torch.cuda.is_available():
# Test CUDA initialization with a small operation
test_tensor = torch.tensor([1.0], device="cuda")
_ = test_tensor * 2
return True
return False
except Exception as e:
print(f"CUDA initialization test failed: {e}")
return False
# Check if running on Hugging Face Spaces
if "SPACE_ID" in os.environ:
cuda_working = is_cuda_working()
if cuda_working:
print("GPU is available and working. Using CUDA.")
# You might want to set some environment variables or configurations here
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
else:
print("CUDA not working properly. Forcing CPU mode.")
os.environ["CUDA_VISIBLE_DEVICES"] = ""
torch.backends.cudnn.enabled = False
# Launch with minimal parameters on Spaces
demo.launch()
else:
# Running locally - use custom server settings and share
print(f"Running locally with device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True # 确保外部访问
)