Spaces:
VIDraft
/
Running on Zero

File size: 3,546 Bytes
29efb71
5488167
29efb71
 
5488167
 
 
29efb71
 
 
 
 
 
5488167
29efb71
 
 
 
 
 
 
 
5488167
29efb71
 
 
 
 
 
5488167
29efb71
 
4617cbd
29efb71
5488167
 
29efb71
5488167
29efb71
 
 
 
 
 
 
071dfa9
29efb71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5488167
 
29efb71
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# app.py
import argparse
import streamlit as st
import os
from pipeline_ace_step import ACEStepPipeline
from data_sampler import DataSampler

# Streamlit ์„ค์ •
st.set_page_config(
    page_title="ACE Step Music Generator",
    page_icon="๐ŸŽต",
    layout="wide"
)

def get_args():
    """ํ™˜๊ฒฝ๋ณ€์ˆ˜ ๋˜๋Š” ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ ์„ค์ •"""
    return {
        'checkpoint_path': os.environ.get('CHECKPOINT_PATH'),
        'device_id': int(os.environ.get('DEVICE_ID', '0')),
        'bf16': os.environ.get('BF16', 'True').lower() == 'true',
        'torch_compile': os.environ.get('TORCH_COMPILE', 'False').lower() == 'true'
    }

@st.cache_resource
def load_model(args):
    """๋ชจ๋ธ ๋กœ๋”ฉ (์บ์‹œ๋จ)"""
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args['device_id'])
    persistent_storage_path = "/data"
    
    model_demo = ACEStepPipeline(
        checkpoint_dir=args['checkpoint_path'],
        dtype="bfloat16" if args['bf16'] else "float32",
        persistent_storage_path=persistent_storage_path,
        torch_compile=args['torch_compile']
    )
    data_sampler = DataSampler()
    return model_demo, data_sampler

def main():
    st.title("๐ŸŽต ACE Step Music Generator")
    
    args = get_args()
    
    try:
        model_demo, data_sampler = load_model(args)
        
        # UI ๊ตฌ์„ฑ
        col1, col2 = st.columns([2, 1])
        
        with col1:
            st.header("Generate Music")
            
            # ํ…์ŠคํŠธ ์ž…๋ ฅ
            prompt = st.text_area(
                "Enter your music description:",
                placeholder="Enter a description of the music you want to generate...",
                height=100
            )
            
            # ์ƒ์„ฑ ๋ฒ„ํŠผ
            if st.button("Generate Music", type="primary"):
                if prompt:
                    with st.spinner("Generating music..."):
                        try:
                            result = model_demo(prompt)
                            st.success("Music generated successfully!")
                            
                            # ๊ฒฐ๊ณผ ํ‘œ์‹œ (result ํ˜•ํƒœ์— ๋”ฐ๋ผ ์กฐ์ • ํ•„์š”)
                            if hasattr(result, 'audio'):
                                st.audio(result.audio)
                            else:
                                st.write(result)
                                
                        except Exception as e:
                            st.error(f"Error generating music: {str(e)}")
                else:
                    st.warning("Please enter a description first.")
        
        with col2:
            st.header("Sample Data")
            
            if st.button("Load Sample"):
                try:
                    sample_data = data_sampler.sample()
                    st.json(sample_data)
                except Exception as e:
                    st.error(f"Error loading sample: {str(e)}")
            
            # ํŒŒ์ผ ์—…๋กœ๋“œ
            uploaded_file = st.file_uploader(
                "Upload JSON data",
                type=['json']
            )
            
            if uploaded_file:
                try:
                    data = data_sampler.load_json(uploaded_file)
                    st.json(data)
                except Exception as e:
                    st.error(f"Error loading file: {str(e)}")
                    
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        import traceback
        st.code(traceback.format_exc())

if __name__ == "__main__":
    main()