Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,546 Bytes
29efb71 5488167 29efb71 5488167 29efb71 5488167 29efb71 5488167 29efb71 5488167 29efb71 4617cbd 29efb71 5488167 29efb71 5488167 29efb71 071dfa9 29efb71 5488167 29efb71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
# app.py
import argparse
import streamlit as st
import os
from pipeline_ace_step import ACEStepPipeline
from data_sampler import DataSampler
# Streamlit ์ค์
st.set_page_config(
page_title="ACE Step Music Generator",
page_icon="๐ต",
layout="wide"
)
def get_args():
"""ํ๊ฒฝ๋ณ์ ๋๋ ๊ธฐ๋ณธ๊ฐ์ผ๋ก ์ค์ """
return {
'checkpoint_path': os.environ.get('CHECKPOINT_PATH'),
'device_id': int(os.environ.get('DEVICE_ID', '0')),
'bf16': os.environ.get('BF16', 'True').lower() == 'true',
'torch_compile': os.environ.get('TORCH_COMPILE', 'False').lower() == 'true'
}
@st.cache_resource
def load_model(args):
"""๋ชจ๋ธ ๋ก๋ฉ (์บ์๋จ)"""
os.environ["CUDA_VISIBLE_DEVICES"] = str(args['device_id'])
persistent_storage_path = "/data"
model_demo = ACEStepPipeline(
checkpoint_dir=args['checkpoint_path'],
dtype="bfloat16" if args['bf16'] else "float32",
persistent_storage_path=persistent_storage_path,
torch_compile=args['torch_compile']
)
data_sampler = DataSampler()
return model_demo, data_sampler
def main():
st.title("๐ต ACE Step Music Generator")
args = get_args()
try:
model_demo, data_sampler = load_model(args)
# UI ๊ตฌ์ฑ
col1, col2 = st.columns([2, 1])
with col1:
st.header("Generate Music")
# ํ
์คํธ ์
๋ ฅ
prompt = st.text_area(
"Enter your music description:",
placeholder="Enter a description of the music you want to generate...",
height=100
)
# ์์ฑ ๋ฒํผ
if st.button("Generate Music", type="primary"):
if prompt:
with st.spinner("Generating music..."):
try:
result = model_demo(prompt)
st.success("Music generated successfully!")
# ๊ฒฐ๊ณผ ํ์ (result ํํ์ ๋ฐ๋ผ ์กฐ์ ํ์)
if hasattr(result, 'audio'):
st.audio(result.audio)
else:
st.write(result)
except Exception as e:
st.error(f"Error generating music: {str(e)}")
else:
st.warning("Please enter a description first.")
with col2:
st.header("Sample Data")
if st.button("Load Sample"):
try:
sample_data = data_sampler.sample()
st.json(sample_data)
except Exception as e:
st.error(f"Error loading sample: {str(e)}")
# ํ์ผ ์
๋ก๋
uploaded_file = st.file_uploader(
"Upload JSON data",
type=['json']
)
if uploaded_file:
try:
data = data_sampler.load_json(uploaded_file)
st.json(data)
except Exception as e:
st.error(f"Error loading file: {str(e)}")
except Exception as e:
st.error(f"Error loading model: {str(e)}")
import traceback
st.code(traceback.format_exc())
if __name__ == "__main__":
main() |