Michael Hu
add more logs
fdc056d
"""NLLB translation provider implementation."""
import logging
from typing import Dict, List, Optional
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from ..base.translation_provider_base import TranslationProviderBase
from ...domain.exceptions import TranslationFailedException
logger = logging.getLogger(__name__)
class NLLBTranslationProvider(TranslationProviderBase):
"""NLLB-200-3.3B translation provider implementation."""
# NLLB language code mappings
LANGUAGE_MAPPINGS = {
'en': 'eng_Latn',
'zh': 'zho_Hans',
'zh-cn': 'zho_Hans',
'zh-tw': 'zho_Hant',
'es': 'spa_Latn',
'fr': 'fra_Latn',
'de': 'deu_Latn',
'ja': 'jpn_Jpan',
'ko': 'kor_Hang',
'ar': 'arb_Arab',
'hi': 'hin_Deva',
'pt': 'por_Latn',
'ru': 'rus_Cyrl',
'it': 'ita_Latn',
'nl': 'nld_Latn',
'pl': 'pol_Latn',
'tr': 'tur_Latn',
'sv': 'swe_Latn',
'da': 'dan_Latn',
'no': 'nor_Latn',
'fi': 'fin_Latn',
'el': 'ell_Grek',
'he': 'heb_Hebr',
'th': 'tha_Thai',
'vi': 'vie_Latn',
'id': 'ind_Latn',
'ms': 'zsm_Latn',
'tl': 'tgl_Latn',
'uk': 'ukr_Cyrl',
'cs': 'ces_Latn',
'sk': 'slk_Latn',
'hu': 'hun_Latn',
'ro': 'ron_Latn',
'bg': 'bul_Cyrl',
'hr': 'hrv_Latn',
'sr': 'srp_Cyrl',
'sl': 'slv_Latn',
'et': 'est_Latn',
'lv': 'lvs_Latn',
'lt': 'lit_Latn',
'mt': 'mlt_Latn',
'ga': 'gle_Latn',
'cy': 'cym_Latn',
'is': 'isl_Latn',
'mk': 'mkd_Cyrl',
'sq': 'sqi_Latn',
'eu': 'eus_Latn',
'ca': 'cat_Latn',
'gl': 'glg_Latn',
'ast': 'ast_Latn',
'oc': 'oci_Latn',
'br': 'bre_Latn',
'co': 'cos_Latn',
'sc': 'srd_Latn',
'rm': 'roh_Latn',
'fur': 'fur_Latn',
'lij': 'lij_Latn',
'vec': 'vec_Latn',
'pms': 'pms_Latn',
'lmo': 'lmo_Latn',
'nap': 'nap_Latn',
'scn': 'scn_Latn',
'wa': 'wln_Latn',
'frp': 'frp_Latn',
'gsw': 'gsw_Latn',
'bar': 'bar_Latn',
'ksh': 'ksh_Latn',
'lb': 'ltz_Latn',
'li': 'lim_Latn',
'nds': 'nds_Latn',
'pdc': 'pdc_Latn',
'sli': 'sli_Latn',
'vmf': 'vmf_Latn',
'yi': 'yid_Hebr',
'af': 'afr_Latn',
'zu': 'zul_Latn',
'xh': 'xho_Latn',
'st': 'sot_Latn',
'tn': 'tsn_Latn',
'ss': 'ssw_Latn',
'nr': 'nbl_Latn',
've': 'ven_Latn',
'ts': 'tso_Latn',
'sw': 'swh_Latn',
'rw': 'kin_Latn',
'rn': 'run_Latn',
'ny': 'nya_Latn',
'sn': 'sna_Latn',
'yo': 'yor_Latn',
'ig': 'ibo_Latn',
'ha': 'hau_Latn',
'ff': 'fuv_Latn',
'wo': 'wol_Latn',
'bm': 'bam_Latn',
'dyu': 'dyu_Latn',
'ee': 'ewe_Latn',
'tw': 'twi_Latn',
'ak': 'aka_Latn',
'gaa': 'gaa_Latn',
'lg': 'lug_Latn',
'luo': 'luo_Latn',
'ki': 'kik_Latn',
'kam': 'kam_Latn',
'luy': 'luy_Latn',
'mer': 'mer_Latn',
'kln': 'kln_Latn',
'kab': 'kab_Latn',
'ber': 'ber_Latn',
'am': 'amh_Ethi',
'ti': 'tir_Ethi',
'om': 'orm_Latn',
'so': 'som_Latn',
'mg': 'plt_Latn',
'ny': 'nya_Latn',
'bem': 'bem_Latn',
'tum': 'tum_Latn',
'loz': 'loz_Latn',
'lua': 'lua_Latn',
'umb': 'umb_Latn',
'kmb': 'kmb_Latn',
'kg': 'kon_Latn',
'ln': 'lin_Latn',
'sg': 'sag_Latn',
'fon': 'fon_Latn',
'mos': 'mos_Latn',
'dga': 'dga_Latn',
'kbp': 'kbp_Latn',
'nus': 'nus_Latn',
'din': 'din_Latn',
'luo': 'luo_Latn',
'ach': 'ach_Latn',
'teo': 'teo_Latn',
'mdt': 'mdt_Latn',
'knc': 'knc_Latn',
'fuv': 'fuv_Latn',
'kr': 'kau_Latn',
'dje': 'dje_Latn',
'son': 'son_Latn',
'tmh': 'tmh_Latn',
'taq': 'taq_Latn',
'ttq': 'ttq_Latn',
'thv': 'thv_Latn',
'taq': 'taq_Tfng',
'shi': 'shi_Tfng',
'tzm': 'tzm_Tfng',
'rif': 'rif_Latn',
'kab': 'kab_Latn',
'shy': 'shy_Latn',
'ber': 'ber_Latn',
'acm': 'acm_Arab',
'aeb': 'aeb_Arab',
'ajp': 'ajp_Arab',
'apc': 'apc_Arab',
'ars': 'ars_Arab',
'ary': 'ary_Arab',
'arz': 'arz_Arab',
'auz': 'auz_Arab',
'avl': 'avl_Arab',
'ayh': 'ayh_Arab',
'ayn': 'ayn_Arab',
'ayp': 'ayp_Arab',
'bbz': 'bbz_Arab',
'pga': 'pga_Arab',
'shu': 'shu_Arab',
'ssh': 'ssh_Arab',
'fa': 'pes_Arab',
'tg': 'tgk_Cyrl',
'ps': 'pbt_Arab',
'ur': 'urd_Arab',
'sd': 'snd_Arab',
'ks': 'kas_Arab',
'dv': 'div_Thaa',
'ne': 'npi_Deva',
'si': 'sin_Sinh',
'my': 'mya_Mymr',
'km': 'khm_Khmr',
'lo': 'lao_Laoo',
'ka': 'kat_Geor',
'hy': 'hye_Armn',
'az': 'azj_Latn',
'kk': 'kaz_Cyrl',
'ky': 'kir_Cyrl',
'uz': 'uzn_Latn',
'tk': 'tuk_Latn',
'mn': 'khk_Cyrl',
'bo': 'bod_Tibt',
'dz': 'dzo_Tibt',
'ug': 'uig_Arab',
'tt': 'tat_Cyrl',
'ba': 'bak_Cyrl',
'cv': 'chv_Cyrl',
'sah': 'sah_Cyrl',
'tyv': 'tyv_Cyrl',
'kjh': 'kjh_Cyrl',
'alt': 'alt_Cyrl',
'krc': 'krc_Cyrl',
'kum': 'kum_Cyrl',
'nog': 'nog_Cyrl',
'kaa': 'kaa_Cyrl',
'crh': 'crh_Latn',
'gag': 'gag_Latn',
'tr': 'tur_Latn',
'az': 'azb_Arab',
'ku': 'ckb_Arab',
'lrc': 'lrc_Arab',
'mzn': 'mzn_Arab',
'glk': 'glk_Arab',
'fa': 'pes_Arab',
'tg': 'tgk_Cyrl',
'prs': 'prs_Arab',
'haz': 'haz_Arab',
'bal': 'bal_Arab',
'bcc': 'bcc_Arab',
'bgp': 'bgp_Arab',
'bqi': 'bqi_Arab',
'ckb': 'ckb_Arab',
'diq': 'diq_Latn',
'hac': 'hac_Arab',
'kur': 'kmr_Latn',
'lki': 'lki_Arab',
'pnb': 'pnb_Arab',
'ps': 'pbt_Arab',
'sd': 'snd_Arab',
'skr': 'skr_Arab',
'ur': 'urd_Arab',
'wne': 'wne_Arab',
'xmf': 'xmf_Geor',
'ka': 'kat_Geor',
'hy': 'hye_Armn',
'xcl': 'xcl_Armn',
'he': 'heb_Hebr',
'yi': 'yid_Hebr',
'lad': 'lad_Hebr',
'ar': 'arb_Arab',
'mt': 'mlt_Latn',
'ml': 'mal_Mlym',
'kn': 'kan_Knda',
'te': 'tel_Telu',
'ta': 'tam_Taml',
'or': 'ory_Orya',
'as': 'asm_Beng',
'bn': 'ben_Beng',
'gu': 'guj_Gujr',
'pa': 'pan_Guru',
'hi': 'hin_Deva',
'mr': 'mar_Deva',
'ne': 'npi_Deva',
'sa': 'san_Deva',
'mai': 'mai_Deva',
'bho': 'bho_Deva',
'mag': 'mag_Deva',
'sck': 'sck_Deva',
'new': 'new_Deva',
'bpy': 'bpy_Beng',
'ctg': 'ctg_Beng',
'rkt': 'rkt_Beng',
'syl': 'syl_Beng',
'sat': 'sat_Olck',
'kha': 'kha_Latn',
'grt': 'grt_Beng',
'lus': 'lus_Latn',
'mni': 'mni_Beng',
'kok': 'kok_Deva',
'gom': 'gom_Deva',
'sd': 'snd_Deva',
'doi': 'doi_Deva',
'ks': 'kas_Deva',
'ur': 'urd_Arab',
'ps': 'pbt_Arab',
'bal': 'bal_Arab',
'bcc': 'bcc_Arab',
'bgp': 'bgp_Arab',
'brh': 'brh_Arab',
'hnd': 'hnd_Arab',
'lah': 'lah_Arab',
'pnb': 'pnb_Arab',
'pst': 'pst_Arab',
'skr': 'skr_Arab',
'wne': 'wne_Arab',
'si': 'sin_Sinh',
'dv': 'div_Thaa',
'my': 'mya_Mymr',
'shn': 'shn_Mymr',
'mnw': 'mnw_Mymr',
'kac': 'kac_Latn',
'shn': 'shn_Mymr',
'km': 'khm_Khmr',
'lo': 'lao_Laoo',
'th': 'tha_Thai',
'vi': 'vie_Latn',
'cjm': 'cjm_Arab',
'bjn': 'bjn_Latn',
'bug': 'bug_Latn',
'jv': 'jav_Latn',
'mad': 'mad_Latn',
'ms': 'zsm_Latn',
'min': 'min_Latn',
'su': 'sun_Latn',
'ban': 'ban_Latn',
'bbc': 'bbc_Latn',
'btk': 'btk_Latn',
'gor': 'gor_Latn',
'ilo': 'ilo_Latn',
'pag': 'pag_Latn',
'war': 'war_Latn',
'hil': 'hil_Latn',
'bcl': 'bcl_Latn',
'pam': 'pam_Latn',
'tl': 'tgl_Latn',
'ceb': 'ceb_Latn',
'akl': 'akl_Latn',
'bik': 'bik_Latn',
'cbk': 'cbk_Latn',
'krj': 'krj_Latn',
'tsg': 'tsg_Latn',
'zh': 'zho_Hans',
'yue': 'yue_Hant',
'wuu': 'wuu_Hans',
'hsn': 'hsn_Hans',
'nan': 'nan_Hant',
'hak': 'hak_Hant',
'gan': 'gan_Hans',
'cdo': 'cdo_Hant',
'lzh': 'lzh_Hans',
'ja': 'jpn_Jpan',
'ko': 'kor_Hang',
'ain': 'ain_Kana',
'ryu': 'ryu_Kana',
'eo': 'epo_Latn',
'ia': 'ina_Latn',
'ie': 'ile_Latn',
'io': 'ido_Latn',
'vo': 'vol_Latn',
'nov': 'nov_Latn',
'lfn': 'lfn_Latn',
'jbo': 'jbo_Latn',
'tlh': 'tlh_Latn',
'na': 'nau_Latn',
'ch': 'cha_Latn',
'mh': 'mah_Latn',
'gil': 'gil_Latn',
'kos': 'kos_Latn',
'pon': 'pon_Latn',
'yap': 'yap_Latn',
'chk': 'chk_Latn',
'uli': 'uli_Latn',
'wol': 'wol_Latn',
'pau': 'pau_Latn',
'sm': 'smo_Latn',
'to': 'ton_Latn',
'fj': 'fij_Latn',
'ty': 'tah_Latn',
'mi': 'mri_Latn',
'haw': 'haw_Latn',
'rap': 'rap_Latn',
'tvl': 'tvl_Latn',
'niu': 'niu_Latn',
'tkl': 'tkl_Latn',
'bi': 'bis_Latn',
'ho': 'hmo_Latn',
'kg': 'kon_Latn',
'kj': 'kua_Latn',
'rw': 'kin_Latn',
'rn': 'run_Latn',
'sg': 'sag_Latn',
'sn': 'sna_Latn',
'ss': 'ssw_Latn',
'st': 'sot_Latn',
'sw': 'swh_Latn',
'tn': 'tsn_Latn',
'ts': 'tso_Latn',
've': 'ven_Latn',
'xh': 'xho_Latn',
'zu': 'zul_Latn',
'nd': 'nde_Latn',
'nr': 'nbl_Latn',
'ny': 'nya_Latn',
'bm': 'bam_Latn',
'ee': 'ewe_Latn',
'ff': 'fuv_Latn',
'ha': 'hau_Latn',
'ig': 'ibo_Latn',
'ki': 'kik_Latn',
'lg': 'lug_Latn',
'ln': 'lin_Latn',
'mg': 'plt_Latn',
'om': 'orm_Latn',
'rw': 'kin_Latn',
'rn': 'run_Latn',
'sg': 'sag_Latn',
'sn': 'sna_Latn',
'so': 'som_Latn',
'sw': 'swh_Latn',
'ti': 'tir_Ethi',
'tw': 'twi_Latn',
'wo': 'wol_Latn',
'xh': 'xho_Latn',
'yo': 'yor_Latn',
'zu': 'zul_Latn'
}
def __init__(self, model_name: str = "facebook/nllb-200-3.3B", max_chunk_length: int = 1000):
"""
Initialize NLLB translation provider.
Args:
model_name: The NLLB model name to use
max_chunk_length: Maximum length for text chunks
"""
# Build supported languages mapping
supported_languages = {}
for lang_code in self.LANGUAGE_MAPPINGS.keys():
# For simplicity, assume all languages can translate to all other languages
# In practice, you might want to be more specific about supported pairs
supported_languages[lang_code] = [
target for target in self.LANGUAGE_MAPPINGS.keys()
if target != lang_code
]
super().__init__(
provider_name="NLLB-200-3.3B",
supported_languages=supported_languages
)
self.model_name = model_name
self.max_chunk_length = max_chunk_length
self._tokenizer: Optional[AutoTokenizer] = None
self._model: Optional[AutoModelForSeq2SeqLM] = None
self._model_loaded = False
def _translate_chunk(self, text: str, source_language: str, target_language: str) -> str:
"""
Translate a single chunk of text using NLLB model.
Args:
text: The text chunk to translate
source_language: Source language code
target_language: Target language code
Returns:
str: The translated text chunk
"""
try:
# Ensure model is loaded
self._ensure_model_loaded()
# Map language codes to NLLB format
source_nllb = self._map_language_code(source_language)
target_nllb = self._map_language_code(target_language)
logger.info(f"Translating chunk from {source_nllb} to {target_nllb}")
# Tokenize with source language specification
inputs = self._tokenizer(
text,
return_tensors="pt",
max_length=1024,
truncation=True
)
# Generate translation with target language specification
outputs = self._model.generate(
**inputs,
forced_bos_token_id=self._tokenizer.convert_tokens_to_ids(target_nllb),
max_new_tokens=1024,
num_beams=4,
early_stopping=True
)
# Decode the translation
translated = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
# Post-process the translation
translated = self._postprocess_text(translated)
logger.info(f"Chunk translation completed: {len(text)} -> {len(translated)} chars")
return translated
except Exception as e:
self._handle_provider_error(e, "chunk translation")
def _ensure_model_loaded(self) -> None:
"""Ensure the NLLB model and tokenizer are loaded."""
if self._model_loaded:
return
try:
logger.info(f"Loading NLLB model: {self.model_name}")
# Load tokenizer
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
src_lang="eng_Latn" # Default source language
)
# Load model
self._model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
self._model_loaded = True
logger.info("NLLB model loaded successfully")
except Exception as e:
logger.error(f"Failed to load NLLB model: {str(e)}")
raise TranslationFailedException(f"Failed to load NLLB model: {str(e)}") from e
def _map_language_code(self, language_code: str) -> str:
"""
Map standard language code to NLLB format.
Args:
language_code: Standard language code (e.g., 'en', 'zh')
Returns:
str: NLLB language code (e.g., 'eng_Latn', 'zho_Hans')
"""
# Normalize language code to lowercase
normalized_code = language_code.lower()
# Check direct mapping
if normalized_code in self.LANGUAGE_MAPPINGS:
return self.LANGUAGE_MAPPINGS[normalized_code]
# Handle common variations
if normalized_code.startswith('zh'):
if 'tw' in normalized_code or 'hant' in normalized_code or 'traditional' in normalized_code:
return 'zho_Hant'
else:
return 'zho_Hans'
# Default fallback for unknown codes
logger.warning(f"Unknown language code: {language_code}, defaulting to English")
return 'eng_Latn'
def is_available(self) -> bool:
"""
Check if the NLLB translation provider is available.
Returns:
bool: True if provider is available, False otherwise
"""
try:
# Try to import required dependencies
import transformers
import torch
# Check if we can load the tokenizer (lightweight check)
if not self._model_loaded:
try:
test_tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
src_lang="eng_Latn"
)
return True
except Exception as e:
logger.warning(f"NLLB model not available: {str(e)}")
return False
else:
return True
except ImportError as e:
logger.warning(f"NLLB dependencies not available: {str(e)}")
return False
def get_supported_languages(self) -> Dict[str, List[str]]:
"""
Get supported language pairs for NLLB provider.
Returns:
dict: Mapping of source languages to supported target languages
"""
return self.supported_languages.copy()
def get_model_info(self) -> Dict[str, str]:
"""
Get information about the loaded model.
Returns:
dict: Model information
"""
return {
'provider': self.provider_name,
'model_name': self.model_name,
'model_loaded': str(self._model_loaded),
'supported_language_count': str(len(self.LANGUAGE_MAPPINGS)),
'max_chunk_length': str(self.max_chunk_length)
}
def set_model_name(self, model_name: str) -> None:
"""
Set a different NLLB model name.
Args:
model_name: The new model name to use
"""
if model_name != self.model_name:
self.model_name = model_name
self._model_loaded = False
self._tokenizer = None
self._model = None
logger.info(f"Model name changed to: {model_name}")
def clear_model_cache(self) -> None:
"""Clear the loaded model from memory."""
if self._model_loaded:
self._tokenizer = None
self._model = None
self._model_loaded = False
logger.info("NLLB model cache cleared")