File size: 10,490 Bytes
38818c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
from flask import Blueprint, jsonify, request
import json
import logging
import torch
import os
import tempfile

from env_vars import API_LOG_LEVEL

translations_blueprint = Blueprint(
    "translations_blueprint",
    __name__,
)

logger = logging.getLogger(__name__)
logger.level = API_LOG_LEVEL
logging.getLogger("boto3").setLevel(API_LOG_LEVEL)
logging.getLogger("botocore").setLevel(API_LOG_LEVEL)


def get_model():
    """Import the model getter function from server module"""
    from server import get_model as server_get_model

    return server_get_model()


def get_text_decoder():
    """Import the text decoder getter function from server module"""
    from server import get_text_decoder as server_get_text_decoder

    return server_get_text_decoder()


def get_device():
    """Import the device getter function from server module"""
    from server import get_device as server_get_device

    return server_get_device()


@translations_blueprint.route("/health")
def health():
    """Health check endpoint"""
    model = get_model()
    device = get_device()
    return {
        "status": "healthy",
        "service": "translations",
        "model_loaded": model is not None,
        "device": str(device) if device else "unknown",
        "cuda_available": torch.cuda.is_available(),
    }


@translations_blueprint.route("/hello")
def hello():
    """Simple hello world endpoint"""
    return {"message": "Hello from Translations API!"}


@translations_blueprint.route("/transcribe", methods=["POST"])
def transcribe_audio():
    """Transcribe audio using the MMS model"""
    try:
        # Check if model is loaded
        model = get_model()
        text_decoder = get_text_decoder()
        device = get_device()

        if model is None or text_decoder is None:
            return (
                jsonify({"error": "Model not loaded. Please check server logs."}),
                503,
            )

        # Check if audio file is provided
        if "audio" not in request.files:
            return jsonify({"error": "No audio file provided"}), 400

        audio_file = request.files["audio"]
        if audio_file.filename == "":
            return jsonify({"error": "No audio file selected"}), 400

        # Save uploaded file temporarily
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
            audio_file.save(tmp_file.name)
            temp_path = tmp_file.name

        try:
            # Import and use the complete transcription pipeline with alignment from model.py
            from model import transcribe_audio_with_alignment

            # Use the complete pipeline function with alignment
            results = transcribe_audio_with_alignment(
                wav_path=temp_path,
                max_duration_seconds=10,
            )

            logger.info(f"Transcription with alignment completed: {results}")

            # Format response with alignment data
            response = {
                "transcription": results.get("transcription", ""),
                "model": "fairseq2-MMS",
                "device": str(device),
                "status": "success",
                "total_duration": results.get("total_duration", 0.0),
                "num_segments": results.get("num_segments", 0),
            }

            # Add alignment information if available
            if results.get("aligned_segments"):
                response["aligned_segments"] = results["aligned_segments"]
                response["alignment_available"] = True
            else:
                response["alignment_available"] = False
                if "alignment_error" in results:
                    response["alignment_error"] = results["alignment_error"]

            return jsonify(response)

        finally:
            # Clean up temporary file
            if os.path.exists(temp_path):
                os.unlink(temp_path)

    except Exception as e:
        logger.error(f"Transcription error: {str(e)}")
        return jsonify({"error": f"Transcription failed: {str(e)}"}), 500


@translations_blueprint.route("/align", methods=["POST"])
def align_transcription():
    """Perform forced alignment on audio with provided transcription"""
    try:
        # Check if model is loaded
        model = get_model()
        if model is None:
            return (
                jsonify({"error": "Model not loaded. Please check server logs."}),
                503,
            )

        # Check if audio file and transcription are provided
        if "audio" not in request.files:
            return jsonify({"error": "No audio file provided"}), 400

        audio_file = request.files["audio"]
        if audio_file.filename == "":
            return jsonify({"error": "No audio file selected"}), 400

        # Get transcription text from form data
        transcription = request.form.get("transcription", "").strip()
        if not transcription:
            return jsonify({"error": "No transcription text provided"}), 400

        # Save uploaded file temporarily
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
            audio_file.save(tmp_file.name)
            temp_path = tmp_file.name

        try:
            # Import forced alignment function
            from model import perform_forced_alignment

            # Tokenize the transcription
            tokens = transcription.split()

            # Get model and device
            model = get_model()
            device = get_device()

            if model is None or device is None:
                return jsonify({"error": "Model not available for alignment"}), 503

            # Perform forced alignment
            aligned_segments = perform_forced_alignment(
                temp_path, tokens, model, device
            )

            # Calculate total duration
            total_duration = aligned_segments[-1]["end"] if aligned_segments else 0.0

            logger.info(f"Forced alignment completed: {len(aligned_segments)} segments")

            return jsonify(
                {
                    "transcription": transcription,
                    "aligned_segments": aligned_segments,
                    "total_duration": total_duration,
                    "num_segments": len(aligned_segments),
                    "status": "success",
                }
            )

        finally:
            # Clean up temporary file
            if os.path.exists(temp_path):
                os.unlink(temp_path)

    except Exception as e:
        logger.error(f"Alignment error: {str(e)}")
        return jsonify({"error": f"Alignment failed: {str(e)}"}), 500


@translations_blueprint.route("/translate", methods=["POST"])
def translate():
    """Main translation endpoint using fairseq2"""
    try:
        # Check if model is loaded
        model = get_model()
        if model is None:
            return (
                jsonify({"error": "Model not loaded. Please check server logs."}),
                503,
            )

        data = request.get_json()

        # Validate input
        if not data or "text" not in data:
            return jsonify({"error": "Missing 'text' field in request"}), 400

        text = data["text"]
        source_lang = data.get("source_lang", "en")
        target_lang = data.get("target_lang", "es")

        logger.info(f"Translation request: {source_lang} -> {target_lang}: '{text}'")
        logger.info(f"Model loaded: {model is not None}")

        # Use fairseq2 model for translation
        # TODO: Implement actual model inference for translation here
        translation = f"[fairseq2-MMS] Translation of '{text}' from {source_lang} to {target_lang}"

        result = {
            "original_text": text,
            "translated_text": translation,
            "source_language": source_lang,
            "target_language": target_lang,
            "model": "fairseq2-MMS",
            "model_loaded": model is not None,
        }

        logger.info(f"Translation completed: {result}")
        return jsonify(result)

    except Exception as e:
        logger.error(f"Translation error: {str(e)}")
        return jsonify({"error": f"Translation failed: {str(e)}"}), 500


@translations_blueprint.route("/models", methods=["GET"])
def list_models():
    """List available translation models"""
    model = get_model()
    models = {
        "available_models": [
            {
                "name": "fairseq2-MMS",
                "description": "fairseq2 MMS (Massively Multilingual Speech) speech-to-text model",
                "available": model is not None,
                "loaded": model is not None,
                "capabilities": ["speech-to-text", "transcription"],
            }
        ],
        "default_model": "fairseq2-MMS",
    }
    return jsonify(models)


@translations_blueprint.route("/languages", methods=["GET"])
def supported_languages():
    """Get list of supported languages"""
    languages = {
        "supported_languages": [
            {"code": "en", "name": "English"},
            {"code": "es", "name": "Spanish"},
            {"code": "fr", "name": "French"},
            {"code": "de", "name": "German"},
            {"code": "it", "name": "Italian"},
            {"code": "pt", "name": "Portuguese"},
            {"code": "ar", "name": "Arabic"},
            {"code": "zh", "name": "Chinese"},
            {"code": "ja", "name": "Japanese"},
            {"code": "ko", "name": "Korean"},
        ],
        "note": "MMS model supports 1143+ languages",
    }
    return jsonify(languages)


@translations_blueprint.route("/test_model", methods=["GET"])
def test_model():
    """Test endpoint to verify model functionality"""
    try:
        model = get_model()
        text_decoder = get_text_decoder()
        device = get_device()

        if model is None:
            return jsonify({"error": "Model not loaded"}), 503

        # Return model status
        return jsonify(
            {
                "model_loaded": True,
                "device": str(device),
                "model_type": str(type(model).__name__),
                "text_decoder_available": text_decoder is not None,
                "cuda_available": torch.cuda.is_available(),
                "status": "Model is ready for inference",
            }
        )

    except Exception as e:
        logger.error(f"Model test error: {str(e)}")
        return jsonify({"error": f"Model test failed: {str(e)}"}), 500