File size: 4,087 Bytes
c21ab36
8db92ed
 
 
c21ab36
ec8ba93
579fccc
8db92ed
 
 
 
 
 
 
c21ab36
 
 
 
 
 
 
 
8db92ed
c21ab36
 
 
33551a3
c21ab36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8db92ed
 
c21ab36
 
 
8db92ed
 
c21ab36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229bbd8
 
c21ab36
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import streamlit as st
import os
import time
import sys
import torch
from huggingface_hub import snapshot_download

current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
sys.path.append(os.path.join(current_dir, "indextts"))

from indextts.infer import IndexTTS
from tools.i18n.i18n import I18nAuto

# Initialize internationalization
i18n = I18nAuto(language="en")  # Changed to English

# GPU configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# App configuration
st.set_page_config(page_title="echoAI - IndexTTS", layout="wide")

# Create necessary directories
os.makedirs("outputs/tasks", exist_ok=True)
os.makedirs("prompts", exist_ok=True)

# Download checkpoints if not exists
if not os.path.exists("checkpoints"):
    snapshot_download("IndexTeam/IndexTTS-1.5", local_dir="checkpoints")

# Load TTS model with GPU support
@st.cache_resource
def load_model():
    tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml")
    tts.load_normalizer()
    if DEVICE == "cuda":
        tts.model.to(DEVICE)  # Move model to GPU if available
    return tts

tts = load_model()

# Inference function with device awareness
def infer(voice_path, text, output_path=None):
    if not output_path:
        output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
    
    # Ensure input is on correct device
    tts.infer(voice_path, text, output_path)
    return output_path

# Streamlit UI
st.title("echoAI - IndexTTS")
st.markdown("""
<h4 style='text-align: center;'>
    An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System
</h4>
<p style='text-align: center;'>
    <a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
</p>
""", unsafe_allow_html=True)

# Device status indicator
st.sidebar.markdown(f"**Device:** {DEVICE.upper()}")

# Main interface
with st.container():
    st.header("Audio Generation")  # Translated
    
    col1, col2 = st.columns(2)
    
    with col1:
        uploaded_audio = st.file_uploader(
            "Upload reference audio",  # Translated
            type=["wav", "mp3", "ogg"],
            accept_multiple_files=False
        )
        
        input_text = st.text_area(
            "Input target text",  # Translated
            height=150,
            placeholder="Enter text to synthesize..."
        )
        
        generate_btn = st.button("Generate Speech")  # Translated

    with col2:
        if generate_btn and uploaded_audio and input_text:
            with st.spinner("Generating audio..."):
                # Save uploaded audio
                audio_path = os.path.join("prompts", uploaded_audio.name)
                with open(audio_path, "wb") as f:
                    f.write(uploaded_audio.getbuffer())
                
                # Perform inference
                try:
                    output_path = infer(audio_path, input_text)
                    st.audio(output_path, format="audio/wav")
                    st.success("Generation complete!")
                    
                    # Download button
                    with open(output_path, "rb") as f:
                        st.download_button(
                            "Download Result",  # Translated
                            f,
                            file_name=os.path.basename(output_path)
                except Exception as e:
                    st.error(f"Error: {str(e)}")
        elif generate_btn:
            st.warning("Please upload an audio file and enter text first!")  # Translated

# Sidebar with additional info
with st.sidebar:
    st.header("About echoAI")
    st.markdown("""
    ### Key Features:
    - Zero-shot voice cloning
    - Industrial-grade TTS
    - Efficient synthesis
    - Controllable output
    """)
    
    st.markdown("---")
    st.markdown("""
    ### Usage Instructions:
    1. Upload a reference audio clip
    2. Enter target text
    3. Click 'Generate Speech'
    """)

if __name__ == "__main__":
    # Cleanup old files if needed
    pass