|
"""
|
|
preprocess.py - Скрипт для предварительной обработки датасета книг
|
|
и создания векторных представлений для поисковой системы
|
|
"""
|
|
|
|
import os
|
|
import pandas as pd
|
|
import numpy as np
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModel
|
|
import faiss
|
|
import re
|
|
import nltk
|
|
from nltk.corpus import stopwords
|
|
from nltk.tokenize import word_tokenize
|
|
import argparse
|
|
from tqdm import tqdm
|
|
import nltk
|
|
nltk.download('punkt')
|
|
nltk.download('stopwords')
|
|
nltk.download('punkt_tab')
|
|
|
|
|
|
try:
|
|
nltk.data.find('corpora/stopwords')
|
|
except LookupError:
|
|
nltk.download('stopwords')
|
|
|
|
try:
|
|
nltk.data.find('tokenizers/punkt')
|
|
except LookupError:
|
|
nltk.download('punkt')
|
|
|
|
stop_words = set(stopwords.words('russian'))
|
|
|
|
|
|
class RuBERTEmbedder:
|
|
def __init__(self, model_name="DeepPavlov/rubert-base-cased"):
|
|
print(f"Загрузка модели {model_name}...")
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
self.model = AutoModel.from_pretrained(model_name)
|
|
self.model.eval()
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"Используется устройство: {self.device}")
|
|
self.model.to(self.device)
|
|
|
|
def mean_pooling(self, model_output, attention_mask):
|
|
"""Среднее значение по токенам для получения эмбеддинга предложения"""
|
|
token_embeddings = model_output[0]
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
|
|
|
def get_embedding(self, text):
|
|
"""Получение векторного представления текста"""
|
|
encoded_input = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
|
|
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
|
|
|
|
with torch.no_grad():
|
|
model_output = self.model(**encoded_input)
|
|
|
|
embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
|
|
return embeddings.cpu().numpy()[0]
|
|
|
|
def get_embeddings_batch(self, texts, batch_size=8):
|
|
"""Получение векторных представлений для списка текстов с использованием батчей"""
|
|
all_embeddings = []
|
|
|
|
for i in tqdm(range(0, len(texts), batch_size), desc="Создание эмбеддингов"):
|
|
batch_texts = texts[i:i+batch_size]
|
|
|
|
batch_texts = [text if text and isinstance(text, str) else " " for text in batch_texts]
|
|
|
|
encoded_input = self.tokenizer(batch_texts, padding=True, truncation=True, max_length=512, return_tensors='pt')
|
|
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
|
|
|
|
with torch.no_grad():
|
|
model_output = self.model(**encoded_input)
|
|
|
|
embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
|
|
all_embeddings.append(embeddings.cpu().numpy())
|
|
|
|
return np.vstack(all_embeddings)
|
|
|
|
def preprocess_text(text):
|
|
"""Предобработка текста: удаление специальных символов, приведение к нижнему регистру, удаление стоп-слов"""
|
|
if isinstance(text, str):
|
|
|
|
text = text.lower()
|
|
|
|
text = re.sub(r'[^\w\s]', '', text)
|
|
|
|
tokens = word_tokenize(text, language='russian')
|
|
|
|
filtered_tokens = [word for word in tokens if word not in stop_words]
|
|
|
|
return ' '.join(filtered_tokens)
|
|
return ''
|
|
|
|
def prepare_data(input_file, output_dir="model", annotation_column="annotation", title_column="title",
|
|
author_column="author", image_url_column="image_url", page_url_column="page_url", sample_size=None):
|
|
"""Подготовка данных для поисковой системы"""
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
print(f"Загрузка данных из {input_file}...")
|
|
df = pd.read_csv(input_file)
|
|
|
|
|
|
if annotation_column not in df.columns:
|
|
raise ValueError(f"В файле отсутствует колонка с аннотациями: {annotation_column}")
|
|
|
|
|
|
initial_size = len(df)
|
|
df = df.dropna(subset=[annotation_column])
|
|
print(f"Удалено записей без аннотаций: {initial_size - len(df)}")
|
|
|
|
|
|
if sample_size and sample_size < len(df):
|
|
df = df.sample(sample_size, random_state=42)
|
|
print(f"Используется случайная выборка из {sample_size} записей")
|
|
|
|
|
|
print("Предобработка аннотаций...")
|
|
df['processed_annotation'] = df[annotation_column].apply(preprocess_text)
|
|
|
|
|
|
print("Инициализация модели для векторизации...")
|
|
embedder = RuBERTEmbedder()
|
|
|
|
|
|
print("Векторизация аннотаций...")
|
|
annotations = df['processed_annotation'].tolist()
|
|
embeddings = embedder.get_embeddings_batch(annotations)
|
|
|
|
|
|
print("Создание индекса FAISS...")
|
|
dimension = embeddings.shape[1]
|
|
index = faiss.IndexFlatIP(dimension)
|
|
faiss.normalize_L2(embeddings)
|
|
index.add(embeddings)
|
|
|
|
|
|
print(f"Сохранение данных в {output_dir}...")
|
|
|
|
|
|
columns_to_save = [col for col in [annotation_column, title_column, author_column, image_url_column, page_url_column, 'processed_annotation'] if col in df.columns]
|
|
df[columns_to_save].to_csv(f"{output_dir}/book_data.csv", index=False)
|
|
|
|
|
|
np.save(f"{output_dir}/embeddings.npy", embeddings)
|
|
|
|
|
|
faiss.write_index(index, f"{output_dir}/faiss_index.bin")
|
|
|
|
print(f"Данные успешно обработаны и сохранены в {output_dir}")
|
|
print(f"Всего книг: {len(df)}")
|
|
return df
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Предобработка датасета книг для поисковой системы')
|
|
parser.add_argument('--input', type=str, required=True, help='Путь к CSV файлу с данными книг')
|
|
parser.add_argument('--output', type=str, default='model', help='Директория для сохранения модели и данных')
|
|
parser.add_argument('--annotation', type=str, default='annotation', help='Имя колонки с аннотациями')
|
|
parser.add_argument('--title', type=str, default='title', help='Имя колонки с названиями книг')
|
|
parser.add_argument('--author', type=str, default='author', help='Имя колонки с авторами')
|
|
parser.add_argument('--image_url', type=str, default='image_url', help='Имя колонки с URL изображений')
|
|
parser.add_argument('--page_url', type=str, default='page_url', help='Имя колонки с URL страниц')
|
|
parser.add_argument('--sample', type=int, default=None, help='Размер выборки (если нужно ограничить)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
prepare_data(
|
|
input_file=args.input,
|
|
output_dir=args.output,
|
|
annotation_column=args.annotation,
|
|
title_column=args.title,
|
|
author_column=args.author,
|
|
image_url_column=args.image_url,
|
|
page_url_column=args.page_url,
|
|
sample_size=args.sample
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
main() |