|
import argparse |
|
import json |
|
import re |
|
import os |
|
import unicodedata |
|
from typing import List |
|
from multiprocessing import Pool |
|
|
|
import fasttext |
|
import pandas as pd |
|
from tqdm import tqdm |
|
|
|
|
|
FASTTEXT_MODEL_PATH = "filter_en.bin" |
|
|
|
THRESHOLD = 0.3 |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--data-path", type=str, required=True, |
|
help="Directory or file path containing input data.") |
|
parser.add_argument("--save-path", type=str, required=True, |
|
help="Root directory to save filtered results.") |
|
parser.add_argument("--content-key", type=str, required=True, |
|
help="JSON key for the review or text content.") |
|
parser.add_argument("--processes-num", type=int, default=64, |
|
help="Number of parallel worker processes.") |
|
parser.add_argument("--write-batch-size", type=int, default=100, |
|
help="Batch size for writing to output file.") |
|
parser.add_argument("--inplace", action="store_true", |
|
help="Skip processing files that already exist.") |
|
return parser.parse_args() |
|
|
|
|
|
def fasttext_preprocess_func(content: str) -> str: |
|
"""Normalize content for FastText inference.""" |
|
content = re.sub(r'\n{3,}', '\n\n', content) |
|
content = content.lower() |
|
content = ''.join( |
|
c for c in unicodedata.normalize('NFKD', content) |
|
if unicodedata.category(c) != 'Mn' |
|
) |
|
content = content.replace('\n', '\\n').replace('\r', '\\r').replace('\t', '\\t') |
|
content = re.sub(r' +', ' ', content).strip() |
|
return content |
|
|
|
|
|
def fasttext_infer(norm_content: str, model: fasttext.FastText): |
|
"""Run FastText model to get the '__label__books' probability.""" |
|
labels, probs = model.predict(norm_content, k=10) |
|
for label, prob in zip(labels, probs): |
|
if label == '__label__books': |
|
return label, float(prob) |
|
return None, 0.0 |
|
|
|
|
|
def load_data(file_path: str, content_key: str) -> List[str]: |
|
"""Load raw text content from supported files.""" |
|
samples: List[str] = [] |
|
if file_path.endswith(('.jsonl', '.json')): |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
for line in f: |
|
data = json.loads(line) |
|
if content_key in data and data[content_key]: |
|
samples.append(str(data[content_key])) |
|
elif file_path.endswith('.parquet'): |
|
df = pd.read_parquet(file_path) |
|
for val in df.get(content_key, []): |
|
if pd.notna(val) and val: |
|
samples.append(str(val)) |
|
else: |
|
raise ValueError(f"Unsupported file type: {file_path}") |
|
return samples |
|
|
|
|
|
def process_file( |
|
file_path: str, |
|
save_path: str, |
|
item: int, |
|
content_key: str, |
|
inplace: bool, |
|
write_batch_size: int) -> None: |
|
"""Process one file: filter by '__label__books' score > THRESHOLD.""" |
|
fasttext_model = fasttext.load_model(FASTTEXT_MODEL_PATH) |
|
contents = load_data(file_path, content_key) |
|
|
|
file_name = os.path.basename(file_path) |
|
base_name, _ = os.path.splitext(file_name) |
|
output_file = os.path.join(save_path, f"{base_name}_filtered.jsonl") |
|
|
|
if inplace and os.path.exists(output_file): |
|
print(f"Skipping existing file: {output_file}") |
|
return |
|
if os.path.exists(output_file): |
|
os.remove(output_file) |
|
|
|
print(f"ID {item}: Processing {file_path} ({len(contents)} records) -> {output_file}") |
|
buffer: List[dict] = [] |
|
for content in tqdm(contents, desc=f"File {item}"): |
|
norm = fasttext_preprocess_func(content) |
|
label, score = fasttext_infer(norm, fasttext_model) |
|
|
|
if label == '__label__books' and score > THRESHOLD: |
|
buffer.append({ |
|
'content': content, |
|
'books_score': score |
|
}) |
|
if len(buffer) >= write_batch_size: |
|
with open(output_file, 'a', encoding='utf-8') as out_f: |
|
out_f.write("\n".join(json.dumps(x, ensure_ascii=False) for x in buffer) + "\n") |
|
buffer.clear() |
|
|
|
if buffer: |
|
with open(output_file, 'a', encoding='utf-8') as out_f: |
|
out_f.write("\n".join(json.dumps(x, ensure_ascii=False) for x in buffer) + "\n") |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
os.makedirs(args.save_path, exist_ok=True) |
|
|
|
|
|
if os.path.isdir(args.data_path): |
|
paths = [os.path.join(args.data_path, fname) for fname in os.listdir(args.data_path)] |
|
else: |
|
paths = [args.data_path] |
|
|
|
print("=" * 80) |
|
print(f"Running with FastText model: {FASTTEXT_MODEL_PATH}") |
|
print(f"Processing {len(paths)} files, threshold={THRESHOLD} for '__label__books'.") |
|
print("=" * 80) |
|
|
|
with Pool(processes=args.processes_num) as pool: |
|
pool.starmap( |
|
process_file, |
|
[(p, args.save_path, i, args.content_key, args.inplace, args.write_batch_size) |
|
for i, p in enumerate(paths)] |
|
) |
|
print("All done.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|