|
import streamlit as st |
|
import os |
|
import time |
|
import sys |
|
import torch |
|
from huggingface_hub import snapshot_download |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append(current_dir) |
|
sys.path.append(os.path.join(current_dir, "indextts")) |
|
|
|
from indextts.infer import IndexTTS |
|
from tools.i18n.i18n import I18nAuto |
|
|
|
|
|
i18n = I18nAuto(language="en") |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
st.set_page_config(page_title="echoAI - IndexTTS", layout="wide") |
|
|
|
|
|
os.makedirs("outputs/tasks", exist_ok=True) |
|
os.makedirs("prompts", exist_ok=True) |
|
|
|
|
|
if not os.path.exists("checkpoints"): |
|
snapshot_download("IndexTeam/IndexTTS-1.5", local_dir="checkpoints") |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml") |
|
tts.load_normalizer() |
|
if DEVICE == "cuda": |
|
tts.model.to(DEVICE) |
|
return tts |
|
|
|
tts = load_model() |
|
|
|
|
|
def infer(voice_path, text, output_path=None): |
|
if not output_path: |
|
output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav") |
|
|
|
|
|
tts.infer(voice_path, text, output_path) |
|
return output_path |
|
|
|
|
|
st.title("echoAI - IndexTTS") |
|
st.markdown(""" |
|
<h4 style='text-align: center;'> |
|
An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System |
|
</h4> |
|
<p style='text-align: center;'> |
|
<a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a> |
|
</p> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.sidebar.markdown(f"**Device:** {DEVICE.upper()}") |
|
|
|
|
|
with st.container(): |
|
st.header("Audio Generation") |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
uploaded_audio = st.file_uploader( |
|
"Upload reference audio", |
|
type=["wav", "mp3", "ogg"], |
|
accept_multiple_files=False |
|
) |
|
|
|
input_text = st.text_area( |
|
"Input target text", |
|
height=150, |
|
placeholder="Enter text to synthesize..." |
|
) |
|
|
|
generate_btn = st.button("Generate Speech") |
|
|
|
with col2: |
|
if generate_btn and uploaded_audio and input_text: |
|
with st.spinner("Generating audio..."): |
|
|
|
audio_path = os.path.join("prompts", uploaded_audio.name) |
|
with open(audio_path, "wb") as f: |
|
f.write(uploaded_audio.getbuffer()) |
|
|
|
|
|
try: |
|
output_path = infer(audio_path, input_text) |
|
st.audio(output_path, format="audio/wav") |
|
st.success("Generation complete!") |
|
|
|
|
|
with open(output_path, "rb") as f: |
|
st.download_button( |
|
"Download Result", |
|
f, |
|
file_name=os.path.basename(output_path) |
|
except Exception as e: |
|
st.error(f"Error: {str(e)}") |
|
elif generate_btn: |
|
st.warning("Please upload an audio file and enter text first!") |
|
|
|
|
|
with st.sidebar: |
|
st.header("About echoAI") |
|
st.markdown(""" |
|
### Key Features: |
|
- Zero-shot voice cloning |
|
- Industrial-grade TTS |
|
- Efficient synthesis |
|
- Controllable output |
|
""") |
|
|
|
st.markdown("---") |
|
st.markdown(""" |
|
### Usage Instructions: |
|
1. Upload a reference audio clip |
|
2. Enter target text |
|
3. Click 'Generate Speech' |
|
""") |
|
|
|
if __name__ == "__main__": |
|
|
|
pass |