ginipick commited on
Commit
29efb71
·
verified ·
1 Parent(s): 7a54875

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -31
app.py CHANGED
@@ -1,45 +1,108 @@
 
1
  import argparse
2
- from ui.components import create_main_demo_ui
 
3
  from pipeline_ace_step import ACEStepPipeline
4
  from data_sampler import DataSampler
5
- import os
6
-
7
-
8
- parser = argparse.ArgumentParser()
9
- parser.add_argument("--checkpoint_path", type=str, default=None)
10
- parser.add_argument("--server_name", type=str, default="0.0.0.0")
11
- parser.add_argument("--port", type=int, default=7860)
12
- parser.add_argument("--device_id", type=int, default=0)
13
- parser.add_argument("--share", action='store_true', default=False)
14
- parser.add_argument("--bf16", action='store_true', default=True)
15
- parser.add_argument("--torch_compile", type=bool, default=False)
16
-
17
- args = parser.parse_args()
18
- os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
19
-
20
 
21
- persistent_storage_path = "/data"
 
 
 
 
 
22
 
 
 
 
 
 
 
 
 
23
 
24
- def main(args):
25
-
 
 
 
 
26
  model_demo = ACEStepPipeline(
27
- checkpoint_dir=args.checkpoint_path,
28
- dtype="bfloat16" if args.bf16 else "float32",
29
  persistent_storage_path=persistent_storage_path,
30
- torch_compile=args.torch_compile
31
  )
32
  data_sampler = DataSampler()
 
33
 
34
- demo = create_main_demo_ui(
35
- text2music_process_func=model_demo.__call__,
36
- sample_data_func=data_sampler.sample,
37
- load_data_func=data_sampler.load_json,
38
- )
39
- demo.queue(default_concurrency_limit=8).launch(
 
40
 
41
- )
42
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  if __name__ == "__main__":
45
- main(args)
 
1
+ # app.py
2
  import argparse
3
+ import streamlit as st
4
+ import os
5
  from pipeline_ace_step import ACEStepPipeline
6
  from data_sampler import DataSampler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Streamlit 설정
9
+ st.set_page_config(
10
+ page_title="ACE Step Music Generator",
11
+ page_icon="🎵",
12
+ layout="wide"
13
+ )
14
 
15
+ def get_args():
16
+ """환경변수 또는 기본값으로 설정"""
17
+ return {
18
+ 'checkpoint_path': os.environ.get('CHECKPOINT_PATH'),
19
+ 'device_id': int(os.environ.get('DEVICE_ID', '0')),
20
+ 'bf16': os.environ.get('BF16', 'True').lower() == 'true',
21
+ 'torch_compile': os.environ.get('TORCH_COMPILE', 'False').lower() == 'true'
22
+ }
23
 
24
+ @st.cache_resource
25
+ def load_model(args):
26
+ """모델 로딩 (캐시됨)"""
27
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(args['device_id'])
28
+ persistent_storage_path = "/data"
29
+
30
  model_demo = ACEStepPipeline(
31
+ checkpoint_dir=args['checkpoint_path'],
32
+ dtype="bfloat16" if args['bf16'] else "float32",
33
  persistent_storage_path=persistent_storage_path,
34
+ torch_compile=args['torch_compile']
35
  )
36
  data_sampler = DataSampler()
37
+ return model_demo, data_sampler
38
 
39
+ def main():
40
+ st.title("🎵 ACE Step Music Generator")
41
+
42
+ args = get_args()
43
+
44
+ try:
45
+ model_demo, data_sampler = load_model(args)
46
 
47
+ # UI 구성
48
+ col1, col2 = st.columns([2, 1])
49
+
50
+ with col1:
51
+ st.header("Generate Music")
52
+
53
+ # 텍스트 입력
54
+ prompt = st.text_area(
55
+ "Enter your music description:",
56
+ placeholder="Enter a description of the music you want to generate...",
57
+ height=100
58
+ )
59
+
60
+ # 생성 버튼
61
+ if st.button("Generate Music", type="primary"):
62
+ if prompt:
63
+ with st.spinner("Generating music..."):
64
+ try:
65
+ result = model_demo(prompt)
66
+ st.success("Music generated successfully!")
67
+
68
+ # 결과 표시 (result 형태에 따라 조정 필요)
69
+ if hasattr(result, 'audio'):
70
+ st.audio(result.audio)
71
+ else:
72
+ st.write(result)
73
+
74
+ except Exception as e:
75
+ st.error(f"Error generating music: {str(e)}")
76
+ else:
77
+ st.warning("Please enter a description first.")
78
+
79
+ with col2:
80
+ st.header("Sample Data")
81
+
82
+ if st.button("Load Sample"):
83
+ try:
84
+ sample_data = data_sampler.sample()
85
+ st.json(sample_data)
86
+ except Exception as e:
87
+ st.error(f"Error loading sample: {str(e)}")
88
+
89
+ # 파일 업로드
90
+ uploaded_file = st.file_uploader(
91
+ "Upload JSON data",
92
+ type=['json']
93
+ )
94
+
95
+ if uploaded_file:
96
+ try:
97
+ data = data_sampler.load_json(uploaded_file)
98
+ st.json(data)
99
+ except Exception as e:
100
+ st.error(f"Error loading file: {str(e)}")
101
+
102
+ except Exception as e:
103
+ st.error(f"Error loading model: {str(e)}")
104
+ import traceback
105
+ st.code(traceback.format_exc())
106
 
107
  if __name__ == "__main__":
108
+ main()