Spaces:
Build error
Build error
"""NLLB translation provider implementation.""" | |
import logging | |
from typing import Dict, List, Optional | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from ..base.translation_provider_base import TranslationProviderBase | |
from ...domain.exceptions import TranslationFailedException | |
logger = logging.getLogger(__name__) | |
class NLLBTranslationProvider(TranslationProviderBase): | |
"""NLLB-200-3.3B translation provider implementation.""" | |
# NLLB language code mappings | |
LANGUAGE_MAPPINGS = { | |
'en': 'eng_Latn', | |
'zh': 'zho_Hans', | |
'zh-cn': 'zho_Hans', | |
'zh-tw': 'zho_Hant', | |
'es': 'spa_Latn', | |
'fr': 'fra_Latn', | |
'de': 'deu_Latn', | |
'ja': 'jpn_Jpan', | |
'ko': 'kor_Hang', | |
'ar': 'arb_Arab', | |
'hi': 'hin_Deva', | |
'pt': 'por_Latn', | |
'ru': 'rus_Cyrl', | |
'it': 'ita_Latn', | |
'nl': 'nld_Latn', | |
'pl': 'pol_Latn', | |
'tr': 'tur_Latn', | |
'sv': 'swe_Latn', | |
'da': 'dan_Latn', | |
'no': 'nor_Latn', | |
'fi': 'fin_Latn', | |
'el': 'ell_Grek', | |
'he': 'heb_Hebr', | |
'th': 'tha_Thai', | |
'vi': 'vie_Latn', | |
'id': 'ind_Latn', | |
'ms': 'zsm_Latn', | |
'tl': 'tgl_Latn', | |
'uk': 'ukr_Cyrl', | |
'cs': 'ces_Latn', | |
'sk': 'slk_Latn', | |
'hu': 'hun_Latn', | |
'ro': 'ron_Latn', | |
'bg': 'bul_Cyrl', | |
'hr': 'hrv_Latn', | |
'sr': 'srp_Cyrl', | |
'sl': 'slv_Latn', | |
'et': 'est_Latn', | |
'lv': 'lvs_Latn', | |
'lt': 'lit_Latn', | |
'mt': 'mlt_Latn', | |
'ga': 'gle_Latn', | |
'cy': 'cym_Latn', | |
'is': 'isl_Latn', | |
'mk': 'mkd_Cyrl', | |
'sq': 'sqi_Latn', | |
'eu': 'eus_Latn', | |
'ca': 'cat_Latn', | |
'gl': 'glg_Latn', | |
'ast': 'ast_Latn', | |
'oc': 'oci_Latn', | |
'br': 'bre_Latn', | |
'co': 'cos_Latn', | |
'sc': 'srd_Latn', | |
'rm': 'roh_Latn', | |
'fur': 'fur_Latn', | |
'lij': 'lij_Latn', | |
'vec': 'vec_Latn', | |
'pms': 'pms_Latn', | |
'lmo': 'lmo_Latn', | |
'nap': 'nap_Latn', | |
'scn': 'scn_Latn', | |
'wa': 'wln_Latn', | |
'frp': 'frp_Latn', | |
'gsw': 'gsw_Latn', | |
'bar': 'bar_Latn', | |
'ksh': 'ksh_Latn', | |
'lb': 'ltz_Latn', | |
'li': 'lim_Latn', | |
'nds': 'nds_Latn', | |
'pdc': 'pdc_Latn', | |
'sli': 'sli_Latn', | |
'vmf': 'vmf_Latn', | |
'yi': 'yid_Hebr', | |
'af': 'afr_Latn', | |
'zu': 'zul_Latn', | |
'xh': 'xho_Latn', | |
'st': 'sot_Latn', | |
'tn': 'tsn_Latn', | |
'ss': 'ssw_Latn', | |
'nr': 'nbl_Latn', | |
've': 'ven_Latn', | |
'ts': 'tso_Latn', | |
'sw': 'swh_Latn', | |
'rw': 'kin_Latn', | |
'rn': 'run_Latn', | |
'ny': 'nya_Latn', | |
'sn': 'sna_Latn', | |
'yo': 'yor_Latn', | |
'ig': 'ibo_Latn', | |
'ha': 'hau_Latn', | |
'ff': 'fuv_Latn', | |
'wo': 'wol_Latn', | |
'bm': 'bam_Latn', | |
'dyu': 'dyu_Latn', | |
'ee': 'ewe_Latn', | |
'tw': 'twi_Latn', | |
'ak': 'aka_Latn', | |
'gaa': 'gaa_Latn', | |
'lg': 'lug_Latn', | |
'luo': 'luo_Latn', | |
'ki': 'kik_Latn', | |
'kam': 'kam_Latn', | |
'luy': 'luy_Latn', | |
'mer': 'mer_Latn', | |
'kln': 'kln_Latn', | |
'kab': 'kab_Latn', | |
'ber': 'ber_Latn', | |
'am': 'amh_Ethi', | |
'ti': 'tir_Ethi', | |
'om': 'orm_Latn', | |
'so': 'som_Latn', | |
'mg': 'plt_Latn', | |
'ny': 'nya_Latn', | |
'bem': 'bem_Latn', | |
'tum': 'tum_Latn', | |
'loz': 'loz_Latn', | |
'lua': 'lua_Latn', | |
'umb': 'umb_Latn', | |
'kmb': 'kmb_Latn', | |
'kg': 'kon_Latn', | |
'ln': 'lin_Latn', | |
'sg': 'sag_Latn', | |
'fon': 'fon_Latn', | |
'mos': 'mos_Latn', | |
'dga': 'dga_Latn', | |
'kbp': 'kbp_Latn', | |
'nus': 'nus_Latn', | |
'din': 'din_Latn', | |
'luo': 'luo_Latn', | |
'ach': 'ach_Latn', | |
'teo': 'teo_Latn', | |
'mdt': 'mdt_Latn', | |
'knc': 'knc_Latn', | |
'fuv': 'fuv_Latn', | |
'kr': 'kau_Latn', | |
'dje': 'dje_Latn', | |
'son': 'son_Latn', | |
'tmh': 'tmh_Latn', | |
'taq': 'taq_Latn', | |
'ttq': 'ttq_Latn', | |
'thv': 'thv_Latn', | |
'taq': 'taq_Tfng', | |
'shi': 'shi_Tfng', | |
'tzm': 'tzm_Tfng', | |
'rif': 'rif_Latn', | |
'kab': 'kab_Latn', | |
'shy': 'shy_Latn', | |
'ber': 'ber_Latn', | |
'acm': 'acm_Arab', | |
'aeb': 'aeb_Arab', | |
'ajp': 'ajp_Arab', | |
'apc': 'apc_Arab', | |
'ars': 'ars_Arab', | |
'ary': 'ary_Arab', | |
'arz': 'arz_Arab', | |
'auz': 'auz_Arab', | |
'avl': 'avl_Arab', | |
'ayh': 'ayh_Arab', | |
'ayn': 'ayn_Arab', | |
'ayp': 'ayp_Arab', | |
'bbz': 'bbz_Arab', | |
'pga': 'pga_Arab', | |
'shu': 'shu_Arab', | |
'ssh': 'ssh_Arab', | |
'fa': 'pes_Arab', | |
'tg': 'tgk_Cyrl', | |
'ps': 'pbt_Arab', | |
'ur': 'urd_Arab', | |
'sd': 'snd_Arab', | |
'ks': 'kas_Arab', | |
'dv': 'div_Thaa', | |
'ne': 'npi_Deva', | |
'si': 'sin_Sinh', | |
'my': 'mya_Mymr', | |
'km': 'khm_Khmr', | |
'lo': 'lao_Laoo', | |
'ka': 'kat_Geor', | |
'hy': 'hye_Armn', | |
'az': 'azj_Latn', | |
'kk': 'kaz_Cyrl', | |
'ky': 'kir_Cyrl', | |
'uz': 'uzn_Latn', | |
'tk': 'tuk_Latn', | |
'mn': 'khk_Cyrl', | |
'bo': 'bod_Tibt', | |
'dz': 'dzo_Tibt', | |
'ug': 'uig_Arab', | |
'tt': 'tat_Cyrl', | |
'ba': 'bak_Cyrl', | |
'cv': 'chv_Cyrl', | |
'sah': 'sah_Cyrl', | |
'tyv': 'tyv_Cyrl', | |
'kjh': 'kjh_Cyrl', | |
'alt': 'alt_Cyrl', | |
'krc': 'krc_Cyrl', | |
'kum': 'kum_Cyrl', | |
'nog': 'nog_Cyrl', | |
'kaa': 'kaa_Cyrl', | |
'crh': 'crh_Latn', | |
'gag': 'gag_Latn', | |
'tr': 'tur_Latn', | |
'az': 'azb_Arab', | |
'ku': 'ckb_Arab', | |
'lrc': 'lrc_Arab', | |
'mzn': 'mzn_Arab', | |
'glk': 'glk_Arab', | |
'fa': 'pes_Arab', | |
'tg': 'tgk_Cyrl', | |
'prs': 'prs_Arab', | |
'haz': 'haz_Arab', | |
'bal': 'bal_Arab', | |
'bcc': 'bcc_Arab', | |
'bgp': 'bgp_Arab', | |
'bqi': 'bqi_Arab', | |
'ckb': 'ckb_Arab', | |
'diq': 'diq_Latn', | |
'hac': 'hac_Arab', | |
'kur': 'kmr_Latn', | |
'lki': 'lki_Arab', | |
'pnb': 'pnb_Arab', | |
'ps': 'pbt_Arab', | |
'sd': 'snd_Arab', | |
'skr': 'skr_Arab', | |
'ur': 'urd_Arab', | |
'wne': 'wne_Arab', | |
'xmf': 'xmf_Geor', | |
'ka': 'kat_Geor', | |
'hy': 'hye_Armn', | |
'xcl': 'xcl_Armn', | |
'he': 'heb_Hebr', | |
'yi': 'yid_Hebr', | |
'lad': 'lad_Hebr', | |
'ar': 'arb_Arab', | |
'mt': 'mlt_Latn', | |
'ml': 'mal_Mlym', | |
'kn': 'kan_Knda', | |
'te': 'tel_Telu', | |
'ta': 'tam_Taml', | |
'or': 'ory_Orya', | |
'as': 'asm_Beng', | |
'bn': 'ben_Beng', | |
'gu': 'guj_Gujr', | |
'pa': 'pan_Guru', | |
'hi': 'hin_Deva', | |
'mr': 'mar_Deva', | |
'ne': 'npi_Deva', | |
'sa': 'san_Deva', | |
'mai': 'mai_Deva', | |
'bho': 'bho_Deva', | |
'mag': 'mag_Deva', | |
'sck': 'sck_Deva', | |
'new': 'new_Deva', | |
'bpy': 'bpy_Beng', | |
'ctg': 'ctg_Beng', | |
'rkt': 'rkt_Beng', | |
'syl': 'syl_Beng', | |
'sat': 'sat_Olck', | |
'kha': 'kha_Latn', | |
'grt': 'grt_Beng', | |
'lus': 'lus_Latn', | |
'mni': 'mni_Beng', | |
'kok': 'kok_Deva', | |
'gom': 'gom_Deva', | |
'sd': 'snd_Deva', | |
'doi': 'doi_Deva', | |
'ks': 'kas_Deva', | |
'ur': 'urd_Arab', | |
'ps': 'pbt_Arab', | |
'bal': 'bal_Arab', | |
'bcc': 'bcc_Arab', | |
'bgp': 'bgp_Arab', | |
'brh': 'brh_Arab', | |
'hnd': 'hnd_Arab', | |
'lah': 'lah_Arab', | |
'pnb': 'pnb_Arab', | |
'pst': 'pst_Arab', | |
'skr': 'skr_Arab', | |
'wne': 'wne_Arab', | |
'si': 'sin_Sinh', | |
'dv': 'div_Thaa', | |
'my': 'mya_Mymr', | |
'shn': 'shn_Mymr', | |
'mnw': 'mnw_Mymr', | |
'kac': 'kac_Latn', | |
'shn': 'shn_Mymr', | |
'km': 'khm_Khmr', | |
'lo': 'lao_Laoo', | |
'th': 'tha_Thai', | |
'vi': 'vie_Latn', | |
'cjm': 'cjm_Arab', | |
'bjn': 'bjn_Latn', | |
'bug': 'bug_Latn', | |
'jv': 'jav_Latn', | |
'mad': 'mad_Latn', | |
'ms': 'zsm_Latn', | |
'min': 'min_Latn', | |
'su': 'sun_Latn', | |
'ban': 'ban_Latn', | |
'bbc': 'bbc_Latn', | |
'btk': 'btk_Latn', | |
'gor': 'gor_Latn', | |
'ilo': 'ilo_Latn', | |
'pag': 'pag_Latn', | |
'war': 'war_Latn', | |
'hil': 'hil_Latn', | |
'bcl': 'bcl_Latn', | |
'pam': 'pam_Latn', | |
'tl': 'tgl_Latn', | |
'ceb': 'ceb_Latn', | |
'akl': 'akl_Latn', | |
'bik': 'bik_Latn', | |
'cbk': 'cbk_Latn', | |
'krj': 'krj_Latn', | |
'tsg': 'tsg_Latn', | |
'zh': 'zho_Hans', | |
'yue': 'yue_Hant', | |
'wuu': 'wuu_Hans', | |
'hsn': 'hsn_Hans', | |
'nan': 'nan_Hant', | |
'hak': 'hak_Hant', | |
'gan': 'gan_Hans', | |
'cdo': 'cdo_Hant', | |
'lzh': 'lzh_Hans', | |
'ja': 'jpn_Jpan', | |
'ko': 'kor_Hang', | |
'ain': 'ain_Kana', | |
'ryu': 'ryu_Kana', | |
'eo': 'epo_Latn', | |
'ia': 'ina_Latn', | |
'ie': 'ile_Latn', | |
'io': 'ido_Latn', | |
'vo': 'vol_Latn', | |
'nov': 'nov_Latn', | |
'lfn': 'lfn_Latn', | |
'jbo': 'jbo_Latn', | |
'tlh': 'tlh_Latn', | |
'na': 'nau_Latn', | |
'ch': 'cha_Latn', | |
'mh': 'mah_Latn', | |
'gil': 'gil_Latn', | |
'kos': 'kos_Latn', | |
'pon': 'pon_Latn', | |
'yap': 'yap_Latn', | |
'chk': 'chk_Latn', | |
'uli': 'uli_Latn', | |
'wol': 'wol_Latn', | |
'pau': 'pau_Latn', | |
'sm': 'smo_Latn', | |
'to': 'ton_Latn', | |
'fj': 'fij_Latn', | |
'ty': 'tah_Latn', | |
'mi': 'mri_Latn', | |
'haw': 'haw_Latn', | |
'rap': 'rap_Latn', | |
'tvl': 'tvl_Latn', | |
'niu': 'niu_Latn', | |
'tkl': 'tkl_Latn', | |
'bi': 'bis_Latn', | |
'ho': 'hmo_Latn', | |
'kg': 'kon_Latn', | |
'kj': 'kua_Latn', | |
'rw': 'kin_Latn', | |
'rn': 'run_Latn', | |
'sg': 'sag_Latn', | |
'sn': 'sna_Latn', | |
'ss': 'ssw_Latn', | |
'st': 'sot_Latn', | |
'sw': 'swh_Latn', | |
'tn': 'tsn_Latn', | |
'ts': 'tso_Latn', | |
've': 'ven_Latn', | |
'xh': 'xho_Latn', | |
'zu': 'zul_Latn', | |
'nd': 'nde_Latn', | |
'nr': 'nbl_Latn', | |
'ny': 'nya_Latn', | |
'bm': 'bam_Latn', | |
'ee': 'ewe_Latn', | |
'ff': 'fuv_Latn', | |
'ha': 'hau_Latn', | |
'ig': 'ibo_Latn', | |
'ki': 'kik_Latn', | |
'lg': 'lug_Latn', | |
'ln': 'lin_Latn', | |
'mg': 'plt_Latn', | |
'om': 'orm_Latn', | |
'rw': 'kin_Latn', | |
'rn': 'run_Latn', | |
'sg': 'sag_Latn', | |
'sn': 'sna_Latn', | |
'so': 'som_Latn', | |
'sw': 'swh_Latn', | |
'ti': 'tir_Ethi', | |
'tw': 'twi_Latn', | |
'wo': 'wol_Latn', | |
'xh': 'xho_Latn', | |
'yo': 'yor_Latn', | |
'zu': 'zul_Latn' | |
} | |
def __init__(self, model_name: str = "facebook/nllb-200-3.3B", max_chunk_length: int = 1000): | |
""" | |
Initialize NLLB translation provider. | |
Args: | |
model_name: The NLLB model name to use | |
max_chunk_length: Maximum length for text chunks | |
""" | |
# Build supported languages mapping | |
supported_languages = {} | |
for lang_code in self.LANGUAGE_MAPPINGS.keys(): | |
# For simplicity, assume all languages can translate to all other languages | |
# In practice, you might want to be more specific about supported pairs | |
supported_languages[lang_code] = [ | |
target for target in self.LANGUAGE_MAPPINGS.keys() | |
if target != lang_code | |
] | |
super().__init__( | |
provider_name="NLLB-200-3.3B", | |
supported_languages=supported_languages | |
) | |
self.model_name = model_name | |
self.max_chunk_length = max_chunk_length | |
self._tokenizer: Optional[AutoTokenizer] = None | |
self._model: Optional[AutoModelForSeq2SeqLM] = None | |
self._model_loaded = False | |
def _translate_chunk(self, text: str, source_language: str, target_language: str) -> str: | |
""" | |
Translate a single chunk of text using NLLB model. | |
Args: | |
text: The text chunk to translate | |
source_language: Source language code | |
target_language: Target language code | |
Returns: | |
str: The translated text chunk | |
""" | |
try: | |
# Ensure model is loaded | |
self._ensure_model_loaded() | |
# Map language codes to NLLB format | |
source_nllb = self._map_language_code(source_language) | |
target_nllb = self._map_language_code(target_language) | |
logger.info(f"Translating chunk from {source_nllb} to {target_nllb}") | |
# Tokenize with source language specification | |
inputs = self._tokenizer( | |
text, | |
return_tensors="pt", | |
max_length=1024, | |
truncation=True | |
) | |
# Generate translation with target language specification | |
outputs = self._model.generate( | |
**inputs, | |
forced_bos_token_id=self._tokenizer.convert_tokens_to_ids(target_nllb), | |
max_new_tokens=1024, | |
num_beams=4, | |
early_stopping=True | |
) | |
# Decode the translation | |
translated = self._tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Post-process the translation | |
translated = self._postprocess_text(translated) | |
logger.info(f"Chunk translation completed: {len(text)} -> {len(translated)} chars") | |
return translated | |
except Exception as e: | |
self._handle_provider_error(e, "chunk translation") | |
def _ensure_model_loaded(self) -> None: | |
"""Ensure the NLLB model and tokenizer are loaded.""" | |
if self._model_loaded: | |
return | |
try: | |
logger.info(f"Loading NLLB model: {self.model_name}") | |
# Load tokenizer | |
self._tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
src_lang="eng_Latn" # Default source language | |
) | |
# Load model | |
self._model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) | |
self._model_loaded = True | |
logger.info("NLLB model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load NLLB model: {str(e)}") | |
raise TranslationFailedException(f"Failed to load NLLB model: {str(e)}") from e | |
def _map_language_code(self, language_code: str) -> str: | |
""" | |
Map standard language code to NLLB format. | |
Args: | |
language_code: Standard language code (e.g., 'en', 'zh') | |
Returns: | |
str: NLLB language code (e.g., 'eng_Latn', 'zho_Hans') | |
""" | |
# Normalize language code to lowercase | |
normalized_code = language_code.lower() | |
# Check direct mapping | |
if normalized_code in self.LANGUAGE_MAPPINGS: | |
return self.LANGUAGE_MAPPINGS[normalized_code] | |
# Handle common variations | |
if normalized_code.startswith('zh'): | |
if 'tw' in normalized_code or 'hant' in normalized_code or 'traditional' in normalized_code: | |
return 'zho_Hant' | |
else: | |
return 'zho_Hans' | |
# Default fallback for unknown codes | |
logger.warning(f"Unknown language code: {language_code}, defaulting to English") | |
return 'eng_Latn' | |
def is_available(self) -> bool: | |
""" | |
Check if the NLLB translation provider is available. | |
Returns: | |
bool: True if provider is available, False otherwise | |
""" | |
try: | |
# Try to import required dependencies | |
import transformers | |
import torch | |
# Check if we can load the tokenizer (lightweight check) | |
if not self._model_loaded: | |
try: | |
test_tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
src_lang="eng_Latn" | |
) | |
return True | |
except Exception as e: | |
logger.warning(f"NLLB model not available: {str(e)}") | |
return False | |
else: | |
return True | |
except ImportError as e: | |
logger.warning(f"NLLB dependencies not available: {str(e)}") | |
return False | |
def get_supported_languages(self) -> Dict[str, List[str]]: | |
""" | |
Get supported language pairs for NLLB provider. | |
Returns: | |
dict: Mapping of source languages to supported target languages | |
""" | |
return self.supported_languages.copy() | |
def get_model_info(self) -> Dict[str, str]: | |
""" | |
Get information about the loaded model. | |
Returns: | |
dict: Model information | |
""" | |
return { | |
'provider': self.provider_name, | |
'model_name': self.model_name, | |
'model_loaded': str(self._model_loaded), | |
'supported_language_count': str(len(self.LANGUAGE_MAPPINGS)), | |
'max_chunk_length': str(self.max_chunk_length) | |
} | |
def set_model_name(self, model_name: str) -> None: | |
""" | |
Set a different NLLB model name. | |
Args: | |
model_name: The new model name to use | |
""" | |
if model_name != self.model_name: | |
self.model_name = model_name | |
self._model_loaded = False | |
self._tokenizer = None | |
self._model = None | |
logger.info(f"Model name changed to: {model_name}") | |
def clear_model_cache(self) -> None: | |
"""Clear the loaded model from memory.""" | |
if self._model_loaded: | |
self._tokenizer = None | |
self._model = None | |
self._model_loaded = False | |
logger.info("NLLB model cache cleared") |