File size: 6,257 Bytes
9a88d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5097fd
 
9a88d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import FileResponse
from pydantic import BaseModel
from TTS.api import TTS
import os
import tempfile
import uuid
import torch
from typing import Optional
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="TTS API", description="Text-to-Speech API using XTTS-v2", version="1.0.0")

class TTSRequest(BaseModel):
    text: str
    language: str = "en"

class TTSService:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"Using device: {self.device}")
        
        # Use absolute paths for the model
        model_path = "XTTS-v2_C3PO/"
        config_path = "XTTS-v2_C3PO/config.json"
        
        # Check if model files exist
        if not os.path.exists(config_path):
            logger.warning(f"Custom model config not found at {config_path}")
            # List contents of model directory for debugging
            model_dir = "/app/XTTS-v2_C3PO"
            if os.path.exists(model_dir):
                logger.info(f"Contents of {model_dir}: {os.listdir(model_dir)}")
            else:
                logger.warning(f"Model directory {model_dir} does not exist")
            
            # Fallback to default XTTS model
            logger.info("Falling back to default XTTS model")
            try:
                self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device)
                logger.info("Default TTS model loaded successfully")
                return
            except Exception as e:
                logger.error(f"Failed to load default TTS model: {e}")
                raise e
        
        try:
            self.tts = TTS(
                model_path=model_path, 
                config_path=config_path, 
                progress_bar=False, 
                gpu=torch.cuda.is_available()
            ).to(self.device)
            logger.info("Custom TTS model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load custom TTS model: {e}")
            # Fallback to default model
            logger.info("Falling back to default XTTS model")
            try:
                self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device)
                logger.info("Default TTS model loaded successfully")
            except Exception as fallback_e:
                logger.error(f"Failed to load default TTS model: {fallback_e}")
                raise fallback_e

    def generate_speech(self, text: str, speaker_wav_path: str, language: str = "en") -> str:
        """Generate speech and return the path to the output file"""
        try:
            # Create a unique filename for the output
            output_filename = f"output_{uuid.uuid4().hex}.wav"
            output_path = os.path.join(tempfile.gettempdir(), output_filename)
            
            # Generate speech
            self.tts.tts_to_file(
                text=text,
                file_path=output_path,
                speaker_wav=speaker_wav_path,
                language=language
            )
            
            return output_path
        except Exception as e:
            logger.error(f"Error generating speech: {e}")
            raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}")

# Initialize TTS service
tts_service = TTSService()

@app.get("/")
async def root():
    return {"message": "TTS API is running", "status": "healthy"}

@app.get("/health")
async def health_check():
    return {"status": "healthy", "device": tts_service.device}

@app.post("/tts")
async def text_to_speech(
    text: str = Form(...),
    language: str = Form("en"),
    speaker_file: UploadFile = File(...)
):
    """
    Convert text to speech using a reference speaker voice
    
    - **text**: The text to convert to speech
    - **language**: Language code (default: "en")
    - **speaker_file**: Audio file containing the reference speaker voice
    """
    
    if not text.strip():
        raise HTTPException(status_code=400, detail="Text cannot be empty")
    
    # Validate file type
    if not speaker_file.content_type.startswith('audio/'):
        raise HTTPException(status_code=400, detail="Speaker file must be an audio file")
    
    try:
        # Save uploaded speaker file temporarily
        speaker_temp_path = "XTTS-v2_C3PO/reference.wav"
        
        with open(speaker_temp_path, "wb") as buffer:
            content = await speaker_file.read()
            buffer.write(content)
        
        # Generate speech
        output_path = tts_service.generate_speech(text, speaker_temp_path, language)
        
        
        # Return the generated audio file
        return FileResponse(
            output_path,
            media_type="audio/wav",
            filename=f"tts_output_{uuid.uuid4().hex}.wav",
            headers={"Content-Disposition": "attachment"}
        )
        
    except Exception as e:
        # Clean up files in case of error
        if 'speaker_temp_path' in locals() and os.path.exists(speaker_temp_path):
            os.remove(speaker_temp_path)
        
        logger.error(f"Error in TTS endpoint: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/tts-with-url")
async def text_to_speech_with_url(request: TTSRequest, speaker_wav_url: str):
    """
    Convert text to speech using a reference speaker voice from URL
    
    - **request**: TTSRequest containing text and language
    - **speaker_wav_url**: URL to the reference speaker audio file
    """
    
    if not request.text.strip():
        raise HTTPException(status_code=400, detail="Text cannot be empty")
    
    try:
        # For this endpoint, you would need to download the file from URL
        # This is a simplified version - you might want to add URL validation and download logic
        raise HTTPException(status_code=501, detail="URL-based speaker input not implemented yet")
        
    except Exception as e:
        logger.error(f"Error in TTS URL endpoint: {e}")
        raise HTTPException(status_code=500, detail=str(e))