Spaces:
Runtime error
Runtime error
import os | |
import time | |
import logging | |
from flask import Flask, request, jsonify | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from flores200_codes import flores_codes | |
app = Flask(__name__) | |
logging.basicConfig(level=logging.DEBUG) | |
def load_models(): | |
model_name_dict = {"nllb-distilled-600M": "facebook/nllb-200-distilled-600M"} | |
model_dict = {} | |
for call_name, real_name in model_name_dict.items(): | |
logging.info(f"\tLoading model: {call_name}") | |
model = AutoModelForSeq2SeqLM.from_pretrained(real_name) | |
tokenizer = AutoTokenizer.from_pretrained(real_name) | |
model_dict[call_name + "_model"] = model | |
model_dict[call_name + "_tokenizer"] = tokenizer | |
return model_dict | |
global model_dict | |
model_dict = load_models() | |
def translate_text(): | |
data = request.json | |
logging.debug(f"Received data: {data}") | |
source_lang = data.get("source") | |
target_lang = data.get("target") | |
input_text = data.get("text") | |
if not source_lang or not target_lang or not input_text: | |
logging.error("Missing fields in the request") | |
return jsonify({"error": "source, target, and text fields are required"}), 400 | |
model_name = "nllb-distilled-600M" | |
start_time = time.time() | |
source = flores_codes.get(source_lang) | |
target = flores_codes.get(target_lang) | |
if not source or not target: | |
logging.error("Invalid source or target language code") | |
return jsonify({"error": "Invalid source or target language code"}), 400 | |
model = model_dict[model_name + "_model"] | |
tokenizer = model_dict[model_name + "_tokenizer"] | |
translator = pipeline( | |
"translation", | |
model=model, | |
tokenizer=tokenizer, | |
src_lang=source, | |
tgt_lang=target, | |
) | |
output = translator(input_text, max_length=400) | |
end_time = time.time() | |
output_text = output[0]["translation_text"] | |
result = { | |
"inference_time": end_time - start_time, | |
"source": source_lang, | |
"target": target_lang, | |
"result": output_text, | |
} | |
logging.debug(f"Translation result: {result}") | |
return jsonify(result) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=5000, debug=True) | |