File size: 2,283 Bytes
e69a4b4
 
5f05a96
e69a4b4
 
 
 
 
 
5f05a96
b05fcfe
e69a4b4
b05fcfe
e69a4b4
 
 
5f05a96
e69a4b4
 
b05fcfe
 
e69a4b4
 
 
 
 
 
b05fcfe
e69a4b4
 
5f05a96
b05fcfe
 
 
 
e69a4b4
5f05a96
e69a4b4
b05fcfe
 
e69a4b4
 
 
b05fcfe
e69a4b4
5f05a96
e69a4b4
 
b05fcfe
 
 
 
 
 
 
 
 
 
e69a4b4
b05fcfe
e69a4b4
b05fcfe
 
e69a4b4
b05fcfe
 
 
 
e69a4b4
5f05a96
e69a4b4
 
b05fcfe
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import time
import logging
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from flores200_codes import flores_codes

app = Flask(__name__)

logging.basicConfig(level=logging.DEBUG)

def load_models():
    model_name_dict = {"nllb-distilled-600M": "facebook/nllb-200-distilled-600M"}
    model_dict = {}

    for call_name, real_name in model_name_dict.items():
        logging.info(f"\tLoading model: {call_name}")
        model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
        tokenizer = AutoTokenizer.from_pretrained(real_name)
        model_dict[call_name + "_model"] = model
        model_dict[call_name + "_tokenizer"] = tokenizer

    return model_dict

global model_dict
model_dict = load_models()

@app.route("/api/translate", methods=["POST"])
def translate_text():
    data = request.json
    logging.debug(f"Received data: {data}")
    source_lang = data.get("source")
    target_lang = data.get("target")
    input_text = data.get("text")

    if not source_lang or not target_lang or not input_text:
        logging.error("Missing fields in the request")
        return jsonify({"error": "source, target, and text fields are required"}), 400

    model_name = "nllb-distilled-600M"
    start_time = time.time()
    source = flores_codes.get(source_lang)
    target = flores_codes.get(target_lang)

    if not source or not target:
        logging.error("Invalid source or target language code")
        return jsonify({"error": "Invalid source or target language code"}), 400

    model = model_dict[model_name + "_model"]
    tokenizer = model_dict[model_name + "_tokenizer"]

    translator = pipeline(
        "translation",
        model=model,
        tokenizer=tokenizer,
        src_lang=source,
        tgt_lang=target,
    )
    output = translator(input_text, max_length=400)

    end_time = time.time()
    output_text = output[0]["translation_text"]

    result = {
        "inference_time": end_time - start_time,
        "source": source_lang,
        "target": target_lang,
        "result": output_text,
    }
    logging.debug(f"Translation result: {result}")
    return jsonify(result)

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5000, debug=True)