File size: 2,181 Bytes
cd1309d
 
 
 
 
c72d839
cd1309d
 
5f94a8b
 
cd1309d
 
 
 
 
 
 
 
c72d839
fdc056d
c72d839
5f94a8b
c72d839
5f94a8b
 
 
 
c72d839
2477bc4
c72d839
 
 
 
 
 
 
 
2477bc4
fdc056d
5f94a8b
c72d839
5f94a8b
 
 
c72d839
 
fdc056d
5f94a8b
c72d839
 
77b7581
c72d839
 
fdc056d
c72d839
 
2477bc4
c72d839
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Text Translation Module using NLLB-3.3B model
Handles text segmentation and batch translation
"""

import logging
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

logger = logging.getLogger(__name__)

def translate_text(text):
    """
    Translate English text to Simplified Chinese
    Args:
        text: Input English text
    Returns:
        Translated Chinese text
    """
    logger.info(f"Starting translation for text length: {len(text)}")

    try:
        # Model initialization with explicit language codes
        logger.info("Loading NLLB model")
        tokenizer = AutoTokenizer.from_pretrained(
            "facebook/nllb-200-3.3B",
            src_lang="eng_Latn"  # Specify source language
        )
        model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-3.3B")
        logger.info("Translation model loaded")

        # Text processing
        max_chunk_length = 1000
        text_chunks = [text[i:i+max_chunk_length] for i in range(0, len(text), max_chunk_length)]
        logger.info(f"Split text into {len(text_chunks)} chunks")

        translated_chunks = []
        for i, chunk in enumerate(text_chunks):
            logger.info(f"Processing chunk {i+1}/{len(text_chunks)}")

            # Tokenize with source language specification
            inputs = tokenizer(
                chunk,
                return_tensors="pt",
                max_length=1024,
                truncation=True
            )

            # Generate translation with target language specification
            outputs = model.generate(
                **inputs,
                forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hans"),
                max_new_tokens=1024
            )

            translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
            translated_chunks.append(translated)
            logger.info(f"Chunk {i+1} translated successfully")

        result = "".join(translated_chunks)
        logger.info(f"Translation completed. Total length: {len(result)}")
        return result

    except Exception as e:
        logger.error(f"Translation failed: {str(e)}", exc_info=True)
        raise