Michael Hu
Migrate translation service to infrastructure layer
9626844
raw
history blame
18.3 kB
"""NLLB translation provider implementation."""
import logging
from typing import Dict, List, Optional
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from ..base.translation_provider_base import TranslationProviderBase
from ...domain.exceptions import TranslationFailedException
logger = logging.getLogger(__name__)
class NLLBTranslationProvider(TranslationProviderBase):
"""NLLB-200-3.3B translation provider implementation."""
# NLLB language code mappings
LANGUAGE_MAPPINGS = {
'en': 'eng_Latn',
'zh': 'zho_Hans',
'zh-cn': 'zho_Hans',
'zh-tw': 'zho_Hant',
'es': 'spa_Latn',
'fr': 'fra_Latn',
'de': 'deu_Latn',
'ja': 'jpn_Jpan',
'ko': 'kor_Hang',
'ar': 'arb_Arab',
'hi': 'hin_Deva',
'pt': 'por_Latn',
'ru': 'rus_Cyrl',
'it': 'ita_Latn',
'nl': 'nld_Latn',
'pl': 'pol_Latn',
'tr': 'tur_Latn',
'sv': 'swe_Latn',
'da': 'dan_Latn',
'no': 'nor_Latn',
'fi': 'fin_Latn',
'el': 'ell_Grek',
'he': 'heb_Hebr',
'th': 'tha_Thai',
'vi': 'vie_Latn',
'id': 'ind_Latn',
'ms': 'zsm_Latn',
'tl': 'tgl_Latn',
'uk': 'ukr_Cyrl',
'cs': 'ces_Latn',
'sk': 'slk_Latn',
'hu': 'hun_Latn',
'ro': 'ron_Latn',
'bg': 'bul_Cyrl',
'hr': 'hrv_Latn',
'sr': 'srp_Cyrl',
'sl': 'slv_Latn',
'et': 'est_Latn',
'lv': 'lvs_Latn',
'lt': 'lit_Latn',
'mt': 'mlt_Latn',
'ga': 'gle_Latn',
'cy': 'cym_Latn',
'is': 'isl_Latn',
'mk': 'mkd_Cyrl',
'sq': 'sqi_Latn',
'eu': 'eus_Latn',
'ca': 'cat_Latn',
'gl': 'glg_Latn',
'ast': 'ast_Latn',
'oc': 'oci_Latn',
'br': 'bre_Latn',
'co': 'cos_Latn',
'sc': 'srd_Latn',
'rm': 'roh_Latn',
'fur': 'fur_Latn',
'lij': 'lij_Latn',
'vec': 'vec_Latn',
'pms': 'pms_Latn',
'lmo': 'lmo_Latn',
'nap': 'nap_Latn',
'scn': 'scn_Latn',
'wa': 'wln_Latn',
'frp': 'frp_Latn',
'gsw': 'gsw_Latn',
'bar': 'bar_Latn',
'ksh': 'ksh_Latn',
'lb': 'ltz_Latn',
'li': 'lim_Latn',
'nds': 'nds_Latn',
'pdc': 'pdc_Latn',
'sli': 'sli_Latn',
'vmf': 'vmf_Latn',
'yi': 'yid_Hebr',
'af': 'afr_Latn',
'zu': 'zul_Latn',
'xh': 'xho_Latn',
'st': 'sot_Latn',
'tn': 'tsn_Latn',
'ss': 'ssw_Latn',
'nr': 'nbl_Latn',
've': 'ven_Latn',
'ts': 'tso_Latn',
'sw': 'swh_Latn',
'rw': 'kin_Latn',
'rn': 'run_Latn',
'ny': 'nya_Latn',
'sn': 'sna_Latn',
'yo': 'yor_Latn',
'ig': 'ibo_Latn',
'ha': 'hau_Latn',
'ff': 'fuv_Latn',
'wo': 'wol_Latn',
'bm': 'bam_Latn',
'dyu': 'dyu_Latn',
'ee': 'ewe_Latn',
'tw': 'twi_Latn',
'ak': 'aka_Latn',
'gaa': 'gaa_Latn',
'lg': 'lug_Latn',
'luo': 'luo_Latn',
'ki': 'kik_Latn',
'kam': 'kam_Latn',
'luy': 'luy_Latn',
'mer': 'mer_Latn',
'kln': 'kln_Latn',
'kab': 'kab_Latn',
'ber': 'ber_Latn',
'am': 'amh_Ethi',
'ti': 'tir_Ethi',
'om': 'orm_Latn',
'so': 'som_Latn',
'mg': 'plt_Latn',
'ny': 'nya_Latn',
'bem': 'bem_Latn',
'tum': 'tum_Latn',
'loz': 'loz_Latn',
'lua': 'lua_Latn',
'umb': 'umb_Latn',
'kmb': 'kmb_Latn',
'kg': 'kon_Latn',
'ln': 'lin_Latn',
'sg': 'sag_Latn',
'fon': 'fon_Latn',
'mos': 'mos_Latn',
'dga': 'dga_Latn',
'kbp': 'kbp_Latn',
'nus': 'nus_Latn',
'din': 'din_Latn',
'luo': 'luo_Latn',
'ach': 'ach_Latn',
'teo': 'teo_Latn',
'mdt': 'mdt_Latn',
'knc': 'knc_Latn',
'fuv': 'fuv_Latn',
'kr': 'kau_Latn',
'dje': 'dje_Latn',
'son': 'son_Latn',
'tmh': 'tmh_Latn',
'taq': 'taq_Latn',
'ttq': 'ttq_Latn',
'thv': 'thv_Latn',
'taq': 'taq_Tfng',
'shi': 'shi_Tfng',
'tzm': 'tzm_Tfng',
'rif': 'rif_Latn',
'kab': 'kab_Latn',
'shy': 'shy_Latn',
'ber': 'ber_Latn',
'acm': 'acm_Arab',
'aeb': 'aeb_Arab',
'ajp': 'ajp_Arab',
'apc': 'apc_Arab',
'ars': 'ars_Arab',
'ary': 'ary_Arab',
'arz': 'arz_Arab',
'auz': 'auz_Arab',
'avl': 'avl_Arab',
'ayh': 'ayh_Arab',
'ayn': 'ayn_Arab',
'ayp': 'ayp_Arab',
'bbz': 'bbz_Arab',
'pga': 'pga_Arab',
'shu': 'shu_Arab',
'ssh': 'ssh_Arab',
'fa': 'pes_Arab',
'tg': 'tgk_Cyrl',
'ps': 'pbt_Arab',
'ur': 'urd_Arab',
'sd': 'snd_Arab',
'ks': 'kas_Arab',
'dv': 'div_Thaa',
'ne': 'npi_Deva',
'si': 'sin_Sinh',
'my': 'mya_Mymr',
'km': 'khm_Khmr',
'lo': 'lao_Laoo',
'ka': 'kat_Geor',
'hy': 'hye_Armn',
'az': 'azj_Latn',
'kk': 'kaz_Cyrl',
'ky': 'kir_Cyrl',
'uz': 'uzn_Latn',
'tk': 'tuk_Latn',
'mn': 'khk_Cyrl',
'bo': 'bod_Tibt',
'dz': 'dzo_Tibt',
'ug': 'uig_Arab',
'tt': 'tat_Cyrl',
'ba': 'bak_Cyrl',
'cv': 'chv_Cyrl',
'sah': 'sah_Cyrl',
'tyv': 'tyv_Cyrl',
'kjh': 'kjh_Cyrl',
'alt': 'alt_Cyrl',
'krc': 'krc_Cyrl',
'kum': 'kum_Cyrl',
'nog': 'nog_Cyrl',
'kaa': 'kaa_Cyrl',
'crh': 'crh_Latn',
'gag': 'gag_Latn',
'tr': 'tur_Latn',
'az': 'azb_Arab',
'ku': 'ckb_Arab',
'lrc': 'lrc_Arab',
'mzn': 'mzn_Arab',
'glk': 'glk_Arab',
'fa': 'pes_Arab',
'tg': 'tgk_Cyrl',
'prs': 'prs_Arab',
'haz': 'haz_Arab',
'bal': 'bal_Arab',
'bcc': 'bcc_Arab',
'bgp': 'bgp_Arab',
'bqi': 'bqi_Arab',
'ckb': 'ckb_Arab',
'diq': 'diq_Latn',
'hac': 'hac_Arab',
'kur': 'kmr_Latn',
'lki': 'lki_Arab',
'pnb': 'pnb_Arab',
'ps': 'pbt_Arab',
'sd': 'snd_Arab',
'skr': 'skr_Arab',
'ur': 'urd_Arab',
'wne': 'wne_Arab',
'xmf': 'xmf_Geor',
'ka': 'kat_Geor',
'hy': 'hye_Armn',
'xcl': 'xcl_Armn',
'he': 'heb_Hebr',
'yi': 'yid_Hebr',
'lad': 'lad_Hebr',
'ar': 'arb_Arab',
'mt': 'mlt_Latn',
'ml': 'mal_Mlym',
'kn': 'kan_Knda',
'te': 'tel_Telu',
'ta': 'tam_Taml',
'or': 'ory_Orya',
'as': 'asm_Beng',
'bn': 'ben_Beng',
'gu': 'guj_Gujr',
'pa': 'pan_Guru',
'hi': 'hin_Deva',
'mr': 'mar_Deva',
'ne': 'npi_Deva',
'sa': 'san_Deva',
'mai': 'mai_Deva',
'bho': 'bho_Deva',
'mag': 'mag_Deva',
'sck': 'sck_Deva',
'new': 'new_Deva',
'bpy': 'bpy_Beng',
'ctg': 'ctg_Beng',
'rkt': 'rkt_Beng',
'syl': 'syl_Beng',
'sat': 'sat_Olck',
'kha': 'kha_Latn',
'grt': 'grt_Beng',
'lus': 'lus_Latn',
'mni': 'mni_Beng',
'kok': 'kok_Deva',
'gom': 'gom_Deva',
'sd': 'snd_Deva',
'doi': 'doi_Deva',
'ks': 'kas_Deva',
'ur': 'urd_Arab',
'ps': 'pbt_Arab',
'bal': 'bal_Arab',
'bcc': 'bcc_Arab',
'bgp': 'bgp_Arab',
'brh': 'brh_Arab',
'hnd': 'hnd_Arab',
'lah': 'lah_Arab',
'pnb': 'pnb_Arab',
'pst': 'pst_Arab',
'skr': 'skr_Arab',
'wne': 'wne_Arab',
'si': 'sin_Sinh',
'dv': 'div_Thaa',
'my': 'mya_Mymr',
'shn': 'shn_Mymr',
'mnw': 'mnw_Mymr',
'kac': 'kac_Latn',
'shn': 'shn_Mymr',
'km': 'khm_Khmr',
'lo': 'lao_Laoo',
'th': 'tha_Thai',
'vi': 'vie_Latn',
'cjm': 'cjm_Arab',
'bjn': 'bjn_Latn',
'bug': 'bug_Latn',
'jv': 'jav_Latn',
'mad': 'mad_Latn',
'ms': 'zsm_Latn',
'min': 'min_Latn',
'su': 'sun_Latn',
'ban': 'ban_Latn',
'bbc': 'bbc_Latn',
'btk': 'btk_Latn',
'gor': 'gor_Latn',
'ilo': 'ilo_Latn',
'pag': 'pag_Latn',
'war': 'war_Latn',
'hil': 'hil_Latn',
'bcl': 'bcl_Latn',
'pam': 'pam_Latn',
'tl': 'tgl_Latn',
'ceb': 'ceb_Latn',
'akl': 'akl_Latn',
'bik': 'bik_Latn',
'cbk': 'cbk_Latn',
'krj': 'krj_Latn',
'tsg': 'tsg_Latn',
'zh': 'zho_Hans',
'yue': 'yue_Hant',
'wuu': 'wuu_Hans',
'hsn': 'hsn_Hans',
'nan': 'nan_Hant',
'hak': 'hak_Hant',
'gan': 'gan_Hans',
'cdo': 'cdo_Hant',
'lzh': 'lzh_Hans',
'ja': 'jpn_Jpan',
'ko': 'kor_Hang',
'ain': 'ain_Kana',
'ryu': 'ryu_Kana',
'eo': 'epo_Latn',
'ia': 'ina_Latn',
'ie': 'ile_Latn',
'io': 'ido_Latn',
'vo': 'vol_Latn',
'nov': 'nov_Latn',
'lfn': 'lfn_Latn',
'jbo': 'jbo_Latn',
'tlh': 'tlh_Latn',
'na': 'nau_Latn',
'ch': 'cha_Latn',
'mh': 'mah_Latn',
'gil': 'gil_Latn',
'kos': 'kos_Latn',
'pon': 'pon_Latn',
'yap': 'yap_Latn',
'chk': 'chk_Latn',
'uli': 'uli_Latn',
'wol': 'wol_Latn',
'pau': 'pau_Latn',
'sm': 'smo_Latn',
'to': 'ton_Latn',
'fj': 'fij_Latn',
'ty': 'tah_Latn',
'mi': 'mri_Latn',
'haw': 'haw_Latn',
'rap': 'rap_Latn',
'tvl': 'tvl_Latn',
'niu': 'niu_Latn',
'tkl': 'tkl_Latn',
'bi': 'bis_Latn',
'ho': 'hmo_Latn',
'kg': 'kon_Latn',
'kj': 'kua_Latn',
'rw': 'kin_Latn',
'rn': 'run_Latn',
'sg': 'sag_Latn',
'sn': 'sna_Latn',
'ss': 'ssw_Latn',
'st': 'sot_Latn',
'sw': 'swh_Latn',
'tn': 'tsn_Latn',
'ts': 'tso_Latn',
've': 'ven_Latn',
'xh': 'xho_Latn',
'zu': 'zul_Latn',
'nd': 'nde_Latn',
'nr': 'nbl_Latn',
'ny': 'nya_Latn',
'bm': 'bam_Latn',
'ee': 'ewe_Latn',
'ff': 'fuv_Latn',
'ha': 'hau_Latn',
'ig': 'ibo_Latn',
'ki': 'kik_Latn',
'lg': 'lug_Latn',
'ln': 'lin_Latn',
'mg': 'plt_Latn',
'om': 'orm_Latn',
'rw': 'kin_Latn',
'rn': 'run_Latn',
'sg': 'sag_Latn',
'sn': 'sna_Latn',
'so': 'som_Latn',
'sw': 'swh_Latn',
'ti': 'tir_Ethi',
'tw': 'twi_Latn',
'wo': 'wol_Latn',
'xh': 'xho_Latn',
'yo': 'yor_Latn',
'zu': 'zul_Latn'
}
def __init__(self, model_name: str = "facebook/nllb-200-3.3B", max_chunk_length: int = 1000):
"""
Initialize NLLB translation provider.
Args:
model_name: The NLLB model name to use
max_chunk_length: Maximum length for text chunks
"""
# Build supported languages mapping
supported_languages = {}
for lang_code in self.LANGUAGE_MAPPINGS.keys():
# For simplicity, assume all languages can translate to all other languages
# In practice, you might want to be more specific about supported pairs
supported_languages[lang_code] = [
target for target in self.LANGUAGE_MAPPINGS.keys()
if target != lang_code
]
super().__init__(
provider_name="NLLB-200-3.3B",
supported_languages=supported_languages
)
self.model_name = model_name
self.max_chunk_length = max_chunk_length
self._tokenizer: Optional[AutoTokenizer] = None
self._model: Optional[AutoModelForSeq2SeqLM] = None
self._model_loaded = False
def _translate_chunk(self, text: str, source_language: str, target_language: str) -> str:
"""
Translate a single chunk of text using NLLB model.
Args:
text: The text chunk to translate
source_language: Source language code
target_language: Target language code
Returns:
str: The translated text chunk
"""
try:
# Ensure model is loaded
self._ensure_model_loaded()
# Map language codes to NLLB format
source_nllb = self._map_language_code(source_language)
target_nllb = self._map_language_code(target_language)
logger.debug(f"Translating chunk from {source_nllb} to {target_nllb}")
# Tokenize with source language specification
inputs = self._tokenizer(
text,
return_tensors="pt",
max_length=1024,
truncation=True
)
# Generate translation with target language specification
outputs = self._model.generate(
**inputs,
forced_bos_token_id=self._tokenizer.convert_tokens_to_ids(target_nllb),
max_new_tokens=1024,
num_beams=4,
early_stopping=True
)
# Decode the translation
translated = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
# Post-process the translation
translated = self._postprocess_text(translated)
logger.debug(f"Chunk translation completed: {len(text)} -> {len(translated)} chars")
return translated
except Exception as e:
self._handle_provider_error(e, "chunk translation")
def _ensure_model_loaded(self) -> None:
"""Ensure the NLLB model and tokenizer are loaded."""
if self._model_loaded:
return
try:
logger.info(f"Loading NLLB model: {self.model_name}")
# Load tokenizer
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
src_lang="eng_Latn" # Default source language
)
# Load model
self._model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
self._model_loaded = True
logger.info("NLLB model loaded successfully")
except Exception as e:
logger.error(f"Failed to load NLLB model: {str(e)}")
raise TranslationFailedException(f"Failed to load NLLB model: {str(e)}") from e
def _map_language_code(self, language_code: str) -> str:
"""
Map standard language code to NLLB format.
Args:
language_code: Standard language code (e.g., 'en', 'zh')
Returns:
str: NLLB language code (e.g., 'eng_Latn', 'zho_Hans')
"""
# Normalize language code to lowercase
normalized_code = language_code.lower()
# Check direct mapping
if normalized_code in self.LANGUAGE_MAPPINGS:
return self.LANGUAGE_MAPPINGS[normalized_code]
# Handle common variations
if normalized_code.startswith('zh'):
if 'tw' in normalized_code or 'hant' in normalized_code or 'traditional' in normalized_code:
return 'zho_Hant'
else:
return 'zho_Hans'
# Default fallback for unknown codes
logger.warning(f"Unknown language code: {language_code}, defaulting to English")
return 'eng_Latn'
def is_available(self) -> bool:
"""
Check if the NLLB translation provider is available.
Returns:
bool: True if provider is available, False otherwise
"""
try:
# Try to import required dependencies
import transformers
import torch
# Check if we can load the tokenizer (lightweight check)
if not self._model_loaded:
try:
test_tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
src_lang="eng_Latn"
)
return True
except Exception as e:
logger.warning(f"NLLB model not available: {str(e)}")
return False
else:
return True
except ImportError as e:
logger.warning(f"NLLB dependencies not available: {str(e)}")
return False
def get_supported_languages(self) -> Dict[str, List[str]]:
"""
Get supported language pairs for NLLB provider.
Returns:
dict: Mapping of source languages to supported target languages
"""
return self.supported_languages.copy()
def get_model_info(self) -> Dict[str, str]:
"""
Get information about the loaded model.
Returns:
dict: Model information
"""
return {
'provider': self.provider_name,
'model_name': self.model_name,
'model_loaded': str(self._model_loaded),
'supported_language_count': str(len(self.LANGUAGE_MAPPINGS)),
'max_chunk_length': str(self.max_chunk_length)
}
def set_model_name(self, model_name: str) -> None:
"""
Set a different NLLB model name.
Args:
model_name: The new model name to use
"""
if model_name != self.model_name:
self.model_name = model_name
self._model_loaded = False
self._tokenizer = None
self._model = None
logger.info(f"Model name changed to: {model_name}")
def clear_model_cache(self) -> None:
"""Clear the loaded model from memory."""
if self._model_loaded:
self._tokenizer = None
self._model = None
self._model_loaded = False
logger.info("NLLB model cache cleared")