ElectricAlexis commited on
Commit
42e9a11
·
verified ·
1 Parent(s): 334c4fa

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -113
app.py CHANGED
@@ -1,8 +1,3 @@
1
- try:
2
- import spaces
3
- USING_SPACES = True
4
- except ImportError:
5
- USING_SPACES = False
6
  import zero
7
  import gradio as gr
8
  import sys
@@ -17,15 +12,6 @@ from inference import inference_patch
17
  from convert import abc2xml, xml2, pdf2img
18
 
19
 
20
-
21
-
22
- def gpu_decorator(func):
23
- if USING_SPACES:
24
- return spaces.GPU(func)
25
- else:
26
- return func
27
-
28
-
29
  # 读取 prompt 组合
30
  with open('prompts.txt', 'r') as f:
31
  prompts = f.readlines()
@@ -106,9 +92,9 @@ def convert_files(abc_content, period, composer, instrumentation):
106
  xml2(filename_base, 'mid')
107
  xml2(filename_base_postinst, 'mid')
108
 
109
- # xml2wav
110
- xml2(filename_base, 'wav')
111
- xml2(filename_base_postinst, 'wav')
112
 
113
  # 将PDF转为图片
114
  images = pdf2img(filename_base)
@@ -119,7 +105,7 @@ def convert_files(abc_content, period, composer, instrumentation):
119
  'xml': f"{filename_base_postinst}.xml",
120
  'pdf': f"{filename_base}.pdf",
121
  'mid': f"{filename_base_postinst}.mid",
122
- 'wav': f"{filename_base_postinst}.wav",
123
  'pages': len(images),
124
  'current_page': 0,
125
  'base': filename_base
@@ -154,7 +140,7 @@ def update_page(direction, data):
154
  return new_image, prev_btn_state, next_btn_state, data
155
 
156
 
157
- @gpu_decorator
158
  def generate_music(period, composer, instrumentation):
159
  """
160
  需要保证每次 yield 的返回值数量一致。
@@ -162,17 +148,33 @@ def generate_music(period, composer, instrumentation):
162
  1) process_output (中间推理信息)
163
  2) final_output (最终 ABC)
164
  3) pdf_image (PDF 第一页对应的 png 路径)
165
- 4) audio_player (WAV 路径)
166
  5) pdf_state (翻页用的 state)
167
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  if (period, composer, instrumentation) not in valid_combinations:
169
  # 如果组合非法,直接抛出错误
170
  raise gr.Error("Invalid prompt combination! Please re-select from the period options")
171
 
172
- # # Ensure model weights were downloaded successfully
173
- # if not os.path.exists(model_weights_path):
174
- # raise gr.Error(f"Model weights not available at {model_weights_path}")
175
-
176
  output_queue = queue.Queue()
177
  original_stdout = sys.stdout
178
  sys.stdout = RealtimeStream(output_queue)
@@ -202,7 +204,7 @@ def generate_music(period, composer, instrumentation):
202
  text = output_queue.get(timeout=0.1)
203
  process_output += text
204
  # 暂时没有最终 ABC,还没有转文件
205
- yield process_output, final_output_abc, pdf_image, audio_file, pdf_state
206
  except queue.Empty:
207
  continue
208
 
@@ -216,24 +218,38 @@ def generate_music(period, composer, instrumentation):
216
 
217
  # 显示转换文件的提示
218
  final_output_abc = "Converting files..."
219
- yield process_output, final_output_abc, pdf_image, audio_file, pdf_state
 
220
 
221
  # 做文件转换
222
  try:
223
  file_paths = convert_files(final_result, period, composer, instrumentation)
224
  final_output_abc = final_result
225
- # 拿到第一张图片和 wav 文件
226
  if file_paths['pages'] > 0:
227
  pdf_image = f"{file_paths['base']}_page_1.png"
228
- audio_file = file_paths['wav']
229
  pdf_state = file_paths # 直接把转换后的信息字典拿来存到 state
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  except Exception as e:
231
  # 如果失败了,把错误信息返回到输出框
232
- yield process_output, f"Error converting files: {str(e)}", None, None, None
233
  return
234
 
235
- # 最后一次 yield,带上所有信息
236
- yield process_output, final_output_abc, pdf_image, audio_file, pdf_state
237
 
238
 
239
  def get_file(file_type, period, composer, instrumentation):
@@ -308,6 +324,13 @@ button[size="sm"] {
308
  gap: 5px; /* 按钮间距 */
309
  }
310
 
 
 
 
 
 
 
 
311
  """
312
 
313
  with gr.Blocks(css=css) as demo:
@@ -361,7 +384,7 @@ with gr.Blocks(css=css) as demo:
361
  # 音频播放
362
  audio_player = gr.Audio(
363
  label="Audio Preview",
364
- format="wav",
365
  interactive=False,
366
  # container=False,
367
  # elem_id="audio-preview"
@@ -395,21 +418,14 @@ with gr.Blocks(css=css) as demo:
395
  elem_classes="page-btn"
396
  )
397
 
398
- # 按钮组
399
- with gr.Row():
400
- gr.Markdown("**Save As: (Scroll down to get the link)**")
401
- save_abc = gr.Button("🅰️ ABC", variant="secondary", size="sm")
402
- save_xml = gr.Button("🎼 XML", variant="secondary", size="sm")
403
- save_pdf = gr.Button("📑 PDF", variant="secondary", size="sm")
404
- save_mid = gr.Button("🎹 MIDI", variant="secondary", size="sm")
405
- save_wav = gr.Button("🎧 WAV", variant="secondary", size="sm")
406
-
407
- # save_status = gr.Textbox(
408
- # label="Save Status",
409
- # interactive=False,
410
- # visible=True,
411
- # max_lines=1
412
- # )
413
 
414
  # 下拉框联动
415
  period_dd.change(
@@ -427,7 +443,7 @@ with gr.Blocks(css=css) as demo:
427
  generate_btn.click(
428
  generate_music,
429
  inputs=[period_dd, composer_dd, instrument_dd],
430
- outputs=[process_output, final_output, pdf_image, audio_player, pdf_state]
431
  )
432
 
433
  # 翻页
@@ -446,71 +462,10 @@ with gr.Blocks(css=css) as demo:
446
  outputs=[pdf_image, prev_btn, next_btn, pdf_state]
447
  )
448
 
449
- # 文件保存按钮
450
- save_abc.click(
451
- lambda state: state.get('abc') if state else None,
452
- inputs=[pdf_state],
453
- outputs=gr.File(label="abc", visible=True)
454
- )
455
- save_xml.click(
456
- lambda state: state.get('xml') if state else None,
457
- inputs=[pdf_state],
458
- outputs=gr.File(label="xml", visible=True)
459
- )
460
- save_pdf.click(
461
- lambda state: state.get('pdf') if state else None,
462
- inputs=[pdf_state],
463
- outputs=gr.File(label="pdf", visible=True)
464
- )
465
- save_mid.click(
466
- lambda state: state.get('mid') if state else None,
467
- inputs=[pdf_state],
468
- outputs=gr.File(label="midi", visible=True)
469
- )
470
- save_wav.click(
471
- lambda state: state.get('wav') if state else None,
472
- inputs=[pdf_state],
473
- outputs=gr.File(label="wav", visible=True)
474
- )
475
-
476
-
477
 
478
  if __name__ == "__main__":
479
  # Configure GPU/CPU handling
480
- import torch
481
-
482
- # Function to initialize CUDA safely and verify it's working
483
- def is_cuda_working():
484
- try:
485
- if torch.cuda.is_available():
486
- # Test CUDA initialization with a small operation
487
- test_tensor = torch.tensor([1.0], device="cuda")
488
- _ = test_tensor * 2
489
- return True
490
- return False
491
- except Exception as e:
492
- print(f"CUDA initialization test failed: {e}")
493
- return False
494
-
495
- # Check if running on Hugging Face Spaces
496
- if "SPACE_ID" in os.environ:
497
- cuda_working = is_cuda_working()
498
- if cuda_working:
499
- print("GPU is available and working. Using CUDA.")
500
- # You might want to set some environment variables or configurations here
501
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
502
- else:
503
- print("CUDA not working properly. Forcing CPU mode.")
504
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
505
- torch.backends.cudnn.enabled = False
506
-
507
- # Launch with minimal parameters on Spaces
508
- demo.launch()
509
- else:
510
- # Running locally - use custom server settings and share
511
- print(f"Running locally with device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
512
- demo.launch(
513
- server_name="0.0.0.0",
514
- server_port=7860,
515
- share=True # 确保外部访问
516
- )
 
 
 
 
 
 
1
  import zero
2
  import gradio as gr
3
  import sys
 
12
  from convert import abc2xml, xml2, pdf2img
13
 
14
 
 
 
 
 
 
 
 
 
 
15
  # 读取 prompt 组合
16
  with open('prompts.txt', 'r') as f:
17
  prompts = f.readlines()
 
92
  xml2(filename_base, 'mid')
93
  xml2(filename_base_postinst, 'mid')
94
 
95
+ # xml2mp3
96
+ xml2(filename_base, 'mp3')
97
+ xml2(filename_base_postinst, 'mp3')
98
 
99
  # 将PDF转为图片
100
  images = pdf2img(filename_base)
 
105
  'xml': f"{filename_base_postinst}.xml",
106
  'pdf': f"{filename_base}.pdf",
107
  'mid': f"{filename_base_postinst}.mid",
108
+ 'mp3': f"{filename_base_postinst}.mp3",
109
  'pages': len(images),
110
  'current_page': 0,
111
  'base': filename_base
 
140
  return new_image, prev_btn_state, next_btn_state, data
141
 
142
 
143
+ @spaces.GPU
144
  def generate_music(period, composer, instrumentation):
145
  """
146
  需要保证每次 yield 的返回值数量一致。
 
148
  1) process_output (中间推理信息)
149
  2) final_output (最终 ABC)
150
  3) pdf_image (PDF 第一页对应的 png 路径)
151
+ 4) audio_player (mp3 路径)
152
  5) pdf_state (翻页用的 state)
153
  """
154
+ # Set a different random seed each time based on current timestamp
155
+ random_seed = int(time.time()) % 10000
156
+ random.seed(random_seed)
157
+
158
+ # For numpy if you're using it
159
+ try:
160
+ import numpy as np
161
+ np.random.seed(random_seed)
162
+ except ImportError:
163
+ pass
164
+
165
+ # For torch if you're using it
166
+ try:
167
+ import torch
168
+ torch.manual_seed(random_seed)
169
+ if torch.cuda.is_available():
170
+ torch.cuda.manual_seed_all(random_seed)
171
+ except ImportError:
172
+ pass
173
+
174
  if (period, composer, instrumentation) not in valid_combinations:
175
  # 如果组合非法,直接抛出错误
176
  raise gr.Error("Invalid prompt combination! Please re-select from the period options")
177
 
 
 
 
 
178
  output_queue = queue.Queue()
179
  original_stdout = sys.stdout
180
  sys.stdout = RealtimeStream(output_queue)
 
204
  text = output_queue.get(timeout=0.1)
205
  process_output += text
206
  # 暂时没有最终 ABC,还没有转文件
207
+ yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
208
  except queue.Empty:
209
  continue
210
 
 
218
 
219
  # 显示转换文件的提示
220
  final_output_abc = "Converting files..."
221
+ yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False)
222
+
223
 
224
  # 做文件转换
225
  try:
226
  file_paths = convert_files(final_result, period, composer, instrumentation)
227
  final_output_abc = final_result
228
+ # 拿到第一张图片和 mp3 文件
229
  if file_paths['pages'] > 0:
230
  pdf_image = f"{file_paths['base']}_page_1.png"
231
+ audio_file = file_paths['mp3']
232
  pdf_state = file_paths # 直接把转换后的信息字典拿来存到 state
233
+
234
+ # 准备下载文件列表
235
+ download_list = []
236
+ if 'abc' in file_paths and os.path.exists(file_paths['abc']):
237
+ download_list.append(file_paths['abc'])
238
+ if 'xml' in file_paths and os.path.exists(file_paths['xml']):
239
+ download_list.append(file_paths['xml'])
240
+ if 'pdf' in file_paths and os.path.exists(file_paths['pdf']):
241
+ download_list.append(file_paths['pdf'])
242
+ if 'mid' in file_paths and os.path.exists(file_paths['mid']):
243
+ download_list.append(file_paths['mid'])
244
+ if 'mp3' in file_paths and os.path.exists(file_paths['mp3']):
245
+ download_list.append(file_paths['mp3'])
246
  except Exception as e:
247
  # 如果失败了,把错误信息返回到输出框
248
+ yield process_output, f"Error converting files: {str(e)}", None, None, None, gr.update(value=None, visible=False)
249
  return
250
 
251
+ # 最后一次 yield,带上所有信息 - 修改此处让组件可见
252
+ yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=download_list, visible=True)
253
 
254
 
255
  def get_file(file_type, period, composer, instrumentation):
 
324
  gap: 5px; /* 按钮间距 */
325
  }
326
 
327
+ /* Download files styling */
328
+ .download-files {
329
+ margin-top: 15px;
330
+ border-radius: 8px;
331
+ box-shadow: 0 2px 8px rgba(0,0,0,0.1);
332
+ }
333
+
334
  """
335
 
336
  with gr.Blocks(css=css) as demo:
 
384
  # 音频播放
385
  audio_player = gr.Audio(
386
  label="Audio Preview",
387
+ format="mp3",
388
  interactive=False,
389
  # container=False,
390
  # elem_id="audio-preview"
 
418
  elem_classes="page-btn"
419
  )
420
 
421
+ with gr.Column():
422
+ gr.Markdown("**Download Files:**")
423
+ download_files = gr.Files(
424
+ label="Generated Files",
425
+ visible=False,
426
+ elem_classes="download-files",
427
+ type="filepath" # Make sure this is set to filepath
428
+ )
 
 
 
 
 
 
 
429
 
430
  # 下拉框联动
431
  period_dd.change(
 
443
  generate_btn.click(
444
  generate_music,
445
  inputs=[period_dd, composer_dd, instrument_dd],
446
+ outputs=[process_output, final_output, pdf_image, audio_player, pdf_state, download_files]
447
  )
448
 
449
  # 翻页
 
462
  outputs=[pdf_image, prev_btn, next_btn, pdf_state]
463
  )
464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
 
466
  if __name__ == "__main__":
467
  # Configure GPU/CPU handling
468
+ demo.launch(
469
+ server_name="0.0.0.0",
470
+ server_port=7860
471
+ )