books_class / books.py
semran1's picture
Upload folder using huggingface_hub
611e31d verified
import argparse
import json
import re
import os
import unicodedata
from typing import List
from multiprocessing import Pool
import fasttext
import pandas as pd
from tqdm import tqdm
# Only use the Kyutai Dactory English FastText model
FASTTEXT_MODEL_PATH = "filter_en.bin"
# Minimum probability threshold for the '__label__books' class
THRESHOLD = 0.3
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, required=True,
help="Directory or file path containing input data.")
parser.add_argument("--save-path", type=str, required=True,
help="Root directory to save filtered results.")
parser.add_argument("--content-key", type=str, required=True,
help="JSON key for the review or text content.")
parser.add_argument("--processes-num", type=int, default=64,
help="Number of parallel worker processes.")
parser.add_argument("--write-batch-size", type=int, default=100,
help="Batch size for writing to output file.")
parser.add_argument("--inplace", action="store_true",
help="Skip processing files that already exist.")
return parser.parse_args()
def fasttext_preprocess_func(content: str) -> str:
"""Normalize content for FastText inference."""
content = re.sub(r'\n{3,}', '\n\n', content) # collapse multiple newlines
content = content.lower()
content = ''.join(
c for c in unicodedata.normalize('NFKD', content)
if unicodedata.category(c) != 'Mn'
)
content = content.replace('\n', '\\n').replace('\r', '\\r').replace('\t', '\\t')
content = re.sub(r' +', ' ', content).strip()
return content
def fasttext_infer(norm_content: str, model: fasttext.FastText):
"""Run FastText model to get the '__label__books' probability."""
labels, probs = model.predict(norm_content, k=10)
for label, prob in zip(labels, probs):
if label == '__label__books':
return label, float(prob)
return None, 0.0
def load_data(file_path: str, content_key: str) -> List[str]:
"""Load raw text content from supported files."""
samples: List[str] = []
if file_path.endswith(('.jsonl', '.json')):
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
if content_key in data and data[content_key]:
samples.append(str(data[content_key]))
elif file_path.endswith('.parquet'):
df = pd.read_parquet(file_path)
for val in df.get(content_key, []):
if pd.notna(val) and val:
samples.append(str(val))
else:
raise ValueError(f"Unsupported file type: {file_path}")
return samples
def process_file(
file_path: str,
save_path: str,
item: int,
content_key: str,
inplace: bool,
write_batch_size: int) -> None:
"""Process one file: filter by '__label__books' score > THRESHOLD."""
fasttext_model = fasttext.load_model(FASTTEXT_MODEL_PATH)
contents = load_data(file_path, content_key)
file_name = os.path.basename(file_path)
base_name, _ = os.path.splitext(file_name)
output_file = os.path.join(save_path, f"{base_name}_filtered.jsonl")
if inplace and os.path.exists(output_file):
print(f"Skipping existing file: {output_file}")
return
if os.path.exists(output_file):
os.remove(output_file)
print(f"ID {item}: Processing {file_path} ({len(contents)} records) -> {output_file}")
buffer: List[dict] = []
for content in tqdm(contents, desc=f"File {item}"):
norm = fasttext_preprocess_func(content)
label, score = fasttext_infer(norm, fasttext_model)
# Keep only if the predicted label is '__label__books' and probability above threshold
if label == '__label__books' and score > THRESHOLD:
buffer.append({
'content': content,
'books_score': score
})
if len(buffer) >= write_batch_size:
with open(output_file, 'a', encoding='utf-8') as out_f:
out_f.write("\n".join(json.dumps(x, ensure_ascii=False) for x in buffer) + "\n")
buffer.clear()
# Write remaining
if buffer:
with open(output_file, 'a', encoding='utf-8') as out_f:
out_f.write("\n".join(json.dumps(x, ensure_ascii=False) for x in buffer) + "\n")
def main():
args = parse_args()
os.makedirs(args.save_path, exist_ok=True)
# Collect input paths
if os.path.isdir(args.data_path):
paths = [os.path.join(args.data_path, fname) for fname in os.listdir(args.data_path)]
else:
paths = [args.data_path]
print("=" * 80)
print(f"Running with FastText model: {FASTTEXT_MODEL_PATH}")
print(f"Processing {len(paths)} files, threshold={THRESHOLD} for '__label__books'.")
print("=" * 80)
with Pool(processes=args.processes_num) as pool:
pool.starmap(
process_file,
[(p, args.save_path, i, args.content_key, args.inplace, args.write_batch_size)
for i, p in enumerate(paths)]
)
print("All done.")
if __name__ == "__main__":
main()