Spaces:
Runtime error
Runtime error
from transformers import BertForSequenceClassification, BertTokenizerFast, Trainer, TrainingArguments | |
from datasets import load_dataset | |
import torch | |
import pandas as pd | |
import numpy as np | |
import gradio as gr | |
# ❗ Загрузка датасета ZhenDOS/alpha_bank_data | |
dataset = load_dataset("ZhenDOS/alpha_bank_data") | |
# ✔️ Загрузка базовой модели и токенайзера | |
tokenizer = BertTokenizerFast.from_pretrained("DeepPavlov/rubert-base-cased") | |
model = BertForSequenceClassification.from_pretrained("DeepPavlov/rubert-base-cased", num_labels=len(dataset["train"].features["label"].names)) | |
# ➕ Токенизация входных данных | |
def tokenize_function(examples): | |
return tokenizer(examples["text"], padding="max_length", truncation=True) | |
tokenized_datasets = dataset.map(tokenize_function, batched=True) | |
# 🏃♂️ Настройки обучения | |
training_args = TrainingArguments( | |
output_dir="./results", | |
evaluation_strategy="epoch", | |
learning_rate=2e-5, | |
per_device_train_batch_size=16, | |
per_device_eval_batch_size=64, | |
num_train_epochs=3, | |
weight_decay=0.01, | |
) | |
# 💨 Процесс обучения | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_datasets["train"], | |
eval_dataset=tokenized_datasets["validation"], | |
) | |
trainer.train() | |
# 📊 Функционал для демонстрации через Gradio | |
def classify_question(question): | |
tokens = tokenizer(question, return_tensors="pt") | |
outputs = model(**tokens) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
pred_label_idx = torch.argmax(probabilities, dim=1).item() | |
categories = dataset["train"].features["label"].names | |
return { | |
"Вероятности классов": dict(zip(categories, probabilities.detach().numpy()[0])), | |
"Прогнозируемый класс": categories[pred_label_idx], | |
} | |
# 🖥️ Графический интерфейс Gradio | |
demo = gr.Interface( | |
fn=classify_question, | |
inputs="text", | |
outputs=[ | |
gr.Label(label="Категории"), | |
gr.Textbox(label="Прогнозируемый класс"), | |
], | |
examples=[ | |
["Как перевести деньги между картами?"], | |
["Что такое кредитная история?"], | |
["Почему моя карта заблокирована?"], | |
], | |
title="Классификация клиентских запросов банка", | |
description="Приложение помогает определить категорию клиентского запроса и оценить вероятность принадлежности каждого класса.", | |
) | |
demo.launch() | |