Spaces:
Running
on
Zero
Running
on
Zero
| # app.py | |
| import argparse | |
| import streamlit as st | |
| import os | |
| from pipeline_ace_step import ACEStepPipeline | |
| from data_sampler import DataSampler | |
| # Streamlit ์ค์ | |
| st.set_page_config( | |
| page_title="ACE Step Music Generator", | |
| page_icon="๐ต", | |
| layout="wide" | |
| ) | |
| def get_args(): | |
| """ํ๊ฒฝ๋ณ์ ๋๋ ๊ธฐ๋ณธ๊ฐ์ผ๋ก ์ค์ """ | |
| return { | |
| 'checkpoint_path': os.environ.get('CHECKPOINT_PATH'), | |
| 'device_id': int(os.environ.get('DEVICE_ID', '0')), | |
| 'bf16': os.environ.get('BF16', 'True').lower() == 'true', | |
| 'torch_compile': os.environ.get('TORCH_COMPILE', 'False').lower() == 'true' | |
| } | |
| def load_model(args): | |
| """๋ชจ๋ธ ๋ก๋ฉ (์บ์๋จ)""" | |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(args['device_id']) | |
| persistent_storage_path = "/data" | |
| model_demo = ACEStepPipeline( | |
| checkpoint_dir=args['checkpoint_path'], | |
| dtype="bfloat16" if args['bf16'] else "float32", | |
| persistent_storage_path=persistent_storage_path, | |
| torch_compile=args['torch_compile'] | |
| ) | |
| data_sampler = DataSampler() | |
| return model_demo, data_sampler | |
| def main(): | |
| st.title("๐ต ACE Step Music Generator") | |
| args = get_args() | |
| try: | |
| model_demo, data_sampler = load_model(args) | |
| # UI ๊ตฌ์ฑ | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.header("Generate Music") | |
| # ํ ์คํธ ์ ๋ ฅ | |
| prompt = st.text_area( | |
| "Enter your music description:", | |
| placeholder="Enter a description of the music you want to generate...", | |
| height=100 | |
| ) | |
| # ์์ฑ ๋ฒํผ | |
| if st.button("Generate Music", type="primary"): | |
| if prompt: | |
| with st.spinner("Generating music..."): | |
| try: | |
| result = model_demo(prompt) | |
| st.success("Music generated successfully!") | |
| # ๊ฒฐ๊ณผ ํ์ (result ํํ์ ๋ฐ๋ผ ์กฐ์ ํ์) | |
| if hasattr(result, 'audio'): | |
| st.audio(result.audio) | |
| else: | |
| st.write(result) | |
| except Exception as e: | |
| st.error(f"Error generating music: {str(e)}") | |
| else: | |
| st.warning("Please enter a description first.") | |
| with col2: | |
| st.header("Sample Data") | |
| if st.button("Load Sample"): | |
| try: | |
| sample_data = data_sampler.sample() | |
| st.json(sample_data) | |
| except Exception as e: | |
| st.error(f"Error loading sample: {str(e)}") | |
| # ํ์ผ ์ ๋ก๋ | |
| uploaded_file = st.file_uploader( | |
| "Upload JSON data", | |
| type=['json'] | |
| ) | |
| if uploaded_file: | |
| try: | |
| data = data_sampler.load_json(uploaded_file) | |
| st.json(data) | |
| except Exception as e: | |
| st.error(f"Error loading file: {str(e)}") | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| import traceback | |
| st.code(traceback.format_exc()) | |
| if __name__ == "__main__": | |
| main() |