import argparse import json import re import os import unicodedata from typing import Tuple, List from multiprocessing import Pool import fasttext import pandas as pd from tqdm import tqdm from transformers import LlamaTokenizerFast language_model_map = { "en": "classifiers/ultra_fineweb_en.bin", "zh": "classifiers/ultra_fineweb_zh.bin" } def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--language", type=str, required=True, help="Inference language, support: en, zh.") parser.add_argument("--tokenizer-path", type=str, default="local_tokenizer", help="Tokenizer path.") parser.add_argument("--content-file", type=str, default="scripts/local_scripts/single_content.txt", help="Content file to infer.") return parser.parse_args() def fasttext_preprocess_func(content: str, tokenizer: LlamaTokenizerFast) -> str: """Fasttext preprocess function. Args: content (str): Content to process. Returns: str: Processed normalized content. """ # 1. remove multiple newlines content = re.sub(r'\n{3,}', '\n\n', content) # 2. lower the content content = content.lower() # 3. remove diacritics content = ''.join( c for c in unicodedata.normalize('NFKD', content) if unicodedata.category(c) != 'Mn') # 4. word segmentation token_ids = tokenizer.encode(content, add_special_tokens=False) single_text_list = [] for token_id in token_ids: curr_text = tokenizer.decode([token_id]) single_text_list.append(curr_text) content = ' '.join(single_text_list) # 5. keep escape chars, \n, \t, \r -> \\n, \\t, \\r, # which will saved as \n, \t, \r in txt file. content = re.sub(r'\n', '\\\\n', content) content = re.sub(r'\r', '\\\\r', content) content = re.sub(r'\t', '\\\\t', content) content = re.sub(r' +', ' ', content) content = content.strip() return content def fasttext_infer(norm_content: str, fasttext_model: fasttext.FastText) -> Tuple[str, float]: """Fasttext inference function Args: content (str): input text Returns: str: json string with pred_label and pred_score """ pred_label, pred_prob = fasttext_model.predict(norm_content) pred_label = pred_label[0] _score = min(pred_prob.tolist()[0], 1) if pred_label == "__label__neg": _score = 1 - _score return pred_label, _score def main(): args = parse_args() language = args.language tokenizer_path = args.tokenizer_path content_file = args.content_file assert language in ["en", "zh"], f"Language {language} is not supported, please check the language." assert os.path.exists(content_file), f"Content file {content_file} does not exist, please check the content file." fasttext_model_path = language_model_map[language] # load tokenizer tokenizer = LlamaTokenizerFast.from_pretrained(tokenizer_path) # load fasttext model fasttext_model = fasttext.load_model(fasttext_model_path) content = open(content_file, "r").read() # first preprocess the content norm_content = fasttext_preprocess_func(content, tokenizer) # then infer the content pred_label, pred_score = fasttext_infer(norm_content, fasttext_model) # finally get the result print("-" * 100) print(f"Content: {content}") print() print(f"Normalized content: {norm_content}") print() print(f" - Pred label: {pred_label}") print(f" - Pred score: {pred_score}") print("-" * 100) if __name__ == "__main__": main()